Coverage for langsmith/wrappers/_gemini.py: 13%

218 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-12-11 16:15 -0800

1from __future__ import annotations 

2 

3import base64 

4import functools 

5import json 

6import logging 

7from collections.abc import Mapping 

8from typing import ( 

9 TYPE_CHECKING, 

10 Any, 

11 Callable, 

12 Optional, 

13 TypeVar, 

14 Union, 

15) 

16 

17from typing_extensions import TypedDict 

18 

19from langsmith import client as ls_client 

20from langsmith import run_helpers 

21from langsmith._internal._beta_decorator import warn_beta 

22from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata 

23 

24if TYPE_CHECKING: 

25 from google import genai # type: ignore[import-untyped, attr-defined] 

26 

27C = TypeVar("C", bound=Union["genai.Client", Any]) 

28logger = logging.getLogger(__name__) 

29 

30 

31def _strip_none(d: dict) -> dict: 

32 """Remove `None` values from dictionary.""" 

33 return {k: v for k, v in d.items() if v is not None} 

34 

35 

36def _convert_config_for_tracing(kwargs: dict) -> None: 

37 """Convert `GenerateContentConfig` to `dict` for LangSmith compatibility.""" 

38 if "config" in kwargs and not isinstance(kwargs["config"], dict): 

39 kwargs["config"] = vars(kwargs["config"]) 

40 

41 

42def _process_gemini_inputs(inputs: dict) -> dict: 

43 r"""Process Gemini inputs to normalize them for LangSmith tracing. 

44 

45 Example: 

46 ```txt 

47 {"contents": "Hello", "model": "gemini-pro"} 

48 → {"messages": [{"role": "user", "content": "Hello"}], "model": "gemini-pro"} 

49 {"contents": [{"role": "user", "parts": [{"text": "What is AI?"}]}], "model": "gemini-pro"} 

50 → {"messages": [{"role": "user", "content": "What is AI?"}], "model": "gemini-pro"} 

51 ``` 

52 """ # noqa: E501 

53 # If contents is not present or not in list format, return as-is 

54 contents = inputs.get("contents") 

55 if not contents: 

56 return inputs 

57 

58 # Handle string input (simple case) 

59 if isinstance(contents, str): 

60 return { 

61 "messages": [{"role": "user", "content": contents}], 

62 "model": inputs.get("model"), 

63 **({k: v for k, v in inputs.items() if k not in ("contents", "model")}), 

64 } 

65 

66 # Handle list of content objects (multimodal case) 

67 if isinstance(contents, list): 

68 # Check if it's a simple list of strings 

69 if all(isinstance(item, str) for item in contents): 

70 # Each string becomes a separate user message (matches Gemini's behavior) 

71 return { 

72 "messages": [{"role": "user", "content": item} for item in contents], 

73 "model": inputs.get("model"), 

74 **({k: v for k, v in inputs.items() if k not in ("contents", "model")}), 

75 } 

76 # Handle complex multimodal case 

77 messages = [] 

78 for content in contents: 

79 if isinstance(content, dict): 

80 role = content.get("role", "user") 

81 parts = content.get("parts", []) 

82 

83 # Extract text and other parts 

84 text_parts = [] 

85 content_parts = [] 

86 

87 for part in parts: 

88 if isinstance(part, dict): 

89 # Handle text parts 

90 if "text" in part and part["text"]: 

91 text_parts.append(part["text"]) 

92 content_parts.append({"type": "text", "text": part["text"]}) 

93 # Handle inline data (images) 

94 elif "inline_data" in part: 

95 inline_data = part["inline_data"] 

96 mime_type = inline_data.get("mime_type", "image/jpeg") 

97 data = inline_data.get("data", b"") 

98 

99 # Convert bytes to base64 string if needed 

100 if isinstance(data, bytes): 

101 data_b64 = base64.b64encode(data).decode("utf-8") 

102 else: 

103 data_b64 = data # Already a string 

104 

105 content_parts.append( 

106 { 

107 "type": "image_url", 

108 "image_url": { 

109 "url": f"data:{mime_type};base64,{data_b64}", 

110 "detail": "high", 

111 }, 

112 } 

113 ) 

114 # Handle function responses 

115 elif "functionResponse" in part: 

116 function_response = part["functionResponse"] 

117 content_parts.append( 

118 { 

119 "type": "function_response", 

120 "function_response": { 

121 "name": function_response.get("name"), 

122 "response": function_response.get( 

123 "response", {} 

124 ), 

125 }, 

126 } 

127 ) 

128 # Handle function calls (for conversation history) 

129 elif "function_call" in part or "functionCall" in part: 

130 function_call = part.get("function_call") or part.get( 

131 "functionCall" 

132 ) 

133 

134 if function_call is not None: 

135 # Normalize to dict (FunctionCall is a Pydantic model) 

136 if not isinstance(function_call, dict): 

137 function_call = function_call.to_dict() 

138 

139 content_parts.append( 

140 { 

141 "type": "function_call", 

142 "function_call": { 

143 "id": function_call.get("id"), 

144 "name": function_call.get("name"), 

145 "arguments": function_call.get("args", {}), 

146 }, 

147 } 

148 ) 

149 elif isinstance(part, str): 

150 # Handle simple string parts 

151 text_parts.append(part) 

152 content_parts.append({"type": "text", "text": part}) 

153 

154 # If only text parts, use simple string format 

155 if content_parts and all( 

156 p.get("type") == "text" for p in content_parts 

157 ): 

158 message_content: Union[str, list[dict[str, Any]]] = "\n".join( 

159 text_parts 

160 ) 

161 else: 

162 message_content = content_parts if content_parts else "" 

163 

164 messages.append({"role": role, "content": message_content}) 

165 return { 

166 "messages": messages, 

167 "model": inputs.get("model"), 

168 **({k: v for k, v in inputs.items() if k not in ("contents", "model")}), 

169 } 

170 

171 # Fallback: return original inputs 

172 return inputs 

173 

174 

175def _infer_invocation_params(kwargs: dict) -> dict: 

176 """Extract invocation parameters for tracing.""" 

177 stripped = _strip_none(kwargs) 

178 config = stripped.get("config", {}) 

179 

180 # Handle both dict config and GenerateContentConfig object 

181 if hasattr(config, "temperature"): 

182 temperature = config.temperature 

183 max_tokens = getattr(config, "max_output_tokens", None) 

184 stop = getattr(config, "stop_sequences", None) 

185 else: 

186 temperature = config.get("temperature") 

187 max_tokens = config.get("max_output_tokens") 

188 stop = config.get("stop_sequences") 

189 

190 return { 

191 "ls_provider": "google", 

192 "ls_model_type": "chat", 

193 "ls_model_name": stripped.get("model"), 

194 "ls_temperature": temperature, 

195 "ls_max_tokens": max_tokens, 

196 "ls_stop": stop, 

197 } 

198 

199 

200def _create_usage_metadata(gemini_usage_metadata: dict) -> UsageMetadata: 

201 """Convert Gemini usage metadata to LangSmith format.""" 

202 prompt_token_count = gemini_usage_metadata.get("prompt_token_count") or 0 

203 candidates_token_count = gemini_usage_metadata.get("candidates_token_count") or 0 

204 cached_content_token_count = ( 

205 gemini_usage_metadata.get("cached_content_token_count") or 0 

206 ) 

207 thoughts_token_count = gemini_usage_metadata.get("thoughts_token_count") or 0 

208 total_token_count = ( 

209 gemini_usage_metadata.get("total_token_count") 

210 or prompt_token_count + candidates_token_count 

211 ) 

212 

213 input_token_details: dict = {} 

214 if cached_content_token_count: 

215 input_token_details["cache_read"] = cached_content_token_count 

216 

217 output_token_details: dict = {} 

218 if thoughts_token_count: 

219 output_token_details["reasoning"] = thoughts_token_count 

220 

221 return UsageMetadata( 

222 input_tokens=prompt_token_count, 

223 output_tokens=candidates_token_count, 

224 total_tokens=total_token_count, 

225 input_token_details=InputTokenDetails( 

226 **{k: v for k, v in input_token_details.items() if v is not None} 

227 ), 

228 output_token_details=OutputTokenDetails( 

229 **{k: v for k, v in output_token_details.items() if v is not None} 

230 ), 

231 ) 

232 

233 

234def _process_generate_content_response(response: Any) -> dict: 

235 """Process Gemini response for tracing.""" 

236 try: 

237 # Convert response to dictionary 

238 if hasattr(response, "to_dict"): 

239 rdict = response.to_dict() 

240 elif hasattr(response, "model_dump"): 

241 rdict = response.model_dump() 

242 else: 

243 rdict = {"text": getattr(response, "text", str(response))} 

244 

245 # Extract content from candidates if available 

246 content_result = "" 

247 content_parts = [] 

248 finish_reason: Optional[str] = None 

249 if "candidates" in rdict and rdict["candidates"]: 

250 candidate = rdict["candidates"][0] 

251 if "content" in candidate: 

252 content = candidate["content"] 

253 if "parts" in content and content["parts"]: 

254 for part in content["parts"]: 

255 # Handle text parts 

256 if "text" in part and part["text"]: 

257 content_result += part["text"] 

258 content_parts.append({"type": "text", "text": part["text"]}) 

259 # Handle inline data (images) in response 

260 elif "inline_data" in part and part["inline_data"] is not None: 

261 inline_data = part["inline_data"] 

262 mime_type = inline_data.get("mime_type", "image/jpeg") 

263 data = inline_data.get("data", b"") 

264 

265 # Convert bytes to base64 string if needed 

266 if isinstance(data, bytes): 

267 data_b64 = base64.b64encode(data).decode("utf-8") 

268 else: 

269 data_b64 = data # Already a string 

270 

271 content_parts.append( 

272 { 

273 "type": "image_url", 

274 "image_url": { 

275 "url": f"data:{mime_type};base64,{data_b64}", 

276 "detail": "high", 

277 }, 

278 } 

279 ) 

280 # Handle function calls in response 

281 elif "function_call" in part or "functionCall" in part: 

282 function_call = part.get("function_call") or part.get( 

283 "functionCall" 

284 ) 

285 

286 if function_call is not None: 

287 # Normalize to dict (FunctionCall is a Pydantic model) 

288 if not isinstance(function_call, dict): 

289 function_call = function_call.to_dict() 

290 

291 content_parts.append( 

292 { 

293 "type": "function_call", 

294 "function_call": { 

295 "id": function_call.get("id"), 

296 "name": function_call.get("name"), 

297 "arguments": function_call.get("args", {}), 

298 }, 

299 } 

300 ) 

301 if "finish_reason" in candidate and candidate["finish_reason"]: 

302 finish_reason = candidate["finish_reason"] 

303 elif "text" in rdict: 

304 content_result = rdict["text"] 

305 content_parts.append({"type": "text", "text": content_result}) 

306 

307 # Build chat-like response format - use OpenAI-compatible format for tool calls 

308 tool_calls = [p for p in content_parts if p.get("type") == "function_call"] 

309 if tool_calls: 

310 # OpenAI-compatible format for LangSmith UI 

311 result = { 

312 "content": content_result or None, 

313 "role": "assistant", 

314 "finish_reason": finish_reason, 

315 "tool_calls": [ 

316 { 

317 "id": tc["function_call"].get("id") or f"call_{i}", 

318 "type": "function", 

319 "index": i, 

320 "function": { 

321 "name": tc["function_call"]["name"], 

322 "arguments": json.dumps(tc["function_call"]["arguments"]), 

323 }, 

324 } 

325 for i, tc in enumerate(tool_calls) 

326 ], 

327 } 

328 elif len(content_parts) > 1 or ( 

329 content_parts and content_parts[0]["type"] != "text" 

330 ): 

331 # Use structured format for mixed non-tool content 

332 result = { 

333 "content": content_parts, 

334 "role": "assistant", 

335 "finish_reason": finish_reason, 

336 } 

337 else: 

338 # Use simple string format for text-only responses 

339 result = { 

340 "content": content_result, 

341 "role": "assistant", 

342 "finish_reason": finish_reason, 

343 } 

344 

345 # Extract and convert usage metadata 

346 usage_metadata = rdict.get("usage_metadata") 

347 usage_dict: UsageMetadata = UsageMetadata( 

348 input_tokens=0, output_tokens=0, total_tokens=0 

349 ) 

350 if usage_metadata: 

351 usage_dict = _create_usage_metadata(usage_metadata) 

352 # Add usage_metadata to both run.extra AND outputs 

353 current_run = run_helpers.get_current_run_tree() 

354 if current_run: 

355 try: 

356 meta = current_run.extra.setdefault("metadata", {}).setdefault( 

357 "usage_metadata", {} 

358 ) 

359 meta.update(usage_dict) 

360 current_run.patch() 

361 except Exception as e: 

362 logger.warning(f"Failed to update usage metadata: {e}") 

363 

364 # Return in a format that avoids stringification by LangSmith 

365 if result.get("tool_calls"): 

366 # For responses with tool calls, return structured format 

367 return { 

368 "content": result["content"], 

369 "role": "assistant", 

370 "finish_reason": finish_reason, 

371 "tool_calls": result["tool_calls"], 

372 "usage_metadata": usage_dict, 

373 } 

374 else: 

375 # For simple text responses, return minimal structure with usage metadata 

376 if isinstance(result["content"], str): 

377 return { 

378 "content": result["content"], 

379 "role": "assistant", 

380 "finish_reason": finish_reason, 

381 "usage_metadata": usage_dict, 

382 } 

383 else: 

384 # For multimodal content, return structured format with usage metadata 

385 return { 

386 "content": result["content"], 

387 "role": "assistant", 

388 "finish_reason": finish_reason, 

389 "usage_metadata": usage_dict, 

390 } 

391 except Exception as e: 

392 logger.debug(f"Error processing Gemini response: {e}") 

393 return {"output": response} 

394 

395 

396def _reduce_generate_content_chunks(all_chunks: list) -> dict: 

397 """Reduce streaming chunks into a single response.""" 

398 if not all_chunks: 

399 return { 

400 "content": "", 

401 "usage_metadata": UsageMetadata( 

402 input_tokens=0, output_tokens=0, total_tokens=0 

403 ), 

404 } 

405 

406 # Accumulate text from all chunks 

407 full_text = "" 

408 last_chunk = None 

409 

410 for chunk in all_chunks: 

411 try: 

412 if hasattr(chunk, "text") and chunk.text: 

413 full_text += chunk.text 

414 last_chunk = chunk 

415 except Exception as e: 

416 logger.debug(f"Error processing chunk: {e}") 

417 

418 # Extract usage metadata from the last chunk 

419 usage_metadata: UsageMetadata = UsageMetadata( 

420 input_tokens=0, output_tokens=0, total_tokens=0 

421 ) 

422 if last_chunk: 

423 try: 

424 if hasattr(last_chunk, "usage_metadata") and last_chunk.usage_metadata: 

425 if hasattr(last_chunk.usage_metadata, "to_dict"): 

426 usage_dict = last_chunk.usage_metadata.to_dict() 

427 elif hasattr(last_chunk.usage_metadata, "model_dump"): 

428 usage_dict = last_chunk.usage_metadata.model_dump() 

429 else: 

430 usage_dict = { 

431 "prompt_token_count": getattr( 

432 last_chunk.usage_metadata, "prompt_token_count", 0 

433 ), 

434 "candidates_token_count": getattr( 

435 last_chunk.usage_metadata, "candidates_token_count", 0 

436 ), 

437 "cached_content_token_count": getattr( 

438 last_chunk.usage_metadata, "cached_content_token_count", 0 

439 ), 

440 "thoughts_token_count": getattr( 

441 last_chunk.usage_metadata, "thoughts_token_count", 0 

442 ), 

443 "total_token_count": getattr( 

444 last_chunk.usage_metadata, "total_token_count", 0 

445 ), 

446 } 

447 # Add usage_metadata to both run.extra AND outputs 

448 usage_metadata = _create_usage_metadata(usage_dict) 

449 current_run = run_helpers.get_current_run_tree() 

450 if current_run: 

451 try: 

452 meta = current_run.extra.setdefault("metadata", {}).setdefault( 

453 "usage_metadata", {} 

454 ) 

455 meta.update(usage_metadata) 

456 current_run.patch() 

457 except Exception as e: 

458 logger.warning(f"Failed to update usage metadata: {e}") 

459 except Exception as e: 

460 logger.debug(f"Error extracting metadata from last chunk: {e}") 

461 

462 # Return minimal structure with usage_metadata in outputs 

463 return { 

464 "content": full_text, 

465 "usage_metadata": usage_metadata, 

466 } 

467 

468 

469def _get_wrapper( 

470 original_generate: Callable, 

471 name: str, 

472 tracing_extra: Optional[TracingExtra] = None, 

473 is_streaming: bool = False, 

474) -> Callable: 

475 """Create a wrapper for Gemini's `generate_content` methods.""" 

476 textra = tracing_extra or {} 

477 

478 @functools.wraps(original_generate) 

479 def generate(*args, **kwargs): 

480 # Handle config object before tracing setup 

481 _convert_config_for_tracing(kwargs) 

482 

483 decorator = run_helpers.traceable( 

484 name=name, 

485 run_type="llm", 

486 reduce_fn=_reduce_generate_content_chunks if is_streaming else None, 

487 process_inputs=_process_gemini_inputs, 

488 process_outputs=( 

489 _process_generate_content_response if not is_streaming else None 

490 ), 

491 _invocation_params_fn=_infer_invocation_params, 

492 **textra, 

493 ) 

494 

495 return decorator(original_generate)(*args, **kwargs) 

496 

497 @functools.wraps(original_generate) 

498 async def agenerate(*args, **kwargs): 

499 # Handle config object before tracing setup 

500 _convert_config_for_tracing(kwargs) 

501 

502 decorator = run_helpers.traceable( 

503 name=name, 

504 run_type="llm", 

505 reduce_fn=_reduce_generate_content_chunks if is_streaming else None, 

506 process_inputs=_process_gemini_inputs, 

507 process_outputs=( 

508 _process_generate_content_response if not is_streaming else None 

509 ), 

510 _invocation_params_fn=_infer_invocation_params, 

511 **textra, 

512 ) 

513 

514 return await decorator(original_generate)(*args, **kwargs) 

515 

516 return agenerate if run_helpers.is_async(original_generate) else generate 

517 

518 

519class TracingExtra(TypedDict, total=False): 

520 metadata: Optional[Mapping[str, Any]] 

521 tags: Optional[list[str]] 

522 client: Optional[ls_client.Client] 

523 

524 

525@warn_beta 

526def wrap_gemini( 

527 client: C, 

528 *, 

529 tracing_extra: Optional[TracingExtra] = None, 

530 chat_name: str = "ChatGoogleGenerativeAI", 

531) -> C: 

532 """Patch the Google Gen AI client to make it traceable. 

533 

534 !!! warning 

535 

536 **BETA**: This wrapper is in beta. 

537 

538 Supports: 

539 - `generate_content` and `generate_content_stream` methods 

540 - Sync and async clients 

541 - Streaming and non-streaming responses 

542 - Tool/function calling with proper UI rendering 

543 - Multimodal inputs (text + images) 

544 - Image generation with `inline_data` support 

545 - Token usage tracking including reasoning tokens 

546 

547 Args: 

548 client: The Google Gen AI client to patch. 

549 tracing_extra: Extra tracing information. 

550 chat_name: The run name for the chat endpoint. 

551 

552 Returns: 

553 The patched client. 

554 

555 Example: 

556 ```python 

557 from google import genai 

558 from google.genai import types 

559 from langsmith import wrappers 

560 

561 # Use Google Gen AI client same as you normally would. 

562 client = wrappers.wrap_gemini(genai.Client(api_key="your-api-key")) 

563 

564 # Basic text generation: 

565 response = client.models.generate_content( 

566 model="gemini-2.5-flash", 

567 contents="Why is the sky blue?", 

568 ) 

569 print(response.text) 

570 

571 # Streaming: 

572 for chunk in client.models.generate_content_stream( 

573 model="gemini-2.5-flash", 

574 contents="Tell me a story", 

575 ): 

576 print(chunk.text, end="") 

577 

578 # Tool/Function calling: 

579 schedule_meeting_function = { 

580 "name": "schedule_meeting", 

581 "description": "Schedules a meeting with specified attendees.", 

582 "parameters": { 

583 "type": "object", 

584 "properties": { 

585 "attendees": {"type": "array", "items": {"type": "string"}}, 

586 "date": {"type": "string"}, 

587 "time": {"type": "string"}, 

588 "topic": {"type": "string"}, 

589 }, 

590 "required": ["attendees", "date", "time", "topic"], 

591 }, 

592 } 

593 

594 tools = types.Tool(function_declarations=[schedule_meeting_function]) 

595 config = types.GenerateContentConfig(tools=[tools]) 

596 

597 response = client.models.generate_content( 

598 model="gemini-2.5-flash", 

599 contents="Schedule a meeting with Bob and Alice tomorrow at 2 PM.", 

600 config=config, 

601 ) 

602 

603 # Image generation: 

604 response = client.models.generate_content( 

605 model="gemini-2.5-flash-image", 

606 contents=["Create a picture of a futuristic city"], 

607 ) 

608 

609 # Save generated image 

610 from io import BytesIO 

611 from PIL import Image 

612 

613 for part in response.candidates[0].content.parts: 

614 if part.inline_data is not None: 

615 image = Image.open(BytesIO(part.inline_data.data)) 

616 image.save("generated_image.png") 

617 ``` 

618 

619 !!! version-added "Added in `langsmith` 0.4.33" 

620 

621 Initial beta release of Google Gemini wrapper. 

622 

623 """ 

624 tracing_extra = tracing_extra or {} 

625 

626 # Check if already wrapped to prevent double-wrapping 

627 if ( 

628 hasattr(client, "models") 

629 and hasattr(client.models, "generate_content") 

630 and hasattr(client.models.generate_content, "__wrapped__") 

631 ): 

632 raise ValueError( 

633 "This Google Gen AI client has already been wrapped. " 

634 "Wrapping a client multiple times is not supported." 

635 ) 

636 

637 # Wrap synchronous methods 

638 if hasattr(client, "models") and hasattr(client.models, "generate_content"): 

639 client.models.generate_content = _get_wrapper( # type: ignore[method-assign] 

640 client.models.generate_content, 

641 chat_name, 

642 tracing_extra=tracing_extra, 

643 is_streaming=False, 

644 ) 

645 

646 if hasattr(client, "models") and hasattr(client.models, "generate_content_stream"): 

647 client.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign] 

648 client.models.generate_content_stream, 

649 chat_name, 

650 tracing_extra=tracing_extra, 

651 is_streaming=True, 

652 ) 

653 

654 # Wrap async methods (aio namespace) 

655 if ( 

656 hasattr(client, "aio") 

657 and hasattr(client.aio, "models") 

658 and hasattr(client.aio.models, "generate_content") 

659 ): 

660 client.aio.models.generate_content = _get_wrapper( # type: ignore[method-assign] 

661 client.aio.models.generate_content, 

662 chat_name, 

663 tracing_extra=tracing_extra, 

664 is_streaming=False, 

665 ) 

666 

667 if ( 

668 hasattr(client, "aio") 

669 and hasattr(client.aio, "models") 

670 and hasattr(client.aio.models, "generate_content_stream") 

671 ): 

672 client.aio.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign] 

673 client.aio.models.generate_content_stream, 

674 chat_name, 

675 tracing_extra=tracing_extra, 

676 is_streaming=True, 

677 ) 

678 

679 return client