Coverage for pydantic_ai_slim/pydantic_ai/messages.py: 95.38%
210 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3from dataclasses import dataclass, field, replace
4from datetime import datetime
5from typing import Annotated, Any, Literal, Union, cast, overload
7import pydantic
8import pydantic_core
10from ._utils import now_utc as _now_utc
11from .exceptions import UnexpectedModelBehavior
14@dataclass
15class SystemPromptPart:
16 """A system prompt, generally written by the application developer.
18 This gives the model context and guidance on how to respond.
19 """
21 content: str
22 """The content of the prompt."""
24 dynamic_ref: str | None = None
25 """The ref of the dynamic system prompt function that generated this part.
27 Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
28 """
30 part_kind: Literal['system-prompt'] = 'system-prompt'
31 """Part type identifier, this is available on all parts as a discriminator."""
34@dataclass
35class UserPromptPart:
36 """A user prompt, generally written by the end user.
38 Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
39 [`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
40 """
42 content: str
43 """The content of the prompt."""
45 timestamp: datetime = field(default_factory=_now_utc)
46 """The timestamp of the prompt."""
48 part_kind: Literal['user-prompt'] = 'user-prompt'
49 """Part type identifier, this is available on all parts as a discriminator."""
52tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
55@dataclass
56class ToolReturnPart:
57 """A tool return message, this encodes the result of running a tool."""
59 tool_name: str
60 """The name of the "tool" was called."""
62 content: Any
63 """The return value."""
65 tool_call_id: str | None = None
66 """Optional tool call identifier, this is used by some models including OpenAI."""
68 timestamp: datetime = field(default_factory=_now_utc)
69 """The timestamp, when the tool returned."""
71 part_kind: Literal['tool-return'] = 'tool-return'
72 """Part type identifier, this is available on all parts as a discriminator."""
74 def model_response_str(self) -> str:
75 """Return a string representation of the content for the model."""
76 if isinstance(self.content, str):
77 return self.content
78 else:
79 return tool_return_ta.dump_json(self.content).decode()
81 def model_response_object(self) -> dict[str, Any]:
82 """Return a dictionary representation of the content, wrapping non-dict types appropriately."""
83 # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
84 if isinstance(self.content, dict): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true
85 return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
86 else:
87 return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
90error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
93@dataclass
94class RetryPromptPart:
95 """A message back to a model asking it to try again.
97 This can be sent for a number of reasons:
99 * Pydantic validation of tool arguments failed, here content is derived from a Pydantic
100 [`ValidationError`][pydantic_core.ValidationError]
101 * a tool raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
102 * no tool was found for the tool name
103 * the model returned plain text when a structured response was expected
104 * Pydantic validation of a structured response failed, here content is derived from a Pydantic
105 [`ValidationError`][pydantic_core.ValidationError]
106 * a result validator raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
107 """
109 content: list[pydantic_core.ErrorDetails] | str
110 """Details of why and how the model should retry.
112 If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of
113 error details.
114 """
116 tool_name: str | None = None
117 """The name of the tool that was called, if any."""
119 tool_call_id: str | None = None
120 """Optional tool call identifier, this is used by some models including OpenAI."""
122 timestamp: datetime = field(default_factory=_now_utc)
123 """The timestamp, when the retry was triggered."""
125 part_kind: Literal['retry-prompt'] = 'retry-prompt'
126 """Part type identifier, this is available on all parts as a discriminator."""
128 def model_response(self) -> str:
129 """Return a string message describing why the retry is requested."""
130 if isinstance(self.content, str):
131 description = self.content
132 else:
133 json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
134 description = f'{len(self.content)} validation errors: {json_errors.decode()}'
135 return f'{description}\n\nFix the errors and try again.'
138ModelRequestPart = Annotated[
139 Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
140]
141"""A message part sent by PydanticAI to a model."""
144@dataclass
145class ModelRequest:
146 """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model."""
148 parts: list[ModelRequestPart]
149 """The parts of the user message."""
151 kind: Literal['request'] = 'request'
152 """Message type identifier, this is available on all parts as a discriminator."""
155@dataclass
156class TextPart:
157 """A plain text response from a model."""
159 content: str
160 """The text content of the response."""
162 part_kind: Literal['text'] = 'text'
163 """Part type identifier, this is available on all parts as a discriminator."""
165 def has_content(self) -> bool:
166 """Return `True` if the text content is non-empty."""
167 return bool(self.content)
170@dataclass
171class ToolCallPart:
172 """A tool call from a model."""
174 tool_name: str
175 """The name of the tool to call."""
177 args: str | dict[str, Any]
178 """The arguments to pass to the tool.
180 This is stored either as a JSON string or a Python dictionary depending on how data was received.
181 """
183 tool_call_id: str | None = None
184 """Optional tool call identifier, this is used by some models including OpenAI."""
186 part_kind: Literal['tool-call'] = 'tool-call'
187 """Part type identifier, this is available on all parts as a discriminator."""
189 def args_as_dict(self) -> dict[str, Any]:
190 """Return the arguments as a Python dictionary.
192 This is just for convenience with models that require dicts as input.
193 """
194 if isinstance(self.args, dict): 194 ↛ 196line 194 didn't jump to line 196 because the condition on line 194 was always true
195 return self.args
196 args = pydantic_core.from_json(self.args)
197 assert isinstance(args, dict), 'args should be a dict'
198 return cast(dict[str, Any], args)
200 def args_as_json_str(self) -> str:
201 """Return the arguments as a JSON string.
203 This is just for convenience with models that require JSON strings as input.
204 """
205 if isinstance(self.args, str):
206 return self.args
207 return pydantic_core.to_json(self.args).decode()
209 def has_content(self) -> bool:
210 """Return `True` if the arguments contain any data."""
211 if isinstance(self.args, dict):
212 # TODO: This should probably return True if you have the value False, or 0, etc.
213 # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
214 return any(self.args.values())
215 else:
216 return bool(self.args)
219ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
220"""A message part returned by a model."""
223@dataclass
224class ModelResponse:
225 """A response from a model, e.g. a message from the model to the PydanticAI app."""
227 parts: list[ModelResponsePart]
228 """The parts of the model message."""
230 model_name: str | None = None
231 """The name of the model that generated the response."""
233 timestamp: datetime = field(default_factory=_now_utc)
234 """The timestamp of the response.
236 If the model provides a timestamp in the response (as OpenAI does) that will be used.
237 """
239 kind: Literal['response'] = 'response'
240 """Message type identifier, this is available on all parts as a discriminator."""
243ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
244"""Any message sent to or returned by a model."""
246ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True))
247"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
250@dataclass
251class TextPartDelta:
252 """A partial update (delta) for a `TextPart` to append new text content."""
254 content_delta: str
255 """The incremental text content to add to the existing `TextPart` content."""
257 part_delta_kind: Literal['text'] = 'text'
258 """Part delta type identifier, used as a discriminator."""
260 def apply(self, part: ModelResponsePart) -> TextPart:
261 """Apply this text delta to an existing `TextPart`.
263 Args:
264 part: The existing model response part, which must be a `TextPart`.
266 Returns:
267 A new `TextPart` with updated text content.
269 Raises:
270 ValueError: If `part` is not a `TextPart`.
271 """
272 if not isinstance(part, TextPart): 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true
273 raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
274 return replace(part, content=part.content + self.content_delta)
277@dataclass
278class ToolCallPartDelta:
279 """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
281 tool_name_delta: str | None = None
282 """Incremental text to add to the existing tool name, if any."""
284 args_delta: str | dict[str, Any] | None = None
285 """Incremental data to add to the tool arguments.
287 If this is a string, it will be appended to existing JSON arguments.
288 If this is a dict, it will be merged with existing dict arguments.
289 """
291 tool_call_id: str | None = None
292 """Optional tool call identifier, this is used by some models including OpenAI.
294 Note this is never treated as a delta — it can replace None, but otherwise if a
295 non-matching value is provided an error will be raised."""
297 part_delta_kind: Literal['tool_call'] = 'tool_call'
298 """Part delta type identifier, used as a discriminator."""
300 def as_part(self) -> ToolCallPart | None:
301 """Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
303 Returns:
304 A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
305 """
306 if self.tool_name_delta is None or self.args_delta is None:
307 return None
309 return ToolCallPart(
310 self.tool_name_delta,
311 self.args_delta,
312 self.tool_call_id,
313 )
315 @overload
316 def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
318 @overload
319 def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ...
321 def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
322 """Apply this delta to a part or delta, returning a new part or delta with the changes applied.
324 Args:
325 part: The existing model response part or delta to update.
327 Returns:
328 Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
330 Raises:
331 ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
332 UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
333 """
334 if isinstance(part, ToolCallPart):
335 return self._apply_to_part(part)
337 if isinstance(part, ToolCallPartDelta): 337 ↛ 340line 337 didn't jump to line 340 because the condition on line 337 was always true
338 return self._apply_to_delta(part)
340 raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
342 def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
343 """Internal helper to apply this delta to another delta."""
344 if self.tool_name_delta:
345 # Append incremental text to the existing tool_name_delta
346 updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta
347 delta = replace(delta, tool_name_delta=updated_tool_name_delta)
349 if isinstance(self.args_delta, str):
350 if isinstance(delta.args_delta, dict):
351 raise UnexpectedModelBehavior(
352 f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})'
353 )
354 updated_args_delta = (delta.args_delta or '') + self.args_delta
355 delta = replace(delta, args_delta=updated_args_delta)
356 elif isinstance(self.args_delta, dict):
357 if isinstance(delta.args_delta, str):
358 raise UnexpectedModelBehavior(
359 f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})'
360 )
361 updated_args_delta = {**(delta.args_delta or {}), **self.args_delta}
362 delta = replace(delta, args_delta=updated_args_delta)
364 if self.tool_call_id:
365 # Set the tool_call_id if it wasn't present, otherwise error if it has changed
366 if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id: 366 ↛ 370line 366 didn't jump to line 370 because the condition on line 366 was always true
367 raise UnexpectedModelBehavior(
368 f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
369 )
370 delta = replace(delta, tool_call_id=self.tool_call_id)
372 # If we now have enough data to create a full ToolCallPart, do so
373 if delta.tool_name_delta is not None and delta.args_delta is not None:
374 return ToolCallPart(
375 delta.tool_name_delta,
376 delta.args_delta,
377 delta.tool_call_id,
378 )
380 return delta
382 def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
383 """Internal helper to apply this delta directly to a `ToolCallPart`."""
384 if self.tool_name_delta:
385 # Append incremental text to the existing tool_name
386 tool_name = part.tool_name + self.tool_name_delta
387 part = replace(part, tool_name=tool_name)
389 if isinstance(self.args_delta, str):
390 if not isinstance(part.args, str):
391 raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
392 updated_json = part.args + self.args_delta
393 part = replace(part, args=updated_json)
394 elif isinstance(self.args_delta, dict):
395 if not isinstance(part.args, dict):
396 raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
397 updated_dict = {**(part.args or {}), **self.args_delta}
398 part = replace(part, args=updated_dict)
400 if self.tool_call_id:
401 # Replace the tool_call_id entirely if given
402 if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
403 raise UnexpectedModelBehavior(
404 f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
405 )
406 part = replace(part, tool_call_id=self.tool_call_id)
407 return part
410ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
411"""A partial update (delta) for any model response part."""
414@dataclass
415class PartStartEvent:
416 """An event indicating that a new part has started.
418 If multiple `PartStartEvent`s are received with the same index,
419 the new one should fully replace the old one.
420 """
422 index: int
423 """The index of the part within the overall response parts list."""
425 part: ModelResponsePart
426 """The newly started `ModelResponsePart`."""
428 event_kind: Literal['part_start'] = 'part_start'
429 """Event type identifier, used as a discriminator."""
432@dataclass
433class PartDeltaEvent:
434 """An event indicating a delta update for an existing part."""
436 index: int
437 """The index of the part within the overall response parts list."""
439 delta: ModelResponsePartDelta
440 """The delta to apply to the specified part."""
442 event_kind: Literal['part_delta'] = 'part_delta'
443 """Event type identifier, used as a discriminator."""
446ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
447"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""