Coverage for pydantic_ai_slim/pydantic_ai/_parts_manager.py: 99.21%
88 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1"""This module provides functionality to manage and update parts of a model's streamed response.
3The manager tracks which parts (in particular, text and tool calls) correspond to which
4vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model),
5and produces PydanticAI-format events as appropriate for consumers of the streaming APIs.
7The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor,
8and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation.
10This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate
11event-emitting logic.
12"""
14from __future__ import annotations as _annotations
16from collections.abc import Hashable
17from dataclasses import dataclass, field
18from typing import Any, Union
20from pydantic_ai.exceptions import UnexpectedModelBehavior
21from pydantic_ai.messages import (
22 ModelResponsePart,
23 ModelResponseStreamEvent,
24 PartDeltaEvent,
25 PartStartEvent,
26 TextPart,
27 TextPartDelta,
28 ToolCallPart,
29 ToolCallPartDelta,
30)
32from ._utils import generate_tool_call_id as _generate_tool_call_id
34VendorId = Hashable
35"""
36Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
37"""
39ManagedPart = Union[ModelResponsePart, ToolCallPartDelta]
40"""
41A union of types that are managed by the ModelResponsePartsManager.
42Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
43this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's.
44"""
47@dataclass
48class ModelResponsePartsManager:
49 """Manages a sequence of parts that make up a model's streamed response.
51 Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs.
52 """
54 _parts: list[ManagedPart] = field(default_factory=list, init=False)
55 """A list of parts (text or tool calls) that make up the current state of the model's response."""
56 _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
57 """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
59 def get_parts(self) -> list[ModelResponsePart]:
60 """Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
62 Returns:
63 A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded.
64 """
65 return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
67 def handle_text_delta(
68 self,
69 *,
70 vendor_part_id: Hashable | None,
71 content: str,
72 ) -> ModelResponseStreamEvent:
73 """Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
75 When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
76 otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
77 to that vendor ID is either created or updated.
79 Args:
80 vendor_part_id: The ID the vendor uses to identify this piece
81 of text. If None, a new part will be created unless the latest part is already
82 a TextPart.
83 content: The text content to append to the appropriate TextPart.
85 Returns:
86 A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
88 Raises:
89 UnexpectedModelBehavior: If attempting to apply text content to a part that is
90 not a TextPart.
91 """
92 existing_text_part_and_index: tuple[TextPart, int] | None = None
94 if vendor_part_id is None:
95 # If the vendor_part_id is None, check if the latest part is a TextPart to update
96 if self._parts:
97 part_index = len(self._parts) - 1
98 latest_part = self._parts[part_index]
99 if isinstance(latest_part, TextPart):
100 existing_text_part_and_index = latest_part, part_index
101 else:
102 # Otherwise, attempt to look up an existing TextPart by vendor_part_id
103 part_index = self._vendor_id_to_part_index.get(vendor_part_id)
104 if part_index is not None:
105 existing_part = self._parts[part_index]
106 if not isinstance(existing_part, TextPart):
107 raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
108 existing_text_part_and_index = existing_part, part_index
110 if existing_text_part_and_index is None:
111 # There is no existing text part that should be updated, so create a new one
112 new_part_index = len(self._parts)
113 part = TextPart(content=content)
114 if vendor_part_id is not None:
115 self._vendor_id_to_part_index[vendor_part_id] = new_part_index
116 self._parts.append(part)
117 return PartStartEvent(index=new_part_index, part=part)
118 else:
119 # Update the existing TextPart with the new content delta
120 existing_text_part, part_index = existing_text_part_and_index
121 part_delta = TextPartDelta(content_delta=content)
122 self._parts[part_index] = part_delta.apply(existing_text_part)
123 return PartDeltaEvent(index=part_index, delta=part_delta)
125 def handle_tool_call_delta(
126 self,
127 *,
128 vendor_part_id: Hashable | None,
129 tool_name: str | None,
130 args: str | dict[str, Any] | None,
131 tool_call_id: str | None,
132 ) -> ModelResponseStreamEvent | None:
133 """Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`.
135 Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which
136 point they are upgraded to `ToolCallPart`s.
138 If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta)
139 if any. Otherwise, a new part (or delta) may be created.
141 Args:
142 vendor_part_id: The ID the vendor uses for this tool call.
143 If None, the latest matching tool call may be updated.
144 tool_name: The name of the tool. If None, the manager does not enforce
145 a name match when `vendor_part_id` is None.
146 args: Arguments for the tool call, either as a string or a dictionary of key-value pairs.
147 tool_call_id: An optional string representing an identifier for this tool call.
149 Returns:
150 - A `PartStartEvent` if a new (fully realized) ToolCallPart is created.
151 - A `PartDeltaEvent` if an existing part is updated.
152 - `None` if no new event is emitted (e.g., the part is still incomplete).
154 Raises:
155 UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not
156 a ToolCallPart or ToolCallPartDelta.
157 """
158 existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None
160 if vendor_part_id is None:
161 # vendor_part_id is None, so check if the latest part is a matching tool call or delta to update
162 # When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather
163 # than a delta on an existing one. We can change this behavior in the future if necessary for some model.
164 if tool_name is None and self._parts:
165 part_index = len(self._parts) - 1
166 latest_part = self._parts[part_index]
167 if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)): 167 ↛ 178line 167 didn't jump to line 178 because the condition on line 167 was always true
168 existing_matching_part_and_index = latest_part, part_index
169 else:
170 # vendor_part_id is provided, so look up the corresponding part or delta
171 part_index = self._vendor_id_to_part_index.get(vendor_part_id)
172 if part_index is not None:
173 existing_part = self._parts[part_index]
174 if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)):
175 raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
176 existing_matching_part_and_index = existing_part, part_index
178 if existing_matching_part_and_index is None:
179 # No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed)
180 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
181 part = delta.as_part() or delta
182 if vendor_part_id is not None:
183 self._vendor_id_to_part_index[vendor_part_id] = len(self._parts)
184 new_part_index = len(self._parts)
185 self._parts.append(part)
186 # Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart
187 if isinstance(part, ToolCallPart):
188 return PartStartEvent(index=new_part_index, part=part)
189 else:
190 # Update the existing part or delta with the new information
191 existing_part, part_index = existing_matching_part_and_index
192 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
193 updated_part = delta.apply(existing_part)
194 self._parts[part_index] = updated_part
195 if isinstance(updated_part, ToolCallPart):
196 if isinstance(existing_part, ToolCallPartDelta):
197 # We just upgraded a delta to a full part, so emit a PartStartEvent
198 return PartStartEvent(index=part_index, part=updated_part)
199 else:
200 # We updated an existing part, so emit a PartDeltaEvent
201 return PartDeltaEvent(index=part_index, delta=delta)
203 def handle_tool_call_part(
204 self,
205 *,
206 vendor_part_id: Hashable | None,
207 tool_name: str,
208 args: str | dict[str, Any],
209 tool_call_id: str | None = None,
210 ) -> ModelResponseStreamEvent:
211 """Immediately create or fully-overwrite a ToolCallPart with the given information.
213 This does not apply a delta; it directly sets the tool call part contents.
215 Args:
216 vendor_part_id: The vendor's ID for this tool call part. If not
217 None and an existing part is found, that part is overwritten.
218 tool_name: The name of the tool being invoked.
219 args: The arguments for the tool call, either as a string or a dictionary.
220 tool_call_id: An optional string identifier for this tool call.
222 Returns:
223 ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
224 has been added to the manager, or replaced an existing part.
225 """
226 new_part = ToolCallPart(
227 tool_name=tool_name,
228 args=args,
229 tool_call_id=tool_call_id or _generate_tool_call_id(),
230 )
231 if vendor_part_id is None:
232 # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
233 new_part_index = len(self._parts)
234 self._parts.append(new_part)
235 else:
236 # vendor_part_id is provided, so find and overwrite or create a new ToolCallPart.
237 maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
238 if maybe_part_index is not None:
239 new_part_index = maybe_part_index
240 self._parts[new_part_index] = new_part
241 else:
242 new_part_index = len(self._parts)
243 self._parts.append(new_part)
244 self._vendor_id_to_part_index[vendor_part_id] = new_part_index
245 return PartStartEvent(index=new_part_index, part=new_part)