Coverage for pydantic_ai_slim/pydantic_ai/_parts_manager.py: 99.20%
87 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1"""This module provides functionality to manage and update parts of a model's streamed response.
3The manager tracks which parts (in particular, text and tool calls) correspond to which
4vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model),
5and produces PydanticAI-format events as appropriate for consumers of the streaming APIs.
7The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor,
8and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation.
10This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate
11event-emitting logic.
12"""
14from __future__ import annotations as _annotations
16from collections.abc import Hashable
17from dataclasses import dataclass, field
18from typing import Any, Union
20from pydantic_ai.exceptions import UnexpectedModelBehavior
21from pydantic_ai.messages import (
22 ModelResponsePart,
23 ModelResponseStreamEvent,
24 PartDeltaEvent,
25 PartStartEvent,
26 TextPart,
27 TextPartDelta,
28 ToolCallPart,
29 ToolCallPartDelta,
30)
32VendorId = Hashable
33"""
34Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
35"""
37ManagedPart = Union[ModelResponsePart, ToolCallPartDelta]
38"""
39A union of types that are managed by the ModelResponsePartsManager.
40Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
41this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's.
42"""
45@dataclass
46class ModelResponsePartsManager:
47 """Manages a sequence of parts that make up a model's streamed response.
49 Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs.
50 """
52 _parts: list[ManagedPart] = field(default_factory=list, init=False)
53 """A list of parts (text or tool calls) that make up the current state of the model's response."""
54 _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
55 """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
57 def get_parts(self) -> list[ModelResponsePart]:
58 """Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
60 Returns:
61 A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded.
62 """
63 return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
65 def handle_text_delta(
66 self,
67 *,
68 vendor_part_id: Hashable | None,
69 content: str,
70 ) -> ModelResponseStreamEvent:
71 """Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
73 When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
74 otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
75 to that vendor ID is either created or updated.
77 Args:
78 vendor_part_id: The ID the vendor uses to identify this piece
79 of text. If None, a new part will be created unless the latest part is already
80 a TextPart.
81 content: The text content to append to the appropriate TextPart.
83 Returns:
84 A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
86 Raises:
87 UnexpectedModelBehavior: If attempting to apply text content to a part that is
88 not a TextPart.
89 """
90 existing_text_part_and_index: tuple[TextPart, int] | None = None
92 if vendor_part_id is None:
93 # If the vendor_part_id is None, check if the latest part is a TextPart to update
94 if self._parts:
95 part_index = len(self._parts) - 1
96 latest_part = self._parts[part_index]
97 if isinstance(latest_part, TextPart):
98 existing_text_part_and_index = latest_part, part_index
99 else:
100 # Otherwise, attempt to look up an existing TextPart by vendor_part_id
101 part_index = self._vendor_id_to_part_index.get(vendor_part_id)
102 if part_index is not None:
103 existing_part = self._parts[part_index]
104 if not isinstance(existing_part, TextPart):
105 raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
106 existing_text_part_and_index = existing_part, part_index
108 if existing_text_part_and_index is None:
109 # There is no existing text part that should be updated, so create a new one
110 new_part_index = len(self._parts)
111 part = TextPart(content=content)
112 if vendor_part_id is not None:
113 self._vendor_id_to_part_index[vendor_part_id] = new_part_index
114 self._parts.append(part)
115 return PartStartEvent(index=new_part_index, part=part)
116 else:
117 # Update the existing TextPart with the new content delta
118 existing_text_part, part_index = existing_text_part_and_index
119 part_delta = TextPartDelta(content_delta=content)
120 self._parts[part_index] = part_delta.apply(existing_text_part)
121 return PartDeltaEvent(index=part_index, delta=part_delta)
123 def handle_tool_call_delta(
124 self,
125 *,
126 vendor_part_id: Hashable | None,
127 tool_name: str | None,
128 args: str | dict[str, Any] | None,
129 tool_call_id: str | None,
130 ) -> ModelResponseStreamEvent | None:
131 """Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`.
133 Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which
134 point they are upgraded to `ToolCallPart`s.
136 If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta)
137 if any. Otherwise, a new part (or delta) may be created.
139 Args:
140 vendor_part_id: The ID the vendor uses for this tool call.
141 If None, the latest matching tool call may be updated.
142 tool_name: The name of the tool. If None, the manager does not enforce
143 a name match when `vendor_part_id` is None.
144 args: Arguments for the tool call, either as a string or a dictionary of key-value pairs.
145 tool_call_id: An optional string representing an identifier for this tool call.
147 Returns:
148 - A `PartStartEvent` if a new (fully realized) ToolCallPart is created.
149 - A `PartDeltaEvent` if an existing part is updated.
150 - `None` if no new event is emitted (e.g., the part is still incomplete).
152 Raises:
153 UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not
154 a ToolCallPart or ToolCallPartDelta.
155 """
156 existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None
158 if vendor_part_id is None:
159 # vendor_part_id is None, so check if the latest part is a matching tool call or delta to update
160 # When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather
161 # than a delta on an existing one. We can change this behavior in the future if necessary for some model.
162 if tool_name is None and self._parts:
163 part_index = len(self._parts) - 1
164 latest_part = self._parts[part_index]
165 if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)): 165 ↛ 176line 165 didn't jump to line 176 because the condition on line 165 was always true
166 existing_matching_part_and_index = latest_part, part_index
167 else:
168 # vendor_part_id is provided, so look up the corresponding part or delta
169 part_index = self._vendor_id_to_part_index.get(vendor_part_id)
170 if part_index is not None:
171 existing_part = self._parts[part_index]
172 if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)):
173 raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
174 existing_matching_part_and_index = existing_part, part_index
176 if existing_matching_part_and_index is None:
177 # No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed)
178 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
179 part = delta.as_part() or delta
180 if vendor_part_id is not None:
181 self._vendor_id_to_part_index[vendor_part_id] = len(self._parts)
182 new_part_index = len(self._parts)
183 self._parts.append(part)
184 # Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart
185 if isinstance(part, ToolCallPart):
186 return PartStartEvent(index=new_part_index, part=part)
187 else:
188 # Update the existing part or delta with the new information
189 existing_part, part_index = existing_matching_part_and_index
190 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
191 updated_part = delta.apply(existing_part)
192 self._parts[part_index] = updated_part
193 if isinstance(updated_part, ToolCallPart):
194 if isinstance(existing_part, ToolCallPartDelta):
195 # We just upgraded a delta to a full part, so emit a PartStartEvent
196 return PartStartEvent(index=part_index, part=updated_part)
197 else:
198 # We updated an existing part, so emit a PartDeltaEvent
199 return PartDeltaEvent(index=part_index, delta=delta)
201 def handle_tool_call_part(
202 self,
203 *,
204 vendor_part_id: Hashable | None,
205 tool_name: str,
206 args: str | dict[str, Any],
207 tool_call_id: str | None = None,
208 ) -> ModelResponseStreamEvent:
209 """Immediately create or fully-overwrite a ToolCallPart with the given information.
211 This does not apply a delta; it directly sets the tool call part contents.
213 Args:
214 vendor_part_id: The vendor's ID for this tool call part. If not
215 None and an existing part is found, that part is overwritten.
216 tool_name: The name of the tool being invoked.
217 args: The arguments for the tool call, either as a string or a dictionary.
218 tool_call_id: An optional string identifier for this tool call.
220 Returns:
221 ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
222 has been added to the manager, or replaced an existing part.
223 """
224 new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
225 if vendor_part_id is None:
226 # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
227 new_part_index = len(self._parts)
228 self._parts.append(new_part)
229 else:
230 # vendor_part_id is provided, so find and overwrite or create a new ToolCallPart.
231 maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
232 if maybe_part_index is not None:
233 new_part_index = maybe_part_index
234 self._parts[new_part_index] = new_part
235 else:
236 new_part_index = len(self._parts)
237 self._parts.append(new_part)
238 self._vendor_id_to_part_index[vendor_part_id] = new_part_index
239 return PartStartEvent(index=new_part_index, part=new_part)