Coverage for pydantic_ai_slim/pydantic_ai/result.py: 89.04%
213 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
4from copy import copy
5from dataclasses import dataclass, field
6from datetime import datetime
7from typing import Generic, Union, cast
9from typing_extensions import TypeVar, assert_type
11from . import _result, _utils, exceptions, messages as _messages, models
12from .messages import AgentStreamEvent, FinalResultEvent
13from .tools import AgentDepsT, RunContext
14from .usage import Usage, UsageLimits
16__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc'
19T = TypeVar('T')
20"""An invariant TypeVar."""
21ResultDataT_inv = TypeVar('ResultDataT_inv', default=str)
22"""
23An invariant type variable for the result data of a model.
25We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used
26in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types
27possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and
28changing it would have negative consequences for the ergonomics of the library.
30At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would
31resolve these potential variance issues.
32"""
33ResultDataT = TypeVar('ResultDataT', default=str, covariant=True)
34"""Covariant type variable for the result data type of a run."""
36ResultValidatorFunc = Union[
37 Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv],
38 Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]],
39 Callable[[ResultDataT_inv], ResultDataT_inv],
40 Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]],
41]
42"""
43A function that always takes and returns the same type of data (which is the result type of an agent run), and:
45* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
46* may or may not be async
48Usage `ResultValidatorFunc[AgentDepsT, T]`.
49"""
52@dataclass
53class AgentStream(Generic[AgentDepsT, ResultDataT]):
54 _raw_stream_response: models.StreamedResponse
55 _result_schema: _result.ResultSchema[ResultDataT] | None
56 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
57 _run_ctx: RunContext[AgentDepsT]
58 _usage_limits: UsageLimits | None
60 _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
61 _final_result_event: FinalResultEvent | None = field(default=None, init=False)
62 _initial_run_ctx_usage: Usage = field(init=False)
64 def __post_init__(self):
65 self._initial_run_ctx_usage = copy(self._run_ctx.usage)
67 async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
68 """Asynchronously stream the (validated) agent outputs."""
69 async for response in self.stream_responses(debounce_by=debounce_by):
70 if self._final_result_event is not None:
71 yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
72 if self._final_result_event is not None: 72 ↛ exitline 72 didn't return from function 'stream_output' because the condition on line 72 was always true
73 yield await self._validate_response(
74 self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
75 )
77 async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
78 """Asynchronously stream the (unvalidated) model responses for the agent."""
79 # if the message currently has any parts with content, yield before streaming
80 msg = self._raw_stream_response.get()
81 for part in msg.parts: 81 ↛ 82line 81 didn't jump to line 82 because the loop on line 81 never started
82 if part.has_content():
83 yield msg
84 break
86 async with _utils.group_by_temporal(self, debounce_by) as group_iter:
87 async for _items in group_iter:
88 yield self._raw_stream_response.get() # current state of the response
90 def usage(self) -> Usage:
91 """Return the usage of the whole run.
93 !!! note
94 This won't return the full usage until the stream is finished.
95 """
96 return self._initial_run_ctx_usage + self._raw_stream_response.usage()
98 async def _validate_response(
99 self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False
100 ) -> ResultDataT:
101 """Validate a structured result message."""
102 if self._result_schema is not None and result_tool_name is not None:
103 match = self._result_schema.find_named_tool(message.parts, result_tool_name)
104 if match is None: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true
105 raise exceptions.UnexpectedModelBehavior(
106 f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
107 )
109 call, result_tool = match
110 result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
112 for validator in self._result_validators:
113 result_data = await validator.validate(result_data, call, self._run_ctx)
114 return result_data
115 else:
116 text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
117 for validator in self._result_validators:
118 text = await validator.validate(
119 text,
120 None,
121 self._run_ctx,
122 )
123 # Since there is no result tool, we can assume that str is compatible with ResultDataT
124 return cast(ResultDataT, text)
126 def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
127 """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
129 This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches
130 on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
131 first match is found.
132 """
133 if self._agent_stream_iterator is not None:
134 return self._agent_stream_iterator
136 async def aiter():
137 result_schema = self._result_schema
138 allow_text_result = result_schema is None or result_schema.allow_text_result
140 def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
141 """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
142 if isinstance(e, _messages.PartStartEvent):
143 new_part = e.part
144 if isinstance(new_part, _messages.ToolCallPart):
145 if result_schema:
146 for call, _ in result_schema.find_tool([new_part]): 146 ↛ exitline 146 didn't return from function '_get_final_result_event' because the loop on line 146 didn't complete
147 return _messages.FinalResultEvent(
148 tool_name=call.tool_name, tool_call_id=call.tool_call_id
149 )
150 elif allow_text_result: 150 ↛ exitline 150 didn't return from function '_get_final_result_event' because the condition on line 150 was always true
151 assert_type(e, _messages.PartStartEvent)
152 return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
154 usage_checking_stream = _get_usage_checking_stream_response(
155 self._raw_stream_response, self._usage_limits, self.usage
156 )
157 async for event in usage_checking_stream:
158 yield event
159 if (final_result_event := _get_final_result_event(event)) is not None:
160 self._final_result_event = final_result_event
161 yield final_result_event
162 break
164 # If we broke out of the above loop, we need to yield the rest of the events
165 # If we didn't, this will just be a no-op
166 async for event in usage_checking_stream:
167 yield event
169 self._agent_stream_iterator = aiter()
170 return self._agent_stream_iterator
173@dataclass
174class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
175 """Result of a streamed run that returns structured data via a tool call."""
177 _all_messages: list[_messages.ModelMessage]
178 _new_message_index: int
180 _usage_limits: UsageLimits | None
181 _stream_response: models.StreamedResponse
182 _result_schema: _result.ResultSchema[ResultDataT] | None
183 _run_ctx: RunContext[AgentDepsT]
184 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
185 _result_tool_name: str | None
186 _on_complete: Callable[[], Awaitable[None]]
188 _initial_run_ctx_usage: Usage = field(init=False)
189 is_complete: bool = field(default=False, init=False)
190 """Whether the stream has all been received.
192 This is set to `True` when one of
193 [`stream`][pydantic_ai.result.StreamedRunResult.stream],
194 [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text],
195 [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or
196 [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes.
197 """
199 def __post_init__(self):
200 self._initial_run_ctx_usage = copy(self._run_ctx.usage)
202 def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
203 """Return the history of _messages.
205 Args:
206 result_tool_return_content: The return content of the tool call to set in the last message.
207 This provides a convenient way to modify the content of the result tool call if you want to continue
208 the conversation and want to set the response to the result tool call. If `None`, the last message will
209 not be modified.
211 Returns:
212 List of messages.
213 """
214 # this is a method to be consistent with the other methods
215 if result_tool_return_content is not None:
216 raise NotImplementedError('Setting result tool return content is not supported for this result type.')
217 return self._all_messages
219 def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
220 """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResult.all_messages] as JSON bytes.
222 Args:
223 result_tool_return_content: The return content of the tool call to set in the last message.
224 This provides a convenient way to modify the content of the result tool call if you want to continue
225 the conversation and want to set the response to the result tool call. If `None`, the last message will
226 not be modified.
228 Returns:
229 JSON bytes representing the messages.
230 """
231 return _messages.ModelMessagesTypeAdapter.dump_json(
232 self.all_messages(result_tool_return_content=result_tool_return_content)
233 )
235 def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
236 """Return new messages associated with this run.
238 Messages from older runs are excluded.
240 Args:
241 result_tool_return_content: The return content of the tool call to set in the last message.
242 This provides a convenient way to modify the content of the result tool call if you want to continue
243 the conversation and want to set the response to the result tool call. If `None`, the last message will
244 not be modified.
246 Returns:
247 List of new messages.
248 """
249 return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
251 def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
252 """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResult.new_messages] as JSON bytes.
254 Args:
255 result_tool_return_content: The return content of the tool call to set in the last message.
256 This provides a convenient way to modify the content of the result tool call if you want to continue
257 the conversation and want to set the response to the result tool call. If `None`, the last message will
258 not be modified.
260 Returns:
261 JSON bytes representing the new messages.
262 """
263 return _messages.ModelMessagesTypeAdapter.dump_json(
264 self.new_messages(result_tool_return_content=result_tool_return_content)
265 )
267 async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
268 """Stream the response as an async iterable.
270 The pydantic validator for structured data will be called in
271 [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
272 on each iteration.
274 Args:
275 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
276 Debouncing is particularly important for long structured responses to reduce the overhead of
277 performing validation as each token is received.
279 Returns:
280 An async iterable of the response data.
281 """
282 async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
283 result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
284 yield result
286 async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
287 """Stream the text result as an async iterable.
289 !!! note
290 Result validators will NOT be called on the text result if `delta=True`.
292 Args:
293 delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
294 up to the current point.
295 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
296 Debouncing is particularly important for long structured responses to reduce the overhead of
297 performing validation as each token is received.
298 """
299 if self._result_schema and not self._result_schema.allow_text_result:
300 raise exceptions.UserError('stream_text() can only be used with text responses')
302 if delta:
303 async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
304 yield text
305 else:
306 async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
307 combined_validated_text = await self._validate_text_result(text)
308 yield combined_validated_text
309 await self._marked_completed(self._stream_response.get())
311 async def stream_structured(
312 self, *, debounce_by: float | None = 0.1
313 ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
314 """Stream the response as an async iterable of Structured LLM Messages.
316 Args:
317 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
318 Debouncing is particularly important for long structured responses to reduce the overhead of
319 performing validation as each token is received.
321 Returns:
322 An async iterable of the structured response message and whether that is the last message.
323 """
324 # if the message currently has any parts with content, yield before streaming
325 msg = self._stream_response.get()
326 for part in msg.parts:
327 if part.has_content():
328 yield msg, False
329 break
331 async for msg in self._stream_response_structured(debounce_by=debounce_by):
332 yield msg, False
334 msg = self._stream_response.get()
335 yield msg, True
337 await self._marked_completed(msg)
339 async def get_data(self) -> ResultDataT:
340 """Stream the whole response, validate and return it."""
341 usage_checking_stream = _get_usage_checking_stream_response(
342 self._stream_response, self._usage_limits, self.usage
343 )
345 async for _ in usage_checking_stream:
346 pass
347 message = self._stream_response.get()
348 await self._marked_completed(message)
349 return await self.validate_structured_result(message)
351 def usage(self) -> Usage:
352 """Return the usage of the whole run.
354 !!! note
355 This won't return the full usage until the stream is finished.
356 """
357 return self._initial_run_ctx_usage + self._stream_response.usage()
359 def timestamp(self) -> datetime:
360 """Get the timestamp of the response."""
361 return self._stream_response.timestamp
363 async def validate_structured_result(
364 self, message: _messages.ModelResponse, *, allow_partial: bool = False
365 ) -> ResultDataT:
366 """Validate a structured result message."""
367 if self._result_schema is not None and self._result_tool_name is not None:
368 match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
369 if match is None: 369 ↛ 370line 369 didn't jump to line 370 because the condition on line 369 was never true
370 raise exceptions.UnexpectedModelBehavior(
371 f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
372 )
374 call, result_tool = match
375 result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
377 for validator in self._result_validators: 377 ↛ 378line 377 didn't jump to line 378 because the loop on line 377 never started
378 result_data = await validator.validate(result_data, call, self._run_ctx)
379 return result_data
380 else:
381 text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
382 for validator in self._result_validators: 382 ↛ 383line 382 didn't jump to line 383 because the loop on line 382 never started
383 text = await validator.validate(
384 text,
385 None,
386 self._run_ctx,
387 )
388 # Since there is no result tool, we can assume that str is compatible with ResultDataT
389 return cast(ResultDataT, text)
391 async def _validate_text_result(self, text: str) -> str:
392 for validator in self._result_validators: 392 ↛ 393line 392 didn't jump to line 393 because the loop on line 392 never started
393 text = await validator.validate(
394 text,
395 None,
396 self._run_ctx,
397 )
398 return text
400 async def _marked_completed(self, message: _messages.ModelResponse) -> None:
401 self.is_complete = True
402 self._all_messages.append(message)
403 await self._on_complete()
405 async def _stream_response_structured(
406 self, *, debounce_by: float | None = 0.1
407 ) -> AsyncIterator[_messages.ModelResponse]:
408 async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
409 async for _items in group_iter:
410 yield self._stream_response.get()
412 async def _stream_response_text(
413 self, *, delta: bool = False, debounce_by: float | None = 0.1
414 ) -> AsyncIterator[str]:
415 """Stream the response as an async iterable of text."""
417 # Define a "merged" version of the iterator that will yield items that have already been retrieved
418 # and items that we receive while streaming. We define a dedicated async iterator for this so we can
419 # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
420 async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
421 # yields tuples of (text_content, part_index)
422 # we don't currently make use of the part_index, but in principle this may be useful
423 # so we retain it here for now to make possible future refactors simpler
424 msg = self._stream_response.get()
425 for i, part in enumerate(msg.parts):
426 if isinstance(part, _messages.TextPart) and part.content:
427 yield part.content, i
429 async for event in self._stream_response:
430 if ( 430 ↛ 435line 430 didn't jump to line 435 because the condition on line 430 was never true
431 isinstance(event, _messages.PartStartEvent)
432 and isinstance(event.part, _messages.TextPart)
433 and event.part.content
434 ):
435 yield event.part.content, event.index
436 elif ( 436 ↛ 429line 436 didn't jump to line 429 because the condition on line 436 was always true
437 isinstance(event, _messages.PartDeltaEvent)
438 and isinstance(event.delta, _messages.TextPartDelta)
439 and event.delta.content_delta
440 ):
441 yield event.delta.content_delta, event.index
443 async def _stream_text_deltas() -> AsyncIterator[str]:
444 async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
445 async for items in group_iter:
446 # Note: we are currently just dropping the part index on the group here
447 yield ''.join([content for content, _ in items])
449 if delta:
450 async for text in _stream_text_deltas():
451 yield text
452 else:
453 # a quick benchmark shows it's faster to build up a string with concat when we're
454 # yielding at each step
455 deltas: list[str] = []
456 async for text in _stream_text_deltas():
457 deltas.append(text)
458 yield ''.join(deltas)
461@dataclass
462class FinalResult(Generic[ResultDataT]):
463 """Marker class storing the final result of an agent run and associated metadata."""
465 data: ResultDataT
466 """The final result data."""
467 tool_name: str | None
468 """Name of the final result tool; `None` if the result came from unstructured text content."""
469 tool_call_id: str | None
470 """ID of the tool call that produced the final result; `None` if the result came from unstructured text content."""
473def _get_usage_checking_stream_response(
474 stream_response: AsyncIterable[_messages.ModelResponseStreamEvent],
475 limits: UsageLimits | None,
476 get_usage: Callable[[], Usage],
477) -> AsyncIterable[_messages.ModelResponseStreamEvent]:
478 if limits is not None and limits.has_token_limits(): 478 ↛ 480line 478 didn't jump to line 480 because the condition on line 478 was never true
480 async def _usage_checking_iterator():
481 async for item in stream_response:
482 limits.check_tokens(get_usage())
483 yield item
485 return _usage_checking_iterator()
486 else:
487 return stream_response