Coverage for pydantic_ai_slim/pydantic_ai/_result.py: 95.65%
149 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import inspect
4import sys
5import types
6from collections.abc import Awaitable, Iterable
7from dataclasses import dataclass, field
8from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
10from pydantic import TypeAdapter, ValidationError
11from typing_extensions import TypeAliasType, TypedDict, TypeVar
13from . import _utils, messages as _messages
14from .exceptions import ModelRetry
15from .result import ResultDataT, ResultDataT_inv, ResultValidatorFunc
16from .tools import AgentDepsT, RunContext, ToolDefinition
18T = TypeVar('T')
19"""An invariant TypeVar."""
22@dataclass
23class ResultValidator(Generic[AgentDepsT, ResultDataT_inv]):
24 function: ResultValidatorFunc[AgentDepsT, ResultDataT_inv]
25 _takes_ctx: bool = field(init=False)
26 _is_async: bool = field(init=False)
28 def __post_init__(self):
29 self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
30 self._is_async = inspect.iscoroutinefunction(self.function)
32 async def validate(
33 self,
34 result: T,
35 tool_call: _messages.ToolCallPart | None,
36 run_context: RunContext[AgentDepsT],
37 ) -> T:
38 """Validate a result but calling the function.
40 Args:
41 result: The result data after Pydantic validation the message content.
42 tool_call: The original tool call message, `None` if there was no tool call.
43 run_context: The current run context.
45 Returns:
46 Result of either the validated result data (ok) or a retry message (Err).
47 """
48 if self._takes_ctx:
49 ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
50 args = ctx, result
51 else:
52 args = (result,)
54 try:
55 if self._is_async:
56 function = cast(Callable[[Any], Awaitable[T]], self.function)
57 result_data = await function(*args)
58 else:
59 function = cast(Callable[[Any], T], self.function)
60 result_data = await _utils.run_in_executor(function, *args)
61 except ModelRetry as r:
62 m = _messages.RetryPromptPart(content=r.message)
63 if tool_call is not None:
64 m.tool_name = tool_call.tool_name
65 m.tool_call_id = tool_call.tool_call_id
66 raise ToolRetryError(m) from r
67 else:
68 return result_data
71class ToolRetryError(Exception):
72 """Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
74 def __init__(self, tool_retry: _messages.RetryPromptPart):
75 self.tool_retry = tool_retry
76 super().__init__()
79@dataclass
80class ResultSchema(Generic[ResultDataT]):
81 """Model the final response from an agent run.
83 Similar to `Tool` but for the final result of running an agent.
84 """
86 tools: dict[str, ResultTool[ResultDataT]]
87 allow_text_result: bool
89 @classmethod
90 def build(
91 cls: type[ResultSchema[T]], response_type: type[T], name: str, description: str | None
92 ) -> ResultSchema[T] | None:
93 """Build a ResultSchema dataclass from a response type."""
94 if response_type is str:
95 return None
97 if response_type_option := extract_str_from_union(response_type):
98 response_type = response_type_option.value
99 allow_text_result = True
100 else:
101 allow_text_result = False
103 def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[T]:
104 return cast(ResultTool[T], ResultTool(a, tool_name_, description, multiple))
106 tools: dict[str, ResultTool[T]] = {}
107 if args := get_union_args(response_type):
108 for i, arg in enumerate(args, start=1):
109 tool_name = union_tool_name(name, arg)
110 while tool_name in tools:
111 tool_name = f'{tool_name}_{i}'
112 tools[tool_name] = _build_tool(arg, tool_name, True)
113 else:
114 tools[name] = _build_tool(response_type, name, False)
116 return cls(tools=tools, allow_text_result=allow_text_result)
118 def find_named_tool(
119 self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
120 ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
121 """Find a tool that matches one of the calls, with a specific name."""
122 for part in parts: 122 ↛ exitline 122 didn't return from function 'find_named_tool' because the loop on line 122 didn't complete
123 if isinstance(part, _messages.ToolCallPart): 123 ↛ 122line 123 didn't jump to line 122 because the condition on line 123 was always true
124 if part.tool_name == tool_name:
125 return part, self.tools[tool_name]
127 def find_tool(
128 self,
129 parts: Iterable[_messages.ModelResponsePart],
130 ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
131 """Find a tool that matches one of the calls."""
132 for part in parts:
133 if isinstance(part, _messages.ToolCallPart): 133 ↛ 132line 133 didn't jump to line 132 because the condition on line 133 was always true
134 if result := self.tools.get(part.tool_name):
135 return part, result
137 def tool_names(self) -> list[str]:
138 """Return the names of the tools."""
139 return list(self.tools.keys())
141 def tool_defs(self) -> list[ToolDefinition]:
142 """Get tool definitions to register with the model."""
143 return [t.tool_def for t in self.tools.values()]
146DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
149@dataclass(init=False)
150class ResultTool(Generic[ResultDataT]):
151 tool_def: ToolDefinition
152 type_adapter: TypeAdapter[Any]
154 def __init__(self, response_type: type[ResultDataT], name: str, description: str | None, multiple: bool):
155 """Build a ResultTool dataclass from a response type."""
156 assert response_type is not str, 'ResultTool does not support str as a response type'
158 if _utils.is_model_like(response_type):
159 self.type_adapter = TypeAdapter(response_type)
160 outer_typed_dict_key: str | None = None
161 # noinspection PyArgumentList
162 parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
163 else:
164 response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
165 self.type_adapter = TypeAdapter(response_data_typed_dict)
166 outer_typed_dict_key = 'response'
167 # noinspection PyArgumentList
168 parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
169 # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
170 parameters_json_schema.pop('title')
172 if json_schema_description := parameters_json_schema.pop('description', None):
173 if description is None: 173 ↛ 176line 173 didn't jump to line 176 because the condition on line 173 was always true
174 tool_description = json_schema_description
175 else:
176 tool_description = f'{description}. {json_schema_description}'
177 else:
178 tool_description = description or DEFAULT_DESCRIPTION
179 if multiple:
180 tool_description = f'{union_arg_name(response_type)}: {tool_description}'
182 self.tool_def = ToolDefinition(
183 name=name,
184 description=tool_description,
185 parameters_json_schema=parameters_json_schema,
186 outer_typed_dict_key=outer_typed_dict_key,
187 )
189 def validate(
190 self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
191 ) -> ResultDataT:
192 """Validate a result message.
194 Args:
195 tool_call: The tool call from the LLM to validate.
196 allow_partial: If true, allow partial validation.
197 wrap_validation_errors: If true, wrap the validation errors in a retry message.
199 Returns:
200 Either the validated result data (left) or a retry message (right).
201 """
202 try:
203 pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
204 if isinstance(tool_call.args, str):
205 result = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
206 else:
207 result = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
208 except ValidationError as e:
209 if wrap_validation_errors: 209 ↛ 217line 209 didn't jump to line 217 because the condition on line 209 was always true
210 m = _messages.RetryPromptPart(
211 tool_name=tool_call.tool_name,
212 content=e.errors(include_url=False),
213 tool_call_id=tool_call.tool_call_id,
214 )
215 raise ToolRetryError(m) from e
216 else:
217 raise
218 else:
219 if k := self.tool_def.outer_typed_dict_key:
220 result = result[k]
221 return result
224def union_tool_name(base_name: str, union_arg: Any) -> str:
225 return f'{base_name}_{union_arg_name(union_arg)}'
228def union_arg_name(union_arg: Any) -> str:
229 return union_arg.__name__
232def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
233 """Extract the string type from a Union, return the remaining union or remaining type."""
234 union_args = get_union_args(response_type)
235 if any(t is str for t in union_args):
236 remain_args: list[Any] = []
237 includes_str = False
238 for arg in union_args:
239 if arg is str:
240 includes_str = True
241 else:
242 remain_args.append(arg)
243 if includes_str: 243 ↛ exitline 243 didn't return from function 'extract_str_from_union' because the condition on line 243 was always true
244 if len(remain_args) == 1: 244 ↛ 247line 244 didn't jump to line 247 because the condition on line 244 was always true
245 return _utils.Some(remain_args[0])
246 else:
247 return _utils.Some(Union[tuple(remain_args)])
250def get_union_args(tp: Any) -> tuple[Any, ...]:
251 """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
252 if isinstance(tp, TypeAliasType):
253 tp = tp.__value__
255 origin = get_origin(tp)
256 if origin_is_union(origin):
257 return get_args(tp)
258 else:
259 return ()
262if sys.version_info < (3, 10):
264 def origin_is_union(tp: type[Any] | None) -> bool:
265 return tp is Union
267else:
269 def origin_is_union(tp: type[Any] | None) -> bool:
270 return tp is Union or tp is types.UnionType