Coverage for langsmith/wrappers/_anthropic.py: 12%

240 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-12-11 16:15 -0800

1from __future__ import annotations 

2 

3import functools 

4import logging 

5from collections.abc import AsyncIterator, Mapping, Sequence 

6from typing import ( 

7 TYPE_CHECKING, 

8 Any, 

9 Callable, 

10 Optional, 

11 TypeVar, 

12 Union, 

13) 

14 

15from pydantic import TypeAdapter 

16from typing_extensions import Self, TypedDict 

17 

18from langsmith import client as ls_client 

19from langsmith import run_helpers 

20from langsmith.schemas import InputTokenDetails, UsageMetadata 

21 

22if TYPE_CHECKING: 

23 import httpx 

24 from anthropic import Anthropic, AsyncAnthropic 

25 from anthropic.types import Completion, Message, MessageStreamEvent 

26 

27C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any]) 

28logger = logging.getLogger(__name__) 

29 

30 

31@functools.lru_cache 

32def _get_not_given() -> Optional[tuple[type, ...]]: 

33 try: 

34 from anthropic._types import NotGiven, Omit 

35 

36 return (NotGiven, Omit) 

37 except ImportError: 

38 return None 

39 

40 

41def _strip_not_given(d: dict) -> dict: 

42 try: 

43 if not_given := _get_not_given(): 

44 d = { 

45 k: v 

46 for k, v in d.items() 

47 if not any(isinstance(v, t) for t in not_given) 

48 } 

49 except Exception as e: 

50 logger.error(f"Error stripping NotGiven: {e}") 

51 

52 if "system" in d: 

53 d["messages"] = [{"role": "system", "content": d["system"]}] + d.get( 

54 "messages", [] 

55 ) 

56 d.pop("system") 

57 return {k: v for k, v in d.items() if v is not None} 

58 

59 

60def _infer_ls_params(kwargs: dict): 

61 stripped = _strip_not_given(kwargs) 

62 

63 stop = stripped.get("stop") 

64 if stop and isinstance(stop, str): 

65 stop = [stop] 

66 

67 # Allowlist of safe invocation parameters to include 

68 # Only include known, non-sensitive parameters 

69 allowed_invocation_keys = { 

70 "mcp_servers", 

71 "service_tier", 

72 "top_k", 

73 "top_p", 

74 "stream", 

75 "thinking", 

76 } 

77 

78 # Only include allowlisted parameters 

79 invocation_params = { 

80 k: v for k, v in stripped.items() if k in allowed_invocation_keys 

81 } 

82 

83 return { 

84 "ls_provider": "anthropic", 

85 "ls_model_type": "chat", 

86 "ls_model_name": stripped.get("model", None), 

87 "ls_temperature": stripped.get("temperature", None), 

88 "ls_max_tokens": stripped.get("max_tokens", None), 

89 "ls_stop": stop, 

90 "ls_invocation_params": invocation_params, 

91 } 

92 

93 

94def _accumulate_event( 

95 *, event: MessageStreamEvent, current_snapshot: Message | None 

96) -> Message | None: 

97 try: 

98 from anthropic.types import ContentBlock 

99 except ImportError: 

100 logger.debug("Error importing ContentBlock") 

101 return current_snapshot 

102 

103 if current_snapshot is None: 

104 if event.type == "message_start": 

105 return event.message 

106 

107 raise RuntimeError( 

108 f'Unexpected event order, got {event.type} before "message_start"' 

109 ) 

110 

111 if event.type == "content_block_start": 

112 # TODO: check index <-- from anthropic SDK :) 

113 adapter: TypeAdapter = TypeAdapter(ContentBlock) 

114 content_block_instance = adapter.validate_python( 

115 event.content_block.model_dump() 

116 ) 

117 current_snapshot.content.append( 

118 content_block_instance, # type: ignore[attr-defined] 

119 ) 

120 elif event.type == "content_block_delta": 

121 content = current_snapshot.content[event.index] 

122 if content.type == "text" and event.delta.type == "text_delta": 

123 content.text += event.delta.text 

124 elif event.type == "message_delta": 

125 current_snapshot.stop_reason = event.delta.stop_reason 

126 current_snapshot.stop_sequence = event.delta.stop_sequence 

127 current_snapshot.usage.output_tokens = event.usage.output_tokens 

128 

129 return current_snapshot 

130 

131 

132def _reduce_chat_chunks(all_chunks: Sequence) -> dict: 

133 full_message = None 

134 for chunk in all_chunks: 

135 try: 

136 full_message = _accumulate_event(event=chunk, current_snapshot=full_message) 

137 except RuntimeError as e: 

138 logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}") 

139 return {"output": all_chunks} 

140 if full_message is None: 

141 return {"output": all_chunks} 

142 d = full_message.model_dump() 

143 d["usage_metadata"] = _create_usage_metadata(d.pop("usage", {})) 

144 d.pop("type", None) 

145 return {"message": d} 

146 

147 

148def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata: 

149 input_tokens = anthropic_token_usage.get("input_tokens") or 0 

150 output_tokens = anthropic_token_usage.get("output_tokens") or 0 

151 total_tokens = input_tokens + output_tokens 

152 input_token_details: dict = { 

153 "cache_read": anthropic_token_usage.get("cache_creation_input_tokens", 0) 

154 + anthropic_token_usage.get("cache_read_input_tokens", 0) 

155 } 

156 return UsageMetadata( 

157 input_tokens=input_tokens, 

158 output_tokens=output_tokens, 

159 total_tokens=total_tokens, 

160 input_token_details=InputTokenDetails( 

161 **{k: v for k, v in input_token_details.items() if v is not None} 

162 ), 

163 ) 

164 

165 

166def _reduce_completions(all_chunks: list[Completion]) -> dict: 

167 all_content = [] 

168 for chunk in all_chunks: 

169 content = chunk.completion 

170 if content is not None: 

171 all_content.append(content) 

172 content = "".join(all_content) 

173 if all_chunks: 

174 d = all_chunks[-1].model_dump() 

175 d["choices"] = [{"text": content}] 

176 else: 

177 d = {"choices": [{"text": content}]} 

178 

179 return d 

180 

181 

182def _process_chat_completion(outputs: Any): 

183 try: 

184 rdict = outputs.model_dump() 

185 anthropic_token_usage = rdict.pop("usage", None) 

186 rdict["usage_metadata"] = ( 

187 _create_usage_metadata(anthropic_token_usage) 

188 if anthropic_token_usage 

189 else None 

190 ) 

191 rdict.pop("type", None) 

192 return {"message": rdict} 

193 except BaseException as e: 

194 logger.debug(f"Error processing chat completion: {e}") 

195 return {"output": outputs} 

196 

197 

198def _get_wrapper( 

199 original_create: Callable, 

200 name: str, 

201 reduce_fn: Callable, 

202 tracing_extra: TracingExtra, 

203) -> Callable: 

204 @functools.wraps(original_create) 

205 def create(*args, **kwargs): 

206 stream = kwargs.get("stream") 

207 decorator = run_helpers.traceable( 

208 name=name, 

209 run_type="llm", 

210 reduce_fn=reduce_fn if stream else None, 

211 process_inputs=_strip_not_given, 

212 process_outputs=_process_chat_completion, 

213 _invocation_params_fn=_infer_ls_params, 

214 **tracing_extra, 

215 ) 

216 

217 result = decorator(original_create)(*args, **kwargs) 

218 return result 

219 

220 @functools.wraps(original_create) 

221 async def acreate(*args, **kwargs): 

222 stream = kwargs.get("stream") 

223 decorator = run_helpers.traceable( 

224 name=name, 

225 run_type="llm", 

226 reduce_fn=reduce_fn if stream else None, 

227 process_inputs=_strip_not_given, 

228 process_outputs=_process_chat_completion, 

229 _invocation_params_fn=_infer_ls_params, 

230 **tracing_extra, 

231 ) 

232 result = await decorator(original_create)(*args, **kwargs) 

233 return result 

234 

235 return acreate if run_helpers.is_async(original_create) else create 

236 

237 

238def _get_stream_wrapper( 

239 original_stream: Callable, 

240 name: str, 

241 tracing_extra: TracingExtra, 

242) -> Callable: 

243 """Create a wrapper for Anthropic's streaming context manager.""" 

244 import anthropic 

245 

246 is_async = "async" in str(original_stream).lower() 

247 configured_traceable = run_helpers.traceable( 

248 name=name, 

249 reduce_fn=_reduce_chat_chunks, 

250 run_type="llm", 

251 process_inputs=_strip_not_given, 

252 _invocation_params_fn=_infer_ls_params, 

253 **tracing_extra, 

254 ) 

255 configured_traceable_text = run_helpers.traceable( 

256 name=name, 

257 run_type="llm", 

258 process_inputs=_strip_not_given, 

259 process_outputs=_process_chat_completion, 

260 _invocation_params_fn=_infer_ls_params, 

261 **tracing_extra, 

262 ) 

263 

264 if is_async: 

265 

266 class AsyncMessageStreamWrapper: 

267 def __init__( 

268 self, 

269 wrapped: anthropic.lib.streaming._messages.AsyncMessageStream, 

270 **kwargs, 

271 ) -> None: 

272 self._wrapped = wrapped 

273 self._kwargs = kwargs 

274 

275 @property 

276 def text_stream(self): 

277 @configured_traceable_text 

278 async def _text_stream(**_): 

279 async for chunk in self._wrapped.text_stream: 

280 yield chunk 

281 run_tree = run_helpers.get_current_run_tree() 

282 final_message = await self._wrapped.get_final_message() 

283 run_tree.outputs = _process_chat_completion(final_message) 

284 

285 return _text_stream(**self._kwargs) 

286 

287 @property 

288 def response(self) -> httpx.Response: 

289 return self._wrapped.response 

290 

291 @property 

292 def request_id(self) -> str | None: 

293 return self._wrapped.request_id 

294 

295 async def __anext__(self) -> MessageStreamEvent: 

296 aiter = self.__aiter__() 

297 return await aiter.__anext__() 

298 

299 async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]: 

300 @configured_traceable 

301 def traced_iter(**_): 

302 return self._wrapped.__aiter__() 

303 

304 async for chunk in traced_iter(**self._kwargs): 

305 yield chunk 

306 

307 async def __aenter__(self) -> Self: 

308 await self._wrapped.__aenter__() 

309 return self 

310 

311 async def __aexit__(self, *exc) -> None: 

312 await self._wrapped.__aexit__(*exc) 

313 

314 async def close(self) -> None: 

315 await self._wrapped.close() 

316 

317 async def get_final_message(self) -> Message: 

318 return await self._wrapped.get_final_message() 

319 

320 async def get_final_text(self) -> str: 

321 return await self._wrapped.get_final_text() 

322 

323 async def until_done(self) -> None: 

324 await self._wrapped.until_done() 

325 

326 @property 

327 def current_message_snapshot(self) -> Message: 

328 return self._wrapped.current_message_snapshot 

329 

330 class AsyncMessagesStreamManagerWrapper: 

331 def __init__(self, **kwargs): 

332 self._kwargs = kwargs 

333 

334 async def __aenter__(self): 

335 self._manager = original_stream(**self._kwargs) 

336 stream = await self._manager.__aenter__() 

337 return AsyncMessageStreamWrapper(stream, **self._kwargs) 

338 

339 async def __aexit__(self, *exc): 

340 await self._manager.__aexit__(*exc) 

341 

342 return AsyncMessagesStreamManagerWrapper 

343 else: 

344 

345 class MessageStreamWrapper: 

346 def __init__( 

347 self, 

348 wrapped: anthropic.lib.streaming._messages.MessageStream, 

349 **kwargs, 

350 ) -> None: 

351 self._wrapped = wrapped 

352 self._kwargs = kwargs 

353 

354 @property 

355 def response(self) -> Any: 

356 return self._wrapped.response 

357 

358 @property 

359 def request_id(self) -> str | None: 

360 return self._wrapped.request_id # type: ignore[no-any-return] 

361 

362 @property 

363 def text_stream(self): 

364 @configured_traceable_text 

365 def _text_stream(**_): 

366 yield from self._wrapped.text_stream 

367 run_tree = run_helpers.get_current_run_tree() 

368 final_message = self._wrapped.get_final_message() 

369 run_tree.outputs = _process_chat_completion(final_message) 

370 

371 return _text_stream(**self._kwargs) 

372 

373 def __next__(self) -> MessageStreamEvent: 

374 return self.__iter__().__next__() 

375 

376 def __iter__(self): 

377 @configured_traceable 

378 def traced_iter(**_): 

379 return self._wrapped.__iter__() 

380 

381 return traced_iter(**self._kwargs) 

382 

383 def __enter__(self) -> Self: 

384 self._wrapped.__enter__() 

385 return self 

386 

387 def __exit__(self, *exc) -> None: 

388 self._wrapped.__exit__(*exc) 

389 

390 def close(self) -> None: 

391 self._wrapped.close() 

392 

393 def get_final_message(self) -> Message: 

394 return self._wrapped.get_final_message() 

395 

396 def get_final_text(self) -> str: 

397 return self._wrapped.get_final_text() 

398 

399 def until_done(self) -> None: 

400 return self._wrapped.until_done() 

401 

402 @property 

403 def current_message_snapshot(self) -> Message: 

404 return self._wrapped.current_message_snapshot 

405 

406 class MessagesStreamManagerWrapper: 

407 def __init__(self, **kwargs): 

408 self._kwargs = kwargs 

409 

410 def __enter__(self): 

411 self._manager = original_stream(**self._kwargs) 

412 return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs) 

413 

414 def __exit__(self, *exc): 

415 self._manager.__exit__(*exc) 

416 

417 return MessagesStreamManagerWrapper 

418 

419 

420class TracingExtra(TypedDict, total=False): 

421 metadata: Optional[Mapping[str, Any]] 

422 tags: Optional[list[str]] 

423 client: Optional[ls_client.Client] 

424 

425 

426def wrap_anthropic(client: C, *, tracing_extra: Optional[TracingExtra] = None) -> C: 

427 """Patch the Anthropic client to make it traceable. 

428 

429 Args: 

430 client: The client to patch. 

431 tracing_extra: Extra tracing information. 

432 

433 Returns: 

434 The patched client. 

435 

436 Example: 

437 ```python 

438 import anthropic 

439 from langsmith import wrappers 

440 

441 client = wrappers.wrap_anthropic(anthropic.Anthropic()) 

442 

443 # Use Anthropic client same as you normally would: 

444 system = "You are a helpful assistant." 

445 messages = [ 

446 { 

447 "role": "user", 

448 "content": "What physics breakthroughs do you predict will happen by 2300?", 

449 } 

450 ] 

451 completion = client.messages.create( 

452 model="claude-3-5-sonnet-latest", 

453 messages=messages, 

454 max_tokens=1000, 

455 system=system, 

456 ) 

457 print(completion.content) 

458 

459 # You can also use the streaming context manager: 

460 with client.messages.stream( 

461 model="claude-3-5-sonnet-latest", 

462 messages=messages, 

463 max_tokens=1000, 

464 system=system, 

465 ) as stream: 

466 for text in stream.text_stream: 

467 print(text, end="", flush=True) 

468 message = stream.get_final_message() 

469 ``` 

470 """ # noqa: E501 

471 tracing_extra = tracing_extra or {} 

472 client.messages.create = _get_wrapper( # type: ignore[method-assign] 

473 client.messages.create, 

474 "ChatAnthropic", 

475 _reduce_chat_chunks, 

476 tracing_extra, 

477 ) 

478 client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign] 

479 client.messages.stream, 

480 "ChatAnthropic", 

481 tracing_extra, 

482 ) 

483 client.completions.create = _get_wrapper( # type: ignore[method-assign] 

484 client.completions.create, 

485 "Anthropic", 

486 _reduce_completions, 

487 tracing_extra, 

488 ) 

489 

490 if ( 

491 hasattr(client, "beta") 

492 and hasattr(client.beta, "messages") 

493 and hasattr(client.beta.messages, "create") 

494 ): 

495 client.beta.messages.create = _get_wrapper( # type: ignore[method-assign] 

496 client.beta.messages.create, # type: ignore 

497 "ChatAnthropic", 

498 _reduce_chat_chunks, 

499 tracing_extra, 

500 ) 

501 return client