Coverage for langsmith/wrappers/_gemini.py: 13%
218 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-12-11 16:15 -0800
« prev ^ index » next coverage.py v7.10.1, created at 2025-12-11 16:15 -0800
1from __future__ import annotations
3import base64
4import functools
5import json
6import logging
7from collections.abc import Mapping
8from typing import (
9 TYPE_CHECKING,
10 Any,
11 Callable,
12 Optional,
13 TypeVar,
14 Union,
15)
17from typing_extensions import TypedDict
19from langsmith import client as ls_client
20from langsmith import run_helpers
21from langsmith._internal._beta_decorator import warn_beta
22from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata
24if TYPE_CHECKING:
25 from google import genai # type: ignore[import-untyped, attr-defined]
27C = TypeVar("C", bound=Union["genai.Client", Any])
28logger = logging.getLogger(__name__)
31def _strip_none(d: dict) -> dict:
32 """Remove `None` values from dictionary."""
33 return {k: v for k, v in d.items() if v is not None}
36def _convert_config_for_tracing(kwargs: dict) -> None:
37 """Convert `GenerateContentConfig` to `dict` for LangSmith compatibility."""
38 if "config" in kwargs and not isinstance(kwargs["config"], dict):
39 kwargs["config"] = vars(kwargs["config"])
42def _process_gemini_inputs(inputs: dict) -> dict:
43 r"""Process Gemini inputs to normalize them for LangSmith tracing.
45 Example:
46 ```txt
47 {"contents": "Hello", "model": "gemini-pro"}
48 → {"messages": [{"role": "user", "content": "Hello"}], "model": "gemini-pro"}
49 {"contents": [{"role": "user", "parts": [{"text": "What is AI?"}]}], "model": "gemini-pro"}
50 → {"messages": [{"role": "user", "content": "What is AI?"}], "model": "gemini-pro"}
51 ```
52 """ # noqa: E501
53 # If contents is not present or not in list format, return as-is
54 contents = inputs.get("contents")
55 if not contents:
56 return inputs
58 # Handle string input (simple case)
59 if isinstance(contents, str):
60 return {
61 "messages": [{"role": "user", "content": contents}],
62 "model": inputs.get("model"),
63 **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
64 }
66 # Handle list of content objects (multimodal case)
67 if isinstance(contents, list):
68 # Check if it's a simple list of strings
69 if all(isinstance(item, str) for item in contents):
70 # Each string becomes a separate user message (matches Gemini's behavior)
71 return {
72 "messages": [{"role": "user", "content": item} for item in contents],
73 "model": inputs.get("model"),
74 **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
75 }
76 # Handle complex multimodal case
77 messages = []
78 for content in contents:
79 if isinstance(content, dict):
80 role = content.get("role", "user")
81 parts = content.get("parts", [])
83 # Extract text and other parts
84 text_parts = []
85 content_parts = []
87 for part in parts:
88 if isinstance(part, dict):
89 # Handle text parts
90 if "text" in part and part["text"]:
91 text_parts.append(part["text"])
92 content_parts.append({"type": "text", "text": part["text"]})
93 # Handle inline data (images)
94 elif "inline_data" in part:
95 inline_data = part["inline_data"]
96 mime_type = inline_data.get("mime_type", "image/jpeg")
97 data = inline_data.get("data", b"")
99 # Convert bytes to base64 string if needed
100 if isinstance(data, bytes):
101 data_b64 = base64.b64encode(data).decode("utf-8")
102 else:
103 data_b64 = data # Already a string
105 content_parts.append(
106 {
107 "type": "image_url",
108 "image_url": {
109 "url": f"data:{mime_type};base64,{data_b64}",
110 "detail": "high",
111 },
112 }
113 )
114 # Handle function responses
115 elif "functionResponse" in part:
116 function_response = part["functionResponse"]
117 content_parts.append(
118 {
119 "type": "function_response",
120 "function_response": {
121 "name": function_response.get("name"),
122 "response": function_response.get(
123 "response", {}
124 ),
125 },
126 }
127 )
128 # Handle function calls (for conversation history)
129 elif "function_call" in part or "functionCall" in part:
130 function_call = part.get("function_call") or part.get(
131 "functionCall"
132 )
134 if function_call is not None:
135 # Normalize to dict (FunctionCall is a Pydantic model)
136 if not isinstance(function_call, dict):
137 function_call = function_call.to_dict()
139 content_parts.append(
140 {
141 "type": "function_call",
142 "function_call": {
143 "id": function_call.get("id"),
144 "name": function_call.get("name"),
145 "arguments": function_call.get("args", {}),
146 },
147 }
148 )
149 elif isinstance(part, str):
150 # Handle simple string parts
151 text_parts.append(part)
152 content_parts.append({"type": "text", "text": part})
154 # If only text parts, use simple string format
155 if content_parts and all(
156 p.get("type") == "text" for p in content_parts
157 ):
158 message_content: Union[str, list[dict[str, Any]]] = "\n".join(
159 text_parts
160 )
161 else:
162 message_content = content_parts if content_parts else ""
164 messages.append({"role": role, "content": message_content})
165 return {
166 "messages": messages,
167 "model": inputs.get("model"),
168 **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
169 }
171 # Fallback: return original inputs
172 return inputs
175def _infer_invocation_params(kwargs: dict) -> dict:
176 """Extract invocation parameters for tracing."""
177 stripped = _strip_none(kwargs)
178 config = stripped.get("config", {})
180 # Handle both dict config and GenerateContentConfig object
181 if hasattr(config, "temperature"):
182 temperature = config.temperature
183 max_tokens = getattr(config, "max_output_tokens", None)
184 stop = getattr(config, "stop_sequences", None)
185 else:
186 temperature = config.get("temperature")
187 max_tokens = config.get("max_output_tokens")
188 stop = config.get("stop_sequences")
190 return {
191 "ls_provider": "google",
192 "ls_model_type": "chat",
193 "ls_model_name": stripped.get("model"),
194 "ls_temperature": temperature,
195 "ls_max_tokens": max_tokens,
196 "ls_stop": stop,
197 }
200def _create_usage_metadata(gemini_usage_metadata: dict) -> UsageMetadata:
201 """Convert Gemini usage metadata to LangSmith format."""
202 prompt_token_count = gemini_usage_metadata.get("prompt_token_count") or 0
203 candidates_token_count = gemini_usage_metadata.get("candidates_token_count") or 0
204 cached_content_token_count = (
205 gemini_usage_metadata.get("cached_content_token_count") or 0
206 )
207 thoughts_token_count = gemini_usage_metadata.get("thoughts_token_count") or 0
208 total_token_count = (
209 gemini_usage_metadata.get("total_token_count")
210 or prompt_token_count + candidates_token_count
211 )
213 input_token_details: dict = {}
214 if cached_content_token_count:
215 input_token_details["cache_read"] = cached_content_token_count
217 output_token_details: dict = {}
218 if thoughts_token_count:
219 output_token_details["reasoning"] = thoughts_token_count
221 return UsageMetadata(
222 input_tokens=prompt_token_count,
223 output_tokens=candidates_token_count,
224 total_tokens=total_token_count,
225 input_token_details=InputTokenDetails(
226 **{k: v for k, v in input_token_details.items() if v is not None}
227 ),
228 output_token_details=OutputTokenDetails(
229 **{k: v for k, v in output_token_details.items() if v is not None}
230 ),
231 )
234def _process_generate_content_response(response: Any) -> dict:
235 """Process Gemini response for tracing."""
236 try:
237 # Convert response to dictionary
238 if hasattr(response, "to_dict"):
239 rdict = response.to_dict()
240 elif hasattr(response, "model_dump"):
241 rdict = response.model_dump()
242 else:
243 rdict = {"text": getattr(response, "text", str(response))}
245 # Extract content from candidates if available
246 content_result = ""
247 content_parts = []
248 finish_reason: Optional[str] = None
249 if "candidates" in rdict and rdict["candidates"]:
250 candidate = rdict["candidates"][0]
251 if "content" in candidate:
252 content = candidate["content"]
253 if "parts" in content and content["parts"]:
254 for part in content["parts"]:
255 # Handle text parts
256 if "text" in part and part["text"]:
257 content_result += part["text"]
258 content_parts.append({"type": "text", "text": part["text"]})
259 # Handle inline data (images) in response
260 elif "inline_data" in part and part["inline_data"] is not None:
261 inline_data = part["inline_data"]
262 mime_type = inline_data.get("mime_type", "image/jpeg")
263 data = inline_data.get("data", b"")
265 # Convert bytes to base64 string if needed
266 if isinstance(data, bytes):
267 data_b64 = base64.b64encode(data).decode("utf-8")
268 else:
269 data_b64 = data # Already a string
271 content_parts.append(
272 {
273 "type": "image_url",
274 "image_url": {
275 "url": f"data:{mime_type};base64,{data_b64}",
276 "detail": "high",
277 },
278 }
279 )
280 # Handle function calls in response
281 elif "function_call" in part or "functionCall" in part:
282 function_call = part.get("function_call") or part.get(
283 "functionCall"
284 )
286 if function_call is not None:
287 # Normalize to dict (FunctionCall is a Pydantic model)
288 if not isinstance(function_call, dict):
289 function_call = function_call.to_dict()
291 content_parts.append(
292 {
293 "type": "function_call",
294 "function_call": {
295 "id": function_call.get("id"),
296 "name": function_call.get("name"),
297 "arguments": function_call.get("args", {}),
298 },
299 }
300 )
301 if "finish_reason" in candidate and candidate["finish_reason"]:
302 finish_reason = candidate["finish_reason"]
303 elif "text" in rdict:
304 content_result = rdict["text"]
305 content_parts.append({"type": "text", "text": content_result})
307 # Build chat-like response format - use OpenAI-compatible format for tool calls
308 tool_calls = [p for p in content_parts if p.get("type") == "function_call"]
309 if tool_calls:
310 # OpenAI-compatible format for LangSmith UI
311 result = {
312 "content": content_result or None,
313 "role": "assistant",
314 "finish_reason": finish_reason,
315 "tool_calls": [
316 {
317 "id": tc["function_call"].get("id") or f"call_{i}",
318 "type": "function",
319 "index": i,
320 "function": {
321 "name": tc["function_call"]["name"],
322 "arguments": json.dumps(tc["function_call"]["arguments"]),
323 },
324 }
325 for i, tc in enumerate(tool_calls)
326 ],
327 }
328 elif len(content_parts) > 1 or (
329 content_parts and content_parts[0]["type"] != "text"
330 ):
331 # Use structured format for mixed non-tool content
332 result = {
333 "content": content_parts,
334 "role": "assistant",
335 "finish_reason": finish_reason,
336 }
337 else:
338 # Use simple string format for text-only responses
339 result = {
340 "content": content_result,
341 "role": "assistant",
342 "finish_reason": finish_reason,
343 }
345 # Extract and convert usage metadata
346 usage_metadata = rdict.get("usage_metadata")
347 usage_dict: UsageMetadata = UsageMetadata(
348 input_tokens=0, output_tokens=0, total_tokens=0
349 )
350 if usage_metadata:
351 usage_dict = _create_usage_metadata(usage_metadata)
352 # Add usage_metadata to both run.extra AND outputs
353 current_run = run_helpers.get_current_run_tree()
354 if current_run:
355 try:
356 meta = current_run.extra.setdefault("metadata", {}).setdefault(
357 "usage_metadata", {}
358 )
359 meta.update(usage_dict)
360 current_run.patch()
361 except Exception as e:
362 logger.warning(f"Failed to update usage metadata: {e}")
364 # Return in a format that avoids stringification by LangSmith
365 if result.get("tool_calls"):
366 # For responses with tool calls, return structured format
367 return {
368 "content": result["content"],
369 "role": "assistant",
370 "finish_reason": finish_reason,
371 "tool_calls": result["tool_calls"],
372 "usage_metadata": usage_dict,
373 }
374 else:
375 # For simple text responses, return minimal structure with usage metadata
376 if isinstance(result["content"], str):
377 return {
378 "content": result["content"],
379 "role": "assistant",
380 "finish_reason": finish_reason,
381 "usage_metadata": usage_dict,
382 }
383 else:
384 # For multimodal content, return structured format with usage metadata
385 return {
386 "content": result["content"],
387 "role": "assistant",
388 "finish_reason": finish_reason,
389 "usage_metadata": usage_dict,
390 }
391 except Exception as e:
392 logger.debug(f"Error processing Gemini response: {e}")
393 return {"output": response}
396def _reduce_generate_content_chunks(all_chunks: list) -> dict:
397 """Reduce streaming chunks into a single response."""
398 if not all_chunks:
399 return {
400 "content": "",
401 "usage_metadata": UsageMetadata(
402 input_tokens=0, output_tokens=0, total_tokens=0
403 ),
404 }
406 # Accumulate text from all chunks
407 full_text = ""
408 last_chunk = None
410 for chunk in all_chunks:
411 try:
412 if hasattr(chunk, "text") and chunk.text:
413 full_text += chunk.text
414 last_chunk = chunk
415 except Exception as e:
416 logger.debug(f"Error processing chunk: {e}")
418 # Extract usage metadata from the last chunk
419 usage_metadata: UsageMetadata = UsageMetadata(
420 input_tokens=0, output_tokens=0, total_tokens=0
421 )
422 if last_chunk:
423 try:
424 if hasattr(last_chunk, "usage_metadata") and last_chunk.usage_metadata:
425 if hasattr(last_chunk.usage_metadata, "to_dict"):
426 usage_dict = last_chunk.usage_metadata.to_dict()
427 elif hasattr(last_chunk.usage_metadata, "model_dump"):
428 usage_dict = last_chunk.usage_metadata.model_dump()
429 else:
430 usage_dict = {
431 "prompt_token_count": getattr(
432 last_chunk.usage_metadata, "prompt_token_count", 0
433 ),
434 "candidates_token_count": getattr(
435 last_chunk.usage_metadata, "candidates_token_count", 0
436 ),
437 "cached_content_token_count": getattr(
438 last_chunk.usage_metadata, "cached_content_token_count", 0
439 ),
440 "thoughts_token_count": getattr(
441 last_chunk.usage_metadata, "thoughts_token_count", 0
442 ),
443 "total_token_count": getattr(
444 last_chunk.usage_metadata, "total_token_count", 0
445 ),
446 }
447 # Add usage_metadata to both run.extra AND outputs
448 usage_metadata = _create_usage_metadata(usage_dict)
449 current_run = run_helpers.get_current_run_tree()
450 if current_run:
451 try:
452 meta = current_run.extra.setdefault("metadata", {}).setdefault(
453 "usage_metadata", {}
454 )
455 meta.update(usage_metadata)
456 current_run.patch()
457 except Exception as e:
458 logger.warning(f"Failed to update usage metadata: {e}")
459 except Exception as e:
460 logger.debug(f"Error extracting metadata from last chunk: {e}")
462 # Return minimal structure with usage_metadata in outputs
463 return {
464 "content": full_text,
465 "usage_metadata": usage_metadata,
466 }
469def _get_wrapper(
470 original_generate: Callable,
471 name: str,
472 tracing_extra: Optional[TracingExtra] = None,
473 is_streaming: bool = False,
474) -> Callable:
475 """Create a wrapper for Gemini's `generate_content` methods."""
476 textra = tracing_extra or {}
478 @functools.wraps(original_generate)
479 def generate(*args, **kwargs):
480 # Handle config object before tracing setup
481 _convert_config_for_tracing(kwargs)
483 decorator = run_helpers.traceable(
484 name=name,
485 run_type="llm",
486 reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
487 process_inputs=_process_gemini_inputs,
488 process_outputs=(
489 _process_generate_content_response if not is_streaming else None
490 ),
491 _invocation_params_fn=_infer_invocation_params,
492 **textra,
493 )
495 return decorator(original_generate)(*args, **kwargs)
497 @functools.wraps(original_generate)
498 async def agenerate(*args, **kwargs):
499 # Handle config object before tracing setup
500 _convert_config_for_tracing(kwargs)
502 decorator = run_helpers.traceable(
503 name=name,
504 run_type="llm",
505 reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
506 process_inputs=_process_gemini_inputs,
507 process_outputs=(
508 _process_generate_content_response if not is_streaming else None
509 ),
510 _invocation_params_fn=_infer_invocation_params,
511 **textra,
512 )
514 return await decorator(original_generate)(*args, **kwargs)
516 return agenerate if run_helpers.is_async(original_generate) else generate
519class TracingExtra(TypedDict, total=False):
520 metadata: Optional[Mapping[str, Any]]
521 tags: Optional[list[str]]
522 client: Optional[ls_client.Client]
525@warn_beta
526def wrap_gemini(
527 client: C,
528 *,
529 tracing_extra: Optional[TracingExtra] = None,
530 chat_name: str = "ChatGoogleGenerativeAI",
531) -> C:
532 """Patch the Google Gen AI client to make it traceable.
534 !!! warning
536 **BETA**: This wrapper is in beta.
538 Supports:
539 - `generate_content` and `generate_content_stream` methods
540 - Sync and async clients
541 - Streaming and non-streaming responses
542 - Tool/function calling with proper UI rendering
543 - Multimodal inputs (text + images)
544 - Image generation with `inline_data` support
545 - Token usage tracking including reasoning tokens
547 Args:
548 client: The Google Gen AI client to patch.
549 tracing_extra: Extra tracing information.
550 chat_name: The run name for the chat endpoint.
552 Returns:
553 The patched client.
555 Example:
556 ```python
557 from google import genai
558 from google.genai import types
559 from langsmith import wrappers
561 # Use Google Gen AI client same as you normally would.
562 client = wrappers.wrap_gemini(genai.Client(api_key="your-api-key"))
564 # Basic text generation:
565 response = client.models.generate_content(
566 model="gemini-2.5-flash",
567 contents="Why is the sky blue?",
568 )
569 print(response.text)
571 # Streaming:
572 for chunk in client.models.generate_content_stream(
573 model="gemini-2.5-flash",
574 contents="Tell me a story",
575 ):
576 print(chunk.text, end="")
578 # Tool/Function calling:
579 schedule_meeting_function = {
580 "name": "schedule_meeting",
581 "description": "Schedules a meeting with specified attendees.",
582 "parameters": {
583 "type": "object",
584 "properties": {
585 "attendees": {"type": "array", "items": {"type": "string"}},
586 "date": {"type": "string"},
587 "time": {"type": "string"},
588 "topic": {"type": "string"},
589 },
590 "required": ["attendees", "date", "time", "topic"],
591 },
592 }
594 tools = types.Tool(function_declarations=[schedule_meeting_function])
595 config = types.GenerateContentConfig(tools=[tools])
597 response = client.models.generate_content(
598 model="gemini-2.5-flash",
599 contents="Schedule a meeting with Bob and Alice tomorrow at 2 PM.",
600 config=config,
601 )
603 # Image generation:
604 response = client.models.generate_content(
605 model="gemini-2.5-flash-image",
606 contents=["Create a picture of a futuristic city"],
607 )
609 # Save generated image
610 from io import BytesIO
611 from PIL import Image
613 for part in response.candidates[0].content.parts:
614 if part.inline_data is not None:
615 image = Image.open(BytesIO(part.inline_data.data))
616 image.save("generated_image.png")
617 ```
619 !!! version-added "Added in `langsmith` 0.4.33"
621 Initial beta release of Google Gemini wrapper.
623 """
624 tracing_extra = tracing_extra or {}
626 # Check if already wrapped to prevent double-wrapping
627 if (
628 hasattr(client, "models")
629 and hasattr(client.models, "generate_content")
630 and hasattr(client.models.generate_content, "__wrapped__")
631 ):
632 raise ValueError(
633 "This Google Gen AI client has already been wrapped. "
634 "Wrapping a client multiple times is not supported."
635 )
637 # Wrap synchronous methods
638 if hasattr(client, "models") and hasattr(client.models, "generate_content"):
639 client.models.generate_content = _get_wrapper( # type: ignore[method-assign]
640 client.models.generate_content,
641 chat_name,
642 tracing_extra=tracing_extra,
643 is_streaming=False,
644 )
646 if hasattr(client, "models") and hasattr(client.models, "generate_content_stream"):
647 client.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign]
648 client.models.generate_content_stream,
649 chat_name,
650 tracing_extra=tracing_extra,
651 is_streaming=True,
652 )
654 # Wrap async methods (aio namespace)
655 if (
656 hasattr(client, "aio")
657 and hasattr(client.aio, "models")
658 and hasattr(client.aio.models, "generate_content")
659 ):
660 client.aio.models.generate_content = _get_wrapper( # type: ignore[method-assign]
661 client.aio.models.generate_content,
662 chat_name,
663 tracing_extra=tracing_extra,
664 is_streaming=False,
665 )
667 if (
668 hasattr(client, "aio")
669 and hasattr(client.aio, "models")
670 and hasattr(client.aio.models, "generate_content_stream")
671 ):
672 client.aio.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign]
673 client.aio.models.generate_content_stream,
674 chat_name,
675 tracing_extra=tracing_extra,
676 is_streaming=True,
677 )
679 return client