Coverage for pydantic_ai_slim/pydantic_ai/result.py: 89.04%

213 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable 

4from copy import copy 

5from dataclasses import dataclass, field 

6from datetime import datetime 

7from typing import Generic, Union, cast 

8 

9from typing_extensions import TypeVar, assert_type 

10 

11from . import _result, _utils, exceptions, messages as _messages, models 

12from .messages import AgentStreamEvent, FinalResultEvent 

13from .tools import AgentDepsT, RunContext 

14from .usage import Usage, UsageLimits 

15 

16__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc' 

17 

18 

19T = TypeVar('T') 

20"""An invariant TypeVar.""" 

21ResultDataT_inv = TypeVar('ResultDataT_inv', default=str) 

22""" 

23An invariant type variable for the result data of a model. 

24 

25We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used 

26in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types 

27possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and 

28changing it would have negative consequences for the ergonomics of the library. 

29 

30At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would 

31resolve these potential variance issues. 

32""" 

33ResultDataT = TypeVar('ResultDataT', default=str, covariant=True) 

34"""Covariant type variable for the result data type of a run.""" 

35 

36ResultValidatorFunc = Union[ 

37 Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv], 

38 Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]], 

39 Callable[[ResultDataT_inv], ResultDataT_inv], 

40 Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]], 

41] 

42""" 

43A function that always takes and returns the same type of data (which is the result type of an agent run), and: 

44 

45* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument 

46* may or may not be async 

47 

48Usage `ResultValidatorFunc[AgentDepsT, T]`. 

49""" 

50 

51 

52@dataclass 

53class AgentStream(Generic[AgentDepsT, ResultDataT]): 

54 _raw_stream_response: models.StreamedResponse 

55 _result_schema: _result.ResultSchema[ResultDataT] | None 

56 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] 

57 _run_ctx: RunContext[AgentDepsT] 

58 _usage_limits: UsageLimits | None 

59 

60 _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) 

61 _final_result_event: FinalResultEvent | None = field(default=None, init=False) 

62 _initial_run_ctx_usage: Usage = field(init=False) 

63 

64 def __post_init__(self): 

65 self._initial_run_ctx_usage = copy(self._run_ctx.usage) 

66 

67 async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: 

68 """Asynchronously stream the (validated) agent outputs.""" 

69 async for response in self.stream_responses(debounce_by=debounce_by): 

70 if self._final_result_event is not None: 

71 yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True) 

72 if self._final_result_event is not None: 72 ↛ exitline 72 didn't return from function 'stream_output' because the condition on line 72 was always true

73 yield await self._validate_response( 

74 self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False 

75 ) 

76 

77 async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: 

78 """Asynchronously stream the (unvalidated) model responses for the agent.""" 

79 # if the message currently has any parts with content, yield before streaming 

80 msg = self._raw_stream_response.get() 

81 for part in msg.parts: 81 ↛ 82line 81 didn't jump to line 82 because the loop on line 81 never started

82 if part.has_content(): 

83 yield msg 

84 break 

85 

86 async with _utils.group_by_temporal(self, debounce_by) as group_iter: 

87 async for _items in group_iter: 

88 yield self._raw_stream_response.get() # current state of the response 

89 

90 def usage(self) -> Usage: 

91 """Return the usage of the whole run. 

92 

93 !!! note 

94 This won't return the full usage until the stream is finished. 

95 """ 

96 return self._initial_run_ctx_usage + self._raw_stream_response.usage() 

97 

98 async def _validate_response( 

99 self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False 

100 ) -> ResultDataT: 

101 """Validate a structured result message.""" 

102 if self._result_schema is not None and result_tool_name is not None: 

103 match = self._result_schema.find_named_tool(message.parts, result_tool_name) 

104 if match is None: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true

105 raise exceptions.UnexpectedModelBehavior( 

106 f'Invalid response, unable to find tool: {self._result_schema.tool_names()}' 

107 ) 

108 

109 call, result_tool = match 

110 result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) 

111 

112 for validator in self._result_validators: 

113 result_data = await validator.validate(result_data, call, self._run_ctx) 

114 return result_data 

115 else: 

116 text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) 

117 for validator in self._result_validators: 

118 text = await validator.validate( 

119 text, 

120 None, 

121 self._run_ctx, 

122 ) 

123 # Since there is no result tool, we can assume that str is compatible with ResultDataT 

124 return cast(ResultDataT, text) 

125 

126 def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: 

127 """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. 

128 

129 This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches 

130 on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the 

131 first match is found. 

132 """ 

133 if self._agent_stream_iterator is not None: 

134 return self._agent_stream_iterator 

135 

136 async def aiter(): 

137 result_schema = self._result_schema 

138 allow_text_result = result_schema is None or result_schema.allow_text_result 

139 

140 def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: 

141 """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" 

142 if isinstance(e, _messages.PartStartEvent): 

143 new_part = e.part 

144 if isinstance(new_part, _messages.ToolCallPart): 

145 if result_schema: 

146 for call, _ in result_schema.find_tool([new_part]): 146 ↛ exitline 146 didn't return from function '_get_final_result_event' because the loop on line 146 didn't complete

147 return _messages.FinalResultEvent( 

148 tool_name=call.tool_name, tool_call_id=call.tool_call_id 

149 ) 

150 elif allow_text_result: 150 ↛ exitline 150 didn't return from function '_get_final_result_event' because the condition on line 150 was always true

151 assert_type(e, _messages.PartStartEvent) 

152 return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) 

153 

154 usage_checking_stream = _get_usage_checking_stream_response( 

155 self._raw_stream_response, self._usage_limits, self.usage 

156 ) 

157 async for event in usage_checking_stream: 

158 yield event 

159 if (final_result_event := _get_final_result_event(event)) is not None: 

160 self._final_result_event = final_result_event 

161 yield final_result_event 

162 break 

163 

164 # If we broke out of the above loop, we need to yield the rest of the events 

165 # If we didn't, this will just be a no-op 

166 async for event in usage_checking_stream: 

167 yield event 

168 

169 self._agent_stream_iterator = aiter() 

170 return self._agent_stream_iterator 

171 

172 

173@dataclass 

174class StreamedRunResult(Generic[AgentDepsT, ResultDataT]): 

175 """Result of a streamed run that returns structured data via a tool call.""" 

176 

177 _all_messages: list[_messages.ModelMessage] 

178 _new_message_index: int 

179 

180 _usage_limits: UsageLimits | None 

181 _stream_response: models.StreamedResponse 

182 _result_schema: _result.ResultSchema[ResultDataT] | None 

183 _run_ctx: RunContext[AgentDepsT] 

184 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] 

185 _result_tool_name: str | None 

186 _on_complete: Callable[[], Awaitable[None]] 

187 

188 _initial_run_ctx_usage: Usage = field(init=False) 

189 is_complete: bool = field(default=False, init=False) 

190 """Whether the stream has all been received. 

191 

192 This is set to `True` when one of 

193 [`stream`][pydantic_ai.result.StreamedRunResult.stream], 

194 [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text], 

195 [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or 

196 [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes. 

197 """ 

198 

199 def __post_init__(self): 

200 self._initial_run_ctx_usage = copy(self._run_ctx.usage) 

201 

202 def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: 

203 """Return the history of _messages. 

204 

205 Args: 

206 result_tool_return_content: The return content of the tool call to set in the last message. 

207 This provides a convenient way to modify the content of the result tool call if you want to continue 

208 the conversation and want to set the response to the result tool call. If `None`, the last message will 

209 not be modified. 

210 

211 Returns: 

212 List of messages. 

213 """ 

214 # this is a method to be consistent with the other methods 

215 if result_tool_return_content is not None: 

216 raise NotImplementedError('Setting result tool return content is not supported for this result type.') 

217 return self._all_messages 

218 

219 def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: 

220 """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResult.all_messages] as JSON bytes. 

221 

222 Args: 

223 result_tool_return_content: The return content of the tool call to set in the last message. 

224 This provides a convenient way to modify the content of the result tool call if you want to continue 

225 the conversation and want to set the response to the result tool call. If `None`, the last message will 

226 not be modified. 

227 

228 Returns: 

229 JSON bytes representing the messages. 

230 """ 

231 return _messages.ModelMessagesTypeAdapter.dump_json( 

232 self.all_messages(result_tool_return_content=result_tool_return_content) 

233 ) 

234 

235 def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: 

236 """Return new messages associated with this run. 

237 

238 Messages from older runs are excluded. 

239 

240 Args: 

241 result_tool_return_content: The return content of the tool call to set in the last message. 

242 This provides a convenient way to modify the content of the result tool call if you want to continue 

243 the conversation and want to set the response to the result tool call. If `None`, the last message will 

244 not be modified. 

245 

246 Returns: 

247 List of new messages. 

248 """ 

249 return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] 

250 

251 def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: 

252 """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResult.new_messages] as JSON bytes. 

253 

254 Args: 

255 result_tool_return_content: The return content of the tool call to set in the last message. 

256 This provides a convenient way to modify the content of the result tool call if you want to continue 

257 the conversation and want to set the response to the result tool call. If `None`, the last message will 

258 not be modified. 

259 

260 Returns: 

261 JSON bytes representing the new messages. 

262 """ 

263 return _messages.ModelMessagesTypeAdapter.dump_json( 

264 self.new_messages(result_tool_return_content=result_tool_return_content) 

265 ) 

266 

267 async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: 

268 """Stream the response as an async iterable. 

269 

270 The pydantic validator for structured data will be called in 

271 [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation) 

272 on each iteration. 

273 

274 Args: 

275 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. 

276 Debouncing is particularly important for long structured responses to reduce the overhead of 

277 performing validation as each token is received. 

278 

279 Returns: 

280 An async iterable of the response data. 

281 """ 

282 async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): 

283 result = await self.validate_structured_result(structured_message, allow_partial=not is_last) 

284 yield result 

285 

286 async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: 

287 """Stream the text result as an async iterable. 

288 

289 !!! note 

290 Result validators will NOT be called on the text result if `delta=True`. 

291 

292 Args: 

293 delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text 

294 up to the current point. 

295 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. 

296 Debouncing is particularly important for long structured responses to reduce the overhead of 

297 performing validation as each token is received. 

298 """ 

299 if self._result_schema and not self._result_schema.allow_text_result: 

300 raise exceptions.UserError('stream_text() can only be used with text responses') 

301 

302 if delta: 

303 async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): 

304 yield text 

305 else: 

306 async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): 

307 combined_validated_text = await self._validate_text_result(text) 

308 yield combined_validated_text 

309 await self._marked_completed(self._stream_response.get()) 

310 

311 async def stream_structured( 

312 self, *, debounce_by: float | None = 0.1 

313 ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]: 

314 """Stream the response as an async iterable of Structured LLM Messages. 

315 

316 Args: 

317 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. 

318 Debouncing is particularly important for long structured responses to reduce the overhead of 

319 performing validation as each token is received. 

320 

321 Returns: 

322 An async iterable of the structured response message and whether that is the last message. 

323 """ 

324 # if the message currently has any parts with content, yield before streaming 

325 msg = self._stream_response.get() 

326 for part in msg.parts: 

327 if part.has_content(): 

328 yield msg, False 

329 break 

330 

331 async for msg in self._stream_response_structured(debounce_by=debounce_by): 

332 yield msg, False 

333 

334 msg = self._stream_response.get() 

335 yield msg, True 

336 

337 await self._marked_completed(msg) 

338 

339 async def get_data(self) -> ResultDataT: 

340 """Stream the whole response, validate and return it.""" 

341 usage_checking_stream = _get_usage_checking_stream_response( 

342 self._stream_response, self._usage_limits, self.usage 

343 ) 

344 

345 async for _ in usage_checking_stream: 

346 pass 

347 message = self._stream_response.get() 

348 await self._marked_completed(message) 

349 return await self.validate_structured_result(message) 

350 

351 def usage(self) -> Usage: 

352 """Return the usage of the whole run. 

353 

354 !!! note 

355 This won't return the full usage until the stream is finished. 

356 """ 

357 return self._initial_run_ctx_usage + self._stream_response.usage() 

358 

359 def timestamp(self) -> datetime: 

360 """Get the timestamp of the response.""" 

361 return self._stream_response.timestamp 

362 

363 async def validate_structured_result( 

364 self, message: _messages.ModelResponse, *, allow_partial: bool = False 

365 ) -> ResultDataT: 

366 """Validate a structured result message.""" 

367 if self._result_schema is not None and self._result_tool_name is not None: 

368 match = self._result_schema.find_named_tool(message.parts, self._result_tool_name) 

369 if match is None: 369 ↛ 370line 369 didn't jump to line 370 because the condition on line 369 was never true

370 raise exceptions.UnexpectedModelBehavior( 

371 f'Invalid response, unable to find tool: {self._result_schema.tool_names()}' 

372 ) 

373 

374 call, result_tool = match 

375 result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) 

376 

377 for validator in self._result_validators: 377 ↛ 378line 377 didn't jump to line 378 because the loop on line 377 never started

378 result_data = await validator.validate(result_data, call, self._run_ctx) 

379 return result_data 

380 else: 

381 text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) 

382 for validator in self._result_validators: 382 ↛ 383line 382 didn't jump to line 383 because the loop on line 382 never started

383 text = await validator.validate( 

384 text, 

385 None, 

386 self._run_ctx, 

387 ) 

388 # Since there is no result tool, we can assume that str is compatible with ResultDataT 

389 return cast(ResultDataT, text) 

390 

391 async def _validate_text_result(self, text: str) -> str: 

392 for validator in self._result_validators: 392 ↛ 393line 392 didn't jump to line 393 because the loop on line 392 never started

393 text = await validator.validate( 

394 text, 

395 None, 

396 self._run_ctx, 

397 ) 

398 return text 

399 

400 async def _marked_completed(self, message: _messages.ModelResponse) -> None: 

401 self.is_complete = True 

402 self._all_messages.append(message) 

403 await self._on_complete() 

404 

405 async def _stream_response_structured( 

406 self, *, debounce_by: float | None = 0.1 

407 ) -> AsyncIterator[_messages.ModelResponse]: 

408 async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: 

409 async for _items in group_iter: 

410 yield self._stream_response.get() 

411 

412 async def _stream_response_text( 

413 self, *, delta: bool = False, debounce_by: float | None = 0.1 

414 ) -> AsyncIterator[str]: 

415 """Stream the response as an async iterable of text.""" 

416 

417 # Define a "merged" version of the iterator that will yield items that have already been retrieved 

418 # and items that we receive while streaming. We define a dedicated async iterator for this so we can 

419 # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. 

420 async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: 

421 # yields tuples of (text_content, part_index) 

422 # we don't currently make use of the part_index, but in principle this may be useful 

423 # so we retain it here for now to make possible future refactors simpler 

424 msg = self._stream_response.get() 

425 for i, part in enumerate(msg.parts): 

426 if isinstance(part, _messages.TextPart) and part.content: 

427 yield part.content, i 

428 

429 async for event in self._stream_response: 

430 if ( 430 ↛ 435line 430 didn't jump to line 435 because the condition on line 430 was never true

431 isinstance(event, _messages.PartStartEvent) 

432 and isinstance(event.part, _messages.TextPart) 

433 and event.part.content 

434 ): 

435 yield event.part.content, event.index 

436 elif ( 436 ↛ 429line 436 didn't jump to line 429 because the condition on line 436 was always true

437 isinstance(event, _messages.PartDeltaEvent) 

438 and isinstance(event.delta, _messages.TextPartDelta) 

439 and event.delta.content_delta 

440 ): 

441 yield event.delta.content_delta, event.index 

442 

443 async def _stream_text_deltas() -> AsyncIterator[str]: 

444 async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: 

445 async for items in group_iter: 

446 # Note: we are currently just dropping the part index on the group here 

447 yield ''.join([content for content, _ in items]) 

448 

449 if delta: 

450 async for text in _stream_text_deltas(): 

451 yield text 

452 else: 

453 # a quick benchmark shows it's faster to build up a string with concat when we're 

454 # yielding at each step 

455 deltas: list[str] = [] 

456 async for text in _stream_text_deltas(): 

457 deltas.append(text) 

458 yield ''.join(deltas) 

459 

460 

461@dataclass 

462class FinalResult(Generic[ResultDataT]): 

463 """Marker class storing the final result of an agent run and associated metadata.""" 

464 

465 data: ResultDataT 

466 """The final result data.""" 

467 tool_name: str | None 

468 """Name of the final result tool; `None` if the result came from unstructured text content.""" 

469 tool_call_id: str | None 

470 """ID of the tool call that produced the final result; `None` if the result came from unstructured text content.""" 

471 

472 

473def _get_usage_checking_stream_response( 

474 stream_response: AsyncIterable[_messages.ModelResponseStreamEvent], 

475 limits: UsageLimits | None, 

476 get_usage: Callable[[], Usage], 

477) -> AsyncIterable[_messages.ModelResponseStreamEvent]: 

478 if limits is not None and limits.has_token_limits(): 478 ↛ 480line 478 didn't jump to line 480 because the condition on line 478 was never true

479 

480 async def _usage_checking_iterator(): 

481 async for item in stream_response: 

482 limits.check_tokens(get_usage()) 

483 yield item 

484 

485 return _usage_checking_iterator() 

486 else: 

487 return stream_response