Coverage for pydantic_ai_slim/pydantic_ai/models/instrumented.py: 96.51%
130 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations
3import json
4from collections.abc import AsyncIterator, Iterator, Mapping
5from contextlib import asynccontextmanager, contextmanager
6from dataclasses import dataclass, field
7from typing import Any, Callable, Literal
8from urllib.parse import urlparse
10from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
11from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
12from opentelemetry.util.types import AttributeValue
13from pydantic import TypeAdapter
15from ..messages import (
16 ModelMessage,
17 ModelRequest,
18 ModelResponse,
19)
20from ..settings import ModelSettings
21from ..usage import Usage
22from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
23from .wrapper import WrapperModel
25MODEL_SETTING_ATTRIBUTES: tuple[
26 Literal[
27 'max_tokens',
28 'top_p',
29 'seed',
30 'temperature',
31 'presence_penalty',
32 'frequency_penalty',
33 ],
34 ...,
35] = (
36 'max_tokens',
37 'top_p',
38 'seed',
39 'temperature',
40 'presence_penalty',
41 'frequency_penalty',
42)
44ANY_ADAPTER = TypeAdapter[Any](Any)
47@dataclass(init=False)
48class InstrumentationSettings:
49 """Options for instrumenting models and agents with OpenTelemetry.
51 Used in:
53 - `Agent(instrument=...)`
54 - [`Agent.instrument_all()`][pydantic_ai.agent.Agent.instrument_all]
55 - [`InstrumentedModel`][pydantic_ai.models.instrumented.InstrumentedModel]
57 See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
58 """
60 tracer: Tracer = field(repr=False)
61 event_logger: EventLogger = field(repr=False)
62 event_mode: Literal['attributes', 'logs'] = 'attributes'
64 def __init__(
65 self,
66 *,
67 event_mode: Literal['attributes', 'logs'] = 'attributes',
68 tracer_provider: TracerProvider | None = None,
69 event_logger_provider: EventLoggerProvider | None = None,
70 ):
71 """Create instrumentation options.
73 Args:
74 event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
75 If `'logs'`, events are emitted as OpenTelemetry log-based events.
76 tracer_provider: The OpenTelemetry tracer provider to use.
77 If not provided, the global tracer provider is used.
78 Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
79 event_logger_provider: The OpenTelemetry event logger provider to use.
80 If not provided, the global event logger provider is used.
81 Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
82 This is only used if `event_mode='logs'`.
83 """
84 from pydantic_ai import __version__
86 tracer_provider = tracer_provider or get_tracer_provider()
87 event_logger_provider = event_logger_provider or get_event_logger_provider()
88 self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
89 self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
90 self.event_mode = event_mode
93GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
94GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
97@dataclass
98class InstrumentedModel(WrapperModel):
99 """Model which wraps another model so that requests are instrumented with OpenTelemetry.
101 See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
102 """
104 settings: InstrumentationSettings
105 """Configuration for instrumenting requests."""
107 def __init__(
108 self,
109 wrapped: Model | KnownModelName,
110 options: InstrumentationSettings | None = None,
111 ) -> None:
112 super().__init__(wrapped)
113 self.settings = options or InstrumentationSettings()
115 async def request(
116 self,
117 messages: list[ModelMessage],
118 model_settings: ModelSettings | None,
119 model_request_parameters: ModelRequestParameters,
120 ) -> tuple[ModelResponse, Usage]:
121 with self._instrument(messages, model_settings, model_request_parameters) as finish:
122 response, usage = await super().request(messages, model_settings, model_request_parameters)
123 finish(response, usage)
124 return response, usage
126 @asynccontextmanager
127 async def request_stream(
128 self,
129 messages: list[ModelMessage],
130 model_settings: ModelSettings | None,
131 model_request_parameters: ModelRequestParameters,
132 ) -> AsyncIterator[StreamedResponse]:
133 with self._instrument(messages, model_settings, model_request_parameters) as finish:
134 response_stream: StreamedResponse | None = None
135 try:
136 async with super().request_stream(
137 messages, model_settings, model_request_parameters
138 ) as response_stream:
139 yield response_stream
140 finally:
141 if response_stream: 141 ↛ exitline 141 didn't jump to the function exit
142 finish(response_stream.get(), response_stream.usage())
144 @contextmanager
145 def _instrument(
146 self,
147 messages: list[ModelMessage],
148 model_settings: ModelSettings | None,
149 model_request_parameters: ModelRequestParameters,
150 ) -> Iterator[Callable[[ModelResponse, Usage], None]]:
151 operation = 'chat'
152 span_name = f'{operation} {self.model_name}'
153 # TODO Missing attributes:
154 # - error.type: unclear if we should do something here or just always rely on span exceptions
155 # - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
156 attributes: dict[str, AttributeValue] = {
157 'gen_ai.operation.name': operation,
158 **self.model_attributes(self.wrapped),
159 'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters)),
160 'logfire.json_schema': json.dumps(
161 {
162 'type': 'object',
163 'properties': {'model_request_parameters': {'type': 'object'}},
164 }
165 ),
166 }
168 if model_settings:
169 for key in MODEL_SETTING_ATTRIBUTES:
170 if isinstance(value := model_settings.get(key), (float, int)):
171 attributes[f'gen_ai.request.{key}'] = value
173 with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
175 def finish(response: ModelResponse, usage: Usage):
176 if not span.is_recording():
177 return
179 events = self.messages_to_otel_events(messages)
180 for event in self.messages_to_otel_events([response]):
181 events.append(
182 Event(
183 'gen_ai.choice',
184 body={
185 # TODO finish_reason
186 'index': 0,
187 'message': event.body,
188 },
189 )
190 )
191 new_attributes: dict[str, AttributeValue] = usage.opentelemetry_attributes() # type: ignore
192 attributes.update(getattr(span, 'attributes', {}))
193 request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
194 new_attributes['gen_ai.response.model'] = response.model_name or request_model
195 span.set_attributes(new_attributes)
196 span.update_name(f'{operation} {request_model}')
197 for event in events:
198 event.attributes = {
199 GEN_AI_SYSTEM_ATTRIBUTE: attributes[GEN_AI_SYSTEM_ATTRIBUTE],
200 **(event.attributes or {}),
201 }
202 self._emit_events(span, events)
204 yield finish
206 def _emit_events(self, span: Span, events: list[Event]) -> None:
207 if self.settings.event_mode == 'logs':
208 for event in events:
209 self.settings.event_logger.emit(event)
210 else:
211 attr_name = 'events'
212 span.set_attributes(
213 {
214 attr_name: json.dumps([self.event_to_dict(event) for event in events]),
215 'logfire.json_schema': json.dumps(
216 {
217 'type': 'object',
218 'properties': {
219 attr_name: {'type': 'array'},
220 'model_request_parameters': {'type': 'object'},
221 },
222 }
223 ),
224 }
225 )
227 @staticmethod
228 def model_attributes(model: Model):
229 attributes: dict[str, AttributeValue] = {
230 GEN_AI_SYSTEM_ATTRIBUTE: model.system,
231 GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
232 }
233 if base_url := model.base_url:
234 try:
235 parsed = urlparse(base_url)
236 except Exception: # pragma: no cover
237 pass
238 else:
239 if parsed.hostname: 239 ↛ 241line 239 didn't jump to line 241 because the condition on line 239 was always true
240 attributes['server.address'] = parsed.hostname
241 if parsed.port: 241 ↛ 244line 241 didn't jump to line 244 because the condition on line 241 was always true
242 attributes['server.port'] = parsed.port
244 return attributes
246 @staticmethod
247 def event_to_dict(event: Event) -> dict[str, Any]:
248 if not event.body: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true
249 body = {}
250 elif isinstance(event.body, Mapping):
251 body = event.body # type: ignore
252 else:
253 body = {'body': event.body}
254 return {**body, **(event.attributes or {})}
256 @staticmethod
257 def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
258 result: list[Event] = []
259 for message_index, message in enumerate(messages):
260 message_events: list[Event] = []
261 if isinstance(message, ModelRequest):
262 for part in message.parts:
263 if hasattr(part, 'otel_event'):
264 message_events.append(part.otel_event())
265 elif isinstance(message, ModelResponse): 265 ↛ 267line 265 didn't jump to line 267 because the condition on line 265 was always true
266 message_events = message.otel_events()
267 for event in message_events:
268 event.attributes = {
269 'gen_ai.message.index': message_index,
270 **(event.attributes or {}),
271 }
272 result.extend(message_events)
273 for event in result:
274 event.body = InstrumentedModel.serialize_any(event.body)
275 return result
277 @staticmethod
278 def serialize_any(value: Any) -> str:
279 try:
280 return ANY_ADAPTER.dump_python(value, mode='json')
281 except Exception:
282 try:
283 return str(value)
284 except Exception as e:
285 return f'Unable to serialize: {e}'