Coverage for langsmith/pytest_plugin.py: 4%

181 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-12-11 16:15 -0800

1"""LangSmith Pytest hooks.""" 

2 

3import importlib.util 

4import json 

5import logging 

6import os 

7import time 

8from collections import defaultdict 

9from threading import Lock 

10from typing import Any 

11 

12import pytest 

13 

14from langsmith import utils as ls_utils 

15from langsmith.testing._internal import test as ls_test 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20def pytest_addoption(parser): 

21 """Set a boolean flag for LangSmith output. 

22 

23 Skip if --langsmith-output is already defined. 

24 """ 

25 try: 

26 # Try to add the option, will raise if it already exists 

27 group = parser.getgroup("langsmith", "LangSmith") 

28 group.addoption( 

29 "--langsmith-output", 

30 action="store_true", 

31 default=False, 

32 help="Use LangSmith output (requires 'rich').", 

33 ) 

34 except ValueError: 

35 # Option already exists 

36 logger.warning( 

37 "LangSmith output flag cannot be added because it's already defined." 

38 ) 

39 

40 

41def _handle_output_args(args): 

42 """Handle output arguments.""" 

43 if any(opt in args for opt in ["--langsmith-output"]): 

44 # Only add --quiet if it's not already there 

45 if not any(a in args for a in ["-qq"]): 

46 args.insert(0, "-qq") 

47 # Disable built-in output capturing 

48 if not any(a in args for a in ["-s", "--capture=no"]): 

49 args.insert(0, "-s") 

50 

51 

52if pytest.__version__.startswith("7."): 

53 

54 def pytest_cmdline_preparse(config, args): 

55 """Call immediately after command line options are parsed (pytest v7).""" 

56 _handle_output_args(args) 

57 

58else: 

59 

60 def pytest_load_initial_conftests(args): 

61 """Handle args in pytest v8+.""" 

62 _handle_output_args(args) 

63 

64 

65@pytest.hookimpl(hookwrapper=True) 

66def pytest_runtest_call(item): 

67 """Apply LangSmith tracking to tests marked with @pytest.mark.langsmith.""" 

68 marker = item.get_closest_marker("langsmith") 

69 if marker: 

70 # Get marker kwargs if any (e.g., 

71 # @pytest.mark.langsmith(output_keys=["expected"])) 

72 kwargs = marker.kwargs if marker else {} 

73 # Wrap the test function with our test decorator 

74 original_func = item.obj 

75 item.obj = ls_test(**kwargs)(original_func) 

76 request_obj = getattr(item, "_request", None) 

77 if request_obj is not None and "request" not in item.funcargs: 

78 item.funcargs["request"] = request_obj 

79 if request_obj is not None and "request" not in item._fixtureinfo.argnames: 

80 # Create a new FuncFixtureInfo instance with updated argnames 

81 item._fixtureinfo = type(item._fixtureinfo)( 

82 argnames=item._fixtureinfo.argnames + ("request",), 

83 initialnames=item._fixtureinfo.initialnames, 

84 names_closure=item._fixtureinfo.names_closure, 

85 name2fixturedefs=item._fixtureinfo.name2fixturedefs, 

86 ) 

87 yield 

88 

89 

90@pytest.hookimpl 

91def pytest_report_teststatus(report, config): 

92 """Remove the short test-status character outputs ("./F").""" 

93 # The hook normally returns a 3-tuple: (short_letter, verbose_word, color) 

94 # By returning empty strings, the progress characters won't show. 

95 if config.getoption("--langsmith-output"): 

96 return "", "", "" 

97 

98 

99class LangSmithPlugin: 

100 """Plugin for rendering LangSmith results.""" 

101 

102 def __init__(self): 

103 """Initialize.""" 

104 from rich.console import Console # type: ignore[import-not-found] 

105 from rich.live import Live # type: ignore[import-not-found] 

106 

107 self.test_suites = defaultdict(list) 

108 self.test_suite_urls = {} 

109 

110 self.process_status = {} # Track process status 

111 self.status_lock = Lock() # Thread-safe updates 

112 self.console = Console() 

113 

114 self.live = Live( 

115 self.generate_tables(), console=self.console, refresh_per_second=10 

116 ) 

117 self.live.start() 

118 self.live.console.print("Collecting tests...") 

119 

120 def pytest_collection_finish(self, session): 

121 """Call after collection phase is completed and session.items is populated.""" 

122 self.collected_nodeids = set() 

123 for item in session.items: 

124 self.collected_nodeids.add(item.nodeid) 

125 

126 def add_process_to_test_suite(self, test_suite, process_id): 

127 """Group a test case with its test suite.""" 

128 self.test_suites[test_suite].append(process_id) 

129 

130 def update_process_status(self, process_id, status): 

131 """Update test results.""" 

132 # First update 

133 if not self.process_status: 

134 self.live.console.print("Running tests...") 

135 

136 with self.status_lock: 

137 current_status = self.process_status.get(process_id, {}) 

138 self.process_status[process_id] = _merge_statuses( 

139 status, 

140 current_status, 

141 unpack=["feedback", "inputs", "reference_outputs", "outputs"], 

142 ) 

143 self.live.update(self.generate_tables()) 

144 

145 def pytest_runtest_logstart(self, nodeid): 

146 """Initialize live display when first test starts.""" 

147 self.update_process_status(nodeid, {"status": "running"}) 

148 

149 def generate_tables(self): 

150 """Generate a collection of tables—one per suite. 

151 

152 Returns a 'Group' object so it can be rendered simultaneously by Rich Live. 

153 """ 

154 from rich.console import Group 

155 

156 tables = [] 

157 for suite_name in self.test_suites: 

158 table = self._generate_table(suite_name) 

159 tables.append(table) 

160 group = Group(*tables) 

161 return group 

162 

163 def _generate_table(self, suite_name: str): 

164 """Generate results table.""" 

165 from rich.table import Table # type: ignore[import-not-found] 

166 

167 process_ids = self.test_suites[suite_name] 

168 

169 title = f"""Test Suite: [bold]{suite_name}[/bold] 

170LangSmith URL: [bright_cyan]{self.test_suite_urls[suite_name]}[/bright_cyan]""" # noqa: E501 

171 table = Table(title=title, title_justify="left") 

172 table.add_column("Test") 

173 table.add_column("Inputs") 

174 table.add_column("Ref outputs") 

175 table.add_column("Outputs") 

176 table.add_column("Status") 

177 table.add_column("Feedback") 

178 table.add_column("Duration") 

179 

180 # Test, inputs, ref outputs, outputs col width 

181 max_status = len("status") 

182 max_duration = len("duration") 

183 now = time.time() 

184 durations = [] 

185 numeric_feedbacks = defaultdict(list) 

186 # Gather data only for this suite 

187 suite_statuses = {pid: self.process_status[pid] for pid in process_ids} 

188 for pid, status in suite_statuses.items(): 

189 duration = status.get("end_time", now) - status.get("start_time", now) 

190 durations.append(duration) 

191 for k, v in status.get("feedback", {}).items(): 

192 if isinstance(v, (float, int, bool)): 

193 numeric_feedbacks[k].append(v) 

194 max_duration = max(len(f"{duration:.2f}s"), max_duration) 

195 max_status = max(len(status.get("status", "queued")), max_status) 

196 

197 passed_count = sum(s.get("status") == "passed" for s in suite_statuses.values()) 

198 failed_count = sum(s.get("status") == "failed" for s in suite_statuses.values()) 

199 

200 # You could arrange a row to show the aggregated data—here, in the last column: 

201 if passed_count + failed_count: 

202 rate = passed_count / (passed_count + failed_count) 

203 color = "green" if rate == 1 else "red" 

204 aggregate_status = f"[{color}]{rate:.0%}[/{color}]" 

205 else: 

206 aggregate_status = "Passed: --" 

207 if durations: 

208 aggregate_duration = f"{sum(durations) / len(durations):.2f}s" 

209 else: 

210 aggregate_duration = "--s" 

211 if numeric_feedbacks: 

212 aggregate_feedback = "\n".join( 

213 f"{k}: {sum(v) / len(v)}" for k, v in numeric_feedbacks.items() 

214 ) 

215 else: 

216 aggregate_feedback = "--" 

217 

218 max_duration = max(max_duration, len(aggregate_duration)) 

219 max_dynamic_col_width = (self.console.width - (max_status + max_duration)) // 5 

220 max_dynamic_col_width = max(max_dynamic_col_width, 8) 

221 

222 for pid, status in suite_statuses.items(): 

223 status_color = { 

224 "running": "yellow", 

225 "passed": "green", 

226 "failed": "red", 

227 "skipped": "cyan", 

228 }.get(status.get("status", "queued"), "white") 

229 

230 duration = status.get("end_time", now) - status.get("start_time", now) 

231 feedback = "\n".join( 

232 f"{_abbreviate(k, max_len=max_dynamic_col_width)}: {int(v) if isinstance(v, bool) else v}" # noqa: E501 

233 for k, v in status.get("feedback", {}).items() 

234 ) 

235 inputs = _dumps_with_fallback(status.get("inputs", {})) 

236 reference_outputs = _dumps_with_fallback( 

237 status.get("reference_outputs", {}) 

238 ) 

239 outputs = _dumps_with_fallback(status.get("outputs", {})) 

240 table.add_row( 

241 _abbreviate_test_name(str(pid), max_len=max_dynamic_col_width), 

242 _abbreviate(inputs, max_len=max_dynamic_col_width), 

243 _abbreviate(reference_outputs, max_len=max_dynamic_col_width), 

244 _abbreviate(outputs, max_len=max_dynamic_col_width)[ 

245 -max_dynamic_col_width: 

246 ], 

247 f"[{status_color}]{status.get('status', 'queued')}[/{status_color}]", 

248 feedback, 

249 f"{duration:.2f}s", 

250 ) 

251 

252 # Add a blank row or a section separator if you like: 

253 table.add_row("", "", "", "", "", "", "") 

254 # Finally, our “footer” row: 

255 table.add_row( 

256 "[bold]Averages[/bold]", 

257 "", 

258 "", 

259 "", 

260 aggregate_status, 

261 aggregate_feedback, 

262 aggregate_duration, 

263 ) 

264 

265 return table 

266 

267 def pytest_configure(self, config): 

268 """Disable warning reporting and show no warnings in output.""" 

269 # Disable general warning reporting 

270 config.option.showwarnings = False 

271 

272 # Disable warning summary 

273 reporter = config.pluginmanager.get_plugin("warnings-plugin") 

274 if reporter: 

275 reporter.warning_summary = lambda *args, **kwargs: None 

276 

277 def pytest_sessionfinish(self, session): 

278 """Stop Rich Live rendering at the end of the session.""" 

279 self.live.stop() 

280 self.live.console.print("\nFinishing up...") 

281 

282 

283def pytest_configure(config): 

284 """Register the 'langsmith' marker.""" 

285 config.addinivalue_line( 

286 "markers", "langsmith: mark test to be tracked in LangSmith" 

287 ) 

288 if config.getoption("--langsmith-output"): 

289 if not importlib.util.find_spec("rich"): 

290 msg = ( 

291 "Must have 'rich' installed to use --langsmith-output. " 

292 "Please install with: `pip install -U 'langsmith[pytest]'`" 

293 ) 

294 raise ValueError(msg) 

295 if os.environ.get("PYTEST_XDIST_TESTRUNUID"): 

296 msg = ( 

297 "--langsmith-output not supported with pytest-xdist. " 

298 "Please remove the '--langsmith-output' option or '-n' option." 

299 ) 

300 raise ValueError(msg) 

301 if ls_utils.test_tracking_is_disabled(): 

302 msg = ( 

303 "--langsmith-output not supported when env var" 

304 "LANGSMITH_TEST_TRACKING='false'. Please remove the" 

305 "'--langsmith-output' option " 

306 "or enable test tracking." 

307 ) 

308 raise ValueError(msg) 

309 config.pluginmanager.register(LangSmithPlugin(), "langsmith_output_plugin") 

310 # Suppress warnings summary 

311 config.option.showwarnings = False 

312 

313 

314def _abbreviate(x: str, max_len: int) -> str: 

315 if len(x) > max_len: 

316 return x[: max_len - 3] + "..." 

317 else: 

318 return x 

319 

320 

321def _abbreviate_test_name(test_name: str, max_len: int) -> str: 

322 if len(test_name) > max_len: 

323 file, test = test_name.split("::") 

324 if len(".py::" + test) > max_len: 

325 return "..." + test[-(max_len - 3) :] 

326 file_len = max_len - len("...::" + test) 

327 return "..." + file[-file_len:] + "::" + test 

328 else: 

329 return test_name 

330 

331 

332def _merge_statuses(update: dict, current: dict, *, unpack: list[str]) -> dict: 

333 for path in unpack: 

334 if path_update := update.pop(path, None): 

335 path_current = current.get(path, {}) 

336 if isinstance(path_update, dict) and isinstance(path_current, dict): 

337 current[path] = {**path_current, **path_update} 

338 else: 

339 current[path] = path_update 

340 return {**current, **update} 

341 

342 

343def _dumps_with_fallback(obj: Any) -> str: 

344 try: 

345 return json.dumps(obj) 

346 except Exception: 

347 return "unserializable"