Coverage for langsmith/utils.py: 18%

351 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-12-11 16:15 -0800

1"""Generic utility functions.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import contextvars 

7import copy 

8import enum 

9import functools 

10import logging 

11import os 

12import pathlib 

13import socket 

14import subprocess 

15import sys 

16import threading 

17import traceback 

18from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence 

19from concurrent.futures import Future, ThreadPoolExecutor 

20from typing import ( 

21 Any, 

22 Callable, 

23 Literal, 

24 Optional, 

25 TypeVar, 

26 Union, 

27 cast, 

28) 

29from urllib import parse as urllib_parse 

30 

31import httpx 

32import requests 

33from typing_extensions import ParamSpec 

34from urllib3.util import Retry # type: ignore[import-untyped] 

35 

36from langsmith import schemas as ls_schemas 

37 

38_LOGGER = logging.getLogger(__name__) 

39 

40 

41class LangSmithError(Exception): 

42 """An error occurred while communicating with the LangSmith API.""" 

43 

44 

45class LangSmithAPIError(LangSmithError): 

46 """Internal server error while communicating with LangSmith.""" 

47 

48 

49class LangSmithRequestTimeout(LangSmithError): 

50 """Client took too long to send request body.""" 

51 

52 

53class LangSmithUserError(LangSmithError): 

54 """User error caused an exception when communicating with LangSmith.""" 

55 

56 

57class LangSmithRateLimitError(LangSmithError): 

58 """You have exceeded the rate limit for the LangSmith API.""" 

59 

60 

61class LangSmithAuthError(LangSmithError): 

62 """Couldn't authenticate with the LangSmith API.""" 

63 

64 

65class LangSmithNotFoundError(LangSmithError): 

66 """Couldn't find the requested resource.""" 

67 

68 

69class LangSmithConflictError(LangSmithError): 

70 """The resource already exists.""" 

71 

72 

73class LangSmithConnectionError(LangSmithError): 

74 """Couldn't connect to the LangSmith API.""" 

75 

76 

77class LangSmithExceptionGroup(LangSmithError): 

78 """Port of ExceptionGroup for Py < 3.11.""" 

79 

80 def __init__( 

81 self, *args: Any, exceptions: Sequence[Exception], **kwargs: Any 

82 ) -> None: 

83 """Initialize.""" 

84 super().__init__(*args, **kwargs) 

85 self.exceptions = exceptions 

86 

87 

88## Warning classes 

89 

90 

91class LangSmithWarning(UserWarning): 

92 """Base class for warnings.""" 

93 

94 

95class LangSmithMissingAPIKeyWarning(LangSmithWarning): 

96 """Warning for missing API key.""" 

97 

98 

99def tracing_is_enabled(ctx: Optional[dict] = None) -> Union[bool, Literal["local"]]: 

100 """Return True if tracing is enabled.""" 

101 # Access global fallbacks via context module to avoid stale references. 

102 import langsmith._internal._context as _context 

103 from langsmith.run_helpers import get_current_run_tree, get_tracing_context 

104 

105 tc = ctx or get_tracing_context() 

106 # You can manually override the environment using context vars. 

107 # Check that first. 

108 # Doing this before checking the run tree lets us 

109 # disable a branch within a trace. 

110 if tc["enabled"] is not None: 

111 return tc["enabled"] 

112 # Next check if we're mid-trace 

113 if get_current_run_tree(): 

114 return True 

115 # If a global fallback was configured, use it next. 

116 if _context._GLOBAL_TRACING_ENABLED is not None: 

117 return _context._GLOBAL_TRACING_ENABLED 

118 # Finally, check the global environment 

119 var_result = get_env_var("TRACING_V2", default=get_env_var("TRACING", default="")) 

120 return var_result == "true" 

121 

122 

123def test_tracking_is_disabled() -> bool: 

124 """Return True if testing is enabled.""" 

125 return get_env_var("TEST_TRACKING", default="") == "false" 

126 

127 

128def xor_args(*arg_groups: tuple[str, ...]) -> Callable: 

129 """Validate specified keyword args are mutually exclusive.""" 

130 

131 def decorator(func: Callable) -> Callable: 

132 @functools.wraps(func) 

133 def wrapper(*args: Any, **kwargs: Any) -> Any: 

134 """Validate exactly one arg in each group is not None.""" 

135 counts = [ 

136 sum(1 for arg in arg_group if kwargs.get(arg) is not None) 

137 for arg_group in arg_groups 

138 ] 

139 invalid_groups = [i for i, count in enumerate(counts) if count != 1] 

140 if invalid_groups: 

141 invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] 

142 raise ValueError( 

143 "Exactly one argument in each of the following" 

144 " groups must be defined:" 

145 f" {', '.join(invalid_group_names)}" 

146 ) 

147 return func(*args, **kwargs) 

148 

149 return wrapper 

150 

151 return decorator 

152 

153 

154def raise_for_status_with_text( 

155 response: Union[requests.Response, httpx.Response], 

156) -> None: 

157 """Raise an error with the response text.""" 

158 try: 

159 response.raise_for_status() 

160 except requests.HTTPError as e: 

161 raise requests.HTTPError(str(e), response.text) from e # type: ignore[call-arg] 

162 except httpx.HTTPStatusError as e: 

163 raise httpx.HTTPStatusError( 

164 f"{str(e)}: {response.text}", 

165 request=response.request, # type: ignore[arg-type] 

166 response=response, # type: ignore[arg-type] 

167 ) from e 

168 

169 

170def get_enum_value(enu: Union[enum.Enum, str]) -> str: 

171 """Get the value of a string enum.""" 

172 if isinstance(enu, enum.Enum): 

173 return enu.value 

174 return enu 

175 

176 

177@functools.lru_cache(maxsize=1) 

178def log_once(level: int, message: str) -> None: 

179 """Log a message at the specified level, but only once.""" 

180 _LOGGER.log(level, message) 

181 

182 

183def _get_message_type(message: Mapping[str, Any]) -> str: 

184 if not message: 

185 raise ValueError("Message is empty.") 

186 if "lc" in message: 

187 if "id" not in message: 

188 raise ValueError( 

189 f"Unexpected format for serialized message: {message}" 

190 " Message does not have an id." 

191 ) 

192 return message["id"][-1].replace("Message", "").lower() 

193 else: 

194 if "type" not in message: 

195 raise ValueError( 

196 f"Unexpected format for stored message: {message}" 

197 " Message does not have a type." 

198 ) 

199 return message["type"] 

200 

201 

202def _get_message_fields(message: Mapping[str, Any]) -> Mapping[str, Any]: 

203 if not message: 

204 raise ValueError("Message is empty.") 

205 if "lc" in message: 

206 if "kwargs" not in message: 

207 raise ValueError( 

208 f"Unexpected format for serialized message: {message}" 

209 " Message does not have kwargs." 

210 ) 

211 return message["kwargs"] 

212 else: 

213 if "data" not in message: 

214 raise ValueError( 

215 f"Unexpected format for stored message: {message}" 

216 " Message does not have data." 

217 ) 

218 return message["data"] 

219 

220 

221def _convert_message(message: Mapping[str, Any]) -> dict[str, Any]: 

222 """Extract message from a message object.""" 

223 message_type = _get_message_type(message) 

224 message_data = _get_message_fields(message) 

225 return {"type": message_type, "data": message_data} 

226 

227 

228def get_messages_from_inputs(inputs: Mapping[str, Any]) -> list[dict[str, Any]]: 

229 """Extract messages from the given inputs dictionary. 

230 

231 Args: 

232 inputs: The inputs dictionary. 

233 

234 Returns: 

235 A list of dictionaries representing the extracted messages. 

236 

237 Raises: 

238 ValueError: If no message(s) are found in the inputs dictionary. 

239 """ 

240 if "messages" in inputs: 

241 return [_convert_message(message) for message in inputs["messages"]] 

242 if "message" in inputs: 

243 return [_convert_message(inputs["message"])] 

244 raise ValueError(f"Could not find message(s) in run with inputs {inputs}.") 

245 

246 

247def get_message_generation_from_outputs(outputs: Mapping[str, Any]) -> dict[str, Any]: 

248 """Retrieve the message generation from the given outputs. 

249 

250 Args: 

251 outputs: The outputs dictionary. 

252 

253 Returns: 

254 The message generation. 

255 

256 Raises: 

257 ValueError: If no generations are found or if multiple generations are present. 

258 """ 

259 if "generations" not in outputs: 

260 raise ValueError(f"No generations found in in run with output: {outputs}.") 

261 generations = outputs["generations"] 

262 if len(generations) != 1: 

263 raise ValueError( 

264 "Chat examples expect exactly one generation." 

265 f" Found {len(generations)} generations: {generations}." 

266 ) 

267 first_generation = generations[0] 

268 if "message" not in first_generation: 

269 raise ValueError( 

270 f"Unexpected format for generation: {first_generation}." 

271 " Generation does not have a message." 

272 ) 

273 return _convert_message(first_generation["message"]) 

274 

275 

276def get_prompt_from_inputs(inputs: Mapping[str, Any]) -> str: 

277 """Retrieve the prompt from the given inputs. 

278 

279 Args: 

280 inputs: The inputs dictionary. 

281 

282 Returns: 

283 str: The prompt. 

284 

285 Raises: 

286 ValueError: If the prompt is not found or if multiple prompts are present. 

287 """ 

288 if "prompt" in inputs: 

289 return inputs["prompt"] 

290 if "prompts" in inputs: 

291 prompts = inputs["prompts"] 

292 if len(prompts) == 1: 

293 return prompts[0] 

294 raise ValueError( 

295 f"Multiple prompts in run with inputs {inputs}." 

296 " Please create example manually." 

297 ) 

298 raise ValueError(f"Could not find prompt in run with inputs {inputs}.") 

299 

300 

301def get_llm_generation_from_outputs(outputs: Mapping[str, Any]) -> str: 

302 """Get the LLM generation from the outputs.""" 

303 if "generations" not in outputs: 

304 raise ValueError(f"No generations found in in run with output: {outputs}.") 

305 generations = outputs["generations"] 

306 if len(generations) != 1: 

307 raise ValueError(f"Multiple generations in run: {generations}") 

308 first_generation = generations[0] 

309 if "text" not in first_generation: 

310 raise ValueError(f"No text in generation: {first_generation}") 

311 return first_generation["text"] 

312 

313 

314@functools.lru_cache(maxsize=1) 

315def get_docker_compose_command() -> list[str]: 

316 """Get the correct docker compose command for this system.""" 

317 try: 

318 subprocess.check_call( 

319 ["docker", "compose", "--version"], 

320 stdout=subprocess.DEVNULL, 

321 stderr=subprocess.DEVNULL, 

322 ) 

323 return ["docker", "compose"] 

324 except (subprocess.CalledProcessError, FileNotFoundError): 

325 try: 

326 subprocess.check_call( 

327 ["docker-compose", "--version"], 

328 stdout=subprocess.DEVNULL, 

329 stderr=subprocess.DEVNULL, 

330 ) 

331 return ["docker-compose"] 

332 except (subprocess.CalledProcessError, FileNotFoundError): 

333 raise ValueError( 

334 "Neither 'docker compose' nor 'docker-compose'" 

335 " commands are available. Please install the Docker" 

336 " server following the instructions for your operating" 

337 " system at https://docs.docker.com/engine/install/" 

338 ) 

339 

340 

341def convert_langchain_message(message: ls_schemas.BaseMessageLike) -> dict: 

342 """Convert a LangChain message to an example.""" 

343 converted: dict[str, Any] = { 

344 "type": message.type, 

345 "data": {"content": message.content}, 

346 } 

347 # Check for presence of keys in additional_kwargs 

348 if message.additional_kwargs and len(message.additional_kwargs) > 0: 

349 converted["data"]["additional_kwargs"] = {**message.additional_kwargs} 

350 return converted 

351 

352 

353def is_base_message_like(obj: object) -> bool: 

354 """Check if the given object is similar to `BaseMessage`. 

355 

356 Args: 

357 obj: The object to check. 

358 

359 Returns: 

360 bool: True if the object is similar to `BaseMessage`, `False` otherwise. 

361 """ 

362 return all( 

363 [ 

364 isinstance(getattr(obj, "content", None), str), 

365 isinstance(getattr(obj, "additional_kwargs", None), dict), 

366 hasattr(obj, "type") and isinstance(getattr(obj, "type"), str), 

367 ] 

368 ) 

369 

370 

371def is_env_var_truish(value: Optional[str]) -> bool: 

372 """Check if the given environment variable is truish.""" 

373 return is_truish(get_env_var(value)) 

374 

375 

376@functools.lru_cache(maxsize=100) 

377def get_env_var( 

378 name: str, 

379 default: Optional[str] = None, 

380 *, 

381 namespaces: tuple = ("LANGSMITH", "LANGCHAIN"), 

382) -> Optional[str]: 

383 """Retrieve an environment variable from a list of namespaces. 

384 

385 Args: 

386 name: The name of the environment variable. 

387 default: The default value to return if the environment variable is not found. 

388 namespaces: A tuple of namespaces to search for the environment variable. 

389 

390 Defaults to `('LANGSMITH', 'LANGCHAINs')`. 

391 

392 Returns: 

393 The value of the environment variable if found, otherwise the default value. 

394 """ 

395 names = [f"{namespace}_{name}" for namespace in namespaces] 

396 for name in names: 

397 value = os.environ.get(name) 

398 if value is not None: 

399 return value 

400 return default 

401 

402 

403@functools.lru_cache(maxsize=1) 

404def get_tracer_project(return_default_value=True) -> Optional[str]: 

405 """Get the project name for a LangSmith tracer.""" 

406 return os.environ.get( 

407 # Hosted LangServe projects get precedence over all other defaults. 

408 # This is to make sure that we always use the associated project 

409 # for a hosted langserve deployment even if the customer sets some 

410 # other project name in their environment. 

411 "HOSTED_LANGSERVE_PROJECT_NAME", 

412 get_env_var( 

413 "PROJECT", 

414 # This is the legacy name for a LANGCHAIN_PROJECT, so it 

415 # has lower precedence than LANGCHAIN_PROJECT 

416 default=get_env_var( 

417 "SESSION", default="default" if return_default_value else None 

418 ), 

419 ), 

420 ) 

421 

422 

423class FilterPoolFullWarning(logging.Filter): 

424 """Filter `urllib3` warnings logged when the connection pool isn't reused.""" 

425 

426 def __init__(self, name: str = "", host: str = "") -> None: 

427 """Initialize the `FilterPoolFullWarning` filter. 

428 

429 Args: 

430 name: The name of the filter. Defaults to `""`. 

431 host: The host to filter. Defaults to `""`. 

432 """ 

433 super().__init__(name) 

434 self._host = host 

435 

436 def filter(self, record) -> bool: 

437 """urllib3.connectionpool:Connection pool is full, discarding connection: ...""" 

438 msg = record.getMessage() 

439 if "Connection pool is full, discarding connection" not in msg: 

440 return True 

441 return self._host not in msg 

442 

443 

444class FilterLangSmithRetry(logging.Filter): 

445 """Filter for retries from this lib.""" 

446 

447 def filter(self, record) -> bool: 

448 """Filter retries from this library.""" 

449 # We re-raise/log manually. 

450 msg = record.getMessage() 

451 return "LangSmithRetry" not in msg 

452 

453 

454class LangSmithRetry(Retry): 

455 """Wrapper to filter logs with this name.""" 

456 

457 

458_FILTER_LOCK = threading.RLock() 

459 

460 

461@contextlib.contextmanager 

462def filter_logs( 

463 logger: logging.Logger, filters: Sequence[logging.Filter] 

464) -> Generator[None, None, None]: 

465 """Temporarily adds specified filters to a logger. 

466 

467 Parameters: 

468 - logger: The logger to which the filters will be added. 

469 - filters: A sequence of `logging.Filter` objects to be temporarily added 

470 to the logger. 

471 """ 

472 with _FILTER_LOCK: 

473 for filter in filters: 

474 logger.addFilter(filter) 

475 # Not actually perfectly thread-safe, but it's only log filters 

476 try: 

477 yield 

478 finally: 

479 with _FILTER_LOCK: 

480 for filter in filters: 

481 try: 

482 logger.removeFilter(filter) 

483 except BaseException: 

484 _LOGGER.warning("Failed to remove filter") 

485 

486 

487def get_cache_dir(cache: Optional[str]) -> Optional[str]: 

488 """Get the testing cache directory. 

489 

490 Args: 

491 cache: The cache path. 

492 

493 Returns: 

494 The cache path if provided, otherwise the value from the `LANGSMITH_TEST_CACHE` 

495 environment variable. 

496 """ 

497 if cache is not None: 

498 return cache 

499 return get_env_var("TEST_CACHE", default=None) 

500 

501 

502def filter_request_headers( 

503 request: Any, 

504 *, 

505 ignore_hosts: Optional[Sequence[str]] = None, 

506 allow_hosts: Optional[Sequence[str]] = None, 

507) -> Any: 

508 """Filter request headers based on `ignore_hosts` and `allow_hosts`.""" 

509 # Legacy behavior 

510 if ignore_hosts and any(request.url.startswith(host) for host in ignore_hosts): 

511 return None 

512 

513 if allow_hosts: 

514 try: 

515 parsed_url = urllib_parse.urlparse(request.url) 

516 except Exception: 

517 # If URL parsing fails, don't cache to be safe 

518 return None 

519 request_host = parsed_url.hostname or "" 

520 # Check if request matches any allowed host 

521 host_matches = any( 

522 # Handle both full URLs (https://api.openai.com) 

523 # and hostnames (api.openai.com) 

524 ( 

525 request.url.startswith(host) 

526 if host.startswith(("http://", "https://")) 

527 else request_host == host or request_host.endswith(f".{host}") 

528 ) 

529 for host in allow_hosts 

530 ) 

531 if not host_matches: 

532 return None 

533 

534 request.headers = {} 

535 return request 

536 

537 

538@contextlib.contextmanager 

539def with_cache( 

540 path: Union[str, pathlib.Path], 

541 ignore_hosts: Optional[Sequence[str]] = None, 

542 allow_hosts: Optional[Sequence[str]] = None, 

543) -> Generator[None, None, None]: 

544 """Use a cache for requests.""" 

545 try: 

546 import vcr # type: ignore[import-untyped] 

547 except ImportError: 

548 raise ImportError( 

549 "vcrpy is required to use caching. Install with:" 

550 'pip install -U "langsmith[vcr]"' 

551 ) 

552 # Fix concurrency issue in vcrpy's patching 

553 from langsmith._internal import _patch as patch_urllib3 

554 

555 patch_urllib3.patch_urllib3() 

556 

557 cache_dir, cache_file = os.path.split(path) 

558 

559 ls_vcr = vcr.VCR( 

560 serializer=( 

561 "yaml" 

562 if cache_file.endswith(".yaml") or cache_file.endswith(".yml") 

563 else "json" 

564 ), 

565 cassette_library_dir=cache_dir, 

566 # Replay previous requests, record new ones 

567 # TODO: Support other modes 

568 record_mode="new_episodes", 

569 match_on=["uri", "method", "path", "body"], 

570 filter_headers=["authorization", "Set-Cookie"], 

571 before_record_request=lambda request: filter_request_headers( 

572 request, ignore_hosts=ignore_hosts, allow_hosts=allow_hosts 

573 ), 

574 ) 

575 with ls_vcr.use_cassette(cache_file): 

576 yield 

577 

578 

579@contextlib.contextmanager 

580def with_optional_cache( 

581 path: Optional[Union[str, pathlib.Path]], 

582 ignore_hosts: Optional[Sequence[str]] = None, 

583 allow_hosts: Optional[Sequence[str]] = None, 

584) -> Generator[None, None, None]: 

585 """Use a cache for requests.""" 

586 if path is not None: 

587 with with_cache(path, ignore_hosts, allow_hosts): 

588 yield 

589 else: 

590 yield 

591 

592 

593def _format_exc() -> str: 

594 # Used internally to format exceptions without cluttering the traceback 

595 tb_lines = traceback.format_exception(*sys.exc_info()) 

596 filtered_lines = [line for line in tb_lines if "langsmith/" not in line] 

597 return "".join(filtered_lines) 

598 

599 

600T = TypeVar("T") 

601 

602 

603def _middle_copy( 

604 val: T, memo: dict[int, Any], max_depth: int = 4, _depth: int = 0 

605) -> T: 

606 cls = type(val) 

607 

608 copier = getattr(cls, "__deepcopy__", None) 

609 if copier is not None: 

610 try: 

611 return copier(memo) 

612 except BaseException: 

613 pass 

614 if _depth >= max_depth: 

615 return val 

616 if isinstance(val, dict): 

617 return { # type: ignore[return-value] 

618 _middle_copy(k, memo, max_depth, _depth + 1): _middle_copy( 

619 v, memo, max_depth, _depth + 1 

620 ) 

621 for k, v in val.items() 

622 } 

623 if isinstance(val, list): 

624 return [_middle_copy(item, memo, max_depth, _depth + 1) for item in val] # type: ignore[return-value] 

625 if isinstance(val, tuple): 

626 return tuple(_middle_copy(item, memo, max_depth, _depth + 1) for item in val) # type: ignore[return-value] 

627 if isinstance(val, set): 

628 return {_middle_copy(item, memo, max_depth, _depth + 1) for item in val} # type: ignore[return-value] 

629 

630 return val 

631 

632 

633def deepish_copy(val: T) -> T: 

634 """Deep copy a value with a compromise for uncopyable objects. 

635 

636 Args: 

637 val: The value to be deep copied. 

638 

639 Returns: 

640 The deep copied value. 

641 """ 

642 memo: dict[int, Any] = {} 

643 try: 

644 return copy.deepcopy(val, memo) 

645 except BaseException as e: 

646 # Generators, locks, etc. cannot be copied 

647 # and raise a TypeError (mentioning pickling, since the dunder methods) 

648 # are re-used for copying. We'll try to do a compromise and copy 

649 # what we can 

650 _LOGGER.debug("Failed to deepcopy input: %s", repr(e)) 

651 return _middle_copy(val, memo) 

652 

653 

654def is_version_greater_or_equal(current_version: str, target_version: str) -> bool: 

655 """Check if the current version is greater or equal to the target version.""" 

656 from packaging import version 

657 

658 current = version.parse(current_version) 

659 target = version.parse(target_version) 

660 return current >= target 

661 

662 

663def parse_prompt_identifier(identifier: str) -> tuple[str, str, str]: 

664 """Parse a string in the format of `owner/name:hash`, `name:hash`, `owner/name`, or `name`. 

665 

666 Args: 

667 identifier: The prompt identifier to parse. 

668 

669 Returns: 

670 A tuple containing `(owner, name, hash)`. 

671 

672 Raises: 

673 ValueError: If the identifier doesn't match the expected formats. 

674 """ # noqa: E501 

675 if ( 

676 not identifier 

677 or identifier.count("/") > 1 

678 or identifier.startswith("/") 

679 or identifier.endswith("/") 

680 ): 

681 raise ValueError(f"Invalid identifier format: {identifier}") 

682 

683 parts = identifier.split(":", 1) 

684 owner_name = parts[0] 

685 commit = parts[1] if len(parts) > 1 else "latest" 

686 

687 if "/" in owner_name: 

688 owner, name = owner_name.split("/", 1) 

689 if not owner or not name: 

690 raise ValueError(f"Invalid identifier format: {identifier}") 

691 return owner, name, commit 

692 else: 

693 if not owner_name: 

694 raise ValueError(f"Invalid identifier format: {identifier}") 

695 return "-", owner_name, commit 

696 

697 

698P = ParamSpec("P") 

699 

700 

701class ContextThreadPoolExecutor(ThreadPoolExecutor): 

702 """ThreadPoolExecutor that copies the context to the child thread.""" 

703 

704 def submit( # type: ignore[override] 

705 self, 

706 func: Callable[P, T], 

707 *args: P.args, 

708 **kwargs: P.kwargs, 

709 ) -> Future[T]: 

710 """Submit a function to the executor. 

711 

712 Args: 

713 func (Callable[..., T]): The function to submit. 

714 *args (Any): The positional arguments to the function. 

715 **kwargs (Any): The keyword arguments to the function. 

716 

717 Returns: 

718 Future[T]: The future for the function. 

719 """ 

720 return super().submit( 

721 cast( 

722 Callable[..., T], 

723 functools.partial( 

724 contextvars.copy_context().run, func, *args, **kwargs 

725 ), 

726 ) 

727 ) 

728 

729 def map( 

730 self, 

731 fn: Callable[..., T], 

732 *iterables: Iterable[Any], 

733 timeout: Optional[float] = None, 

734 chunksize: int = 1, 

735 ) -> Iterator[T]: 

736 """Return an iterator equivalent to stdlib map. 

737 

738 Each function will receive its own copy of the context from the parent thread. 

739 

740 Args: 

741 fn: A callable that will take as many arguments as there are 

742 passed iterables. 

743 timeout: The maximum number of seconds to wait. If None, then there 

744 is no limit on the wait time. 

745 chunksize: The size of the chunks the iterable will be broken into 

746 before being passed to a child process. This argument is only 

747 used by ProcessPoolExecutor; it is ignored by 

748 ThreadPoolExecutor. 

749 

750 Returns: 

751 An iterator equivalent to: map(func, *iterables) but the calls may 

752 be evaluated out-of-order. 

753 

754 Raises: 

755 TimeoutError: If the entire result iterator could not be generated 

756 before the given timeout. 

757 Exception: If fn(*args) raises for any values. 

758 """ 

759 contexts = [contextvars.copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type] 

760 

761 def _wrapped_fn(*args: Any) -> T: 

762 return contexts.pop().run(fn, *args) 

763 

764 return super().map( 

765 _wrapped_fn, 

766 *iterables, 

767 timeout=timeout, 

768 chunksize=chunksize, 

769 ) 

770 

771 

772def get_api_url(api_url: Optional[str]) -> str: 

773 """Get the LangSmith API URL from the environment or the given value.""" 

774 _api_url = api_url or cast( 

775 str, 

776 get_env_var( 

777 "ENDPOINT", 

778 default="https://api.smith.langchain.com", 

779 ), 

780 ) 

781 if not _api_url.strip(): 

782 raise LangSmithUserError("LangSmith API URL cannot be empty") 

783 return _api_url.strip().strip('"').strip("'").rstrip("/") 

784 

785 

786def get_api_key(api_key: Optional[str]) -> Optional[str]: 

787 """Get the API key from the environment or the given value.""" 

788 api_key_ = api_key if api_key is not None else get_env_var("API_KEY", default=None) 

789 if api_key_ is None or not api_key_.strip(): 

790 return None 

791 return api_key_.strip().strip('"').strip("'") 

792 

793 

794def get_workspace_id(workspace_id: Optional[str]) -> Optional[str]: 

795 """Get workspace ID.""" 

796 workspace_id_ = ( 

797 workspace_id 

798 if workspace_id is not None 

799 else get_env_var("WORKSPACE_ID", default=None) 

800 ) 

801 if workspace_id_ is None or not workspace_id_.strip(): 

802 return None 

803 return workspace_id_.strip().strip('"').strip("'") 

804 

805 

806def _is_localhost(url: str) -> bool: 

807 """Check if the URL is localhost. 

808 

809 Parameters 

810 ---------- 

811 url : str 

812 The URL to check. 

813 

814 Returns: 

815 ------- 

816 bool 

817 True if the URL is localhost, False otherwise. 

818 """ 

819 try: 

820 netloc = urllib_parse.urlsplit(url).netloc.split(":")[0] 

821 ip = socket.gethostbyname(netloc) 

822 return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::") 

823 except socket.gaierror: 

824 return False 

825 

826 

827@functools.lru_cache(maxsize=2) 

828def get_host_url(web_url: Optional[str], api_url: str): 

829 """Get the host URL based on the web URL or API URL.""" 

830 if web_url: 

831 return web_url 

832 parsed_url = urllib_parse.urlparse(api_url) 

833 if _is_localhost(api_url): 

834 link = "http://localhost" 

835 elif str(parsed_url.path).endswith("/api"): 

836 new_path = str(parsed_url.path).rsplit("/api", 1)[0] 

837 link = urllib_parse.urlunparse(parsed_url._replace(path=new_path)) 

838 elif str(parsed_url.path).endswith("/api/v1"): 

839 new_path = str(parsed_url.path).rsplit("/api/v1", 1)[0] 

840 link = urllib_parse.urlunparse(parsed_url._replace(path=new_path)) 

841 elif str(parsed_url.netloc).startswith("eu."): 

842 link = "https://eu.smith.langchain.com" 

843 elif str(parsed_url.netloc).startswith("dev."): 

844 link = "https://dev.smith.langchain.com" 

845 elif str(parsed_url.netloc).startswith("beta."): 

846 link = "https://beta.smith.langchain.com" 

847 else: 

848 link = "https://smith.langchain.com" 

849 return link 

850 

851 

852def _get_function_name(fn: Callable, depth: int = 0) -> str: 

853 if depth > 2 or not callable(fn): 

854 return str(fn) 

855 

856 if hasattr(fn, "__name__"): 

857 return fn.__name__ 

858 

859 if isinstance(fn, functools.partial): 

860 return _get_function_name(fn.func, depth + 1) 

861 

862 if hasattr(fn, "__call__"): 

863 if hasattr(fn, "__class__") and hasattr(fn.__class__, "__name__"): 

864 return fn.__class__.__name__ 

865 return _get_function_name(fn.__call__, depth + 1) 

866 

867 return str(fn) 

868 

869 

870def is_truish(val: Any) -> bool: 

871 """Check if the value is truish. 

872 

873 Args: 

874 val (Any): The value to check. 

875 

876 Returns: 

877 bool: True if the value is truish, False otherwise. 

878 """ 

879 if isinstance(val, str): 

880 return val.lower() == "true" or val == "1" 

881 return bool(val)