Coverage for tests/models/test_instrumented.py: 98.96%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations 

2 

3from collections.abc import AsyncIterator 

4from contextlib import asynccontextmanager 

5from datetime import datetime 

6 

7import pytest 

8from dirty_equals import IsJson 

9from inline_snapshot import snapshot 

10from logfire_api import DEFAULT_LOGFIRE_INSTANCE 

11from opentelemetry._events import NoOpEventLoggerProvider 

12from opentelemetry.trace import NoOpTracerProvider 

13 

14from pydantic_ai.messages import ( 

15 ModelMessage, 

16 ModelRequest, 

17 ModelResponse, 

18 ModelResponseStreamEvent, 

19 PartDeltaEvent, 

20 PartStartEvent, 

21 RetryPromptPart, 

22 SystemPromptPart, 

23 TextPart, 

24 TextPartDelta, 

25 ToolCallPart, 

26 ToolReturnPart, 

27 UserPromptPart, 

28) 

29from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse 

30from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel 

31from pydantic_ai.settings import ModelSettings 

32from pydantic_ai.usage import Usage 

33 

34from ..conftest import try_import 

35 

36with try_import() as imports_successful: 

37 from logfire.testing import CaptureLogfire 

38 

39pytestmark = [ 

40 pytest.mark.skipif(not imports_successful(), reason='logfire not installed'), 

41 pytest.mark.anyio, 

42] 

43 

44requires_logfire_events = pytest.mark.skipif( 

45 not hasattr(DEFAULT_LOGFIRE_INSTANCE.config, 'get_event_logger_provider'), 

46 reason='old logfire without events/logs support', 

47) 

48 

49 

50class MyModel(Model): 

51 @property 

52 def system(self) -> str: 

53 return 'my_system' 

54 

55 @property 

56 def model_name(self) -> str: 

57 return 'my_model' 

58 

59 @property 

60 def base_url(self) -> str: 

61 return 'https://example.com:8000/foo' 

62 

63 async def request( 

64 self, 

65 messages: list[ModelMessage], 

66 model_settings: ModelSettings | None, 

67 model_request_parameters: ModelRequestParameters, 

68 ) -> tuple[ModelResponse, Usage]: 

69 return ( 

70 ModelResponse( 

71 parts=[ 

72 TextPart('text1'), 

73 ToolCallPart('tool1', 'args1', 'tool_call_1'), 

74 ToolCallPart('tool2', {'args2': 3}, 'tool_call_2'), 

75 TextPart('text2'), 

76 {}, # test unexpected parts # type: ignore 

77 ], 

78 model_name='my_model_123', 

79 ), 

80 Usage(request_tokens=100, response_tokens=200), 

81 ) 

82 

83 @asynccontextmanager 

84 async def request_stream( 

85 self, 

86 messages: list[ModelMessage], 

87 model_settings: ModelSettings | None, 

88 model_request_parameters: ModelRequestParameters, 

89 ) -> AsyncIterator[StreamedResponse]: 

90 yield MyResponseStream() 

91 

92 

93class MyResponseStream(StreamedResponse): 

94 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

95 self._usage = Usage(request_tokens=300, response_tokens=400) 

96 yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') 

97 yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2') 

98 

99 @property 

100 def model_name(self) -> str: 

101 return 'my_model_123' 

102 

103 @property 

104 def timestamp(self) -> datetime: 

105 return datetime(2022, 1, 1) 

106 

107 

108@requires_logfire_events 

109async def test_instrumented_model(capfire: CaptureLogfire): 

110 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='logs')) 

111 assert model.system == 'my_system' 

112 assert model.model_name == 'my_model' 

113 

114 messages = [ 

115 ModelRequest( 

116 parts=[ 

117 SystemPromptPart('system_prompt'), 

118 UserPromptPart('user_prompt'), 

119 ToolReturnPart('tool3', 'tool_return_content', 'tool_call_3'), 

120 RetryPromptPart('retry_prompt1', tool_name='tool4', tool_call_id='tool_call_4'), 

121 RetryPromptPart('retry_prompt2'), 

122 {}, # test unexpected parts # type: ignore 

123 ] 

124 ), 

125 ModelResponse( 

126 parts=[ 

127 TextPart('text3'), 

128 ] 

129 ), 

130 ] 

131 await model.request( 

132 messages, 

133 model_settings=ModelSettings(temperature=1), 

134 model_request_parameters=ModelRequestParameters( 

135 function_tools=[], 

136 allow_text_result=True, 

137 result_tools=[], 

138 ), 

139 ) 

140 

141 assert capfire.exporter.exported_spans_as_dict() == snapshot( 

142 [ 

143 { 

144 'name': 'chat my_model', 

145 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 

146 'parent': None, 

147 'start_time': 1000000000, 

148 'end_time': 18000000000, 

149 'attributes': { 

150 'gen_ai.operation.name': 'chat', 

151 'gen_ai.system': 'my_system', 

152 'gen_ai.request.model': 'my_model', 

153 'server.address': 'example.com', 

154 'server.port': 8000, 

155 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}', 

156 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 

157 'gen_ai.request.temperature': 1, 

158 'logfire.msg': 'chat my_model', 

159 'logfire.span_type': 'span', 

160 'gen_ai.response.model': 'my_model_123', 

161 'gen_ai.usage.input_tokens': 100, 

162 'gen_ai.usage.output_tokens': 200, 

163 }, 

164 }, 

165 ] 

166 ) 

167 

168 assert capfire.log_exporter.exported_logs_as_dicts() == snapshot( 

169 [ 

170 { 

171 'body': {'content': 'system_prompt', 'role': 'system'}, 

172 'severity_number': 9, 

173 'severity_text': None, 

174 'attributes': { 

175 'gen_ai.system': 'my_system', 

176 'gen_ai.message.index': 0, 

177 'event.name': 'gen_ai.system.message', 

178 }, 

179 'timestamp': 2000000000, 

180 'observed_timestamp': 3000000000, 

181 'trace_id': 1, 

182 'span_id': 1, 

183 'trace_flags': 1, 

184 }, 

185 { 

186 'body': {'content': 'user_prompt', 'role': 'user'}, 

187 'severity_number': 9, 

188 'severity_text': None, 

189 'attributes': { 

190 'gen_ai.system': 'my_system', 

191 'gen_ai.message.index': 0, 

192 'event.name': 'gen_ai.user.message', 

193 }, 

194 'timestamp': 4000000000, 

195 'observed_timestamp': 5000000000, 

196 'trace_id': 1, 

197 'span_id': 1, 

198 'trace_flags': 1, 

199 }, 

200 { 

201 'body': {'content': 'tool_return_content', 'role': 'tool', 'id': 'tool_call_3', 'name': 'tool3'}, 

202 'severity_number': 9, 

203 'severity_text': None, 

204 'attributes': { 

205 'gen_ai.system': 'my_system', 

206 'gen_ai.message.index': 0, 

207 'event.name': 'gen_ai.tool.message', 

208 }, 

209 'timestamp': 6000000000, 

210 'observed_timestamp': 7000000000, 

211 'trace_id': 1, 

212 'span_id': 1, 

213 'trace_flags': 1, 

214 }, 

215 { 

216 'body': { 

217 'content': """\ 

218retry_prompt1 

219 

220Fix the errors and try again.\ 

221""", 

222 'role': 'tool', 

223 'id': 'tool_call_4', 

224 'name': 'tool4', 

225 }, 

226 'severity_number': 9, 

227 'severity_text': None, 

228 'attributes': { 

229 'gen_ai.system': 'my_system', 

230 'gen_ai.message.index': 0, 

231 'event.name': 'gen_ai.tool.message', 

232 }, 

233 'timestamp': 8000000000, 

234 'observed_timestamp': 9000000000, 

235 'trace_id': 1, 

236 'span_id': 1, 

237 'trace_flags': 1, 

238 }, 

239 { 

240 'body': { 

241 'content': """\ 

242retry_prompt2 

243 

244Fix the errors and try again.\ 

245""", 

246 'role': 'user', 

247 }, 

248 'severity_number': 9, 

249 'severity_text': None, 

250 'attributes': { 

251 'gen_ai.system': 'my_system', 

252 'gen_ai.message.index': 0, 

253 'event.name': 'gen_ai.user.message', 

254 }, 

255 'timestamp': 10000000000, 

256 'observed_timestamp': 11000000000, 

257 'trace_id': 1, 

258 'span_id': 1, 

259 'trace_flags': 1, 

260 }, 

261 { 

262 'body': {'role': 'assistant', 'content': 'text3'}, 

263 'severity_number': 9, 

264 'severity_text': None, 

265 'attributes': { 

266 'gen_ai.system': 'my_system', 

267 'gen_ai.message.index': 1, 

268 'event.name': 'gen_ai.assistant.message', 

269 }, 

270 'timestamp': 12000000000, 

271 'observed_timestamp': 13000000000, 

272 'trace_id': 1, 

273 'span_id': 1, 

274 'trace_flags': 1, 

275 }, 

276 { 

277 'body': { 

278 'index': 0, 

279 'message': { 

280 'role': 'assistant', 

281 'content': 'text1', 

282 'tool_calls': [ 

283 { 

284 'id': 'tool_call_1', 

285 'type': 'function', 

286 'function': {'name': 'tool1', 'arguments': 'args1'}, 

287 }, 

288 { 

289 'id': 'tool_call_2', 

290 'type': 'function', 

291 'function': {'name': 'tool2', 'arguments': {'args2': 3}}, 

292 }, 

293 ], 

294 }, 

295 }, 

296 'severity_number': 9, 

297 'severity_text': None, 

298 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'}, 

299 'timestamp': 14000000000, 

300 'observed_timestamp': 15000000000, 

301 'trace_id': 1, 

302 'span_id': 1, 

303 'trace_flags': 1, 

304 }, 

305 { 

306 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text2'}}, 

307 'severity_number': 9, 

308 'severity_text': None, 

309 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'}, 

310 'timestamp': 16000000000, 

311 'observed_timestamp': 17000000000, 

312 'trace_id': 1, 

313 'span_id': 1, 

314 'trace_flags': 1, 

315 }, 

316 ] 

317 ) 

318 

319 

320async def test_instrumented_model_not_recording(): 

321 model = InstrumentedModel( 

322 MyModel(), 

323 InstrumentationSettings(tracer_provider=NoOpTracerProvider(), event_logger_provider=NoOpEventLoggerProvider()), 

324 ) 

325 

326 messages: list[ModelMessage] = [ModelRequest(parts=[SystemPromptPart('system_prompt')])] 

327 await model.request( 

328 messages, 

329 model_settings=ModelSettings(temperature=1), 

330 model_request_parameters=ModelRequestParameters( 

331 function_tools=[], 

332 allow_text_result=True, 

333 result_tools=[], 

334 ), 

335 ) 

336 

337 

338@requires_logfire_events 

339async def test_instrumented_model_stream(capfire: CaptureLogfire): 

340 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='logs')) 

341 

342 messages: list[ModelMessage] = [ 

343 ModelRequest( 

344 parts=[ 

345 UserPromptPart('user_prompt'), 

346 ] 

347 ), 

348 ] 

349 async with model.request_stream( 

350 messages, 

351 model_settings=ModelSettings(temperature=1), 

352 model_request_parameters=ModelRequestParameters( 

353 function_tools=[], 

354 allow_text_result=True, 

355 result_tools=[], 

356 ), 

357 ) as response_stream: 

358 assert [event async for event in response_stream] == snapshot( 

359 [ 

360 PartStartEvent(index=0, part=TextPart(content='text1')), 

361 PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='text2')), 

362 ] 

363 ) 

364 

365 assert capfire.exporter.exported_spans_as_dict() == snapshot( 

366 [ 

367 { 

368 'name': 'chat my_model', 

369 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 

370 'parent': None, 

371 'start_time': 1000000000, 

372 'end_time': 6000000000, 

373 'attributes': { 

374 'gen_ai.operation.name': 'chat', 

375 'gen_ai.system': 'my_system', 

376 'gen_ai.request.model': 'my_model', 

377 'server.address': 'example.com', 

378 'server.port': 8000, 

379 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}', 

380 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 

381 'gen_ai.request.temperature': 1, 

382 'logfire.msg': 'chat my_model', 

383 'logfire.span_type': 'span', 

384 'gen_ai.response.model': 'my_model_123', 

385 'gen_ai.usage.input_tokens': 300, 

386 'gen_ai.usage.output_tokens': 400, 

387 }, 

388 }, 

389 ] 

390 ) 

391 

392 assert capfire.log_exporter.exported_logs_as_dicts() == snapshot( 

393 [ 

394 { 

395 'body': {'content': 'user_prompt', 'role': 'user'}, 

396 'severity_number': 9, 

397 'severity_text': None, 

398 'attributes': { 

399 'gen_ai.system': 'my_system', 

400 'gen_ai.message.index': 0, 

401 'event.name': 'gen_ai.user.message', 

402 }, 

403 'timestamp': 2000000000, 

404 'observed_timestamp': 3000000000, 

405 'trace_id': 1, 

406 'span_id': 1, 

407 'trace_flags': 1, 

408 }, 

409 { 

410 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text1text2'}}, 

411 'severity_number': 9, 

412 'severity_text': None, 

413 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'}, 

414 'timestamp': 4000000000, 

415 'observed_timestamp': 5000000000, 

416 'trace_id': 1, 

417 'span_id': 1, 

418 'trace_flags': 1, 

419 }, 

420 ] 

421 ) 

422 

423 

424@requires_logfire_events 

425async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 

426 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='logs')) 

427 

428 messages: list[ModelMessage] = [ 

429 ModelRequest( 

430 parts=[ 

431 UserPromptPart('user_prompt'), 

432 ] 

433 ), 

434 ] 

435 

436 with pytest.raises(RuntimeError): 

437 async with model.request_stream( 

438 messages, 

439 model_settings=ModelSettings(temperature=1), 

440 model_request_parameters=ModelRequestParameters( 

441 function_tools=[], 

442 allow_text_result=True, 

443 result_tools=[], 

444 ), 

445 ) as response_stream: 

446 async for event in response_stream: 446 ↛ 450line 446 didn't jump to line 450

447 assert event == PartStartEvent(index=0, part=TextPart(content='text1')) 

448 raise RuntimeError 

449 

450 assert capfire.exporter.exported_spans_as_dict() == snapshot( 

451 [ 

452 { 

453 'name': 'chat my_model', 

454 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 

455 'parent': None, 

456 'start_time': 1000000000, 

457 'end_time': 7000000000, 

458 'attributes': { 

459 'gen_ai.operation.name': 'chat', 

460 'gen_ai.system': 'my_system', 

461 'gen_ai.request.model': 'my_model', 

462 'server.address': 'example.com', 

463 'server.port': 8000, 

464 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}', 

465 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 

466 'gen_ai.request.temperature': 1, 

467 'logfire.msg': 'chat my_model', 

468 'logfire.span_type': 'span', 

469 'gen_ai.response.model': 'my_model_123', 

470 'gen_ai.usage.input_tokens': 300, 

471 'gen_ai.usage.output_tokens': 400, 

472 'logfire.level_num': 17, 

473 }, 

474 'events': [ 

475 { 

476 'name': 'exception', 

477 'timestamp': 6000000000, 

478 'attributes': { 

479 'exception.type': 'RuntimeError', 

480 'exception.message': '', 

481 'exception.stacktrace': 'RuntimeError', 

482 'exception.escaped': 'False', 

483 }, 

484 } 

485 ], 

486 }, 

487 ] 

488 ) 

489 

490 assert capfire.log_exporter.exported_logs_as_dicts() == snapshot( 

491 [ 

492 { 

493 'body': {'content': 'user_prompt', 'role': 'user'}, 

494 'severity_number': 9, 

495 'severity_text': None, 

496 'attributes': { 

497 'gen_ai.system': 'my_system', 

498 'gen_ai.message.index': 0, 

499 'event.name': 'gen_ai.user.message', 

500 }, 

501 'timestamp': 2000000000, 

502 'observed_timestamp': 3000000000, 

503 'trace_id': 1, 

504 'span_id': 1, 

505 'trace_flags': 1, 

506 }, 

507 { 

508 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text1'}}, 

509 'severity_number': 9, 

510 'severity_text': None, 

511 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'}, 

512 'timestamp': 4000000000, 

513 'observed_timestamp': 5000000000, 

514 'trace_id': 1, 

515 'span_id': 1, 

516 'trace_flags': 1, 

517 }, 

518 ] 

519 ) 

520 

521 

522async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 

523 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='attributes')) 

524 assert model.system == 'my_system' 

525 assert model.model_name == 'my_model' 

526 

527 messages = [ 

528 ModelRequest( 

529 parts=[ 

530 SystemPromptPart('system_prompt'), 

531 UserPromptPart('user_prompt'), 

532 ToolReturnPart('tool3', 'tool_return_content', 'tool_call_3'), 

533 RetryPromptPart('retry_prompt1', tool_name='tool4', tool_call_id='tool_call_4'), 

534 RetryPromptPart('retry_prompt2'), 

535 {}, # test unexpected parts # type: ignore 

536 ] 

537 ), 

538 ModelResponse( 

539 parts=[ 

540 TextPart('text3'), 

541 ] 

542 ), 

543 ] 

544 await model.request( 

545 messages, 

546 model_settings=ModelSettings(temperature=1), 

547 model_request_parameters=ModelRequestParameters( 

548 function_tools=[], 

549 allow_text_result=True, 

550 result_tools=[], 

551 ), 

552 ) 

553 

554 assert capfire.exporter.exported_spans_as_dict() == snapshot( 

555 [ 

556 { 

557 'name': 'chat my_model', 

558 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 

559 'parent': None, 

560 'start_time': 1000000000, 

561 'end_time': 2000000000, 

562 'attributes': { 

563 'gen_ai.operation.name': 'chat', 

564 'gen_ai.system': 'my_system', 

565 'gen_ai.request.model': 'my_model', 

566 'server.address': 'example.com', 

567 'server.port': 8000, 

568 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}', 

569 'gen_ai.request.temperature': 1, 

570 'logfire.msg': 'chat my_model', 

571 'logfire.span_type': 'span', 

572 'gen_ai.response.model': 'my_model_123', 

573 'gen_ai.usage.input_tokens': 100, 

574 'gen_ai.usage.output_tokens': 200, 

575 'events': IsJson( 

576 snapshot( 

577 [ 

578 { 

579 'event.name': 'gen_ai.system.message', 

580 'content': 'system_prompt', 

581 'role': 'system', 

582 'gen_ai.message.index': 0, 

583 'gen_ai.system': 'my_system', 

584 }, 

585 { 

586 'event.name': 'gen_ai.user.message', 

587 'content': 'user_prompt', 

588 'role': 'user', 

589 'gen_ai.message.index': 0, 

590 'gen_ai.system': 'my_system', 

591 }, 

592 { 

593 'event.name': 'gen_ai.tool.message', 

594 'content': 'tool_return_content', 

595 'role': 'tool', 

596 'name': 'tool3', 

597 'id': 'tool_call_3', 

598 'gen_ai.message.index': 0, 

599 'gen_ai.system': 'my_system', 

600 }, 

601 { 

602 'event.name': 'gen_ai.tool.message', 

603 'content': """\ 

604retry_prompt1 

605 

606Fix the errors and try again.\ 

607""", 

608 'role': 'tool', 

609 'name': 'tool4', 

610 'id': 'tool_call_4', 

611 'gen_ai.message.index': 0, 

612 'gen_ai.system': 'my_system', 

613 }, 

614 { 

615 'event.name': 'gen_ai.user.message', 

616 'content': """\ 

617retry_prompt2 

618 

619Fix the errors and try again.\ 

620""", 

621 'role': 'user', 

622 'gen_ai.message.index': 0, 

623 'gen_ai.system': 'my_system', 

624 }, 

625 { 

626 'event.name': 'gen_ai.assistant.message', 

627 'role': 'assistant', 

628 'content': 'text3', 

629 'gen_ai.message.index': 1, 

630 'gen_ai.system': 'my_system', 

631 }, 

632 { 

633 'event.name': 'gen_ai.choice', 

634 'index': 0, 

635 'message': { 

636 'role': 'assistant', 

637 'content': 'text1', 

638 'tool_calls': [ 

639 { 

640 'id': 'tool_call_1', 

641 'type': 'function', 

642 'function': {'name': 'tool1', 'arguments': 'args1'}, 

643 }, 

644 { 

645 'id': 'tool_call_2', 

646 'type': 'function', 

647 'function': {'name': 'tool2', 'arguments': {'args2': 3}}, 

648 }, 

649 ], 

650 }, 

651 'gen_ai.system': 'my_system', 

652 }, 

653 { 

654 'event.name': 'gen_ai.choice', 

655 'index': 0, 

656 'message': {'role': 'assistant', 'content': 'text2'}, 

657 'gen_ai.system': 'my_system', 

658 }, 

659 ] 

660 ) 

661 ), 

662 'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}, "model_request_parameters": {"type": "object"}}}', 

663 }, 

664 }, 

665 ] 

666 ) 

667 

668 

669def test_messages_to_otel_events_serialization_errors(): 

670 class Foo: 

671 def __repr__(self): 

672 return 'Foo()' 

673 

674 class Bar: 

675 def __repr__(self): 

676 raise ValueError('error!') 

677 

678 messages = [ 

679 ModelResponse(parts=[ToolCallPart('tool', {'arg': Foo()}, tool_call_id='tool_call_id')]), 

680 ModelRequest(parts=[ToolReturnPart('tool', Bar(), tool_call_id='return_tool_call_id')]), 

681 ] 

682 

683 assert [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)] == [ 

684 { 

685 'body': "{'role': 'assistant', 'tool_calls': [{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'tool', 'arguments': {'arg': Foo()}}}]}", 

686 'gen_ai.message.index': 0, 

687 'event.name': 'gen_ai.assistant.message', 

688 }, 

689 { 

690 'body': 'Unable to serialize: error!', 

691 'gen_ai.message.index': 1, 

692 'event.name': 'gen_ai.tool.message', 

693 }, 

694 ]