Coverage for pydantic_ai_slim/pydantic_ai/models/instrumented.py: 96.51%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations 

2 

3import json 

4from collections.abc import AsyncIterator, Iterator, Mapping 

5from contextlib import asynccontextmanager, contextmanager 

6from dataclasses import dataclass, field 

7from typing import Any, Callable, Literal 

8from urllib.parse import urlparse 

9 

10from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider 

11from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider 

12from opentelemetry.util.types import AttributeValue 

13from pydantic import TypeAdapter 

14 

15from ..messages import ( 

16 ModelMessage, 

17 ModelRequest, 

18 ModelResponse, 

19) 

20from ..settings import ModelSettings 

21from ..usage import Usage 

22from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse 

23from .wrapper import WrapperModel 

24 

25MODEL_SETTING_ATTRIBUTES: tuple[ 

26 Literal[ 

27 'max_tokens', 

28 'top_p', 

29 'seed', 

30 'temperature', 

31 'presence_penalty', 

32 'frequency_penalty', 

33 ], 

34 ..., 

35] = ( 

36 'max_tokens', 

37 'top_p', 

38 'seed', 

39 'temperature', 

40 'presence_penalty', 

41 'frequency_penalty', 

42) 

43 

44ANY_ADAPTER = TypeAdapter[Any](Any) 

45 

46 

47@dataclass(init=False) 

48class InstrumentationSettings: 

49 """Options for instrumenting models and agents with OpenTelemetry. 

50 

51 Used in: 

52 

53 - `Agent(instrument=...)` 

54 - [`Agent.instrument_all()`][pydantic_ai.agent.Agent.instrument_all] 

55 - [`InstrumentedModel`][pydantic_ai.models.instrumented.InstrumentedModel] 

56 

57 See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. 

58 """ 

59 

60 tracer: Tracer = field(repr=False) 

61 event_logger: EventLogger = field(repr=False) 

62 event_mode: Literal['attributes', 'logs'] = 'attributes' 

63 

64 def __init__( 

65 self, 

66 *, 

67 event_mode: Literal['attributes', 'logs'] = 'attributes', 

68 tracer_provider: TracerProvider | None = None, 

69 event_logger_provider: EventLoggerProvider | None = None, 

70 ): 

71 """Create instrumentation options. 

72 

73 Args: 

74 event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes. 

75 If `'logs'`, events are emitted as OpenTelemetry log-based events. 

76 tracer_provider: The OpenTelemetry tracer provider to use. 

77 If not provided, the global tracer provider is used. 

78 Calling `logfire.configure()` sets the global tracer provider, so most users don't need this. 

79 event_logger_provider: The OpenTelemetry event logger provider to use. 

80 If not provided, the global event logger provider is used. 

81 Calling `logfire.configure()` sets the global event logger provider, so most users don't need this. 

82 This is only used if `event_mode='logs'`. 

83 """ 

84 from pydantic_ai import __version__ 

85 

86 tracer_provider = tracer_provider or get_tracer_provider() 

87 event_logger_provider = event_logger_provider or get_event_logger_provider() 

88 self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__) 

89 self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__) 

90 self.event_mode = event_mode 

91 

92 

93GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system' 

94GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model' 

95 

96 

97@dataclass 

98class InstrumentedModel(WrapperModel): 

99 """Model which wraps another model so that requests are instrumented with OpenTelemetry. 

100 

101 See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. 

102 """ 

103 

104 settings: InstrumentationSettings 

105 """Configuration for instrumenting requests.""" 

106 

107 def __init__( 

108 self, 

109 wrapped: Model | KnownModelName, 

110 options: InstrumentationSettings | None = None, 

111 ) -> None: 

112 super().__init__(wrapped) 

113 self.settings = options or InstrumentationSettings() 

114 

115 async def request( 

116 self, 

117 messages: list[ModelMessage], 

118 model_settings: ModelSettings | None, 

119 model_request_parameters: ModelRequestParameters, 

120 ) -> tuple[ModelResponse, Usage]: 

121 with self._instrument(messages, model_settings, model_request_parameters) as finish: 

122 response, usage = await super().request(messages, model_settings, model_request_parameters) 

123 finish(response, usage) 

124 return response, usage 

125 

126 @asynccontextmanager 

127 async def request_stream( 

128 self, 

129 messages: list[ModelMessage], 

130 model_settings: ModelSettings | None, 

131 model_request_parameters: ModelRequestParameters, 

132 ) -> AsyncIterator[StreamedResponse]: 

133 with self._instrument(messages, model_settings, model_request_parameters) as finish: 

134 response_stream: StreamedResponse | None = None 

135 try: 

136 async with super().request_stream( 

137 messages, model_settings, model_request_parameters 

138 ) as response_stream: 

139 yield response_stream 

140 finally: 

141 if response_stream: 141 ↛ exitline 141 didn't jump to the function exit

142 finish(response_stream.get(), response_stream.usage()) 

143 

144 @contextmanager 

145 def _instrument( 

146 self, 

147 messages: list[ModelMessage], 

148 model_settings: ModelSettings | None, 

149 model_request_parameters: ModelRequestParameters, 

150 ) -> Iterator[Callable[[ModelResponse, Usage], None]]: 

151 operation = 'chat' 

152 span_name = f'{operation} {self.model_name}' 

153 # TODO Missing attributes: 

154 # - error.type: unclear if we should do something here or just always rely on span exceptions 

155 # - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these 

156 attributes: dict[str, AttributeValue] = { 

157 'gen_ai.operation.name': operation, 

158 **self.model_attributes(self.wrapped), 

159 'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters)), 

160 'logfire.json_schema': json.dumps( 

161 { 

162 'type': 'object', 

163 'properties': {'model_request_parameters': {'type': 'object'}}, 

164 } 

165 ), 

166 } 

167 

168 if model_settings: 

169 for key in MODEL_SETTING_ATTRIBUTES: 

170 if isinstance(value := model_settings.get(key), (float, int)): 

171 attributes[f'gen_ai.request.{key}'] = value 

172 

173 with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span: 

174 

175 def finish(response: ModelResponse, usage: Usage): 

176 if not span.is_recording(): 

177 return 

178 

179 events = self.messages_to_otel_events(messages) 

180 for event in self.messages_to_otel_events([response]): 

181 events.append( 

182 Event( 

183 'gen_ai.choice', 

184 body={ 

185 # TODO finish_reason 

186 'index': 0, 

187 'message': event.body, 

188 }, 

189 ) 

190 ) 

191 new_attributes: dict[str, AttributeValue] = usage.opentelemetry_attributes() # type: ignore 

192 attributes.update(getattr(span, 'attributes', {})) 

193 request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE] 

194 new_attributes['gen_ai.response.model'] = response.model_name or request_model 

195 span.set_attributes(new_attributes) 

196 span.update_name(f'{operation} {request_model}') 

197 for event in events: 

198 event.attributes = { 

199 GEN_AI_SYSTEM_ATTRIBUTE: attributes[GEN_AI_SYSTEM_ATTRIBUTE], 

200 **(event.attributes or {}), 

201 } 

202 self._emit_events(span, events) 

203 

204 yield finish 

205 

206 def _emit_events(self, span: Span, events: list[Event]) -> None: 

207 if self.settings.event_mode == 'logs': 

208 for event in events: 

209 self.settings.event_logger.emit(event) 

210 else: 

211 attr_name = 'events' 

212 span.set_attributes( 

213 { 

214 attr_name: json.dumps([self.event_to_dict(event) for event in events]), 

215 'logfire.json_schema': json.dumps( 

216 { 

217 'type': 'object', 

218 'properties': { 

219 attr_name: {'type': 'array'}, 

220 'model_request_parameters': {'type': 'object'}, 

221 }, 

222 } 

223 ), 

224 } 

225 ) 

226 

227 @staticmethod 

228 def model_attributes(model: Model): 

229 attributes: dict[str, AttributeValue] = { 

230 GEN_AI_SYSTEM_ATTRIBUTE: model.system, 

231 GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name, 

232 } 

233 if base_url := model.base_url: 

234 try: 

235 parsed = urlparse(base_url) 

236 except Exception: # pragma: no cover 

237 pass 

238 else: 

239 if parsed.hostname: 239 ↛ 241line 239 didn't jump to line 241 because the condition on line 239 was always true

240 attributes['server.address'] = parsed.hostname 

241 if parsed.port: 241 ↛ 244line 241 didn't jump to line 244 because the condition on line 241 was always true

242 attributes['server.port'] = parsed.port 

243 

244 return attributes 

245 

246 @staticmethod 

247 def event_to_dict(event: Event) -> dict[str, Any]: 

248 if not event.body: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true

249 body = {} 

250 elif isinstance(event.body, Mapping): 

251 body = event.body # type: ignore 

252 else: 

253 body = {'body': event.body} 

254 return {**body, **(event.attributes or {})} 

255 

256 @staticmethod 

257 def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]: 

258 result: list[Event] = [] 

259 for message_index, message in enumerate(messages): 

260 message_events: list[Event] = [] 

261 if isinstance(message, ModelRequest): 

262 for part in message.parts: 

263 if hasattr(part, 'otel_event'): 

264 message_events.append(part.otel_event()) 

265 elif isinstance(message, ModelResponse): 265 ↛ 267line 265 didn't jump to line 267 because the condition on line 265 was always true

266 message_events = message.otel_events() 

267 for event in message_events: 

268 event.attributes = { 

269 'gen_ai.message.index': message_index, 

270 **(event.attributes or {}), 

271 } 

272 result.extend(message_events) 

273 for event in result: 

274 event.body = InstrumentedModel.serialize_any(event.body) 

275 return result 

276 

277 @staticmethod 

278 def serialize_any(value: Any) -> str: 

279 try: 

280 return ANY_ADAPTER.dump_python(value, mode='json') 

281 except Exception: 

282 try: 

283 return str(value) 

284 except Exception as e: 

285 return f'Unable to serialize: {e}'