Coverage for langsmith/wrappers/_openai.py: 16%

197 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-12-11 16:15 -0800

1from __future__ import annotations 

2 

3import functools 

4import logging 

5from collections import defaultdict 

6from collections.abc import Mapping 

7from typing import ( 

8 TYPE_CHECKING, 

9 Any, 

10 Callable, 

11 Optional, 

12 TypeVar, 

13 Union, 

14) 

15 

16from typing_extensions import TypedDict 

17 

18from langsmith import client as ls_client 

19from langsmith import run_helpers 

20from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata 

21 

22if TYPE_CHECKING: 

23 from openai import AsyncOpenAI, OpenAI 

24 from openai.types.chat.chat_completion_chunk import ( 

25 ChatCompletionChunk, 

26 Choice, 

27 ChoiceDeltaToolCall, 

28 ) 

29 from openai.types.completion import Completion 

30 from openai.types.responses import ResponseStreamEvent # type: ignore 

31 

32# Any is used since it may work with Azure or other providers 

33C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI", Any]) 

34logger = logging.getLogger(__name__) 

35 

36 

37@functools.lru_cache 

38def _get_omit_types() -> tuple[type, ...]: 

39 """Get NotGiven/Omit sentinel types used by OpenAI SDK.""" 

40 types: list[type[Any]] = [] 

41 try: 

42 from openai._types import NotGiven, Omit 

43 

44 types.append(NotGiven) 

45 types.append(Omit) 

46 except ImportError: 

47 pass 

48 

49 return tuple(types) 

50 

51 

52def _strip_not_given(d: dict) -> dict: 

53 try: 

54 omit_types = _get_omit_types() 

55 if not omit_types: 

56 return d 

57 return { 

58 k: v 

59 for k, v in d.items() 

60 if not (isinstance(v, omit_types) or (k.startswith("extra_") and v is None)) 

61 } 

62 except Exception as e: 

63 logger.error(f"Error stripping NotGiven: {e}") 

64 return d 

65 

66 

67def _process_inputs(d: dict) -> dict: 

68 """Strip `NotGiven` values and serialize `text_format` to JSON schema.""" 

69 d = _strip_not_given(d) 

70 

71 # Convert text_format (Pydantic model) to JSON schema if present 

72 if "text_format" in d: 

73 text_format = d["text_format"] 

74 if hasattr(text_format, "model_json_schema"): 

75 try: 

76 return { 

77 **d, 

78 "text_format": text_format.model_json_schema(), 

79 } 

80 except Exception: 

81 pass 

82 return d 

83 

84 

85def _infer_invocation_params(model_type: str, provider: str, kwargs: dict): 

86 stripped = _strip_not_given(kwargs) 

87 

88 stop = stripped.get("stop") 

89 if stop and isinstance(stop, str): 

90 stop = [stop] 

91 

92 # Allowlist of safe invocation parameters to include 

93 # Only include known, non-sensitive parameters 

94 allowed_invocation_keys = { 

95 "frequency_penalty", 

96 "n", 

97 "logit_bias", 

98 "logprobs", 

99 "modalities", 

100 "parallel_tool_calls", 

101 "prediction", 

102 "presence_penalty", 

103 "prompt_cache_key", 

104 "reasoning", 

105 "reasoning_effort", 

106 "response_format", 

107 "seed", 

108 "service_tier", 

109 "stream_options", 

110 "top_logprobs", 

111 "top_p", 

112 "truncation", 

113 "user", 

114 "verbosity", 

115 "web_search_options", 

116 } 

117 

118 # Only include allowlisted parameters 

119 invocation_params = { 

120 k: v for k, v in stripped.items() if k in allowed_invocation_keys 

121 } 

122 

123 return { 

124 "ls_provider": provider, 

125 "ls_model_type": model_type, 

126 "ls_model_name": stripped.get("model"), 

127 "ls_temperature": stripped.get("temperature"), 

128 "ls_max_tokens": stripped.get("max_tokens") 

129 or stripped.get("max_completion_tokens") 

130 or stripped.get("max_output_tokens"), 

131 "ls_stop": stop, 

132 "ls_invocation_params": invocation_params, 

133 } 

134 

135 

136def _reduce_choices(choices: list[Choice]) -> dict: 

137 reversed_choices = list(reversed(choices)) 

138 message: dict[str, Any] = { 

139 "role": "assistant", 

140 "content": "", 

141 } 

142 for c in reversed_choices: 

143 if hasattr(c, "delta") and getattr(c.delta, "role", None): 

144 message["role"] = c.delta.role 

145 break 

146 tool_calls: defaultdict[int, list[ChoiceDeltaToolCall]] = defaultdict(list) 

147 for c in choices: 

148 if hasattr(c, "delta"): 

149 if getattr(c.delta, "content", None): 

150 message["content"] += c.delta.content 

151 if getattr(c.delta, "function_call", None): 

152 if not message.get("function_call"): 

153 message["function_call"] = {"name": "", "arguments": ""} 

154 name_ = getattr(c.delta.function_call, "name", None) 

155 if name_: 

156 message["function_call"]["name"] += name_ 

157 arguments_ = getattr(c.delta.function_call, "arguments", None) 

158 if arguments_: 

159 message["function_call"]["arguments"] += arguments_ 

160 if getattr(c.delta, "tool_calls", None): 

161 tool_calls_list = c.delta.tool_calls 

162 if tool_calls_list is not None: 

163 for tool_call in tool_calls_list: 

164 tool_calls[tool_call.index].append(tool_call) 

165 if tool_calls: 

166 message["tool_calls"] = [None for _ in range(max(tool_calls.keys()) + 1)] 

167 for index, tool_call_chunks in tool_calls.items(): 

168 message["tool_calls"][index] = { 

169 "index": index, 

170 "id": next((c.id for c in tool_call_chunks if c.id), None), 

171 "type": next((c.type for c in tool_call_chunks if c.type), None), 

172 "function": {"name": "", "arguments": ""}, 

173 } 

174 for chunk in tool_call_chunks: 

175 if getattr(chunk, "function", None): 

176 name_ = getattr(chunk.function, "name", None) 

177 if name_: 

178 message["tool_calls"][index]["function"]["name"] += name_ 

179 arguments_ = getattr(chunk.function, "arguments", None) 

180 if arguments_: 

181 message["tool_calls"][index]["function"]["arguments"] += ( 

182 arguments_ 

183 ) 

184 return { 

185 "index": getattr(choices[0], "index", 0) if choices else 0, 

186 "finish_reason": next( 

187 ( 

188 c.finish_reason 

189 for c in reversed_choices 

190 if getattr(c, "finish_reason", None) 

191 ), 

192 None, 

193 ), 

194 "message": message, 

195 } 

196 

197 

198def _reduce_chat(all_chunks: list[ChatCompletionChunk]) -> dict: 

199 choices_by_index: defaultdict[int, list[Choice]] = defaultdict(list) 

200 for chunk in all_chunks: 

201 for choice in chunk.choices: 

202 choices_by_index[choice.index].append(choice) 

203 if all_chunks: 

204 d = all_chunks[-1].model_dump() 

205 d["choices"] = [ 

206 _reduce_choices(choices) for choices in choices_by_index.values() 

207 ] 

208 else: 

209 d = {"choices": [{"message": {"role": "assistant", "content": ""}}]} 

210 # streamed outputs don't go through `process_outputs` 

211 # so we need to flatten metadata here 

212 oai_token_usage = d.pop("usage", None) 

213 d["usage_metadata"] = ( 

214 _create_usage_metadata(oai_token_usage) if oai_token_usage else None 

215 ) 

216 return d 

217 

218 

219def _reduce_completions(all_chunks: list[Completion]) -> dict: 

220 all_content = [] 

221 for chunk in all_chunks: 

222 content = chunk.choices[0].text 

223 if content is not None: 

224 all_content.append(content) 

225 content = "".join(all_content) 

226 if all_chunks: 

227 d = all_chunks[-1].model_dump() 

228 d["choices"] = [{"text": content}] 

229 else: 

230 d = {"choices": [{"text": content}]} 

231 

232 return d 

233 

234 

235def _create_usage_metadata( 

236 oai_token_usage: dict, service_tier: Optional[str] = None 

237) -> UsageMetadata: 

238 recognized_service_tier = ( 

239 service_tier if service_tier in ["priority", "flex"] else None 

240 ) 

241 service_tier_prefix = ( 

242 f"{recognized_service_tier}_" if recognized_service_tier else "" 

243 ) 

244 

245 input_tokens = ( 

246 oai_token_usage.get("prompt_tokens") or oai_token_usage.get("input_tokens") or 0 

247 ) 

248 output_tokens = ( 

249 oai_token_usage.get("completion_tokens") 

250 or oai_token_usage.get("output_tokens") 

251 or 0 

252 ) 

253 total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens 

254 input_token_details: dict = { 

255 "audio": ( 

256 oai_token_usage.get("prompt_tokens_details") 

257 or oai_token_usage.get("input_tokens_details") 

258 or {} 

259 ).get("audio_tokens"), 

260 f"{service_tier_prefix}cache_read": ( 

261 oai_token_usage.get("prompt_tokens_details") 

262 or oai_token_usage.get("input_tokens_details") 

263 or {} 

264 ).get("cached_tokens"), 

265 } 

266 output_token_details: dict = { 

267 "audio": ( 

268 oai_token_usage.get("completion_tokens_details") 

269 or oai_token_usage.get("output_tokens_details") 

270 or {} 

271 ).get("audio_tokens"), 

272 f"{service_tier_prefix}reasoning": ( 

273 oai_token_usage.get("completion_tokens_details") 

274 or oai_token_usage.get("output_tokens_details") 

275 or {} 

276 ).get("reasoning_tokens"), 

277 } 

278 

279 if recognized_service_tier: 

280 # Avoid counting cache read and reasoning tokens towards the 

281 # service tier token count since service tier tokens are already 

282 # priced differently 

283 input_token_details[recognized_service_tier] = input_tokens - ( 

284 input_token_details.get(f"{service_tier_prefix}cache_read") or 0 

285 ) 

286 output_token_details[recognized_service_tier] = output_tokens - ( 

287 output_token_details.get(f"{service_tier_prefix}reasoning") or 0 

288 ) 

289 

290 return UsageMetadata( 

291 input_tokens=input_tokens, 

292 output_tokens=output_tokens, 

293 total_tokens=total_tokens, 

294 input_token_details=InputTokenDetails( 

295 **{k: v for k, v in input_token_details.items() if v is not None} 

296 ), 

297 output_token_details=OutputTokenDetails( 

298 **{k: v for k, v in output_token_details.items() if v is not None} 

299 ), 

300 ) 

301 

302 

303def _process_chat_completion(outputs: Any): 

304 try: 

305 rdict = outputs.model_dump() 

306 oai_token_usage = rdict.pop("usage", None) 

307 rdict["usage_metadata"] = ( 

308 _create_usage_metadata(oai_token_usage, rdict.get("service_tier")) 

309 if oai_token_usage 

310 else None 

311 ) 

312 return rdict 

313 except BaseException as e: 

314 logger.debug(f"Error processing chat completion: {e}") 

315 return {"output": outputs} 

316 

317 

318def _get_wrapper( 

319 original_create: Callable, 

320 name: str, 

321 reduce_fn: Callable, 

322 tracing_extra: Optional[TracingExtra] = None, 

323 invocation_params_fn: Optional[Callable] = None, 

324 process_outputs: Optional[Callable] = None, 

325) -> Callable: 

326 textra = tracing_extra or {} 

327 

328 @functools.wraps(original_create) 

329 def create(*args, **kwargs): 

330 decorator = run_helpers.traceable( 

331 name=name, 

332 run_type="llm", 

333 reduce_fn=reduce_fn if kwargs.get("stream") is True else None, 

334 process_inputs=_process_inputs, 

335 _invocation_params_fn=invocation_params_fn, 

336 process_outputs=process_outputs, 

337 **textra, 

338 ) 

339 

340 return decorator(original_create)(*args, **kwargs) 

341 

342 @functools.wraps(original_create) 

343 async def acreate(*args, **kwargs): 

344 decorator = run_helpers.traceable( 

345 name=name, 

346 run_type="llm", 

347 reduce_fn=reduce_fn if kwargs.get("stream") is True else None, 

348 process_inputs=_process_inputs, 

349 _invocation_params_fn=invocation_params_fn, 

350 process_outputs=process_outputs, 

351 **textra, 

352 ) 

353 return await decorator(original_create)(*args, **kwargs) 

354 

355 return acreate if run_helpers.is_async(original_create) else create 

356 

357 

358def _get_parse_wrapper( 

359 original_parse: Callable, 

360 name: str, 

361 process_outputs: Callable, 

362 tracing_extra: Optional[TracingExtra] = None, 

363 invocation_params_fn: Optional[Callable] = None, 

364) -> Callable: 

365 textra = tracing_extra or {} 

366 

367 @functools.wraps(original_parse) 

368 def parse(*args, **kwargs): 

369 decorator = run_helpers.traceable( 

370 name=name, 

371 run_type="llm", 

372 reduce_fn=None, 

373 process_inputs=_process_inputs, 

374 _invocation_params_fn=invocation_params_fn, 

375 process_outputs=process_outputs, 

376 **textra, 

377 ) 

378 return decorator(original_parse)(*args, **kwargs) 

379 

380 @functools.wraps(original_parse) 

381 async def aparse(*args, **kwargs): 

382 decorator = run_helpers.traceable( 

383 name=name, 

384 run_type="llm", 

385 reduce_fn=None, 

386 process_inputs=_process_inputs, 

387 _invocation_params_fn=invocation_params_fn, 

388 process_outputs=process_outputs, 

389 **textra, 

390 ) 

391 return await decorator(original_parse)(*args, **kwargs) 

392 

393 return aparse if run_helpers.is_async(original_parse) else parse 

394 

395 

396def _reduce_response_events(events: list[ResponseStreamEvent]) -> dict: 

397 for event in events: 

398 if event.type == "response.completed": 

399 return _process_responses_api_output(event.response) 

400 return {} 

401 

402 

403class TracingExtra(TypedDict, total=False): 

404 metadata: Optional[Mapping[str, Any]] 

405 tags: Optional[list[str]] 

406 client: Optional[ls_client.Client] 

407 

408 

409def wrap_openai( 

410 client: C, 

411 *, 

412 tracing_extra: Optional[TracingExtra] = None, 

413 chat_name: str = "ChatOpenAI", 

414 completions_name: str = "OpenAI", 

415) -> C: 

416 """Patch the OpenAI client to make it traceable. 

417 

418 Supports: 

419 - Chat and Responses API's 

420 - Sync and async OpenAI clients 

421 - `create` and `parse` methods 

422 - With and without streaming 

423 

424 Args: 

425 client: The client to patch. 

426 tracing_extra: Extra tracing information. 

427 chat_name: The run name for the chat completions endpoint. 

428 completions_name: The run name for the completions endpoint. 

429 

430 Returns: 

431 The patched client. 

432 

433 Example: 

434 ```python 

435 import openai 

436 from langsmith import wrappers 

437 

438 # Use OpenAI client same as you normally would. 

439 client = wrappers.wrap_openai(openai.OpenAI()) 

440 

441 # Chat API: 

442 messages = [ 

443 {"role": "system", "content": "You are a helpful assistant."}, 

444 { 

445 "role": "user", 

446 "content": "What physics breakthroughs do you predict will happen by 2300?", 

447 }, 

448 ] 

449 completion = client.chat.completions.create( 

450 model="gpt-4o-mini", messages=messages 

451 ) 

452 print(completion.choices[0].message.content) 

453 

454 # Responses API: 

455 response = client.responses.create( 

456 model="gpt-4o-mini", 

457 messages=messages, 

458 ) 

459 print(response.output_text) 

460 ``` 

461 

462 !!! warning "Behavior changed in `langsmith` 0.3.16" 

463 

464 Support for Responses API added. 

465 """ # noqa: E501 

466 tracing_extra = tracing_extra or {} 

467 

468 ls_provider = "openai" 

469 try: 

470 from openai import AsyncAzureOpenAI, AzureOpenAI 

471 

472 if isinstance(client, AzureOpenAI) or isinstance(client, AsyncAzureOpenAI): 

473 ls_provider = "azure" 

474 chat_name = "AzureChatOpenAI" 

475 completions_name = "AzureOpenAI" 

476 except ImportError: 

477 pass 

478 

479 # First wrap the create methods - these handle non-streaming cases 

480 client.chat.completions.create = _get_wrapper( # type: ignore[method-assign] 

481 client.chat.completions.create, 

482 chat_name, 

483 _reduce_chat, 

484 tracing_extra=tracing_extra, 

485 invocation_params_fn=functools.partial( 

486 _infer_invocation_params, "chat", ls_provider 

487 ), 

488 process_outputs=_process_chat_completion, 

489 ) 

490 

491 client.completions.create = _get_wrapper( # type: ignore[method-assign] 

492 client.completions.create, 

493 completions_name, 

494 _reduce_completions, 

495 tracing_extra=tracing_extra, 

496 invocation_params_fn=functools.partial( 

497 _infer_invocation_params, "llm", ls_provider 

498 ), 

499 ) 

500 

501 # Wrap beta.chat.completions.parse if it exists 

502 if ( 

503 hasattr(client, "beta") 

504 and hasattr(client.beta, "chat") 

505 and hasattr(client.beta.chat, "completions") 

506 and hasattr(client.beta.chat.completions, "parse") 

507 ): 

508 client.beta.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign] 

509 client.beta.chat.completions.parse, # type: ignore 

510 chat_name, 

511 _process_chat_completion, 

512 tracing_extra=tracing_extra, 

513 invocation_params_fn=functools.partial( 

514 _infer_invocation_params, "chat", ls_provider 

515 ), 

516 ) 

517 

518 # Wrap chat.completions.parse if it exists 

519 if ( 

520 hasattr(client, "chat") 

521 and hasattr(client.chat, "completions") 

522 and hasattr(client.chat.completions, "parse") 

523 ): 

524 client.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign] 

525 client.chat.completions.parse, # type: ignore 

526 chat_name, 

527 _process_chat_completion, 

528 tracing_extra=tracing_extra, 

529 invocation_params_fn=functools.partial( 

530 _infer_invocation_params, "chat", ls_provider 

531 ), 

532 ) 

533 

534 # For the responses API: "client.responses.create(**kwargs)" 

535 if hasattr(client, "responses"): 

536 if hasattr(client.responses, "create"): 

537 client.responses.create = _get_wrapper( # type: ignore[method-assign] 

538 client.responses.create, 

539 chat_name, 

540 _reduce_response_events, 

541 process_outputs=_process_responses_api_output, 

542 tracing_extra=tracing_extra, 

543 invocation_params_fn=functools.partial( 

544 _infer_invocation_params, "chat", ls_provider 

545 ), 

546 ) 

547 if hasattr(client.responses, "parse"): 

548 client.responses.parse = _get_parse_wrapper( # type: ignore[method-assign] 

549 client.responses.parse, 

550 chat_name, 

551 _process_responses_api_output, 

552 tracing_extra=tracing_extra, 

553 invocation_params_fn=functools.partial( 

554 _infer_invocation_params, "chat", ls_provider 

555 ), 

556 ) 

557 

558 return client 

559 

560 

561def _process_responses_api_output(response: Any) -> dict: 

562 if response: 

563 try: 

564 output = response.model_dump(exclude_none=True, mode="json") 

565 if usage := output.pop("usage", None): 

566 output["usage_metadata"] = _create_usage_metadata( 

567 usage, output.get("service_tier") 

568 ) 

569 return output 

570 except Exception: 

571 return {"output": response} 

572 return {}