Coverage for pydantic_ai_slim/pydantic_ai/_result.py: 95.65%

149 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3import inspect 

4import sys 

5import types 

6from collections.abc import Awaitable, Iterable 

7from dataclasses import dataclass, field 

8from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin 

9 

10from pydantic import TypeAdapter, ValidationError 

11from typing_extensions import TypeAliasType, TypedDict, TypeVar 

12 

13from . import _utils, messages as _messages 

14from .exceptions import ModelRetry 

15from .result import ResultDataT, ResultDataT_inv, ResultValidatorFunc 

16from .tools import AgentDepsT, RunContext, ToolDefinition 

17 

18T = TypeVar('T') 

19"""An invariant TypeVar.""" 

20 

21 

22@dataclass 

23class ResultValidator(Generic[AgentDepsT, ResultDataT_inv]): 

24 function: ResultValidatorFunc[AgentDepsT, ResultDataT_inv] 

25 _takes_ctx: bool = field(init=False) 

26 _is_async: bool = field(init=False) 

27 

28 def __post_init__(self): 

29 self._takes_ctx = len(inspect.signature(self.function).parameters) > 1 

30 self._is_async = inspect.iscoroutinefunction(self.function) 

31 

32 async def validate( 

33 self, 

34 result: T, 

35 tool_call: _messages.ToolCallPart | None, 

36 run_context: RunContext[AgentDepsT], 

37 ) -> T: 

38 """Validate a result but calling the function. 

39 

40 Args: 

41 result: The result data after Pydantic validation the message content. 

42 tool_call: The original tool call message, `None` if there was no tool call. 

43 run_context: The current run context. 

44 

45 Returns: 

46 Result of either the validated result data (ok) or a retry message (Err). 

47 """ 

48 if self._takes_ctx: 

49 ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None) 

50 args = ctx, result 

51 else: 

52 args = (result,) 

53 

54 try: 

55 if self._is_async: 

56 function = cast(Callable[[Any], Awaitable[T]], self.function) 

57 result_data = await function(*args) 

58 else: 

59 function = cast(Callable[[Any], T], self.function) 

60 result_data = await _utils.run_in_executor(function, *args) 

61 except ModelRetry as r: 

62 m = _messages.RetryPromptPart(content=r.message) 

63 if tool_call is not None: 

64 m.tool_name = tool_call.tool_name 

65 m.tool_call_id = tool_call.tool_call_id 

66 raise ToolRetryError(m) from r 

67 else: 

68 return result_data 

69 

70 

71class ToolRetryError(Exception): 

72 """Internal exception used to signal a `ToolRetry` message should be returned to the LLM.""" 

73 

74 def __init__(self, tool_retry: _messages.RetryPromptPart): 

75 self.tool_retry = tool_retry 

76 super().__init__() 

77 

78 

79@dataclass 

80class ResultSchema(Generic[ResultDataT]): 

81 """Model the final response from an agent run. 

82 

83 Similar to `Tool` but for the final result of running an agent. 

84 """ 

85 

86 tools: dict[str, ResultTool[ResultDataT]] 

87 allow_text_result: bool 

88 

89 @classmethod 

90 def build( 

91 cls: type[ResultSchema[T]], response_type: type[T], name: str, description: str | None 

92 ) -> ResultSchema[T] | None: 

93 """Build a ResultSchema dataclass from a response type.""" 

94 if response_type is str: 

95 return None 

96 

97 if response_type_option := extract_str_from_union(response_type): 

98 response_type = response_type_option.value 

99 allow_text_result = True 

100 else: 

101 allow_text_result = False 

102 

103 def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[T]: 

104 return cast(ResultTool[T], ResultTool(a, tool_name_, description, multiple)) 

105 

106 tools: dict[str, ResultTool[T]] = {} 

107 if args := get_union_args(response_type): 

108 for i, arg in enumerate(args, start=1): 

109 tool_name = union_tool_name(name, arg) 

110 while tool_name in tools: 

111 tool_name = f'{tool_name}_{i}' 

112 tools[tool_name] = _build_tool(arg, tool_name, True) 

113 else: 

114 tools[name] = _build_tool(response_type, name, False) 

115 

116 return cls(tools=tools, allow_text_result=allow_text_result) 

117 

118 def find_named_tool( 

119 self, parts: Iterable[_messages.ModelResponsePart], tool_name: str 

120 ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None: 

121 """Find a tool that matches one of the calls, with a specific name.""" 

122 for part in parts: 122 ↛ exitline 122 didn't return from function 'find_named_tool' because the loop on line 122 didn't complete

123 if isinstance(part, _messages.ToolCallPart): 123 ↛ 122line 123 didn't jump to line 122 because the condition on line 123 was always true

124 if part.tool_name == tool_name: 

125 return part, self.tools[tool_name] 

126 

127 def find_tool( 

128 self, 

129 parts: Iterable[_messages.ModelResponsePart], 

130 ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None: 

131 """Find a tool that matches one of the calls.""" 

132 for part in parts: 

133 if isinstance(part, _messages.ToolCallPart): 133 ↛ 132line 133 didn't jump to line 132 because the condition on line 133 was always true

134 if result := self.tools.get(part.tool_name): 

135 return part, result 

136 

137 def tool_names(self) -> list[str]: 

138 """Return the names of the tools.""" 

139 return list(self.tools.keys()) 

140 

141 def tool_defs(self) -> list[ToolDefinition]: 

142 """Get tool definitions to register with the model.""" 

143 return [t.tool_def for t in self.tools.values()] 

144 

145 

146DEFAULT_DESCRIPTION = 'The final response which ends this conversation' 

147 

148 

149@dataclass(init=False) 

150class ResultTool(Generic[ResultDataT]): 

151 tool_def: ToolDefinition 

152 type_adapter: TypeAdapter[Any] 

153 

154 def __init__(self, response_type: type[ResultDataT], name: str, description: str | None, multiple: bool): 

155 """Build a ResultTool dataclass from a response type.""" 

156 assert response_type is not str, 'ResultTool does not support str as a response type' 

157 

158 if _utils.is_model_like(response_type): 

159 self.type_adapter = TypeAdapter(response_type) 

160 outer_typed_dict_key: str | None = None 

161 # noinspection PyArgumentList 

162 parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema()) 

163 else: 

164 response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa 

165 self.type_adapter = TypeAdapter(response_data_typed_dict) 

166 outer_typed_dict_key = 'response' 

167 # noinspection PyArgumentList 

168 parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema()) 

169 # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM 

170 parameters_json_schema.pop('title') 

171 

172 if json_schema_description := parameters_json_schema.pop('description', None): 

173 if description is None: 173 ↛ 176line 173 didn't jump to line 176 because the condition on line 173 was always true

174 tool_description = json_schema_description 

175 else: 

176 tool_description = f'{description}. {json_schema_description}' 

177 else: 

178 tool_description = description or DEFAULT_DESCRIPTION 

179 if multiple: 

180 tool_description = f'{union_arg_name(response_type)}: {tool_description}' 

181 

182 self.tool_def = ToolDefinition( 

183 name=name, 

184 description=tool_description, 

185 parameters_json_schema=parameters_json_schema, 

186 outer_typed_dict_key=outer_typed_dict_key, 

187 ) 

188 

189 def validate( 

190 self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True 

191 ) -> ResultDataT: 

192 """Validate a result message. 

193 

194 Args: 

195 tool_call: The tool call from the LLM to validate. 

196 allow_partial: If true, allow partial validation. 

197 wrap_validation_errors: If true, wrap the validation errors in a retry message. 

198 

199 Returns: 

200 Either the validated result data (left) or a retry message (right). 

201 """ 

202 try: 

203 pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' 

204 if isinstance(tool_call.args, str): 

205 result = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial) 

206 else: 

207 result = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial) 

208 except ValidationError as e: 

209 if wrap_validation_errors: 209 ↛ 217line 209 didn't jump to line 217 because the condition on line 209 was always true

210 m = _messages.RetryPromptPart( 

211 tool_name=tool_call.tool_name, 

212 content=e.errors(include_url=False), 

213 tool_call_id=tool_call.tool_call_id, 

214 ) 

215 raise ToolRetryError(m) from e 

216 else: 

217 raise 

218 else: 

219 if k := self.tool_def.outer_typed_dict_key: 

220 result = result[k] 

221 return result 

222 

223 

224def union_tool_name(base_name: str, union_arg: Any) -> str: 

225 return f'{base_name}_{union_arg_name(union_arg)}' 

226 

227 

228def union_arg_name(union_arg: Any) -> str: 

229 return union_arg.__name__ 

230 

231 

232def extract_str_from_union(response_type: Any) -> _utils.Option[Any]: 

233 """Extract the string type from a Union, return the remaining union or remaining type.""" 

234 union_args = get_union_args(response_type) 

235 if any(t is str for t in union_args): 

236 remain_args: list[Any] = [] 

237 includes_str = False 

238 for arg in union_args: 

239 if arg is str: 

240 includes_str = True 

241 else: 

242 remain_args.append(arg) 

243 if includes_str: 243 ↛ exitline 243 didn't return from function 'extract_str_from_union' because the condition on line 243 was always true

244 if len(remain_args) == 1: 244 ↛ 247line 244 didn't jump to line 247 because the condition on line 244 was always true

245 return _utils.Some(remain_args[0]) 

246 else: 

247 return _utils.Some(Union[tuple(remain_args)]) 

248 

249 

250def get_union_args(tp: Any) -> tuple[Any, ...]: 

251 """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union.""" 

252 if isinstance(tp, TypeAliasType): 

253 tp = tp.__value__ 

254 

255 origin = get_origin(tp) 

256 if origin_is_union(origin): 

257 return get_args(tp) 

258 else: 

259 return () 

260 

261 

262if sys.version_info < (3, 10): 

263 

264 def origin_is_union(tp: type[Any] | None) -> bool: 

265 return tp is Union 

266 

267else: 

268 

269 def origin_is_union(tp: type[Any] | None) -> bool: 

270 return tp is Union or tp is types.UnionType