Coverage for pydantic_ai_slim/pydantic_ai/_parts_manager.py: 99.20%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1"""This module provides functionality to manage and update parts of a model's streamed response. 

2 

3The manager tracks which parts (in particular, text and tool calls) correspond to which 

4vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model), 

5and produces PydanticAI-format events as appropriate for consumers of the streaming APIs. 

6 

7The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor, 

8and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation. 

9 

10This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate 

11event-emitting logic. 

12""" 

13 

14from __future__ import annotations as _annotations 

15 

16from collections.abc import Hashable 

17from dataclasses import dataclass, field 

18from typing import Any, Union 

19 

20from pydantic_ai.exceptions import UnexpectedModelBehavior 

21from pydantic_ai.messages import ( 

22 ModelResponsePart, 

23 ModelResponseStreamEvent, 

24 PartDeltaEvent, 

25 PartStartEvent, 

26 TextPart, 

27 TextPartDelta, 

28 ToolCallPart, 

29 ToolCallPartDelta, 

30) 

31 

32VendorId = Hashable 

33""" 

34Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.) 

35""" 

36 

37ManagedPart = Union[ModelResponsePart, ToolCallPartDelta] 

38""" 

39A union of types that are managed by the ModelResponsePartsManager. 

40Because many vendors have streaming APIs that may produce not-fully-formed tool calls, 

41this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's. 

42""" 

43 

44 

45@dataclass 

46class ModelResponsePartsManager: 

47 """Manages a sequence of parts that make up a model's streamed response. 

48 

49 Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs. 

50 """ 

51 

52 _parts: list[ManagedPart] = field(default_factory=list, init=False) 

53 """A list of parts (text or tool calls) that make up the current state of the model's response.""" 

54 _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) 

55 """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides.""" 

56 

57 def get_parts(self) -> list[ModelResponsePart]: 

58 """Return only model response parts that are complete (i.e., not ToolCallPartDelta's). 

59 

60 Returns: 

61 A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded. 

62 """ 

63 return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] 

64 

65 def handle_text_delta( 

66 self, 

67 *, 

68 vendor_part_id: Hashable | None, 

69 content: str, 

70 ) -> ModelResponseStreamEvent: 

71 """Handle incoming text content, creating or updating a TextPart in the manager as appropriate. 

72 

73 When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart; 

74 otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding 

75 to that vendor ID is either created or updated. 

76 

77 Args: 

78 vendor_part_id: The ID the vendor uses to identify this piece 

79 of text. If None, a new part will be created unless the latest part is already 

80 a TextPart. 

81 content: The text content to append to the appropriate TextPart. 

82 

83 Returns: 

84 A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated. 

85 

86 Raises: 

87 UnexpectedModelBehavior: If attempting to apply text content to a part that is 

88 not a TextPart. 

89 """ 

90 existing_text_part_and_index: tuple[TextPart, int] | None = None 

91 

92 if vendor_part_id is None: 

93 # If the vendor_part_id is None, check if the latest part is a TextPart to update 

94 if self._parts: 

95 part_index = len(self._parts) - 1 

96 latest_part = self._parts[part_index] 

97 if isinstance(latest_part, TextPart): 

98 existing_text_part_and_index = latest_part, part_index 

99 else: 

100 # Otherwise, attempt to look up an existing TextPart by vendor_part_id 

101 part_index = self._vendor_id_to_part_index.get(vendor_part_id) 

102 if part_index is not None: 

103 existing_part = self._parts[part_index] 

104 if not isinstance(existing_part, TextPart): 

105 raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}') 

106 existing_text_part_and_index = existing_part, part_index 

107 

108 if existing_text_part_and_index is None: 

109 # There is no existing text part that should be updated, so create a new one 

110 new_part_index = len(self._parts) 

111 part = TextPart(content=content) 

112 if vendor_part_id is not None: 

113 self._vendor_id_to_part_index[vendor_part_id] = new_part_index 

114 self._parts.append(part) 

115 return PartStartEvent(index=new_part_index, part=part) 

116 else: 

117 # Update the existing TextPart with the new content delta 

118 existing_text_part, part_index = existing_text_part_and_index 

119 part_delta = TextPartDelta(content_delta=content) 

120 self._parts[part_index] = part_delta.apply(existing_text_part) 

121 return PartDeltaEvent(index=part_index, delta=part_delta) 

122 

123 def handle_tool_call_delta( 

124 self, 

125 *, 

126 vendor_part_id: Hashable | None, 

127 tool_name: str | None, 

128 args: str | dict[str, Any] | None, 

129 tool_call_id: str | None, 

130 ) -> ModelResponseStreamEvent | None: 

131 """Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`. 

132 

133 Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which 

134 point they are upgraded to `ToolCallPart`s. 

135 

136 If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta) 

137 if any. Otherwise, a new part (or delta) may be created. 

138 

139 Args: 

140 vendor_part_id: The ID the vendor uses for this tool call. 

141 If None, the latest matching tool call may be updated. 

142 tool_name: The name of the tool. If None, the manager does not enforce 

143 a name match when `vendor_part_id` is None. 

144 args: Arguments for the tool call, either as a string or a dictionary of key-value pairs. 

145 tool_call_id: An optional string representing an identifier for this tool call. 

146 

147 Returns: 

148 - A `PartStartEvent` if a new (fully realized) ToolCallPart is created. 

149 - A `PartDeltaEvent` if an existing part is updated. 

150 - `None` if no new event is emitted (e.g., the part is still incomplete). 

151 

152 Raises: 

153 UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not 

154 a ToolCallPart or ToolCallPartDelta. 

155 """ 

156 existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None 

157 

158 if vendor_part_id is None: 

159 # vendor_part_id is None, so check if the latest part is a matching tool call or delta to update 

160 # When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather 

161 # than a delta on an existing one. We can change this behavior in the future if necessary for some model. 

162 if tool_name is None and self._parts: 

163 part_index = len(self._parts) - 1 

164 latest_part = self._parts[part_index] 

165 if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)): 165 ↛ 176line 165 didn't jump to line 176 because the condition on line 165 was always true

166 existing_matching_part_and_index = latest_part, part_index 

167 else: 

168 # vendor_part_id is provided, so look up the corresponding part or delta 

169 part_index = self._vendor_id_to_part_index.get(vendor_part_id) 

170 if part_index is not None: 

171 existing_part = self._parts[part_index] 

172 if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)): 

173 raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}') 

174 existing_matching_part_and_index = existing_part, part_index 

175 

176 if existing_matching_part_and_index is None: 

177 # No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed) 

178 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) 

179 part = delta.as_part() or delta 

180 if vendor_part_id is not None: 

181 self._vendor_id_to_part_index[vendor_part_id] = len(self._parts) 

182 new_part_index = len(self._parts) 

183 self._parts.append(part) 

184 # Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart 

185 if isinstance(part, ToolCallPart): 

186 return PartStartEvent(index=new_part_index, part=part) 

187 else: 

188 # Update the existing part or delta with the new information 

189 existing_part, part_index = existing_matching_part_and_index 

190 delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) 

191 updated_part = delta.apply(existing_part) 

192 self._parts[part_index] = updated_part 

193 if isinstance(updated_part, ToolCallPart): 

194 if isinstance(existing_part, ToolCallPartDelta): 

195 # We just upgraded a delta to a full part, so emit a PartStartEvent 

196 return PartStartEvent(index=part_index, part=updated_part) 

197 else: 

198 # We updated an existing part, so emit a PartDeltaEvent 

199 return PartDeltaEvent(index=part_index, delta=delta) 

200 

201 def handle_tool_call_part( 

202 self, 

203 *, 

204 vendor_part_id: Hashable | None, 

205 tool_name: str, 

206 args: str | dict[str, Any], 

207 tool_call_id: str | None = None, 

208 ) -> ModelResponseStreamEvent: 

209 """Immediately create or fully-overwrite a ToolCallPart with the given information. 

210 

211 This does not apply a delta; it directly sets the tool call part contents. 

212 

213 Args: 

214 vendor_part_id: The vendor's ID for this tool call part. If not 

215 None and an existing part is found, that part is overwritten. 

216 tool_name: The name of the tool being invoked. 

217 args: The arguments for the tool call, either as a string or a dictionary. 

218 tool_call_id: An optional string identifier for this tool call. 

219 

220 Returns: 

221 ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part 

222 has been added to the manager, or replaced an existing part. 

223 """ 

224 new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id) 

225 if vendor_part_id is None: 

226 # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list 

227 new_part_index = len(self._parts) 

228 self._parts.append(new_part) 

229 else: 

230 # vendor_part_id is provided, so find and overwrite or create a new ToolCallPart. 

231 maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id) 

232 if maybe_part_index is not None: 

233 new_part_index = maybe_part_index 

234 self._parts[new_part_index] = new_part 

235 else: 

236 new_part_index = len(self._parts) 

237 self._parts.append(new_part) 

238 self._vendor_id_to_part_index[vendor_part_id] = new_part_index 

239 return PartStartEvent(index=new_part_index, part=new_part)