Coverage for pydantic_ai_slim/pydantic_ai/result.py: 94.34%
160 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3from abc import ABC, abstractmethod
4from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
5from copy import deepcopy
6from dataclasses import dataclass, field
7from datetime import datetime
8from typing import Generic, Union, cast
10import logfire_api
11from typing_extensions import TypeVar
13from . import _result, _utils, exceptions, messages as _messages, models
14from .tools import AgentDepsT, RunContext
15from .usage import Usage, UsageLimits
17__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
20T = TypeVar('T')
21"""An invariant TypeVar."""
22ResultDataT_inv = TypeVar('ResultDataT_inv', default=str)
23"""
24An invariant type variable for the result data of a model.
26We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used
27in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types
28possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and
29changing it would have negative consequences for the ergonomics of the library.
31At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would
32resolve these potential variance issues.
33"""
34ResultDataT = TypeVar('ResultDataT', default=str, covariant=True)
35"""Covariant type variable for the result data type of a run."""
37ResultValidatorFunc = Union[
38 Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv],
39 Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]],
40 Callable[[ResultDataT_inv], ResultDataT_inv],
41 Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]],
42]
43"""
44A function that always takes and returns the same type of data (which is the result type of an agent run), and:
46* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
47* may or may not be async
49Usage `ResultValidatorFunc[AgentDepsT, T]`.
50"""
52_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
55@dataclass
56class _BaseRunResult(ABC, Generic[ResultDataT]):
57 """Base type for results.
59 You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
60 """
62 _all_messages: list[_messages.ModelMessage]
63 _new_message_index: int
65 def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
66 """Return the history of _messages.
68 Args:
69 result_tool_return_content: The return content of the tool call to set in the last message.
70 This provides a convenient way to modify the content of the result tool call if you want to continue
71 the conversation and want to set the response to the result tool call. If `None`, the last message will
72 not be modified.
74 Returns:
75 List of messages.
76 """
77 # this is a method to be consistent with the other methods
78 if result_tool_return_content is not None:
79 raise NotImplementedError('Setting result tool return content is not supported for this result type.')
80 return self._all_messages
82 def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
83 """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.
85 Args:
86 result_tool_return_content: The return content of the tool call to set in the last message.
87 This provides a convenient way to modify the content of the result tool call if you want to continue
88 the conversation and want to set the response to the result tool call. If `None`, the last message will
89 not be modified.
91 Returns:
92 JSON bytes representing the messages.
93 """
94 return _messages.ModelMessagesTypeAdapter.dump_json(
95 self.all_messages(result_tool_return_content=result_tool_return_content)
96 )
98 def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
99 """Return new messages associated with this run.
101 Messages from older runs are excluded.
103 Args:
104 result_tool_return_content: The return content of the tool call to set in the last message.
105 This provides a convenient way to modify the content of the result tool call if you want to continue
106 the conversation and want to set the response to the result tool call. If `None`, the last message will
107 not be modified.
109 Returns:
110 List of new messages.
111 """
112 return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
114 def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
115 """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes.
117 Args:
118 result_tool_return_content: The return content of the tool call to set in the last message.
119 This provides a convenient way to modify the content of the result tool call if you want to continue
120 the conversation and want to set the response to the result tool call. If `None`, the last message will
121 not be modified.
123 Returns:
124 JSON bytes representing the new messages.
125 """
126 return _messages.ModelMessagesTypeAdapter.dump_json(
127 self.new_messages(result_tool_return_content=result_tool_return_content)
128 )
130 @abstractmethod
131 def usage(self) -> Usage:
132 raise NotImplementedError()
135@dataclass
136class RunResult(_BaseRunResult[ResultDataT]):
137 """Result of a non-streamed run."""
139 data: ResultDataT
140 """Data from the final response in the run."""
141 _result_tool_name: str | None
142 _usage: Usage
144 def usage(self) -> Usage:
145 """Return the usage of the whole run."""
146 return self._usage
148 def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
149 """Return the history of _messages.
151 Args:
152 result_tool_return_content: The return content of the tool call to set in the last message.
153 This provides a convenient way to modify the content of the result tool call if you want to continue
154 the conversation and want to set the response to the result tool call. If `None`, the last message will
155 not be modified.
157 Returns:
158 List of messages.
159 """
160 if result_tool_return_content is not None:
161 return self._set_result_tool_return(result_tool_return_content)
162 else:
163 return self._all_messages
165 def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
166 """Set return content for the result tool.
168 Useful if you want to continue the conversation and want to set the response to the result tool call.
169 """
170 if not self._result_tool_name:
171 raise ValueError('Cannot set result tool return content when the return type is `str`.')
172 messages = deepcopy(self._all_messages)
173 last_message = messages[-1]
174 for part in last_message.parts:
175 if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name:
176 part.content = return_content
177 return messages
178 raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.')
181@dataclass
182class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]):
183 """Result of a streamed run that returns structured data via a tool call."""
185 _usage_limits: UsageLimits | None
186 _stream_response: models.StreamedResponse
187 _result_schema: _result.ResultSchema[ResultDataT] | None
188 _run_ctx: RunContext[AgentDepsT]
189 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
190 _result_tool_name: str | None
191 _on_complete: Callable[[], Awaitable[None]]
192 is_complete: bool = field(default=False, init=False)
193 """Whether the stream has all been received.
195 This is set to `True` when one of
196 [`stream`][pydantic_ai.result.StreamedRunResult.stream],
197 [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text],
198 [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or
199 [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes.
200 """
202 async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
203 """Stream the response as an async iterable.
205 The pydantic validator for structured data will be called in
206 [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
207 on each iteration.
209 Args:
210 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
211 Debouncing is particularly important for long structured responses to reduce the overhead of
212 performing validation as each token is received.
214 Returns:
215 An async iterable of the response data.
216 """
217 async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
218 result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
219 yield result
221 async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
222 """Stream the text result as an async iterable.
224 !!! note
225 Result validators will NOT be called on the text result if `delta=True`.
227 Args:
228 delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
229 up to the current point.
230 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
231 Debouncing is particularly important for long structured responses to reduce the overhead of
232 performing validation as each token is received.
233 """
234 if self._result_schema and not self._result_schema.allow_text_result:
235 raise exceptions.UserError('stream_text() can only be used with text responses')
237 usage_checking_stream = _get_usage_checking_stream_response(
238 self._stream_response, self._usage_limits, self.usage
239 )
241 # Define a "merged" version of the iterator that will yield items that have already been retrieved
242 # and items that we receive while streaming. We define a dedicated async iterator for this so we can
243 # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
244 async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
245 # if the response currently has any parts with content, yield those before streaming
246 msg = self._stream_response.get()
247 for i, part in enumerate(msg.parts):
248 if isinstance(part, _messages.TextPart) and part.content:
249 yield part.content, i
251 async for event in usage_checking_stream:
252 if ( 252 ↛ 257line 252 didn't jump to line 257
253 isinstance(event, _messages.PartStartEvent)
254 and isinstance(event.part, _messages.TextPart)
255 and event.part.content
256 ):
257 yield event.part.content, event.index
258 elif ( 258 ↛ 251line 258 didn't jump to line 251
259 isinstance(event, _messages.PartDeltaEvent)
260 and isinstance(event.delta, _messages.TextPartDelta)
261 and event.delta.content_delta
262 ):
263 yield event.delta.content_delta, event.index
265 async def _stream_text_deltas() -> AsyncIterator[str]:
266 async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
267 async for items in group_iter:
268 yield ''.join([content for content, _ in items])
270 with _logfire.span('response stream text') as lf_span:
271 if delta:
272 async for text in _stream_text_deltas():
273 yield text
274 else:
275 # a quick benchmark shows it's faster to build up a string with concat when we're
276 # yielding at each step
277 deltas: list[str] = []
278 combined_validated_text = ''
279 async for text in _stream_text_deltas():
280 deltas.append(text)
281 combined_text = ''.join(deltas)
282 combined_validated_text = await self._validate_text_result(combined_text)
283 yield combined_validated_text
285 lf_span.set_attribute('combined_text', combined_validated_text)
286 await self._marked_completed(
287 _messages.ModelResponse(
288 parts=[_messages.TextPart(combined_validated_text)],
289 model_name=self._stream_response.model_name(),
290 )
291 )
293 async def stream_structured(
294 self, *, debounce_by: float | None = 0.1
295 ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
296 """Stream the response as an async iterable of Structured LLM Messages.
298 Args:
299 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
300 Debouncing is particularly important for long structured responses to reduce the overhead of
301 performing validation as each token is received.
303 Returns:
304 An async iterable of the structured response message and whether that is the last message.
305 """
306 usage_checking_stream = _get_usage_checking_stream_response(
307 self._stream_response, self._usage_limits, self.usage
308 )
310 with _logfire.span('response stream structured') as lf_span:
311 # if the message currently has any parts with content, yield before streaming
312 msg = self._stream_response.get()
313 for part in msg.parts:
314 if part.has_content():
315 yield msg, False
316 break
318 async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
319 async for _events in group_iter:
320 msg = self._stream_response.get()
321 yield msg, False
322 msg = self._stream_response.get()
323 yield msg, True
324 # TODO: Should this now be `final_response` instead of `structured_response`?
325 lf_span.set_attribute('structured_response', msg)
326 await self._marked_completed(msg)
328 async def get_data(self) -> ResultDataT:
329 """Stream the whole response, validate and return it."""
330 usage_checking_stream = _get_usage_checking_stream_response(
331 self._stream_response, self._usage_limits, self.usage
332 )
334 async for _ in usage_checking_stream:
335 pass
336 message = self._stream_response.get()
337 await self._marked_completed(message)
338 return await self.validate_structured_result(message)
340 def usage(self) -> Usage:
341 """Return the usage of the whole run.
343 !!! note
344 This won't return the full usage until the stream is finished.
345 """
346 return self._run_ctx.usage + self._stream_response.usage()
348 def timestamp(self) -> datetime:
349 """Get the timestamp of the response."""
350 return self._stream_response.timestamp()
352 async def validate_structured_result(
353 self, message: _messages.ModelResponse, *, allow_partial: bool = False
354 ) -> ResultDataT:
355 """Validate a structured result message."""
356 if self._result_schema is not None and self._result_tool_name is not None:
357 match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
358 if match is None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 raise exceptions.UnexpectedModelBehavior(
360 f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
361 )
363 call, result_tool = match
364 result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
366 for validator in self._result_validators: 366 ↛ 367line 366 didn't jump to line 367 because the loop on line 366 never started
367 result_data = await validator.validate(result_data, call, self._run_ctx)
368 return result_data
369 else:
370 text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
371 for validator in self._result_validators: 371 ↛ 372line 371 didn't jump to line 372 because the loop on line 371 never started
372 text = await validator.validate(
373 text,
374 None,
375 self._run_ctx,
376 )
377 # Since there is no result tool, we can assume that str is compatible with ResultDataT
378 return cast(ResultDataT, text)
380 async def _validate_text_result(self, text: str) -> str:
381 for validator in self._result_validators: 381 ↛ 382line 381 didn't jump to line 382 because the loop on line 381 never started
382 text = await validator.validate(
383 text,
384 None,
385 self._run_ctx,
386 )
387 return text
389 async def _marked_completed(self, message: _messages.ModelResponse) -> None:
390 self.is_complete = True
391 self._all_messages.append(message)
392 await self._on_complete()
395def _get_usage_checking_stream_response(
396 stream_response: AsyncIterable[_messages.ModelResponseStreamEvent],
397 limits: UsageLimits | None,
398 get_usage: Callable[[], Usage],
399) -> AsyncIterable[_messages.ModelResponseStreamEvent]:
400 if limits is not None and limits.has_token_limits():
402 async def _usage_checking_iterator():
403 async for item in stream_response: 403 ↛ exitline 403 didn't return from function '_usage_checking_iterator' because the loop on line 403 didn't complete
404 limits.check_tokens(get_usage())
405 yield item
407 return _usage_checking_iterator()
408 else:
409 return stream_response