Coverage for tests/models/test_instrumented.py: 98.96%
94 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations
3from collections.abc import AsyncIterator
4from contextlib import asynccontextmanager
5from datetime import datetime
7import pytest
8from dirty_equals import IsJson
9from inline_snapshot import snapshot
10from logfire_api import DEFAULT_LOGFIRE_INSTANCE
11from opentelemetry._events import NoOpEventLoggerProvider
12from opentelemetry.trace import NoOpTracerProvider
14from pydantic_ai.messages import (
15 ModelMessage,
16 ModelRequest,
17 ModelResponse,
18 ModelResponseStreamEvent,
19 PartDeltaEvent,
20 PartStartEvent,
21 RetryPromptPart,
22 SystemPromptPart,
23 TextPart,
24 TextPartDelta,
25 ToolCallPart,
26 ToolReturnPart,
27 UserPromptPart,
28)
29from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
30from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel
31from pydantic_ai.settings import ModelSettings
32from pydantic_ai.usage import Usage
34from ..conftest import try_import
36with try_import() as imports_successful:
37 from logfire.testing import CaptureLogfire
39pytestmark = [
40 pytest.mark.skipif(not imports_successful(), reason='logfire not installed'),
41 pytest.mark.anyio,
42]
44requires_logfire_events = pytest.mark.skipif(
45 not hasattr(DEFAULT_LOGFIRE_INSTANCE.config, 'get_event_logger_provider'),
46 reason='old logfire without events/logs support',
47)
50class MyModel(Model):
51 @property
52 def system(self) -> str:
53 return 'my_system'
55 @property
56 def model_name(self) -> str:
57 return 'my_model'
59 @property
60 def base_url(self) -> str:
61 return 'https://example.com:8000/foo'
63 async def request(
64 self,
65 messages: list[ModelMessage],
66 model_settings: ModelSettings | None,
67 model_request_parameters: ModelRequestParameters,
68 ) -> tuple[ModelResponse, Usage]:
69 return (
70 ModelResponse(
71 parts=[
72 TextPart('text1'),
73 ToolCallPart('tool1', 'args1', 'tool_call_1'),
74 ToolCallPart('tool2', {'args2': 3}, 'tool_call_2'),
75 TextPart('text2'),
76 {}, # test unexpected parts # type: ignore
77 ],
78 model_name='my_model_123',
79 ),
80 Usage(request_tokens=100, response_tokens=200),
81 )
83 @asynccontextmanager
84 async def request_stream(
85 self,
86 messages: list[ModelMessage],
87 model_settings: ModelSettings | None,
88 model_request_parameters: ModelRequestParameters,
89 ) -> AsyncIterator[StreamedResponse]:
90 yield MyResponseStream()
93class MyResponseStream(StreamedResponse):
94 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
95 self._usage = Usage(request_tokens=300, response_tokens=400)
96 yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1')
97 yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2')
99 @property
100 def model_name(self) -> str:
101 return 'my_model_123'
103 @property
104 def timestamp(self) -> datetime:
105 return datetime(2022, 1, 1)
108@requires_logfire_events
109async def test_instrumented_model(capfire: CaptureLogfire):
110 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='logs'))
111 assert model.system == 'my_system'
112 assert model.model_name == 'my_model'
114 messages = [
115 ModelRequest(
116 parts=[
117 SystemPromptPart('system_prompt'),
118 UserPromptPart('user_prompt'),
119 ToolReturnPart('tool3', 'tool_return_content', 'tool_call_3'),
120 RetryPromptPart('retry_prompt1', tool_name='tool4', tool_call_id='tool_call_4'),
121 RetryPromptPart('retry_prompt2'),
122 {}, # test unexpected parts # type: ignore
123 ]
124 ),
125 ModelResponse(
126 parts=[
127 TextPart('text3'),
128 ]
129 ),
130 ]
131 await model.request(
132 messages,
133 model_settings=ModelSettings(temperature=1),
134 model_request_parameters=ModelRequestParameters(
135 function_tools=[],
136 allow_text_result=True,
137 result_tools=[],
138 ),
139 )
141 assert capfire.exporter.exported_spans_as_dict() == snapshot(
142 [
143 {
144 'name': 'chat my_model',
145 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
146 'parent': None,
147 'start_time': 1000000000,
148 'end_time': 18000000000,
149 'attributes': {
150 'gen_ai.operation.name': 'chat',
151 'gen_ai.system': 'my_system',
152 'gen_ai.request.model': 'my_model',
153 'server.address': 'example.com',
154 'server.port': 8000,
155 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}',
156 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}',
157 'gen_ai.request.temperature': 1,
158 'logfire.msg': 'chat my_model',
159 'logfire.span_type': 'span',
160 'gen_ai.response.model': 'my_model_123',
161 'gen_ai.usage.input_tokens': 100,
162 'gen_ai.usage.output_tokens': 200,
163 },
164 },
165 ]
166 )
168 assert capfire.log_exporter.exported_logs_as_dicts() == snapshot(
169 [
170 {
171 'body': {'content': 'system_prompt', 'role': 'system'},
172 'severity_number': 9,
173 'severity_text': None,
174 'attributes': {
175 'gen_ai.system': 'my_system',
176 'gen_ai.message.index': 0,
177 'event.name': 'gen_ai.system.message',
178 },
179 'timestamp': 2000000000,
180 'observed_timestamp': 3000000000,
181 'trace_id': 1,
182 'span_id': 1,
183 'trace_flags': 1,
184 },
185 {
186 'body': {'content': 'user_prompt', 'role': 'user'},
187 'severity_number': 9,
188 'severity_text': None,
189 'attributes': {
190 'gen_ai.system': 'my_system',
191 'gen_ai.message.index': 0,
192 'event.name': 'gen_ai.user.message',
193 },
194 'timestamp': 4000000000,
195 'observed_timestamp': 5000000000,
196 'trace_id': 1,
197 'span_id': 1,
198 'trace_flags': 1,
199 },
200 {
201 'body': {'content': 'tool_return_content', 'role': 'tool', 'id': 'tool_call_3', 'name': 'tool3'},
202 'severity_number': 9,
203 'severity_text': None,
204 'attributes': {
205 'gen_ai.system': 'my_system',
206 'gen_ai.message.index': 0,
207 'event.name': 'gen_ai.tool.message',
208 },
209 'timestamp': 6000000000,
210 'observed_timestamp': 7000000000,
211 'trace_id': 1,
212 'span_id': 1,
213 'trace_flags': 1,
214 },
215 {
216 'body': {
217 'content': """\
218retry_prompt1
220Fix the errors and try again.\
221""",
222 'role': 'tool',
223 'id': 'tool_call_4',
224 'name': 'tool4',
225 },
226 'severity_number': 9,
227 'severity_text': None,
228 'attributes': {
229 'gen_ai.system': 'my_system',
230 'gen_ai.message.index': 0,
231 'event.name': 'gen_ai.tool.message',
232 },
233 'timestamp': 8000000000,
234 'observed_timestamp': 9000000000,
235 'trace_id': 1,
236 'span_id': 1,
237 'trace_flags': 1,
238 },
239 {
240 'body': {
241 'content': """\
242retry_prompt2
244Fix the errors and try again.\
245""",
246 'role': 'user',
247 },
248 'severity_number': 9,
249 'severity_text': None,
250 'attributes': {
251 'gen_ai.system': 'my_system',
252 'gen_ai.message.index': 0,
253 'event.name': 'gen_ai.user.message',
254 },
255 'timestamp': 10000000000,
256 'observed_timestamp': 11000000000,
257 'trace_id': 1,
258 'span_id': 1,
259 'trace_flags': 1,
260 },
261 {
262 'body': {'role': 'assistant', 'content': 'text3'},
263 'severity_number': 9,
264 'severity_text': None,
265 'attributes': {
266 'gen_ai.system': 'my_system',
267 'gen_ai.message.index': 1,
268 'event.name': 'gen_ai.assistant.message',
269 },
270 'timestamp': 12000000000,
271 'observed_timestamp': 13000000000,
272 'trace_id': 1,
273 'span_id': 1,
274 'trace_flags': 1,
275 },
276 {
277 'body': {
278 'index': 0,
279 'message': {
280 'role': 'assistant',
281 'content': 'text1',
282 'tool_calls': [
283 {
284 'id': 'tool_call_1',
285 'type': 'function',
286 'function': {'name': 'tool1', 'arguments': 'args1'},
287 },
288 {
289 'id': 'tool_call_2',
290 'type': 'function',
291 'function': {'name': 'tool2', 'arguments': {'args2': 3}},
292 },
293 ],
294 },
295 },
296 'severity_number': 9,
297 'severity_text': None,
298 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'},
299 'timestamp': 14000000000,
300 'observed_timestamp': 15000000000,
301 'trace_id': 1,
302 'span_id': 1,
303 'trace_flags': 1,
304 },
305 {
306 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text2'}},
307 'severity_number': 9,
308 'severity_text': None,
309 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'},
310 'timestamp': 16000000000,
311 'observed_timestamp': 17000000000,
312 'trace_id': 1,
313 'span_id': 1,
314 'trace_flags': 1,
315 },
316 ]
317 )
320async def test_instrumented_model_not_recording():
321 model = InstrumentedModel(
322 MyModel(),
323 InstrumentationSettings(tracer_provider=NoOpTracerProvider(), event_logger_provider=NoOpEventLoggerProvider()),
324 )
326 messages: list[ModelMessage] = [ModelRequest(parts=[SystemPromptPart('system_prompt')])]
327 await model.request(
328 messages,
329 model_settings=ModelSettings(temperature=1),
330 model_request_parameters=ModelRequestParameters(
331 function_tools=[],
332 allow_text_result=True,
333 result_tools=[],
334 ),
335 )
338@requires_logfire_events
339async def test_instrumented_model_stream(capfire: CaptureLogfire):
340 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='logs'))
342 messages: list[ModelMessage] = [
343 ModelRequest(
344 parts=[
345 UserPromptPart('user_prompt'),
346 ]
347 ),
348 ]
349 async with model.request_stream(
350 messages,
351 model_settings=ModelSettings(temperature=1),
352 model_request_parameters=ModelRequestParameters(
353 function_tools=[],
354 allow_text_result=True,
355 result_tools=[],
356 ),
357 ) as response_stream:
358 assert [event async for event in response_stream] == snapshot(
359 [
360 PartStartEvent(index=0, part=TextPart(content='text1')),
361 PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='text2')),
362 ]
363 )
365 assert capfire.exporter.exported_spans_as_dict() == snapshot(
366 [
367 {
368 'name': 'chat my_model',
369 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
370 'parent': None,
371 'start_time': 1000000000,
372 'end_time': 6000000000,
373 'attributes': {
374 'gen_ai.operation.name': 'chat',
375 'gen_ai.system': 'my_system',
376 'gen_ai.request.model': 'my_model',
377 'server.address': 'example.com',
378 'server.port': 8000,
379 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}',
380 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}',
381 'gen_ai.request.temperature': 1,
382 'logfire.msg': 'chat my_model',
383 'logfire.span_type': 'span',
384 'gen_ai.response.model': 'my_model_123',
385 'gen_ai.usage.input_tokens': 300,
386 'gen_ai.usage.output_tokens': 400,
387 },
388 },
389 ]
390 )
392 assert capfire.log_exporter.exported_logs_as_dicts() == snapshot(
393 [
394 {
395 'body': {'content': 'user_prompt', 'role': 'user'},
396 'severity_number': 9,
397 'severity_text': None,
398 'attributes': {
399 'gen_ai.system': 'my_system',
400 'gen_ai.message.index': 0,
401 'event.name': 'gen_ai.user.message',
402 },
403 'timestamp': 2000000000,
404 'observed_timestamp': 3000000000,
405 'trace_id': 1,
406 'span_id': 1,
407 'trace_flags': 1,
408 },
409 {
410 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text1text2'}},
411 'severity_number': 9,
412 'severity_text': None,
413 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'},
414 'timestamp': 4000000000,
415 'observed_timestamp': 5000000000,
416 'trace_id': 1,
417 'span_id': 1,
418 'trace_flags': 1,
419 },
420 ]
421 )
424@requires_logfire_events
425async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
426 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='logs'))
428 messages: list[ModelMessage] = [
429 ModelRequest(
430 parts=[
431 UserPromptPart('user_prompt'),
432 ]
433 ),
434 ]
436 with pytest.raises(RuntimeError):
437 async with model.request_stream(
438 messages,
439 model_settings=ModelSettings(temperature=1),
440 model_request_parameters=ModelRequestParameters(
441 function_tools=[],
442 allow_text_result=True,
443 result_tools=[],
444 ),
445 ) as response_stream:
446 async for event in response_stream: 446 ↛ 450line 446 didn't jump to line 450
447 assert event == PartStartEvent(index=0, part=TextPart(content='text1'))
448 raise RuntimeError
450 assert capfire.exporter.exported_spans_as_dict() == snapshot(
451 [
452 {
453 'name': 'chat my_model',
454 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
455 'parent': None,
456 'start_time': 1000000000,
457 'end_time': 7000000000,
458 'attributes': {
459 'gen_ai.operation.name': 'chat',
460 'gen_ai.system': 'my_system',
461 'gen_ai.request.model': 'my_model',
462 'server.address': 'example.com',
463 'server.port': 8000,
464 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}',
465 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}',
466 'gen_ai.request.temperature': 1,
467 'logfire.msg': 'chat my_model',
468 'logfire.span_type': 'span',
469 'gen_ai.response.model': 'my_model_123',
470 'gen_ai.usage.input_tokens': 300,
471 'gen_ai.usage.output_tokens': 400,
472 'logfire.level_num': 17,
473 },
474 'events': [
475 {
476 'name': 'exception',
477 'timestamp': 6000000000,
478 'attributes': {
479 'exception.type': 'RuntimeError',
480 'exception.message': '',
481 'exception.stacktrace': 'RuntimeError',
482 'exception.escaped': 'False',
483 },
484 }
485 ],
486 },
487 ]
488 )
490 assert capfire.log_exporter.exported_logs_as_dicts() == snapshot(
491 [
492 {
493 'body': {'content': 'user_prompt', 'role': 'user'},
494 'severity_number': 9,
495 'severity_text': None,
496 'attributes': {
497 'gen_ai.system': 'my_system',
498 'gen_ai.message.index': 0,
499 'event.name': 'gen_ai.user.message',
500 },
501 'timestamp': 2000000000,
502 'observed_timestamp': 3000000000,
503 'trace_id': 1,
504 'span_id': 1,
505 'trace_flags': 1,
506 },
507 {
508 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text1'}},
509 'severity_number': 9,
510 'severity_text': None,
511 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'},
512 'timestamp': 4000000000,
513 'observed_timestamp': 5000000000,
514 'trace_id': 1,
515 'span_id': 1,
516 'trace_flags': 1,
517 },
518 ]
519 )
522async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire):
523 model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='attributes'))
524 assert model.system == 'my_system'
525 assert model.model_name == 'my_model'
527 messages = [
528 ModelRequest(
529 parts=[
530 SystemPromptPart('system_prompt'),
531 UserPromptPart('user_prompt'),
532 ToolReturnPart('tool3', 'tool_return_content', 'tool_call_3'),
533 RetryPromptPart('retry_prompt1', tool_name='tool4', tool_call_id='tool_call_4'),
534 RetryPromptPart('retry_prompt2'),
535 {}, # test unexpected parts # type: ignore
536 ]
537 ),
538 ModelResponse(
539 parts=[
540 TextPart('text3'),
541 ]
542 ),
543 ]
544 await model.request(
545 messages,
546 model_settings=ModelSettings(temperature=1),
547 model_request_parameters=ModelRequestParameters(
548 function_tools=[],
549 allow_text_result=True,
550 result_tools=[],
551 ),
552 )
554 assert capfire.exporter.exported_spans_as_dict() == snapshot(
555 [
556 {
557 'name': 'chat my_model',
558 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
559 'parent': None,
560 'start_time': 1000000000,
561 'end_time': 2000000000,
562 'attributes': {
563 'gen_ai.operation.name': 'chat',
564 'gen_ai.system': 'my_system',
565 'gen_ai.request.model': 'my_model',
566 'server.address': 'example.com',
567 'server.port': 8000,
568 'model_request_parameters': '{"function_tools": [], "allow_text_result": true, "result_tools": []}',
569 'gen_ai.request.temperature': 1,
570 'logfire.msg': 'chat my_model',
571 'logfire.span_type': 'span',
572 'gen_ai.response.model': 'my_model_123',
573 'gen_ai.usage.input_tokens': 100,
574 'gen_ai.usage.output_tokens': 200,
575 'events': IsJson(
576 snapshot(
577 [
578 {
579 'event.name': 'gen_ai.system.message',
580 'content': 'system_prompt',
581 'role': 'system',
582 'gen_ai.message.index': 0,
583 'gen_ai.system': 'my_system',
584 },
585 {
586 'event.name': 'gen_ai.user.message',
587 'content': 'user_prompt',
588 'role': 'user',
589 'gen_ai.message.index': 0,
590 'gen_ai.system': 'my_system',
591 },
592 {
593 'event.name': 'gen_ai.tool.message',
594 'content': 'tool_return_content',
595 'role': 'tool',
596 'name': 'tool3',
597 'id': 'tool_call_3',
598 'gen_ai.message.index': 0,
599 'gen_ai.system': 'my_system',
600 },
601 {
602 'event.name': 'gen_ai.tool.message',
603 'content': """\
604retry_prompt1
606Fix the errors and try again.\
607""",
608 'role': 'tool',
609 'name': 'tool4',
610 'id': 'tool_call_4',
611 'gen_ai.message.index': 0,
612 'gen_ai.system': 'my_system',
613 },
614 {
615 'event.name': 'gen_ai.user.message',
616 'content': """\
617retry_prompt2
619Fix the errors and try again.\
620""",
621 'role': 'user',
622 'gen_ai.message.index': 0,
623 'gen_ai.system': 'my_system',
624 },
625 {
626 'event.name': 'gen_ai.assistant.message',
627 'role': 'assistant',
628 'content': 'text3',
629 'gen_ai.message.index': 1,
630 'gen_ai.system': 'my_system',
631 },
632 {
633 'event.name': 'gen_ai.choice',
634 'index': 0,
635 'message': {
636 'role': 'assistant',
637 'content': 'text1',
638 'tool_calls': [
639 {
640 'id': 'tool_call_1',
641 'type': 'function',
642 'function': {'name': 'tool1', 'arguments': 'args1'},
643 },
644 {
645 'id': 'tool_call_2',
646 'type': 'function',
647 'function': {'name': 'tool2', 'arguments': {'args2': 3}},
648 },
649 ],
650 },
651 'gen_ai.system': 'my_system',
652 },
653 {
654 'event.name': 'gen_ai.choice',
655 'index': 0,
656 'message': {'role': 'assistant', 'content': 'text2'},
657 'gen_ai.system': 'my_system',
658 },
659 ]
660 )
661 ),
662 'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}, "model_request_parameters": {"type": "object"}}}',
663 },
664 },
665 ]
666 )
669def test_messages_to_otel_events_serialization_errors():
670 class Foo:
671 def __repr__(self):
672 return 'Foo()'
674 class Bar:
675 def __repr__(self):
676 raise ValueError('error!')
678 messages = [
679 ModelResponse(parts=[ToolCallPart('tool', {'arg': Foo()}, tool_call_id='tool_call_id')]),
680 ModelRequest(parts=[ToolReturnPart('tool', Bar(), tool_call_id='return_tool_call_id')]),
681 ]
683 assert [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)] == [
684 {
685 'body': "{'role': 'assistant', 'tool_calls': [{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'tool', 'arguments': {'arg': Foo()}}}]}",
686 'gen_ai.message.index': 0,
687 'event.name': 'gen_ai.assistant.message',
688 },
689 {
690 'body': 'Unable to serialize: error!',
691 'gen_ai.message.index': 1,
692 'event.name': 'gen_ai.tool.message',
693 },
694 ]