Coverage for pydantic_ai_slim/pydantic_ai/_parts_manager.py: 99.21%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1"""This module provides functionality to manage and update parts of a model's streamed response. 

2 

3The manager tracks which parts (in particular, text and tool calls) correspond to which 

4vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model), 

5and produces PydanticAI-format events as appropriate for consumers of the streaming APIs. 

6 

7The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor, 

8and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation. 

9 

10This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate 

11event-emitting logic. 

12""" 

13 

14from __future__ import annotations as _annotations 

15 

16from collections.abc import Hashable 

17from dataclasses import dataclass, field 

18from typing import Any, Union 

19 

20from pydantic_ai.exceptions import UnexpectedModelBehavior 

21from pydantic_ai.messages import ( 

22 ModelResponsePart, 

23 ModelResponseStreamEvent, 

24 PartDeltaEvent, 

25 PartStartEvent, 

26 TextPart, 

27 TextPartDelta, 

28 ToolCallPart, 

29 ToolCallPartDelta, 

30) 

31 

32from ._utils import generate_tool_call_id as _generate_tool_call_id 

33 

34VendorId = Hashable 

35""" 

36Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.) 

37""" 

38 

39ManagedPart = Union[ModelResponsePart, ToolCallPartDelta] 

40""" 

41A union of types that are managed by the ModelResponsePartsManager. 

42Because many vendors have streaming APIs that may produce not-fully-formed tool calls, 

43this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's. 

44""" 

45 

46 

47@dataclass 

48class ModelResponsePartsManager: 

49 """Manages a sequence of parts that make up a model's streamed response. 

50 

51 Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs. 

52 """ 

53 

54 _parts: list[ManagedPart] = field(default_factory=list, init=False) 

55 """A list of parts (text or tool calls) that make up the current state of the model's response.""" 

56 _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) 

57 """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides.""" 

58 

59 def get_parts(self) -> list[ModelResponsePart]: 

60 """Return only model response parts that are complete (i.e., not ToolCallPartDelta's). 

61 

62 Returns: 

63 A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded. 

64 """ 

65 return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] 

66 

67 def handle_text_delta( 

68 self, 

69 *, 

70 vendor_part_id: Hashable | None, 

71 content: str, 

72 ) -> ModelResponseStreamEvent: 

73 """Handle incoming text content, creating or updating a TextPart in the manager as appropriate. 

74 

75 When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart; 

76 otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding 

77 to that vendor ID is either created or updated. 

78 

79 Args: 

80 vendor_part_id: The ID the vendor uses to identify this piece 

81 of text. If None, a new part will be created unless the latest part is already 

82 a TextPart. 

83 content: The text content to append to the appropriate TextPart. 

84 

85 Returns: 

86 A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated. 

87 

88 Raises: 

89 UnexpectedModelBehavior: If attempting to apply text content to a part that is 

90 not a TextPart. 

91 """ 

92 existing_text_part_and_index: tuple[TextPart, int] | None = None 

93 

94 if vendor_part_id is None: 

95 # If the vendor_part_id is None, check if the latest part is a TextPart to update 

96 if self._parts: 

97 part_index = len(self._parts) - 1 

98 latest_part = self._parts[part_index] 

99 if isinstance(latest_part, TextPart): 

100 existing_text_part_and_index = latest_part, part_index 

101 else: 

102 # Otherwise, attempt to look up an existing TextPart by vendor_part_id 

103 part_index = self._vendor_id_to_part_index.get(vendor_part_id) 

104 if part_index is not None: 

105 existing_part = self._parts[part_index] 

106 if not isinstance(existing_part, TextPart): 

107 raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}') 

108 existing_text_part_and_index = existing_part, part_index 

109 

110 if existing_text_part_and_index is None: 

111 # There is no existing text part that should be updated, so create a new one 

112 new_part_index = len(self._parts) 

113 part = TextPart(content=content) 

114 if vendor_part_id is not None: 

115 self._vendor_id_to_part_index[vendor_part_id] = new_part_index 

116 self._parts.append(part) 

117 return PartStartEvent(index=new_part_index, part=part) 

118 else: 

119 # Update the existing TextPart with the new content delta 

120 existing_text_part, part_index = existing_text_part_and_index 

121 part_delta = TextPartDelta(content_delta=content) 

122 self._parts[part_index] = part_delta.apply(existing_text_part) 

123 return PartDeltaEvent(index=part_index, delta=part_delta) 

124 

125 def handle_tool_call_delta( 

126 self, 

127 *, 

128 vendor_part_id: Hashable | None, 

129 tool_name: str | None, 

130 args: str | dict[str, Any] | None, 

131 tool_call_id: str | None, 

132 ) -> ModelResponseStreamEvent | None: 

133 """Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`. 

134 

135 Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which 

136 point they are upgraded to `ToolCallPart`s. 

137 

138 If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta) 

139 if any. Otherwise, a new part (or delta) may be created. 

140 

141 Args: 

142 vendor_part_id: The ID the vendor uses for this tool call. 

143 If None, the latest matching tool call may be updated. 

144 tool_name: The name of the tool. If None, the manager does not enforce 

145 a name match when `vendor_part_id` is None. 

146 args: Arguments for the tool call, either as a string or a dictionary of key-value pairs. 

147 tool_call_id: An optional string representing an identifier for this tool call. 

148 

149 Returns: 

150 - A `PartStartEvent` if a new (fully realized) ToolCallPart is created. 

151 - A `PartDeltaEvent` if an existing part is updated. 

152 - `None` if no new event is emitted (e.g., the part is still incomplete). 

153 

154 Raises: 

155 UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not 

156 a ToolCallPart or ToolCallPartDelta. 

157 """ 

158 existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None 

159 

160 if vendor_part_id is None: 

161 # vendor_part_id is None, so check if the latest part is a matching tool call or delta to update 

162 # When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather 

163 # than a delta on an existing one. We can change this behavior in the future if necessary for some model. 

164 if tool_name is None and self._parts: 

165 part_index = len(self._parts) - 1 

166 latest_part = self._parts[part_index] 

167 if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)): 167 ↛ 178line 167 didn't jump to line 178 because the condition on line 167 was always true

168 existing_matching_part_and_index = latest_part, part_index 

169 else: 

170 # vendor_part_id is provided, so look up the corresponding part or delta 

171 part_index = self._vendor_id_to_part_index.get(vendor_part_id) 

172 if part_index is not None: 

173 existing_part = self._parts[part_index] 

174 if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)): 

175 raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}') 

176 existing_matching_part_and_index = existing_part, part_index 

177 

178 if existing_matching_part_and_index is None: 

179 # No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed) 

180 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) 

181 part = delta.as_part() or delta 

182 if vendor_part_id is not None: 

183 self._vendor_id_to_part_index[vendor_part_id] = len(self._parts) 

184 new_part_index = len(self._parts) 

185 self._parts.append(part) 

186 # Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart 

187 if isinstance(part, ToolCallPart): 

188 return PartStartEvent(index=new_part_index, part=part) 

189 else: 

190 # Update the existing part or delta with the new information 

191 existing_part, part_index = existing_matching_part_and_index 

192 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) 

193 updated_part = delta.apply(existing_part) 

194 self._parts[part_index] = updated_part 

195 if isinstance(updated_part, ToolCallPart): 

196 if isinstance(existing_part, ToolCallPartDelta): 

197 # We just upgraded a delta to a full part, so emit a PartStartEvent 

198 return PartStartEvent(index=part_index, part=updated_part) 

199 else: 

200 # We updated an existing part, so emit a PartDeltaEvent 

201 return PartDeltaEvent(index=part_index, delta=delta) 

202 

203 def handle_tool_call_part( 

204 self, 

205 *, 

206 vendor_part_id: Hashable | None, 

207 tool_name: str, 

208 args: str | dict[str, Any], 

209 tool_call_id: str | None = None, 

210 ) -> ModelResponseStreamEvent: 

211 """Immediately create or fully-overwrite a ToolCallPart with the given information. 

212 

213 This does not apply a delta; it directly sets the tool call part contents. 

214 

215 Args: 

216 vendor_part_id: The vendor's ID for this tool call part. If not 

217 None and an existing part is found, that part is overwritten. 

218 tool_name: The name of the tool being invoked. 

219 args: The arguments for the tool call, either as a string or a dictionary. 

220 tool_call_id: An optional string identifier for this tool call. 

221 

222 Returns: 

223 ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part 

224 has been added to the manager, or replaced an existing part. 

225 """ 

226 new_part = ToolCallPart( 

227 tool_name=tool_name, 

228 args=args, 

229 tool_call_id=tool_call_id or _generate_tool_call_id(), 

230 ) 

231 if vendor_part_id is None: 

232 # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list 

233 new_part_index = len(self._parts) 

234 self._parts.append(new_part) 

235 else: 

236 # vendor_part_id is provided, so find and overwrite or create a new ToolCallPart. 

237 maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id) 

238 if maybe_part_index is not None: 

239 new_part_index = maybe_part_index 

240 self._parts[new_part_index] = new_part 

241 else: 

242 new_part_index = len(self._parts) 

243 self._parts.append(new_part) 

244 self._vendor_id_to_part_index[vendor_part_id] = new_part_index 

245 return PartStartEvent(index=new_part_index, part=new_part)