Coverage for langsmith/wrappers/_anthropic.py: 12%
240 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-12-11 16:15 -0800
« prev ^ index » next coverage.py v7.10.1, created at 2025-12-11 16:15 -0800
1from __future__ import annotations
3import functools
4import logging
5from collections.abc import AsyncIterator, Mapping, Sequence
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 Callable,
10 Optional,
11 TypeVar,
12 Union,
13)
15from pydantic import TypeAdapter
16from typing_extensions import Self, TypedDict
18from langsmith import client as ls_client
19from langsmith import run_helpers
20from langsmith.schemas import InputTokenDetails, UsageMetadata
22if TYPE_CHECKING:
23 import httpx
24 from anthropic import Anthropic, AsyncAnthropic
25 from anthropic.types import Completion, Message, MessageStreamEvent
27C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any])
28logger = logging.getLogger(__name__)
31@functools.lru_cache
32def _get_not_given() -> Optional[tuple[type, ...]]:
33 try:
34 from anthropic._types import NotGiven, Omit
36 return (NotGiven, Omit)
37 except ImportError:
38 return None
41def _strip_not_given(d: dict) -> dict:
42 try:
43 if not_given := _get_not_given():
44 d = {
45 k: v
46 for k, v in d.items()
47 if not any(isinstance(v, t) for t in not_given)
48 }
49 except Exception as e:
50 logger.error(f"Error stripping NotGiven: {e}")
52 if "system" in d:
53 d["messages"] = [{"role": "system", "content": d["system"]}] + d.get(
54 "messages", []
55 )
56 d.pop("system")
57 return {k: v for k, v in d.items() if v is not None}
60def _infer_ls_params(kwargs: dict):
61 stripped = _strip_not_given(kwargs)
63 stop = stripped.get("stop")
64 if stop and isinstance(stop, str):
65 stop = [stop]
67 # Allowlist of safe invocation parameters to include
68 # Only include known, non-sensitive parameters
69 allowed_invocation_keys = {
70 "mcp_servers",
71 "service_tier",
72 "top_k",
73 "top_p",
74 "stream",
75 "thinking",
76 }
78 # Only include allowlisted parameters
79 invocation_params = {
80 k: v for k, v in stripped.items() if k in allowed_invocation_keys
81 }
83 return {
84 "ls_provider": "anthropic",
85 "ls_model_type": "chat",
86 "ls_model_name": stripped.get("model", None),
87 "ls_temperature": stripped.get("temperature", None),
88 "ls_max_tokens": stripped.get("max_tokens", None),
89 "ls_stop": stop,
90 "ls_invocation_params": invocation_params,
91 }
94def _accumulate_event(
95 *, event: MessageStreamEvent, current_snapshot: Message | None
96) -> Message | None:
97 try:
98 from anthropic.types import ContentBlock
99 except ImportError:
100 logger.debug("Error importing ContentBlock")
101 return current_snapshot
103 if current_snapshot is None:
104 if event.type == "message_start":
105 return event.message
107 raise RuntimeError(
108 f'Unexpected event order, got {event.type} before "message_start"'
109 )
111 if event.type == "content_block_start":
112 # TODO: check index <-- from anthropic SDK :)
113 adapter: TypeAdapter = TypeAdapter(ContentBlock)
114 content_block_instance = adapter.validate_python(
115 event.content_block.model_dump()
116 )
117 current_snapshot.content.append(
118 content_block_instance, # type: ignore[attr-defined]
119 )
120 elif event.type == "content_block_delta":
121 content = current_snapshot.content[event.index]
122 if content.type == "text" and event.delta.type == "text_delta":
123 content.text += event.delta.text
124 elif event.type == "message_delta":
125 current_snapshot.stop_reason = event.delta.stop_reason
126 current_snapshot.stop_sequence = event.delta.stop_sequence
127 current_snapshot.usage.output_tokens = event.usage.output_tokens
129 return current_snapshot
132def _reduce_chat_chunks(all_chunks: Sequence) -> dict:
133 full_message = None
134 for chunk in all_chunks:
135 try:
136 full_message = _accumulate_event(event=chunk, current_snapshot=full_message)
137 except RuntimeError as e:
138 logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}")
139 return {"output": all_chunks}
140 if full_message is None:
141 return {"output": all_chunks}
142 d = full_message.model_dump()
143 d["usage_metadata"] = _create_usage_metadata(d.pop("usage", {}))
144 d.pop("type", None)
145 return {"message": d}
148def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata:
149 input_tokens = anthropic_token_usage.get("input_tokens") or 0
150 output_tokens = anthropic_token_usage.get("output_tokens") or 0
151 total_tokens = input_tokens + output_tokens
152 input_token_details: dict = {
153 "cache_read": anthropic_token_usage.get("cache_creation_input_tokens", 0)
154 + anthropic_token_usage.get("cache_read_input_tokens", 0)
155 }
156 return UsageMetadata(
157 input_tokens=input_tokens,
158 output_tokens=output_tokens,
159 total_tokens=total_tokens,
160 input_token_details=InputTokenDetails(
161 **{k: v for k, v in input_token_details.items() if v is not None}
162 ),
163 )
166def _reduce_completions(all_chunks: list[Completion]) -> dict:
167 all_content = []
168 for chunk in all_chunks:
169 content = chunk.completion
170 if content is not None:
171 all_content.append(content)
172 content = "".join(all_content)
173 if all_chunks:
174 d = all_chunks[-1].model_dump()
175 d["choices"] = [{"text": content}]
176 else:
177 d = {"choices": [{"text": content}]}
179 return d
182def _process_chat_completion(outputs: Any):
183 try:
184 rdict = outputs.model_dump()
185 anthropic_token_usage = rdict.pop("usage", None)
186 rdict["usage_metadata"] = (
187 _create_usage_metadata(anthropic_token_usage)
188 if anthropic_token_usage
189 else None
190 )
191 rdict.pop("type", None)
192 return {"message": rdict}
193 except BaseException as e:
194 logger.debug(f"Error processing chat completion: {e}")
195 return {"output": outputs}
198def _get_wrapper(
199 original_create: Callable,
200 name: str,
201 reduce_fn: Callable,
202 tracing_extra: TracingExtra,
203) -> Callable:
204 @functools.wraps(original_create)
205 def create(*args, **kwargs):
206 stream = kwargs.get("stream")
207 decorator = run_helpers.traceable(
208 name=name,
209 run_type="llm",
210 reduce_fn=reduce_fn if stream else None,
211 process_inputs=_strip_not_given,
212 process_outputs=_process_chat_completion,
213 _invocation_params_fn=_infer_ls_params,
214 **tracing_extra,
215 )
217 result = decorator(original_create)(*args, **kwargs)
218 return result
220 @functools.wraps(original_create)
221 async def acreate(*args, **kwargs):
222 stream = kwargs.get("stream")
223 decorator = run_helpers.traceable(
224 name=name,
225 run_type="llm",
226 reduce_fn=reduce_fn if stream else None,
227 process_inputs=_strip_not_given,
228 process_outputs=_process_chat_completion,
229 _invocation_params_fn=_infer_ls_params,
230 **tracing_extra,
231 )
232 result = await decorator(original_create)(*args, **kwargs)
233 return result
235 return acreate if run_helpers.is_async(original_create) else create
238def _get_stream_wrapper(
239 original_stream: Callable,
240 name: str,
241 tracing_extra: TracingExtra,
242) -> Callable:
243 """Create a wrapper for Anthropic's streaming context manager."""
244 import anthropic
246 is_async = "async" in str(original_stream).lower()
247 configured_traceable = run_helpers.traceable(
248 name=name,
249 reduce_fn=_reduce_chat_chunks,
250 run_type="llm",
251 process_inputs=_strip_not_given,
252 _invocation_params_fn=_infer_ls_params,
253 **tracing_extra,
254 )
255 configured_traceable_text = run_helpers.traceable(
256 name=name,
257 run_type="llm",
258 process_inputs=_strip_not_given,
259 process_outputs=_process_chat_completion,
260 _invocation_params_fn=_infer_ls_params,
261 **tracing_extra,
262 )
264 if is_async:
266 class AsyncMessageStreamWrapper:
267 def __init__(
268 self,
269 wrapped: anthropic.lib.streaming._messages.AsyncMessageStream,
270 **kwargs,
271 ) -> None:
272 self._wrapped = wrapped
273 self._kwargs = kwargs
275 @property
276 def text_stream(self):
277 @configured_traceable_text
278 async def _text_stream(**_):
279 async for chunk in self._wrapped.text_stream:
280 yield chunk
281 run_tree = run_helpers.get_current_run_tree()
282 final_message = await self._wrapped.get_final_message()
283 run_tree.outputs = _process_chat_completion(final_message)
285 return _text_stream(**self._kwargs)
287 @property
288 def response(self) -> httpx.Response:
289 return self._wrapped.response
291 @property
292 def request_id(self) -> str | None:
293 return self._wrapped.request_id
295 async def __anext__(self) -> MessageStreamEvent:
296 aiter = self.__aiter__()
297 return await aiter.__anext__()
299 async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
300 @configured_traceable
301 def traced_iter(**_):
302 return self._wrapped.__aiter__()
304 async for chunk in traced_iter(**self._kwargs):
305 yield chunk
307 async def __aenter__(self) -> Self:
308 await self._wrapped.__aenter__()
309 return self
311 async def __aexit__(self, *exc) -> None:
312 await self._wrapped.__aexit__(*exc)
314 async def close(self) -> None:
315 await self._wrapped.close()
317 async def get_final_message(self) -> Message:
318 return await self._wrapped.get_final_message()
320 async def get_final_text(self) -> str:
321 return await self._wrapped.get_final_text()
323 async def until_done(self) -> None:
324 await self._wrapped.until_done()
326 @property
327 def current_message_snapshot(self) -> Message:
328 return self._wrapped.current_message_snapshot
330 class AsyncMessagesStreamManagerWrapper:
331 def __init__(self, **kwargs):
332 self._kwargs = kwargs
334 async def __aenter__(self):
335 self._manager = original_stream(**self._kwargs)
336 stream = await self._manager.__aenter__()
337 return AsyncMessageStreamWrapper(stream, **self._kwargs)
339 async def __aexit__(self, *exc):
340 await self._manager.__aexit__(*exc)
342 return AsyncMessagesStreamManagerWrapper
343 else:
345 class MessageStreamWrapper:
346 def __init__(
347 self,
348 wrapped: anthropic.lib.streaming._messages.MessageStream,
349 **kwargs,
350 ) -> None:
351 self._wrapped = wrapped
352 self._kwargs = kwargs
354 @property
355 def response(self) -> Any:
356 return self._wrapped.response
358 @property
359 def request_id(self) -> str | None:
360 return self._wrapped.request_id # type: ignore[no-any-return]
362 @property
363 def text_stream(self):
364 @configured_traceable_text
365 def _text_stream(**_):
366 yield from self._wrapped.text_stream
367 run_tree = run_helpers.get_current_run_tree()
368 final_message = self._wrapped.get_final_message()
369 run_tree.outputs = _process_chat_completion(final_message)
371 return _text_stream(**self._kwargs)
373 def __next__(self) -> MessageStreamEvent:
374 return self.__iter__().__next__()
376 def __iter__(self):
377 @configured_traceable
378 def traced_iter(**_):
379 return self._wrapped.__iter__()
381 return traced_iter(**self._kwargs)
383 def __enter__(self) -> Self:
384 self._wrapped.__enter__()
385 return self
387 def __exit__(self, *exc) -> None:
388 self._wrapped.__exit__(*exc)
390 def close(self) -> None:
391 self._wrapped.close()
393 def get_final_message(self) -> Message:
394 return self._wrapped.get_final_message()
396 def get_final_text(self) -> str:
397 return self._wrapped.get_final_text()
399 def until_done(self) -> None:
400 return self._wrapped.until_done()
402 @property
403 def current_message_snapshot(self) -> Message:
404 return self._wrapped.current_message_snapshot
406 class MessagesStreamManagerWrapper:
407 def __init__(self, **kwargs):
408 self._kwargs = kwargs
410 def __enter__(self):
411 self._manager = original_stream(**self._kwargs)
412 return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs)
414 def __exit__(self, *exc):
415 self._manager.__exit__(*exc)
417 return MessagesStreamManagerWrapper
420class TracingExtra(TypedDict, total=False):
421 metadata: Optional[Mapping[str, Any]]
422 tags: Optional[list[str]]
423 client: Optional[ls_client.Client]
426def wrap_anthropic(client: C, *, tracing_extra: Optional[TracingExtra] = None) -> C:
427 """Patch the Anthropic client to make it traceable.
429 Args:
430 client: The client to patch.
431 tracing_extra: Extra tracing information.
433 Returns:
434 The patched client.
436 Example:
437 ```python
438 import anthropic
439 from langsmith import wrappers
441 client = wrappers.wrap_anthropic(anthropic.Anthropic())
443 # Use Anthropic client same as you normally would:
444 system = "You are a helpful assistant."
445 messages = [
446 {
447 "role": "user",
448 "content": "What physics breakthroughs do you predict will happen by 2300?",
449 }
450 ]
451 completion = client.messages.create(
452 model="claude-3-5-sonnet-latest",
453 messages=messages,
454 max_tokens=1000,
455 system=system,
456 )
457 print(completion.content)
459 # You can also use the streaming context manager:
460 with client.messages.stream(
461 model="claude-3-5-sonnet-latest",
462 messages=messages,
463 max_tokens=1000,
464 system=system,
465 ) as stream:
466 for text in stream.text_stream:
467 print(text, end="", flush=True)
468 message = stream.get_final_message()
469 ```
470 """ # noqa: E501
471 tracing_extra = tracing_extra or {}
472 client.messages.create = _get_wrapper( # type: ignore[method-assign]
473 client.messages.create,
474 "ChatAnthropic",
475 _reduce_chat_chunks,
476 tracing_extra,
477 )
478 client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign]
479 client.messages.stream,
480 "ChatAnthropic",
481 tracing_extra,
482 )
483 client.completions.create = _get_wrapper( # type: ignore[method-assign]
484 client.completions.create,
485 "Anthropic",
486 _reduce_completions,
487 tracing_extra,
488 )
490 if (
491 hasattr(client, "beta")
492 and hasattr(client.beta, "messages")
493 and hasattr(client.beta.messages, "create")
494 ):
495 client.beta.messages.create = _get_wrapper( # type: ignore[method-assign]
496 client.beta.messages.create, # type: ignore
497 "ChatAnthropic",
498 _reduce_chat_chunks,
499 tracing_extra,
500 )
501 return client