Coverage for pydantic_ai_slim/pydantic_ai/messages.py: 95.38%

210 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3from dataclasses import dataclass, field, replace 

4from datetime import datetime 

5from typing import Annotated, Any, Literal, Union, cast, overload 

6 

7import pydantic 

8import pydantic_core 

9 

10from ._utils import now_utc as _now_utc 

11from .exceptions import UnexpectedModelBehavior 

12 

13 

14@dataclass 

15class SystemPromptPart: 

16 """A system prompt, generally written by the application developer. 

17 

18 This gives the model context and guidance on how to respond. 

19 """ 

20 

21 content: str 

22 """The content of the prompt.""" 

23 

24 dynamic_ref: str | None = None 

25 """The ref of the dynamic system prompt function that generated this part. 

26 

27 Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information. 

28 """ 

29 

30 part_kind: Literal['system-prompt'] = 'system-prompt' 

31 """Part type identifier, this is available on all parts as a discriminator.""" 

32 

33 

34@dataclass 

35class UserPromptPart: 

36 """A user prompt, generally written by the end user. 

37 

38 Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run], 

39 [`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream]. 

40 """ 

41 

42 content: str 

43 """The content of the prompt.""" 

44 

45 timestamp: datetime = field(default_factory=_now_utc) 

46 """The timestamp of the prompt.""" 

47 

48 part_kind: Literal['user-prompt'] = 'user-prompt' 

49 """Part type identifier, this is available on all parts as a discriminator.""" 

50 

51 

52tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True)) 

53 

54 

55@dataclass 

56class ToolReturnPart: 

57 """A tool return message, this encodes the result of running a tool.""" 

58 

59 tool_name: str 

60 """The name of the "tool" was called.""" 

61 

62 content: Any 

63 """The return value.""" 

64 

65 tool_call_id: str | None = None 

66 """Optional tool call identifier, this is used by some models including OpenAI.""" 

67 

68 timestamp: datetime = field(default_factory=_now_utc) 

69 """The timestamp, when the tool returned.""" 

70 

71 part_kind: Literal['tool-return'] = 'tool-return' 

72 """Part type identifier, this is available on all parts as a discriminator.""" 

73 

74 def model_response_str(self) -> str: 

75 """Return a string representation of the content for the model.""" 

76 if isinstance(self.content, str): 

77 return self.content 

78 else: 

79 return tool_return_ta.dump_json(self.content).decode() 

80 

81 def model_response_object(self) -> dict[str, Any]: 

82 """Return a dictionary representation of the content, wrapping non-dict types appropriately.""" 

83 # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict 

84 if isinstance(self.content, dict): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true

85 return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType] 

86 else: 

87 return {'return_value': tool_return_ta.dump_python(self.content, mode='json')} 

88 

89 

90error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True)) 

91 

92 

93@dataclass 

94class RetryPromptPart: 

95 """A message back to a model asking it to try again. 

96 

97 This can be sent for a number of reasons: 

98 

99 * Pydantic validation of tool arguments failed, here content is derived from a Pydantic 

100 [`ValidationError`][pydantic_core.ValidationError] 

101 * a tool raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception 

102 * no tool was found for the tool name 

103 * the model returned plain text when a structured response was expected 

104 * Pydantic validation of a structured response failed, here content is derived from a Pydantic 

105 [`ValidationError`][pydantic_core.ValidationError] 

106 * a result validator raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception 

107 """ 

108 

109 content: list[pydantic_core.ErrorDetails] | str 

110 """Details of why and how the model should retry. 

111 

112 If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of 

113 error details. 

114 """ 

115 

116 tool_name: str | None = None 

117 """The name of the tool that was called, if any.""" 

118 

119 tool_call_id: str | None = None 

120 """Optional tool call identifier, this is used by some models including OpenAI.""" 

121 

122 timestamp: datetime = field(default_factory=_now_utc) 

123 """The timestamp, when the retry was triggered.""" 

124 

125 part_kind: Literal['retry-prompt'] = 'retry-prompt' 

126 """Part type identifier, this is available on all parts as a discriminator.""" 

127 

128 def model_response(self) -> str: 

129 """Return a string message describing why the retry is requested.""" 

130 if isinstance(self.content, str): 

131 description = self.content 

132 else: 

133 json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2) 

134 description = f'{len(self.content)} validation errors: {json_errors.decode()}' 

135 return f'{description}\n\nFix the errors and try again.' 

136 

137 

138ModelRequestPart = Annotated[ 

139 Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind') 

140] 

141"""A message part sent by PydanticAI to a model.""" 

142 

143 

144@dataclass 

145class ModelRequest: 

146 """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model.""" 

147 

148 parts: list[ModelRequestPart] 

149 """The parts of the user message.""" 

150 

151 kind: Literal['request'] = 'request' 

152 """Message type identifier, this is available on all parts as a discriminator.""" 

153 

154 

155@dataclass 

156class TextPart: 

157 """A plain text response from a model.""" 

158 

159 content: str 

160 """The text content of the response.""" 

161 

162 part_kind: Literal['text'] = 'text' 

163 """Part type identifier, this is available on all parts as a discriminator.""" 

164 

165 def has_content(self) -> bool: 

166 """Return `True` if the text content is non-empty.""" 

167 return bool(self.content) 

168 

169 

170@dataclass 

171class ToolCallPart: 

172 """A tool call from a model.""" 

173 

174 tool_name: str 

175 """The name of the tool to call.""" 

176 

177 args: str | dict[str, Any] 

178 """The arguments to pass to the tool. 

179 

180 This is stored either as a JSON string or a Python dictionary depending on how data was received. 

181 """ 

182 

183 tool_call_id: str | None = None 

184 """Optional tool call identifier, this is used by some models including OpenAI.""" 

185 

186 part_kind: Literal['tool-call'] = 'tool-call' 

187 """Part type identifier, this is available on all parts as a discriminator.""" 

188 

189 def args_as_dict(self) -> dict[str, Any]: 

190 """Return the arguments as a Python dictionary. 

191 

192 This is just for convenience with models that require dicts as input. 

193 """ 

194 if isinstance(self.args, dict): 194 ↛ 196line 194 didn't jump to line 196 because the condition on line 194 was always true

195 return self.args 

196 args = pydantic_core.from_json(self.args) 

197 assert isinstance(args, dict), 'args should be a dict' 

198 return cast(dict[str, Any], args) 

199 

200 def args_as_json_str(self) -> str: 

201 """Return the arguments as a JSON string. 

202 

203 This is just for convenience with models that require JSON strings as input. 

204 """ 

205 if isinstance(self.args, str): 

206 return self.args 

207 return pydantic_core.to_json(self.args).decode() 

208 

209 def has_content(self) -> bool: 

210 """Return `True` if the arguments contain any data.""" 

211 if isinstance(self.args, dict): 

212 # TODO: This should probably return True if you have the value False, or 0, etc. 

213 # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts 

214 return any(self.args.values()) 

215 else: 

216 return bool(self.args) 

217 

218 

219ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')] 

220"""A message part returned by a model.""" 

221 

222 

223@dataclass 

224class ModelResponse: 

225 """A response from a model, e.g. a message from the model to the PydanticAI app.""" 

226 

227 parts: list[ModelResponsePart] 

228 """The parts of the model message.""" 

229 

230 model_name: str | None = None 

231 """The name of the model that generated the response.""" 

232 

233 timestamp: datetime = field(default_factory=_now_utc) 

234 """The timestamp of the response. 

235 

236 If the model provides a timestamp in the response (as OpenAI does) that will be used. 

237 """ 

238 

239 kind: Literal['response'] = 'response' 

240 """Message type identifier, this is available on all parts as a discriminator.""" 

241 

242 

243ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')] 

244"""Any message sent to or returned by a model.""" 

245 

246ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True)) 

247"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages.""" 

248 

249 

250@dataclass 

251class TextPartDelta: 

252 """A partial update (delta) for a `TextPart` to append new text content.""" 

253 

254 content_delta: str 

255 """The incremental text content to add to the existing `TextPart` content.""" 

256 

257 part_delta_kind: Literal['text'] = 'text' 

258 """Part delta type identifier, used as a discriminator.""" 

259 

260 def apply(self, part: ModelResponsePart) -> TextPart: 

261 """Apply this text delta to an existing `TextPart`. 

262 

263 Args: 

264 part: The existing model response part, which must be a `TextPart`. 

265 

266 Returns: 

267 A new `TextPart` with updated text content. 

268 

269 Raises: 

270 ValueError: If `part` is not a `TextPart`. 

271 """ 

272 if not isinstance(part, TextPart): 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true

273 raise ValueError('Cannot apply TextPartDeltas to non-TextParts') 

274 return replace(part, content=part.content + self.content_delta) 

275 

276 

277@dataclass 

278class ToolCallPartDelta: 

279 """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID.""" 

280 

281 tool_name_delta: str | None = None 

282 """Incremental text to add to the existing tool name, if any.""" 

283 

284 args_delta: str | dict[str, Any] | None = None 

285 """Incremental data to add to the tool arguments. 

286 

287 If this is a string, it will be appended to existing JSON arguments. 

288 If this is a dict, it will be merged with existing dict arguments. 

289 """ 

290 

291 tool_call_id: str | None = None 

292 """Optional tool call identifier, this is used by some models including OpenAI. 

293 

294 Note this is never treated as a delta — it can replace None, but otherwise if a 

295 non-matching value is provided an error will be raised.""" 

296 

297 part_delta_kind: Literal['tool_call'] = 'tool_call' 

298 """Part delta type identifier, used as a discriminator.""" 

299 

300 def as_part(self) -> ToolCallPart | None: 

301 """Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`. 

302 

303 Returns: 

304 A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`. 

305 """ 

306 if self.tool_name_delta is None or self.args_delta is None: 

307 return None 

308 

309 return ToolCallPart( 

310 self.tool_name_delta, 

311 self.args_delta, 

312 self.tool_call_id, 

313 ) 

314 

315 @overload 

316 def apply(self, part: ModelResponsePart) -> ToolCallPart: ... 

317 

318 @overload 

319 def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ... 

320 

321 def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: 

322 """Apply this delta to a part or delta, returning a new part or delta with the changes applied. 

323 

324 Args: 

325 part: The existing model response part or delta to update. 

326 

327 Returns: 

328 Either a new `ToolCallPart` or an updated `ToolCallPartDelta`. 

329 

330 Raises: 

331 ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`. 

332 UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa. 

333 """ 

334 if isinstance(part, ToolCallPart): 

335 return self._apply_to_part(part) 

336 

337 if isinstance(part, ToolCallPartDelta): 337 ↛ 340line 337 didn't jump to line 340 because the condition on line 337 was always true

338 return self._apply_to_delta(part) 

339 

340 raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}') 

341 

342 def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: 

343 """Internal helper to apply this delta to another delta.""" 

344 if self.tool_name_delta: 

345 # Append incremental text to the existing tool_name_delta 

346 updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta 

347 delta = replace(delta, tool_name_delta=updated_tool_name_delta) 

348 

349 if isinstance(self.args_delta, str): 

350 if isinstance(delta.args_delta, dict): 

351 raise UnexpectedModelBehavior( 

352 f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})' 

353 ) 

354 updated_args_delta = (delta.args_delta or '') + self.args_delta 

355 delta = replace(delta, args_delta=updated_args_delta) 

356 elif isinstance(self.args_delta, dict): 

357 if isinstance(delta.args_delta, str): 

358 raise UnexpectedModelBehavior( 

359 f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})' 

360 ) 

361 updated_args_delta = {**(delta.args_delta or {}), **self.args_delta} 

362 delta = replace(delta, args_delta=updated_args_delta) 

363 

364 if self.tool_call_id: 

365 # Set the tool_call_id if it wasn't present, otherwise error if it has changed 

366 if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id: 366 ↛ 370line 366 didn't jump to line 370 because the condition on line 366 was always true

367 raise UnexpectedModelBehavior( 

368 f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})' 

369 ) 

370 delta = replace(delta, tool_call_id=self.tool_call_id) 

371 

372 # If we now have enough data to create a full ToolCallPart, do so 

373 if delta.tool_name_delta is not None and delta.args_delta is not None: 

374 return ToolCallPart( 

375 delta.tool_name_delta, 

376 delta.args_delta, 

377 delta.tool_call_id, 

378 ) 

379 

380 return delta 

381 

382 def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart: 

383 """Internal helper to apply this delta directly to a `ToolCallPart`.""" 

384 if self.tool_name_delta: 

385 # Append incremental text to the existing tool_name 

386 tool_name = part.tool_name + self.tool_name_delta 

387 part = replace(part, tool_name=tool_name) 

388 

389 if isinstance(self.args_delta, str): 

390 if not isinstance(part.args, str): 

391 raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})') 

392 updated_json = part.args + self.args_delta 

393 part = replace(part, args=updated_json) 

394 elif isinstance(self.args_delta, dict): 

395 if not isinstance(part.args, dict): 

396 raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})') 

397 updated_dict = {**(part.args or {}), **self.args_delta} 

398 part = replace(part, args=updated_dict) 

399 

400 if self.tool_call_id: 

401 # Replace the tool_call_id entirely if given 

402 if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id: 

403 raise UnexpectedModelBehavior( 

404 f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})' 

405 ) 

406 part = replace(part, tool_call_id=self.tool_call_id) 

407 return part 

408 

409 

410ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')] 

411"""A partial update (delta) for any model response part.""" 

412 

413 

414@dataclass 

415class PartStartEvent: 

416 """An event indicating that a new part has started. 

417 

418 If multiple `PartStartEvent`s are received with the same index, 

419 the new one should fully replace the old one. 

420 """ 

421 

422 index: int 

423 """The index of the part within the overall response parts list.""" 

424 

425 part: ModelResponsePart 

426 """The newly started `ModelResponsePart`.""" 

427 

428 event_kind: Literal['part_start'] = 'part_start' 

429 """Event type identifier, used as a discriminator.""" 

430 

431 

432@dataclass 

433class PartDeltaEvent: 

434 """An event indicating a delta update for an existing part.""" 

435 

436 index: int 

437 """The index of the part within the overall response parts list.""" 

438 

439 delta: ModelResponsePartDelta 

440 """The delta to apply to the specified part.""" 

441 

442 event_kind: Literal['part_delta'] = 'part_delta' 

443 """Event type identifier, used as a discriminator.""" 

444 

445 

446ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')] 

447"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""