Coverage for langsmith/_testing.py: 0%
253 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-06 01:30 -0800
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-06 01:30 -0800
1from __future__ import annotations
3import atexit
4import datetime
5import functools
6import inspect
7import logging
8import threading
9import uuid
10import warnings
11from collections import defaultdict
12from pathlib import Path
13from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload
15import orjson
16from typing_extensions import TypedDict
18from langsmith import client as ls_client
19from langsmith import env as ls_env
20from langsmith import run_helpers as rh
21from langsmith import run_trees as rt
22from langsmith import schemas as ls_schemas
23from langsmith import utils as ls_utils
25try:
26 import pytest # type: ignore
28 SkipException = pytest.skip.Exception
29except ImportError:
31 class SkipException(Exception): # type: ignore[no-redef]
32 pass
35logger = logging.getLogger(__name__)
38T = TypeVar("T")
39U = TypeVar("U")
42@overload
43def test(
44 func: Callable,
45) -> Callable: ...
48@overload
49def test(
50 *,
51 id: Optional[uuid.UUID] = None,
52 output_keys: Optional[Sequence[str]] = None,
53 client: Optional[ls_client.Client] = None,
54 test_suite_name: Optional[str] = None,
55) -> Callable[[Callable], Callable]: ...
58def test(*args: Any, **kwargs: Any) -> Callable:
59 """Create a test case in LangSmith.
61 This decorator is used to mark a function as a test case for LangSmith. It ensures
62 that the necessary example data is created and associated with the test function.
63 The decorated function will be executed as a test case, and the results will be
64 recorded and reported by LangSmith.
66 Args:
67 - id (Optional[uuid.UUID]): A unique identifier for the test case. If not
68 provided, an ID will be generated based on the test function's module
69 and name.
70 - output_keys (Optional[Sequence[str]]): A list of keys to be considered as
71 the output keys for the test case. These keys will be extracted from the
72 test function's inputs and stored as the expected outputs.
73 - client (Optional[ls_client.Client]): An instance of the LangSmith client
74 to be used for communication with the LangSmith service. If not provided,
75 a default client will be used.
76 - test_suite_name (Optional[str]): The name of the test suite to which the
77 test case belongs. If not provided, the test suite name will be determined
78 based on the environment or the package name.
80 Returns:
81 Callable: The decorated test function.
83 Environment:
84 - LANGSMITH_TEST_CACHE: If set, API calls will be cached to disk to
85 save time and costs during testing. Recommended to commit the
86 cache files to your repository for faster CI/CD runs.
87 Requires the 'langsmith[vcr]' package to be installed.
88 - LANGSMITH_TEST_TRACKING: Set this variable to the path of a directory
89 to enable caching of test results. This is useful for re-running tests
90 without re-executing the code. Requires the 'langsmith[vcr]' package.
92 Example:
93 For basic usage, simply decorate a test function with `@test`:
95 >>> @test
96 ... def test_addition():
97 ... assert 3 + 4 == 7
100 Any code that is traced (such as those traced using `@traceable`
101 or `wrap_*` functions) will be traced within the test case for
102 improved visibility and debugging.
104 >>> from langsmith import traceable
105 >>> @traceable
106 ... def generate_numbers():
107 ... return 3, 4
109 >>> @test
110 ... def test_nested():
111 ... # Traced code will be included in the test case
112 ... a, b = generate_numbers()
113 ... assert a + b == 7
115 LLM calls are expensive! Cache requests by setting
116 `LANGSMITH_TEST_CACHE=path/to/cache`. Check in these files to speed up
117 CI/CD pipelines, so your results only change when your prompt or requested
118 model changes.
120 Note that this will require that you install langsmith with the `vcr` extra:
122 `pip install -U "langsmith[vcr]"`
124 Caching is faster if you install libyaml. See
125 https://vcrpy.readthedocs.io/en/latest/installation.html#speed for more details.
127 >>> # os.environ["LANGSMITH_TEST_CACHE"] = "tests/cassettes"
128 >>> import openai
129 >>> from langsmith.wrappers import wrap_openai
130 >>> oai_client = wrap_openai(openai.Client())
131 >>> @test
132 ... def test_openai_says_hello():
133 ... # Traced code will be included in the test case
134 ... response = oai_client.chat.completions.create(
135 ... model="gpt-3.5-turbo",
136 ... messages=[
137 ... {"role": "system", "content": "You are a helpful assistant."},
138 ... {"role": "user", "content": "Say hello!"},
139 ... ],
140 ... )
141 ... assert "hello" in response.choices[0].message.content.lower()
143 LLMs are stochastic. Naive assertions are flakey. You can use langsmith's
144 `expect` to score and make approximate assertions on your results.
146 >>> from langsmith import expect
147 >>> @test
148 ... def test_output_semantically_close():
149 ... response = oai_client.chat.completions.create(
150 ... model="gpt-3.5-turbo",
151 ... messages=[
152 ... {"role": "system", "content": "You are a helpful assistant."},
153 ... {"role": "user", "content": "Say hello!"},
154 ... ],
155 ... )
156 ... # The embedding_distance call logs the embedding distance to LangSmith
157 ... expect.embedding_distance(
158 ... prediction=response.choices[0].message.content,
159 ... reference="Hello!",
160 ... # The following optional assertion logs a
161 ... # pass/fail score to LangSmith
162 ... # and raises an AssertionError if the assertion fails.
163 ... ).to_be_less_than(1.0)
164 ... # Compute damerau_levenshtein distance
165 ... expect.edit_distance(
166 ... prediction=response.choices[0].message.content,
167 ... reference="Hello!",
168 ... # And then log a pass/fail score to LangSmith
169 ... ).to_be_less_than(1.0)
171 The `@test` decorator works natively with pytest fixtures.
172 The values will populate the "inputs" of the corresponding example in LangSmith.
174 >>> import pytest
175 >>> @pytest.fixture
176 ... def some_input():
177 ... return "Some input"
178 >>>
179 >>> @test
180 ... def test_with_fixture(some_input: str):
181 ... assert "input" in some_input
182 >>>
184 You can still use pytest.parametrize() as usual to run multiple test cases
185 using the same test function.
187 >>> @test(output_keys=["expected"])
188 ... @pytest.mark.parametrize(
189 ... "a, b, expected",
190 ... [
191 ... (1, 2, 3),
192 ... (3, 4, 7),
193 ... ],
194 ... )
195 ... def test_addition_with_multiple_inputs(a: int, b: int, expected: int):
196 ... assert a + b == expected
198 By default, each test case will be assigned a consistent, unique identifier
199 based on the function name and module. You can also provide a custom identifier
200 using the `id` argument:
201 >>> @test(id="1a77e4b5-1d38-4081-b829-b0442cf3f145")
202 ... def test_multiplication():
203 ... assert 3 * 4 == 12
205 By default, all test test inputs are saved as "inputs" to a dataset.
206 You can specify the `output_keys` argument to persist those keys
207 within the dataset's "outputs" fields.
209 >>> @pytest.fixture
210 ... def expected_output():
211 ... return "input"
212 >>> @test(output_keys=["expected_output"])
213 ... def test_with_expected_output(some_input: str, expected_output: str):
214 ... assert expected_output in some_input
217 To run these tests, use the pytest CLI. Or directly run the test functions.
218 >>> test_output_semantically_close()
219 >>> test_addition()
220 >>> test_nested()
221 >>> test_with_fixture("Some input")
222 >>> test_with_expected_output("Some input", "Some")
223 >>> test_multiplication()
224 >>> test_openai_says_hello()
225 >>> test_addition_with_multiple_inputs(1, 2, 3)
226 """
227 langtest_extra = _UTExtra(
228 id=kwargs.pop("id", None),
229 output_keys=kwargs.pop("output_keys", None),
230 client=kwargs.pop("client", None),
231 test_suite_name=kwargs.pop("test_suite_name", None),
232 cache=ls_utils.get_cache_dir(kwargs.pop("cache", None)),
233 )
234 if kwargs:
235 warnings.warn(f"Unexpected keyword arguments: {kwargs.keys()}")
236 disable_tracking = ls_utils.test_tracking_is_disabled()
237 if disable_tracking:
238 warnings.warn(
239 "LANGSMITH_TEST_TRACKING is set to 'false'."
240 " Skipping LangSmith test tracking."
241 )
243 def decorator(func: Callable) -> Callable:
244 if inspect.iscoroutinefunction(func):
246 @functools.wraps(func)
247 async def async_wrapper(*test_args: Any, **test_kwargs: Any):
248 if disable_tracking:
249 return await func(*test_args, **test_kwargs)
250 await _arun_test(
251 func, *test_args, **test_kwargs, langtest_extra=langtest_extra
252 )
254 return async_wrapper
256 @functools.wraps(func)
257 def wrapper(*test_args: Any, **test_kwargs: Any):
258 if disable_tracking:
259 return func(*test_args, **test_kwargs)
260 _run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra)
262 return wrapper
264 if args and callable(args[0]):
265 return decorator(args[0])
267 return decorator
270## Private functions
273def _get_experiment_name() -> str:
274 # TODO Make more easily configurable
275 prefix = ls_utils.get_tracer_project(False) or "TestSuiteResult"
276 name = f"{prefix}:{uuid.uuid4().hex[:8]}"
277 return name
280def _get_test_suite_name(func: Callable) -> str:
281 test_suite_name = ls_utils.get_env_var("TEST_SUITE")
282 if test_suite_name:
283 return test_suite_name
284 repo_name = ls_env.get_git_info()["repo_name"]
285 try:
286 mod = inspect.getmodule(func)
287 if mod:
288 return f"{repo_name}.{mod.__name__}"
289 except BaseException:
290 logger.debug("Could not determine test suite name from file path.")
292 raise ValueError("Please set the LANGSMITH_TEST_SUITE environment variable.")
295def _get_test_suite(
296 client: ls_client.Client, test_suite_name: str
297) -> ls_schemas.Dataset:
298 if client.has_dataset(dataset_name=test_suite_name):
299 return client.read_dataset(dataset_name=test_suite_name)
300 else:
301 repo = ls_env.get_git_info().get("remote_url") or ""
302 description = "Test suite"
303 if repo:
304 description += f" for {repo}"
305 return client.create_dataset(
306 dataset_name=test_suite_name, description=description
307 )
310def _start_experiment(
311 client: ls_client.Client,
312 test_suite: ls_schemas.Dataset,
313) -> ls_schemas.TracerSession:
314 experiment_name = _get_experiment_name()
315 try:
316 return client.create_project(
317 experiment_name,
318 reference_dataset_id=test_suite.id,
319 description="Test Suite Results.",
320 metadata={
321 "revision_id": ls_env.get_langchain_env_var_metadata().get(
322 "revision_id"
323 )
324 },
325 )
326 except ls_utils.LangSmithConflictError:
327 return client.read_project(project_name=experiment_name)
330# Track the number of times a parameter has been used in a test
331# This is to ensure that we can uniquely identify each test case
332# defined using pytest.mark.parametrize
333_param_dict: dict = defaultdict(lambda: defaultdict(int))
336def _get_id(func: Callable, inputs: dict, suite_id: uuid.UUID) -> Tuple[uuid.UUID, str]:
337 global _param_dict
338 try:
339 file_path = str(Path(inspect.getfile(func)).relative_to(Path.cwd()))
340 except ValueError:
341 # Fall back to module name if file path is not available
342 file_path = func.__module__
343 identifier = f"{suite_id}{file_path}::{func.__name__}"
344 input_keys = tuple(sorted(inputs.keys()))
345 arg_indices = []
346 for key in input_keys:
347 _param_dict[identifier][key] += 1
348 arg_indices.append(f"{key}{_param_dict[identifier][key]}")
349 if arg_indices:
350 identifier += f"[{'-'.join(arg_indices)}]"
351 return uuid.uuid5(uuid.NAMESPACE_DNS, identifier), identifier[len(str(suite_id)) :]
354def _end_tests(
355 test_suite: _LangSmithTestSuite,
356):
357 git_info = ls_env.get_git_info() or {}
358 test_suite.client.update_project(
359 test_suite.experiment_id,
360 end_time=datetime.datetime.now(datetime.timezone.utc),
361 metadata={
362 **git_info,
363 "dataset_version": test_suite.get_version(),
364 "revision_id": ls_env.get_langchain_env_var_metadata().get("revision_id"),
365 },
366 )
367 test_suite.wait()
370VT = TypeVar("VT", bound=Optional[dict])
373def _serde_example_values(values: VT) -> VT:
374 if values is None:
375 return values
376 bts = ls_client._dumps_json(values)
377 return orjson.loads(bts)
380class _LangSmithTestSuite:
381 _instances: Optional[dict] = None
382 _lock = threading.RLock()
384 def __init__(
385 self,
386 client: Optional[ls_client.Client],
387 experiment: ls_schemas.TracerSession,
388 dataset: ls_schemas.Dataset,
389 ):
390 self.client = client or rt.get_cached_client()
391 self._experiment = experiment
392 self._dataset = dataset
393 self._version: Optional[datetime.datetime] = None
394 self._executor = ls_utils.ContextThreadPoolExecutor(max_workers=1)
395 atexit.register(_end_tests, self)
397 @property
398 def id(self):
399 return self._dataset.id
401 @property
402 def experiment_id(self):
403 return self._experiment.id
405 @property
406 def experiment(self):
407 return self._experiment
409 @classmethod
410 def from_test(
411 cls,
412 client: Optional[ls_client.Client],
413 func: Callable,
414 test_suite_name: Optional[str] = None,
415 ) -> _LangSmithTestSuite:
416 client = client or rt.get_cached_client()
417 test_suite_name = test_suite_name or _get_test_suite_name(func)
418 with cls._lock:
419 if not cls._instances:
420 cls._instances = {}
421 if test_suite_name not in cls._instances:
422 test_suite = _get_test_suite(client, test_suite_name)
423 experiment = _start_experiment(client, test_suite)
424 cls._instances[test_suite_name] = cls(client, experiment, test_suite)
425 return cls._instances[test_suite_name]
427 @property
428 def name(self):
429 return self._experiment.name
431 def update_version(self, version: datetime.datetime) -> None:
432 with self._lock:
433 if self._version is None or version > self._version:
434 self._version = version
436 def get_version(self) -> Optional[datetime.datetime]:
437 with self._lock:
438 return self._version
440 def submit_result(
441 self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False
442 ) -> None:
443 self._executor.submit(self._submit_result, run_id, error, skipped=skipped)
445 def _submit_result(
446 self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False
447 ) -> None:
448 if error:
449 if skipped:
450 self.client.create_feedback(
451 run_id,
452 key="pass",
453 # Don't factor into aggregate score
454 score=None,
455 comment=f"Skipped: {repr(error)}",
456 )
457 else:
458 self.client.create_feedback(
459 run_id, key="pass", score=0, comment=f"Error: {repr(error)}"
460 )
461 else:
462 self.client.create_feedback(
463 run_id,
464 key="pass",
465 score=1,
466 )
468 def sync_example(
469 self, example_id: uuid.UUID, inputs: dict, outputs: dict, metadata: dict
470 ) -> None:
471 self._executor.submit(
472 self._sync_example, example_id, inputs, outputs, metadata.copy()
473 )
475 def _sync_example(
476 self, example_id: uuid.UUID, inputs: dict, outputs: dict, metadata: dict
477 ) -> None:
478 inputs_ = _serde_example_values(inputs)
479 outputs_ = _serde_example_values(outputs)
480 try:
481 example = self.client.read_example(example_id=example_id)
482 if (
483 inputs_ != example.inputs
484 or outputs_ != example.outputs
485 or str(example.dataset_id) != str(self.id)
486 ):
487 self.client.update_example(
488 example_id=example.id,
489 inputs=inputs_,
490 outputs=outputs_,
491 metadata=metadata,
492 dataset_id=self.id,
493 )
494 except ls_utils.LangSmithNotFoundError:
495 example = self.client.create_example(
496 example_id=example_id,
497 inputs=inputs_,
498 outputs=outputs_,
499 dataset_id=self.id,
500 metadata=metadata,
501 created_at=self._experiment.start_time,
502 )
503 if example.modified_at:
504 self.update_version(example.modified_at)
506 def wait(self):
507 self._executor.shutdown(wait=True)
510class _UTExtra(TypedDict, total=False):
511 client: Optional[ls_client.Client]
512 id: Optional[uuid.UUID]
513 output_keys: Optional[Sequence[str]]
514 test_suite_name: Optional[str]
515 cache: Optional[str]
518def _get_test_repr(func: Callable, sig: inspect.Signature) -> str:
519 name = getattr(func, "__name__", None) or ""
520 description = getattr(func, "__doc__", None) or ""
521 if description:
522 description = f" - {description.strip()}"
523 return f"{name}{sig}{description}"
526def _ensure_example(
527 func: Callable, *args: Any, langtest_extra: _UTExtra, **kwargs: Any
528) -> Tuple[_LangSmithTestSuite, uuid.UUID]:
529 client = langtest_extra["client"] or rt.get_cached_client()
530 output_keys = langtest_extra["output_keys"]
531 signature = inspect.signature(func)
532 inputs: dict = rh._get_inputs_safe(signature, *args, **kwargs)
533 outputs = {}
534 if output_keys:
535 for k in output_keys:
536 outputs[k] = inputs.pop(k, None)
537 test_suite = _LangSmithTestSuite.from_test(
538 client, func, langtest_extra.get("test_suite_name")
539 )
540 example_id, example_name = _get_id(func, inputs, test_suite.id)
541 example_id = langtest_extra["id"] or example_id
542 test_suite.sync_example(
543 example_id,
544 inputs,
545 outputs,
546 metadata={"signature": _get_test_repr(func, signature), "name": example_name},
547 )
548 return test_suite, example_id
551def _run_test(
552 func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any
553) -> None:
554 test_suite, example_id = _ensure_example(
555 func, *test_args, **test_kwargs, langtest_extra=langtest_extra
556 )
557 run_id = uuid.uuid4()
559 def _test():
560 func_inputs = rh._get_inputs_safe(
561 inspect.signature(func), *test_args, **test_kwargs
562 )
563 with rh.trace(
564 name=getattr(func, "__name__", "Test"),
565 run_id=run_id,
566 reference_example_id=example_id,
567 inputs=func_inputs,
568 project_name=test_suite.name,
569 exceptions_to_handle=(SkipException,),
570 ) as run_tree:
571 try:
572 result = func(*test_args, **test_kwargs)
573 run_tree.end(
574 outputs=(
575 result
576 if result is None or isinstance(result, dict)
577 else {"output": result}
578 )
579 )
580 except SkipException as e:
581 test_suite.submit_result(run_id, error=repr(e), skipped=True)
582 run_tree.end(
583 outputs={"skipped_reason": repr(e)},
584 )
585 raise e
586 except BaseException as e:
587 test_suite.submit_result(run_id, error=repr(e))
588 raise e
589 try:
590 test_suite.submit_result(run_id, error=None)
591 except BaseException as e:
592 logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
594 cache_path = (
595 Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
596 if langtest_extra["cache"]
597 else None
598 )
599 current_context = rh.get_tracing_context()
600 metadata = {
601 **(current_context["metadata"] or {}),
602 **{
603 "experiment": test_suite.experiment.name,
604 "reference_example_id": str(example_id),
605 },
606 }
607 with rh.tracing_context(
608 **{**current_context, "metadata": metadata}
609 ), ls_utils.with_optional_cache(
610 cache_path, ignore_hosts=[test_suite.client.api_url]
611 ):
612 _test()
615async def _arun_test(
616 func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any
617) -> None:
618 test_suite, example_id = _ensure_example(
619 func, *test_args, **test_kwargs, langtest_extra=langtest_extra
620 )
621 run_id = uuid.uuid4()
623 async def _test():
624 func_inputs = rh._get_inputs_safe(
625 inspect.signature(func), *test_args, **test_kwargs
626 )
627 with rh.trace(
628 name=getattr(func, "__name__", "Test"),
629 run_id=run_id,
630 reference_example_id=example_id,
631 inputs=func_inputs,
632 project_name=test_suite.name,
633 exceptions_to_handle=(SkipException,),
634 ) as run_tree:
635 try:
636 result = await func(*test_args, **test_kwargs)
637 run_tree.end(
638 outputs=(
639 result
640 if result is None or isinstance(result, dict)
641 else {"output": result}
642 )
643 )
644 except SkipException as e:
645 test_suite.submit_result(run_id, error=repr(e), skipped=True)
646 run_tree.end(
647 outputs={"skipped_reason": repr(e)},
648 )
649 raise e
650 except BaseException as e:
651 test_suite.submit_result(run_id, error=repr(e))
652 raise e
653 try:
654 test_suite.submit_result(run_id, error=None)
655 except BaseException as e:
656 logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
658 cache_path = (
659 Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
660 if langtest_extra["cache"]
661 else None
662 )
663 current_context = rh.get_tracing_context()
664 metadata = {
665 **(current_context["metadata"] or {}),
666 **{
667 "experiment": test_suite.experiment.name,
668 "reference_example_id": str(example_id),
669 },
670 }
671 with rh.tracing_context(
672 **{**current_context, "metadata": metadata}
673 ), ls_utils.with_optional_cache(
674 cache_path, ignore_hosts=[test_suite.client.api_url]
675 ):
676 await _test()
679# For backwards compatibility
680unit = test