Coverage for langsmith/wrappers/_openai.py: 16%
197 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-12-11 16:15 -0800
« prev ^ index » next coverage.py v7.10.1, created at 2025-12-11 16:15 -0800
1from __future__ import annotations
3import functools
4import logging
5from collections import defaultdict
6from collections.abc import Mapping
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 Callable,
11 Optional,
12 TypeVar,
13 Union,
14)
16from typing_extensions import TypedDict
18from langsmith import client as ls_client
19from langsmith import run_helpers
20from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata
22if TYPE_CHECKING:
23 from openai import AsyncOpenAI, OpenAI
24 from openai.types.chat.chat_completion_chunk import (
25 ChatCompletionChunk,
26 Choice,
27 ChoiceDeltaToolCall,
28 )
29 from openai.types.completion import Completion
30 from openai.types.responses import ResponseStreamEvent # type: ignore
32# Any is used since it may work with Azure or other providers
33C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI", Any])
34logger = logging.getLogger(__name__)
37@functools.lru_cache
38def _get_omit_types() -> tuple[type, ...]:
39 """Get NotGiven/Omit sentinel types used by OpenAI SDK."""
40 types: list[type[Any]] = []
41 try:
42 from openai._types import NotGiven, Omit
44 types.append(NotGiven)
45 types.append(Omit)
46 except ImportError:
47 pass
49 return tuple(types)
52def _strip_not_given(d: dict) -> dict:
53 try:
54 omit_types = _get_omit_types()
55 if not omit_types:
56 return d
57 return {
58 k: v
59 for k, v in d.items()
60 if not (isinstance(v, omit_types) or (k.startswith("extra_") and v is None))
61 }
62 except Exception as e:
63 logger.error(f"Error stripping NotGiven: {e}")
64 return d
67def _process_inputs(d: dict) -> dict:
68 """Strip `NotGiven` values and serialize `text_format` to JSON schema."""
69 d = _strip_not_given(d)
71 # Convert text_format (Pydantic model) to JSON schema if present
72 if "text_format" in d:
73 text_format = d["text_format"]
74 if hasattr(text_format, "model_json_schema"):
75 try:
76 return {
77 **d,
78 "text_format": text_format.model_json_schema(),
79 }
80 except Exception:
81 pass
82 return d
85def _infer_invocation_params(model_type: str, provider: str, kwargs: dict):
86 stripped = _strip_not_given(kwargs)
88 stop = stripped.get("stop")
89 if stop and isinstance(stop, str):
90 stop = [stop]
92 # Allowlist of safe invocation parameters to include
93 # Only include known, non-sensitive parameters
94 allowed_invocation_keys = {
95 "frequency_penalty",
96 "n",
97 "logit_bias",
98 "logprobs",
99 "modalities",
100 "parallel_tool_calls",
101 "prediction",
102 "presence_penalty",
103 "prompt_cache_key",
104 "reasoning",
105 "reasoning_effort",
106 "response_format",
107 "seed",
108 "service_tier",
109 "stream_options",
110 "top_logprobs",
111 "top_p",
112 "truncation",
113 "user",
114 "verbosity",
115 "web_search_options",
116 }
118 # Only include allowlisted parameters
119 invocation_params = {
120 k: v for k, v in stripped.items() if k in allowed_invocation_keys
121 }
123 return {
124 "ls_provider": provider,
125 "ls_model_type": model_type,
126 "ls_model_name": stripped.get("model"),
127 "ls_temperature": stripped.get("temperature"),
128 "ls_max_tokens": stripped.get("max_tokens")
129 or stripped.get("max_completion_tokens")
130 or stripped.get("max_output_tokens"),
131 "ls_stop": stop,
132 "ls_invocation_params": invocation_params,
133 }
136def _reduce_choices(choices: list[Choice]) -> dict:
137 reversed_choices = list(reversed(choices))
138 message: dict[str, Any] = {
139 "role": "assistant",
140 "content": "",
141 }
142 for c in reversed_choices:
143 if hasattr(c, "delta") and getattr(c.delta, "role", None):
144 message["role"] = c.delta.role
145 break
146 tool_calls: defaultdict[int, list[ChoiceDeltaToolCall]] = defaultdict(list)
147 for c in choices:
148 if hasattr(c, "delta"):
149 if getattr(c.delta, "content", None):
150 message["content"] += c.delta.content
151 if getattr(c.delta, "function_call", None):
152 if not message.get("function_call"):
153 message["function_call"] = {"name": "", "arguments": ""}
154 name_ = getattr(c.delta.function_call, "name", None)
155 if name_:
156 message["function_call"]["name"] += name_
157 arguments_ = getattr(c.delta.function_call, "arguments", None)
158 if arguments_:
159 message["function_call"]["arguments"] += arguments_
160 if getattr(c.delta, "tool_calls", None):
161 tool_calls_list = c.delta.tool_calls
162 if tool_calls_list is not None:
163 for tool_call in tool_calls_list:
164 tool_calls[tool_call.index].append(tool_call)
165 if tool_calls:
166 message["tool_calls"] = [None for _ in range(max(tool_calls.keys()) + 1)]
167 for index, tool_call_chunks in tool_calls.items():
168 message["tool_calls"][index] = {
169 "index": index,
170 "id": next((c.id for c in tool_call_chunks if c.id), None),
171 "type": next((c.type for c in tool_call_chunks if c.type), None),
172 "function": {"name": "", "arguments": ""},
173 }
174 for chunk in tool_call_chunks:
175 if getattr(chunk, "function", None):
176 name_ = getattr(chunk.function, "name", None)
177 if name_:
178 message["tool_calls"][index]["function"]["name"] += name_
179 arguments_ = getattr(chunk.function, "arguments", None)
180 if arguments_:
181 message["tool_calls"][index]["function"]["arguments"] += (
182 arguments_
183 )
184 return {
185 "index": getattr(choices[0], "index", 0) if choices else 0,
186 "finish_reason": next(
187 (
188 c.finish_reason
189 for c in reversed_choices
190 if getattr(c, "finish_reason", None)
191 ),
192 None,
193 ),
194 "message": message,
195 }
198def _reduce_chat(all_chunks: list[ChatCompletionChunk]) -> dict:
199 choices_by_index: defaultdict[int, list[Choice]] = defaultdict(list)
200 for chunk in all_chunks:
201 for choice in chunk.choices:
202 choices_by_index[choice.index].append(choice)
203 if all_chunks:
204 d = all_chunks[-1].model_dump()
205 d["choices"] = [
206 _reduce_choices(choices) for choices in choices_by_index.values()
207 ]
208 else:
209 d = {"choices": [{"message": {"role": "assistant", "content": ""}}]}
210 # streamed outputs don't go through `process_outputs`
211 # so we need to flatten metadata here
212 oai_token_usage = d.pop("usage", None)
213 d["usage_metadata"] = (
214 _create_usage_metadata(oai_token_usage) if oai_token_usage else None
215 )
216 return d
219def _reduce_completions(all_chunks: list[Completion]) -> dict:
220 all_content = []
221 for chunk in all_chunks:
222 content = chunk.choices[0].text
223 if content is not None:
224 all_content.append(content)
225 content = "".join(all_content)
226 if all_chunks:
227 d = all_chunks[-1].model_dump()
228 d["choices"] = [{"text": content}]
229 else:
230 d = {"choices": [{"text": content}]}
232 return d
235def _create_usage_metadata(
236 oai_token_usage: dict, service_tier: Optional[str] = None
237) -> UsageMetadata:
238 recognized_service_tier = (
239 service_tier if service_tier in ["priority", "flex"] else None
240 )
241 service_tier_prefix = (
242 f"{recognized_service_tier}_" if recognized_service_tier else ""
243 )
245 input_tokens = (
246 oai_token_usage.get("prompt_tokens") or oai_token_usage.get("input_tokens") or 0
247 )
248 output_tokens = (
249 oai_token_usage.get("completion_tokens")
250 or oai_token_usage.get("output_tokens")
251 or 0
252 )
253 total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens
254 input_token_details: dict = {
255 "audio": (
256 oai_token_usage.get("prompt_tokens_details")
257 or oai_token_usage.get("input_tokens_details")
258 or {}
259 ).get("audio_tokens"),
260 f"{service_tier_prefix}cache_read": (
261 oai_token_usage.get("prompt_tokens_details")
262 or oai_token_usage.get("input_tokens_details")
263 or {}
264 ).get("cached_tokens"),
265 }
266 output_token_details: dict = {
267 "audio": (
268 oai_token_usage.get("completion_tokens_details")
269 or oai_token_usage.get("output_tokens_details")
270 or {}
271 ).get("audio_tokens"),
272 f"{service_tier_prefix}reasoning": (
273 oai_token_usage.get("completion_tokens_details")
274 or oai_token_usage.get("output_tokens_details")
275 or {}
276 ).get("reasoning_tokens"),
277 }
279 if recognized_service_tier:
280 # Avoid counting cache read and reasoning tokens towards the
281 # service tier token count since service tier tokens are already
282 # priced differently
283 input_token_details[recognized_service_tier] = input_tokens - (
284 input_token_details.get(f"{service_tier_prefix}cache_read") or 0
285 )
286 output_token_details[recognized_service_tier] = output_tokens - (
287 output_token_details.get(f"{service_tier_prefix}reasoning") or 0
288 )
290 return UsageMetadata(
291 input_tokens=input_tokens,
292 output_tokens=output_tokens,
293 total_tokens=total_tokens,
294 input_token_details=InputTokenDetails(
295 **{k: v for k, v in input_token_details.items() if v is not None}
296 ),
297 output_token_details=OutputTokenDetails(
298 **{k: v for k, v in output_token_details.items() if v is not None}
299 ),
300 )
303def _process_chat_completion(outputs: Any):
304 try:
305 rdict = outputs.model_dump()
306 oai_token_usage = rdict.pop("usage", None)
307 rdict["usage_metadata"] = (
308 _create_usage_metadata(oai_token_usage, rdict.get("service_tier"))
309 if oai_token_usage
310 else None
311 )
312 return rdict
313 except BaseException as e:
314 logger.debug(f"Error processing chat completion: {e}")
315 return {"output": outputs}
318def _get_wrapper(
319 original_create: Callable,
320 name: str,
321 reduce_fn: Callable,
322 tracing_extra: Optional[TracingExtra] = None,
323 invocation_params_fn: Optional[Callable] = None,
324 process_outputs: Optional[Callable] = None,
325) -> Callable:
326 textra = tracing_extra or {}
328 @functools.wraps(original_create)
329 def create(*args, **kwargs):
330 decorator = run_helpers.traceable(
331 name=name,
332 run_type="llm",
333 reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
334 process_inputs=_process_inputs,
335 _invocation_params_fn=invocation_params_fn,
336 process_outputs=process_outputs,
337 **textra,
338 )
340 return decorator(original_create)(*args, **kwargs)
342 @functools.wraps(original_create)
343 async def acreate(*args, **kwargs):
344 decorator = run_helpers.traceable(
345 name=name,
346 run_type="llm",
347 reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
348 process_inputs=_process_inputs,
349 _invocation_params_fn=invocation_params_fn,
350 process_outputs=process_outputs,
351 **textra,
352 )
353 return await decorator(original_create)(*args, **kwargs)
355 return acreate if run_helpers.is_async(original_create) else create
358def _get_parse_wrapper(
359 original_parse: Callable,
360 name: str,
361 process_outputs: Callable,
362 tracing_extra: Optional[TracingExtra] = None,
363 invocation_params_fn: Optional[Callable] = None,
364) -> Callable:
365 textra = tracing_extra or {}
367 @functools.wraps(original_parse)
368 def parse(*args, **kwargs):
369 decorator = run_helpers.traceable(
370 name=name,
371 run_type="llm",
372 reduce_fn=None,
373 process_inputs=_process_inputs,
374 _invocation_params_fn=invocation_params_fn,
375 process_outputs=process_outputs,
376 **textra,
377 )
378 return decorator(original_parse)(*args, **kwargs)
380 @functools.wraps(original_parse)
381 async def aparse(*args, **kwargs):
382 decorator = run_helpers.traceable(
383 name=name,
384 run_type="llm",
385 reduce_fn=None,
386 process_inputs=_process_inputs,
387 _invocation_params_fn=invocation_params_fn,
388 process_outputs=process_outputs,
389 **textra,
390 )
391 return await decorator(original_parse)(*args, **kwargs)
393 return aparse if run_helpers.is_async(original_parse) else parse
396def _reduce_response_events(events: list[ResponseStreamEvent]) -> dict:
397 for event in events:
398 if event.type == "response.completed":
399 return _process_responses_api_output(event.response)
400 return {}
403class TracingExtra(TypedDict, total=False):
404 metadata: Optional[Mapping[str, Any]]
405 tags: Optional[list[str]]
406 client: Optional[ls_client.Client]
409def wrap_openai(
410 client: C,
411 *,
412 tracing_extra: Optional[TracingExtra] = None,
413 chat_name: str = "ChatOpenAI",
414 completions_name: str = "OpenAI",
415) -> C:
416 """Patch the OpenAI client to make it traceable.
418 Supports:
419 - Chat and Responses API's
420 - Sync and async OpenAI clients
421 - `create` and `parse` methods
422 - With and without streaming
424 Args:
425 client: The client to patch.
426 tracing_extra: Extra tracing information.
427 chat_name: The run name for the chat completions endpoint.
428 completions_name: The run name for the completions endpoint.
430 Returns:
431 The patched client.
433 Example:
434 ```python
435 import openai
436 from langsmith import wrappers
438 # Use OpenAI client same as you normally would.
439 client = wrappers.wrap_openai(openai.OpenAI())
441 # Chat API:
442 messages = [
443 {"role": "system", "content": "You are a helpful assistant."},
444 {
445 "role": "user",
446 "content": "What physics breakthroughs do you predict will happen by 2300?",
447 },
448 ]
449 completion = client.chat.completions.create(
450 model="gpt-4o-mini", messages=messages
451 )
452 print(completion.choices[0].message.content)
454 # Responses API:
455 response = client.responses.create(
456 model="gpt-4o-mini",
457 messages=messages,
458 )
459 print(response.output_text)
460 ```
462 !!! warning "Behavior changed in `langsmith` 0.3.16"
464 Support for Responses API added.
465 """ # noqa: E501
466 tracing_extra = tracing_extra or {}
468 ls_provider = "openai"
469 try:
470 from openai import AsyncAzureOpenAI, AzureOpenAI
472 if isinstance(client, AzureOpenAI) or isinstance(client, AsyncAzureOpenAI):
473 ls_provider = "azure"
474 chat_name = "AzureChatOpenAI"
475 completions_name = "AzureOpenAI"
476 except ImportError:
477 pass
479 # First wrap the create methods - these handle non-streaming cases
480 client.chat.completions.create = _get_wrapper( # type: ignore[method-assign]
481 client.chat.completions.create,
482 chat_name,
483 _reduce_chat,
484 tracing_extra=tracing_extra,
485 invocation_params_fn=functools.partial(
486 _infer_invocation_params, "chat", ls_provider
487 ),
488 process_outputs=_process_chat_completion,
489 )
491 client.completions.create = _get_wrapper( # type: ignore[method-assign]
492 client.completions.create,
493 completions_name,
494 _reduce_completions,
495 tracing_extra=tracing_extra,
496 invocation_params_fn=functools.partial(
497 _infer_invocation_params, "llm", ls_provider
498 ),
499 )
501 # Wrap beta.chat.completions.parse if it exists
502 if (
503 hasattr(client, "beta")
504 and hasattr(client.beta, "chat")
505 and hasattr(client.beta.chat, "completions")
506 and hasattr(client.beta.chat.completions, "parse")
507 ):
508 client.beta.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
509 client.beta.chat.completions.parse, # type: ignore
510 chat_name,
511 _process_chat_completion,
512 tracing_extra=tracing_extra,
513 invocation_params_fn=functools.partial(
514 _infer_invocation_params, "chat", ls_provider
515 ),
516 )
518 # Wrap chat.completions.parse if it exists
519 if (
520 hasattr(client, "chat")
521 and hasattr(client.chat, "completions")
522 and hasattr(client.chat.completions, "parse")
523 ):
524 client.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
525 client.chat.completions.parse, # type: ignore
526 chat_name,
527 _process_chat_completion,
528 tracing_extra=tracing_extra,
529 invocation_params_fn=functools.partial(
530 _infer_invocation_params, "chat", ls_provider
531 ),
532 )
534 # For the responses API: "client.responses.create(**kwargs)"
535 if hasattr(client, "responses"):
536 if hasattr(client.responses, "create"):
537 client.responses.create = _get_wrapper( # type: ignore[method-assign]
538 client.responses.create,
539 chat_name,
540 _reduce_response_events,
541 process_outputs=_process_responses_api_output,
542 tracing_extra=tracing_extra,
543 invocation_params_fn=functools.partial(
544 _infer_invocation_params, "chat", ls_provider
545 ),
546 )
547 if hasattr(client.responses, "parse"):
548 client.responses.parse = _get_parse_wrapper( # type: ignore[method-assign]
549 client.responses.parse,
550 chat_name,
551 _process_responses_api_output,
552 tracing_extra=tracing_extra,
553 invocation_params_fn=functools.partial(
554 _infer_invocation_params, "chat", ls_provider
555 ),
556 )
558 return client
561def _process_responses_api_output(response: Any) -> dict:
562 if response:
563 try:
564 output = response.model_dump(exclude_none=True, mode="json")
565 if usage := output.pop("usage", None):
566 output["usage_metadata"] = _create_usage_metadata(
567 usage, output.get("service_tier")
568 )
569 return output
570 except Exception:
571 return {"output": response}
572 return {}