Coverage for langsmith/_testing.py: 0%

253 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-06 01:30 -0800

1from __future__ import annotations 

2 

3import atexit 

4import datetime 

5import functools 

6import inspect 

7import logging 

8import threading 

9import uuid 

10import warnings 

11from collections import defaultdict 

12from pathlib import Path 

13from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload 

14 

15import orjson 

16from typing_extensions import TypedDict 

17 

18from langsmith import client as ls_client 

19from langsmith import env as ls_env 

20from langsmith import run_helpers as rh 

21from langsmith import run_trees as rt 

22from langsmith import schemas as ls_schemas 

23from langsmith import utils as ls_utils 

24 

25try: 

26 import pytest # type: ignore 

27 

28 SkipException = pytest.skip.Exception 

29except ImportError: 

30 

31 class SkipException(Exception): # type: ignore[no-redef] 

32 pass 

33 

34 

35logger = logging.getLogger(__name__) 

36 

37 

38T = TypeVar("T") 

39U = TypeVar("U") 

40 

41 

42@overload 

43def test( 

44 func: Callable, 

45) -> Callable: ... 

46 

47 

48@overload 

49def test( 

50 *, 

51 id: Optional[uuid.UUID] = None, 

52 output_keys: Optional[Sequence[str]] = None, 

53 client: Optional[ls_client.Client] = None, 

54 test_suite_name: Optional[str] = None, 

55) -> Callable[[Callable], Callable]: ... 

56 

57 

58def test(*args: Any, **kwargs: Any) -> Callable: 

59 """Create a test case in LangSmith. 

60 

61 This decorator is used to mark a function as a test case for LangSmith. It ensures 

62 that the necessary example data is created and associated with the test function. 

63 The decorated function will be executed as a test case, and the results will be 

64 recorded and reported by LangSmith. 

65 

66 Args: 

67 - id (Optional[uuid.UUID]): A unique identifier for the test case. If not 

68 provided, an ID will be generated based on the test function's module 

69 and name. 

70 - output_keys (Optional[Sequence[str]]): A list of keys to be considered as 

71 the output keys for the test case. These keys will be extracted from the 

72 test function's inputs and stored as the expected outputs. 

73 - client (Optional[ls_client.Client]): An instance of the LangSmith client 

74 to be used for communication with the LangSmith service. If not provided, 

75 a default client will be used. 

76 - test_suite_name (Optional[str]): The name of the test suite to which the 

77 test case belongs. If not provided, the test suite name will be determined 

78 based on the environment or the package name. 

79 

80 Returns: 

81 Callable: The decorated test function. 

82 

83 Environment: 

84 - LANGSMITH_TEST_CACHE: If set, API calls will be cached to disk to 

85 save time and costs during testing. Recommended to commit the 

86 cache files to your repository for faster CI/CD runs. 

87 Requires the 'langsmith[vcr]' package to be installed. 

88 - LANGSMITH_TEST_TRACKING: Set this variable to the path of a directory 

89 to enable caching of test results. This is useful for re-running tests 

90 without re-executing the code. Requires the 'langsmith[vcr]' package. 

91 

92 Example: 

93 For basic usage, simply decorate a test function with `@test`: 

94 

95 >>> @test 

96 ... def test_addition(): 

97 ... assert 3 + 4 == 7 

98 

99 

100 Any code that is traced (such as those traced using `@traceable` 

101 or `wrap_*` functions) will be traced within the test case for 

102 improved visibility and debugging. 

103 

104 >>> from langsmith import traceable 

105 >>> @traceable 

106 ... def generate_numbers(): 

107 ... return 3, 4 

108 

109 >>> @test 

110 ... def test_nested(): 

111 ... # Traced code will be included in the test case 

112 ... a, b = generate_numbers() 

113 ... assert a + b == 7 

114 

115 LLM calls are expensive! Cache requests by setting 

116 `LANGSMITH_TEST_CACHE=path/to/cache`. Check in these files to speed up 

117 CI/CD pipelines, so your results only change when your prompt or requested 

118 model changes. 

119 

120 Note that this will require that you install langsmith with the `vcr` extra: 

121 

122 `pip install -U "langsmith[vcr]"` 

123 

124 Caching is faster if you install libyaml. See 

125 https://vcrpy.readthedocs.io/en/latest/installation.html#speed for more details. 

126 

127 >>> # os.environ["LANGSMITH_TEST_CACHE"] = "tests/cassettes" 

128 >>> import openai 

129 >>> from langsmith.wrappers import wrap_openai 

130 >>> oai_client = wrap_openai(openai.Client()) 

131 >>> @test 

132 ... def test_openai_says_hello(): 

133 ... # Traced code will be included in the test case 

134 ... response = oai_client.chat.completions.create( 

135 ... model="gpt-3.5-turbo", 

136 ... messages=[ 

137 ... {"role": "system", "content": "You are a helpful assistant."}, 

138 ... {"role": "user", "content": "Say hello!"}, 

139 ... ], 

140 ... ) 

141 ... assert "hello" in response.choices[0].message.content.lower() 

142 

143 LLMs are stochastic. Naive assertions are flakey. You can use langsmith's 

144 `expect` to score and make approximate assertions on your results. 

145 

146 >>> from langsmith import expect 

147 >>> @test 

148 ... def test_output_semantically_close(): 

149 ... response = oai_client.chat.completions.create( 

150 ... model="gpt-3.5-turbo", 

151 ... messages=[ 

152 ... {"role": "system", "content": "You are a helpful assistant."}, 

153 ... {"role": "user", "content": "Say hello!"}, 

154 ... ], 

155 ... ) 

156 ... # The embedding_distance call logs the embedding distance to LangSmith 

157 ... expect.embedding_distance( 

158 ... prediction=response.choices[0].message.content, 

159 ... reference="Hello!", 

160 ... # The following optional assertion logs a 

161 ... # pass/fail score to LangSmith 

162 ... # and raises an AssertionError if the assertion fails. 

163 ... ).to_be_less_than(1.0) 

164 ... # Compute damerau_levenshtein distance 

165 ... expect.edit_distance( 

166 ... prediction=response.choices[0].message.content, 

167 ... reference="Hello!", 

168 ... # And then log a pass/fail score to LangSmith 

169 ... ).to_be_less_than(1.0) 

170 

171 The `@test` decorator works natively with pytest fixtures. 

172 The values will populate the "inputs" of the corresponding example in LangSmith. 

173 

174 >>> import pytest 

175 >>> @pytest.fixture 

176 ... def some_input(): 

177 ... return "Some input" 

178 >>> 

179 >>> @test 

180 ... def test_with_fixture(some_input: str): 

181 ... assert "input" in some_input 

182 >>> 

183 

184 You can still use pytest.parametrize() as usual to run multiple test cases 

185 using the same test function. 

186 

187 >>> @test(output_keys=["expected"]) 

188 ... @pytest.mark.parametrize( 

189 ... "a, b, expected", 

190 ... [ 

191 ... (1, 2, 3), 

192 ... (3, 4, 7), 

193 ... ], 

194 ... ) 

195 ... def test_addition_with_multiple_inputs(a: int, b: int, expected: int): 

196 ... assert a + b == expected 

197 

198 By default, each test case will be assigned a consistent, unique identifier 

199 based on the function name and module. You can also provide a custom identifier 

200 using the `id` argument: 

201 >>> @test(id="1a77e4b5-1d38-4081-b829-b0442cf3f145") 

202 ... def test_multiplication(): 

203 ... assert 3 * 4 == 12 

204 

205 By default, all test test inputs are saved as "inputs" to a dataset. 

206 You can specify the `output_keys` argument to persist those keys 

207 within the dataset's "outputs" fields. 

208 

209 >>> @pytest.fixture 

210 ... def expected_output(): 

211 ... return "input" 

212 >>> @test(output_keys=["expected_output"]) 

213 ... def test_with_expected_output(some_input: str, expected_output: str): 

214 ... assert expected_output in some_input 

215 

216 

217 To run these tests, use the pytest CLI. Or directly run the test functions. 

218 >>> test_output_semantically_close() 

219 >>> test_addition() 

220 >>> test_nested() 

221 >>> test_with_fixture("Some input") 

222 >>> test_with_expected_output("Some input", "Some") 

223 >>> test_multiplication() 

224 >>> test_openai_says_hello() 

225 >>> test_addition_with_multiple_inputs(1, 2, 3) 

226 """ 

227 langtest_extra = _UTExtra( 

228 id=kwargs.pop("id", None), 

229 output_keys=kwargs.pop("output_keys", None), 

230 client=kwargs.pop("client", None), 

231 test_suite_name=kwargs.pop("test_suite_name", None), 

232 cache=ls_utils.get_cache_dir(kwargs.pop("cache", None)), 

233 ) 

234 if kwargs: 

235 warnings.warn(f"Unexpected keyword arguments: {kwargs.keys()}") 

236 disable_tracking = ls_utils.test_tracking_is_disabled() 

237 if disable_tracking: 

238 warnings.warn( 

239 "LANGSMITH_TEST_TRACKING is set to 'false'." 

240 " Skipping LangSmith test tracking." 

241 ) 

242 

243 def decorator(func: Callable) -> Callable: 

244 if inspect.iscoroutinefunction(func): 

245 

246 @functools.wraps(func) 

247 async def async_wrapper(*test_args: Any, **test_kwargs: Any): 

248 if disable_tracking: 

249 return await func(*test_args, **test_kwargs) 

250 await _arun_test( 

251 func, *test_args, **test_kwargs, langtest_extra=langtest_extra 

252 ) 

253 

254 return async_wrapper 

255 

256 @functools.wraps(func) 

257 def wrapper(*test_args: Any, **test_kwargs: Any): 

258 if disable_tracking: 

259 return func(*test_args, **test_kwargs) 

260 _run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra) 

261 

262 return wrapper 

263 

264 if args and callable(args[0]): 

265 return decorator(args[0]) 

266 

267 return decorator 

268 

269 

270## Private functions 

271 

272 

273def _get_experiment_name() -> str: 

274 # TODO Make more easily configurable 

275 prefix = ls_utils.get_tracer_project(False) or "TestSuiteResult" 

276 name = f"{prefix}:{uuid.uuid4().hex[:8]}" 

277 return name 

278 

279 

280def _get_test_suite_name(func: Callable) -> str: 

281 test_suite_name = ls_utils.get_env_var("TEST_SUITE") 

282 if test_suite_name: 

283 return test_suite_name 

284 repo_name = ls_env.get_git_info()["repo_name"] 

285 try: 

286 mod = inspect.getmodule(func) 

287 if mod: 

288 return f"{repo_name}.{mod.__name__}" 

289 except BaseException: 

290 logger.debug("Could not determine test suite name from file path.") 

291 

292 raise ValueError("Please set the LANGSMITH_TEST_SUITE environment variable.") 

293 

294 

295def _get_test_suite( 

296 client: ls_client.Client, test_suite_name: str 

297) -> ls_schemas.Dataset: 

298 if client.has_dataset(dataset_name=test_suite_name): 

299 return client.read_dataset(dataset_name=test_suite_name) 

300 else: 

301 repo = ls_env.get_git_info().get("remote_url") or "" 

302 description = "Test suite" 

303 if repo: 

304 description += f" for {repo}" 

305 return client.create_dataset( 

306 dataset_name=test_suite_name, description=description 

307 ) 

308 

309 

310def _start_experiment( 

311 client: ls_client.Client, 

312 test_suite: ls_schemas.Dataset, 

313) -> ls_schemas.TracerSession: 

314 experiment_name = _get_experiment_name() 

315 try: 

316 return client.create_project( 

317 experiment_name, 

318 reference_dataset_id=test_suite.id, 

319 description="Test Suite Results.", 

320 metadata={ 

321 "revision_id": ls_env.get_langchain_env_var_metadata().get( 

322 "revision_id" 

323 ) 

324 }, 

325 ) 

326 except ls_utils.LangSmithConflictError: 

327 return client.read_project(project_name=experiment_name) 

328 

329 

330# Track the number of times a parameter has been used in a test 

331# This is to ensure that we can uniquely identify each test case 

332# defined using pytest.mark.parametrize 

333_param_dict: dict = defaultdict(lambda: defaultdict(int)) 

334 

335 

336def _get_id(func: Callable, inputs: dict, suite_id: uuid.UUID) -> Tuple[uuid.UUID, str]: 

337 global _param_dict 

338 try: 

339 file_path = str(Path(inspect.getfile(func)).relative_to(Path.cwd())) 

340 except ValueError: 

341 # Fall back to module name if file path is not available 

342 file_path = func.__module__ 

343 identifier = f"{suite_id}{file_path}::{func.__name__}" 

344 input_keys = tuple(sorted(inputs.keys())) 

345 arg_indices = [] 

346 for key in input_keys: 

347 _param_dict[identifier][key] += 1 

348 arg_indices.append(f"{key}{_param_dict[identifier][key]}") 

349 if arg_indices: 

350 identifier += f"[{'-'.join(arg_indices)}]" 

351 return uuid.uuid5(uuid.NAMESPACE_DNS, identifier), identifier[len(str(suite_id)) :] 

352 

353 

354def _end_tests( 

355 test_suite: _LangSmithTestSuite, 

356): 

357 git_info = ls_env.get_git_info() or {} 

358 test_suite.client.update_project( 

359 test_suite.experiment_id, 

360 end_time=datetime.datetime.now(datetime.timezone.utc), 

361 metadata={ 

362 **git_info, 

363 "dataset_version": test_suite.get_version(), 

364 "revision_id": ls_env.get_langchain_env_var_metadata().get("revision_id"), 

365 }, 

366 ) 

367 test_suite.wait() 

368 

369 

370VT = TypeVar("VT", bound=Optional[dict]) 

371 

372 

373def _serde_example_values(values: VT) -> VT: 

374 if values is None: 

375 return values 

376 bts = ls_client._dumps_json(values) 

377 return orjson.loads(bts) 

378 

379 

380class _LangSmithTestSuite: 

381 _instances: Optional[dict] = None 

382 _lock = threading.RLock() 

383 

384 def __init__( 

385 self, 

386 client: Optional[ls_client.Client], 

387 experiment: ls_schemas.TracerSession, 

388 dataset: ls_schemas.Dataset, 

389 ): 

390 self.client = client or rt.get_cached_client() 

391 self._experiment = experiment 

392 self._dataset = dataset 

393 self._version: Optional[datetime.datetime] = None 

394 self._executor = ls_utils.ContextThreadPoolExecutor(max_workers=1) 

395 atexit.register(_end_tests, self) 

396 

397 @property 

398 def id(self): 

399 return self._dataset.id 

400 

401 @property 

402 def experiment_id(self): 

403 return self._experiment.id 

404 

405 @property 

406 def experiment(self): 

407 return self._experiment 

408 

409 @classmethod 

410 def from_test( 

411 cls, 

412 client: Optional[ls_client.Client], 

413 func: Callable, 

414 test_suite_name: Optional[str] = None, 

415 ) -> _LangSmithTestSuite: 

416 client = client or rt.get_cached_client() 

417 test_suite_name = test_suite_name or _get_test_suite_name(func) 

418 with cls._lock: 

419 if not cls._instances: 

420 cls._instances = {} 

421 if test_suite_name not in cls._instances: 

422 test_suite = _get_test_suite(client, test_suite_name) 

423 experiment = _start_experiment(client, test_suite) 

424 cls._instances[test_suite_name] = cls(client, experiment, test_suite) 

425 return cls._instances[test_suite_name] 

426 

427 @property 

428 def name(self): 

429 return self._experiment.name 

430 

431 def update_version(self, version: datetime.datetime) -> None: 

432 with self._lock: 

433 if self._version is None or version > self._version: 

434 self._version = version 

435 

436 def get_version(self) -> Optional[datetime.datetime]: 

437 with self._lock: 

438 return self._version 

439 

440 def submit_result( 

441 self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False 

442 ) -> None: 

443 self._executor.submit(self._submit_result, run_id, error, skipped=skipped) 

444 

445 def _submit_result( 

446 self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False 

447 ) -> None: 

448 if error: 

449 if skipped: 

450 self.client.create_feedback( 

451 run_id, 

452 key="pass", 

453 # Don't factor into aggregate score 

454 score=None, 

455 comment=f"Skipped: {repr(error)}", 

456 ) 

457 else: 

458 self.client.create_feedback( 

459 run_id, key="pass", score=0, comment=f"Error: {repr(error)}" 

460 ) 

461 else: 

462 self.client.create_feedback( 

463 run_id, 

464 key="pass", 

465 score=1, 

466 ) 

467 

468 def sync_example( 

469 self, example_id: uuid.UUID, inputs: dict, outputs: dict, metadata: dict 

470 ) -> None: 

471 self._executor.submit( 

472 self._sync_example, example_id, inputs, outputs, metadata.copy() 

473 ) 

474 

475 def _sync_example( 

476 self, example_id: uuid.UUID, inputs: dict, outputs: dict, metadata: dict 

477 ) -> None: 

478 inputs_ = _serde_example_values(inputs) 

479 outputs_ = _serde_example_values(outputs) 

480 try: 

481 example = self.client.read_example(example_id=example_id) 

482 if ( 

483 inputs_ != example.inputs 

484 or outputs_ != example.outputs 

485 or str(example.dataset_id) != str(self.id) 

486 ): 

487 self.client.update_example( 

488 example_id=example.id, 

489 inputs=inputs_, 

490 outputs=outputs_, 

491 metadata=metadata, 

492 dataset_id=self.id, 

493 ) 

494 except ls_utils.LangSmithNotFoundError: 

495 example = self.client.create_example( 

496 example_id=example_id, 

497 inputs=inputs_, 

498 outputs=outputs_, 

499 dataset_id=self.id, 

500 metadata=metadata, 

501 created_at=self._experiment.start_time, 

502 ) 

503 if example.modified_at: 

504 self.update_version(example.modified_at) 

505 

506 def wait(self): 

507 self._executor.shutdown(wait=True) 

508 

509 

510class _UTExtra(TypedDict, total=False): 

511 client: Optional[ls_client.Client] 

512 id: Optional[uuid.UUID] 

513 output_keys: Optional[Sequence[str]] 

514 test_suite_name: Optional[str] 

515 cache: Optional[str] 

516 

517 

518def _get_test_repr(func: Callable, sig: inspect.Signature) -> str: 

519 name = getattr(func, "__name__", None) or "" 

520 description = getattr(func, "__doc__", None) or "" 

521 if description: 

522 description = f" - {description.strip()}" 

523 return f"{name}{sig}{description}" 

524 

525 

526def _ensure_example( 

527 func: Callable, *args: Any, langtest_extra: _UTExtra, **kwargs: Any 

528) -> Tuple[_LangSmithTestSuite, uuid.UUID]: 

529 client = langtest_extra["client"] or rt.get_cached_client() 

530 output_keys = langtest_extra["output_keys"] 

531 signature = inspect.signature(func) 

532 inputs: dict = rh._get_inputs_safe(signature, *args, **kwargs) 

533 outputs = {} 

534 if output_keys: 

535 for k in output_keys: 

536 outputs[k] = inputs.pop(k, None) 

537 test_suite = _LangSmithTestSuite.from_test( 

538 client, func, langtest_extra.get("test_suite_name") 

539 ) 

540 example_id, example_name = _get_id(func, inputs, test_suite.id) 

541 example_id = langtest_extra["id"] or example_id 

542 test_suite.sync_example( 

543 example_id, 

544 inputs, 

545 outputs, 

546 metadata={"signature": _get_test_repr(func, signature), "name": example_name}, 

547 ) 

548 return test_suite, example_id 

549 

550 

551def _run_test( 

552 func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any 

553) -> None: 

554 test_suite, example_id = _ensure_example( 

555 func, *test_args, **test_kwargs, langtest_extra=langtest_extra 

556 ) 

557 run_id = uuid.uuid4() 

558 

559 def _test(): 

560 func_inputs = rh._get_inputs_safe( 

561 inspect.signature(func), *test_args, **test_kwargs 

562 ) 

563 with rh.trace( 

564 name=getattr(func, "__name__", "Test"), 

565 run_id=run_id, 

566 reference_example_id=example_id, 

567 inputs=func_inputs, 

568 project_name=test_suite.name, 

569 exceptions_to_handle=(SkipException,), 

570 ) as run_tree: 

571 try: 

572 result = func(*test_args, **test_kwargs) 

573 run_tree.end( 

574 outputs=( 

575 result 

576 if result is None or isinstance(result, dict) 

577 else {"output": result} 

578 ) 

579 ) 

580 except SkipException as e: 

581 test_suite.submit_result(run_id, error=repr(e), skipped=True) 

582 run_tree.end( 

583 outputs={"skipped_reason": repr(e)}, 

584 ) 

585 raise e 

586 except BaseException as e: 

587 test_suite.submit_result(run_id, error=repr(e)) 

588 raise e 

589 try: 

590 test_suite.submit_result(run_id, error=None) 

591 except BaseException as e: 

592 logger.warning(f"Failed to create feedback for run_id {run_id}: {e}") 

593 

594 cache_path = ( 

595 Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml" 

596 if langtest_extra["cache"] 

597 else None 

598 ) 

599 current_context = rh.get_tracing_context() 

600 metadata = { 

601 **(current_context["metadata"] or {}), 

602 **{ 

603 "experiment": test_suite.experiment.name, 

604 "reference_example_id": str(example_id), 

605 }, 

606 } 

607 with rh.tracing_context( 

608 **{**current_context, "metadata": metadata} 

609 ), ls_utils.with_optional_cache( 

610 cache_path, ignore_hosts=[test_suite.client.api_url] 

611 ): 

612 _test() 

613 

614 

615async def _arun_test( 

616 func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any 

617) -> None: 

618 test_suite, example_id = _ensure_example( 

619 func, *test_args, **test_kwargs, langtest_extra=langtest_extra 

620 ) 

621 run_id = uuid.uuid4() 

622 

623 async def _test(): 

624 func_inputs = rh._get_inputs_safe( 

625 inspect.signature(func), *test_args, **test_kwargs 

626 ) 

627 with rh.trace( 

628 name=getattr(func, "__name__", "Test"), 

629 run_id=run_id, 

630 reference_example_id=example_id, 

631 inputs=func_inputs, 

632 project_name=test_suite.name, 

633 exceptions_to_handle=(SkipException,), 

634 ) as run_tree: 

635 try: 

636 result = await func(*test_args, **test_kwargs) 

637 run_tree.end( 

638 outputs=( 

639 result 

640 if result is None or isinstance(result, dict) 

641 else {"output": result} 

642 ) 

643 ) 

644 except SkipException as e: 

645 test_suite.submit_result(run_id, error=repr(e), skipped=True) 

646 run_tree.end( 

647 outputs={"skipped_reason": repr(e)}, 

648 ) 

649 raise e 

650 except BaseException as e: 

651 test_suite.submit_result(run_id, error=repr(e)) 

652 raise e 

653 try: 

654 test_suite.submit_result(run_id, error=None) 

655 except BaseException as e: 

656 logger.warning(f"Failed to create feedback for run_id {run_id}: {e}") 

657 

658 cache_path = ( 

659 Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml" 

660 if langtest_extra["cache"] 

661 else None 

662 ) 

663 current_context = rh.get_tracing_context() 

664 metadata = { 

665 **(current_context["metadata"] or {}), 

666 **{ 

667 "experiment": test_suite.experiment.name, 

668 "reference_example_id": str(example_id), 

669 }, 

670 } 

671 with rh.tracing_context( 

672 **{**current_context, "metadata": metadata} 

673 ), ls_utils.with_optional_cache( 

674 cache_path, ignore_hosts=[test_suite.client.api_url] 

675 ): 

676 await _test() 

677 

678 

679# For backwards compatibility 

680unit = test