Coverage for pydantic_ai_slim/pydantic_ai/tools.py: 95.73%

142 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3import dataclasses 

4import inspect 

5from collections.abc import Awaitable 

6from dataclasses import dataclass, field 

7from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast 

8 

9from pydantic import ValidationError 

10from pydantic_core import SchemaValidator 

11from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar 

12 

13from . import _pydantic, _utils, messages as _messages, models 

14from .exceptions import ModelRetry, UnexpectedModelBehavior 

15 

16if TYPE_CHECKING: 

17 from .result import Usage 

18 

19__all__ = ( 

20 'AgentDepsT', 

21 'DocstringFormat', 

22 'RunContext', 

23 'SystemPromptFunc', 

24 'ToolFuncContext', 

25 'ToolFuncPlain', 

26 'ToolFuncEither', 

27 'ToolParams', 

28 'ToolPrepareFunc', 

29 'Tool', 

30 'ObjectJsonSchema', 

31 'ToolDefinition', 

32) 

33 

34AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) 

35"""Type variable for agent dependencies.""" 

36 

37 

38@dataclasses.dataclass 

39class RunContext(Generic[AgentDepsT]): 

40 """Information about the current call.""" 

41 

42 deps: AgentDepsT 

43 """Dependencies for the agent.""" 

44 model: models.Model 

45 """The model used in this run.""" 

46 usage: Usage 

47 """LLM usage associated with the run.""" 

48 prompt: str 

49 """The original user prompt passed to the run.""" 

50 messages: list[_messages.ModelMessage] = field(default_factory=list) 

51 """Messages exchanged in the conversation so far.""" 

52 tool_name: str | None = None 

53 """Name of the tool being called.""" 

54 retry: int = 0 

55 """Number of retries so far.""" 

56 run_step: int = 0 

57 """The current step in the run.""" 

58 

59 def replace_with( 

60 self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET 

61 ) -> RunContext[AgentDepsT]: 

62 # Create a new `RunContext` a new `retry` value and `tool_name`. 

63 kwargs = {} 

64 if retry is not None: 

65 kwargs['retry'] = retry 

66 if tool_name is not _utils.UNSET: 66 ↛ 68line 66 didn't jump to line 68 because the condition on line 66 was always true

67 kwargs['tool_name'] = tool_name 

68 return dataclasses.replace(self, **kwargs) 

69 

70 

71ToolParams = ParamSpec('ToolParams', default=...) 

72"""Retrieval function param spec.""" 

73 

74SystemPromptFunc = Union[ 

75 Callable[[RunContext[AgentDepsT]], str], 

76 Callable[[RunContext[AgentDepsT]], Awaitable[str]], 

77 Callable[[], str], 

78 Callable[[], Awaitable[str]], 

79] 

80"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async. 

81 

82Usage `SystemPromptFunc[AgentDepsT]`. 

83""" 

84 

85ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any] 

86"""A tool function that takes `RunContext` as the first argument. 

87 

88Usage `ToolContextFunc[AgentDepsT, ToolParams]`. 

89""" 

90ToolFuncPlain = Callable[ToolParams, Any] 

91"""A tool function that does not take `RunContext` as the first argument. 

92 

93Usage `ToolPlainFunc[ToolParams]`. 

94""" 

95ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]] 

96"""Either kind of tool function. 

97 

98This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and 

99[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain]. 

100 

101Usage `ToolFuncEither[AgentDepsT, ToolParams]`. 

102""" 

103ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]' 

104"""Definition of a function that can prepare a tool definition at call time. 

105 

106See [tool docs](../tools.md#tool-prepare) for more information. 

107 

108Example — here `only_if_42` is valid as a `ToolPrepareFunc`: 

109 

110```python {noqa="I001"} 

111from typing import Union 

112 

113from pydantic_ai import RunContext, Tool 

114from pydantic_ai.tools import ToolDefinition 

115 

116async def only_if_42( 

117 ctx: RunContext[int], tool_def: ToolDefinition 

118) -> Union[ToolDefinition, None]: 

119 if ctx.deps == 42: 

120 return tool_def 

121 

122def hitchhiker(ctx: RunContext[int], answer: str) -> str: 

123 return f'{ctx.deps} {answer}' 

124 

125hitchhiker = Tool(hitchhiker, prepare=only_if_42) 

126``` 

127 

128Usage `ToolPrepareFunc[AgentDepsT]`. 

129""" 

130 

131DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto'] 

132"""Supported docstring formats. 

133 

134* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings. 

135* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings. 

136* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings. 

137* `'auto'` — Automatically infer the format based on the structure of the docstring. 

138""" 

139 

140A = TypeVar('A') 

141 

142 

143@dataclass(init=False) 

144class Tool(Generic[AgentDepsT]): 

145 """A tool function for an agent.""" 

146 

147 function: ToolFuncEither[AgentDepsT] 

148 takes_ctx: bool 

149 max_retries: int | None 

150 name: str 

151 description: str 

152 prepare: ToolPrepareFunc[AgentDepsT] | None 

153 docstring_format: DocstringFormat 

154 require_parameter_descriptions: bool 

155 _is_async: bool = field(init=False) 

156 _single_arg_name: str | None = field(init=False) 

157 _positional_fields: list[str] = field(init=False) 

158 _var_positional_field: str | None = field(init=False) 

159 _validator: SchemaValidator = field(init=False, repr=False) 

160 _parameters_json_schema: ObjectJsonSchema = field(init=False) 

161 current_retry: int = field(default=0, init=False) 

162 

163 def __init__( 

164 self, 

165 function: ToolFuncEither[AgentDepsT], 

166 *, 

167 takes_ctx: bool | None = None, 

168 max_retries: int | None = None, 

169 name: str | None = None, 

170 description: str | None = None, 

171 prepare: ToolPrepareFunc[AgentDepsT] | None = None, 

172 docstring_format: DocstringFormat = 'auto', 

173 require_parameter_descriptions: bool = False, 

174 ): 

175 """Create a new tool instance. 

176 

177 Example usage: 

178 

179 ```python {noqa="I001"} 

180 from pydantic_ai import Agent, RunContext, Tool 

181 

182 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: 

183 return f'{ctx.deps} {x} {y}' 

184 

185 agent = Agent('test', tools=[Tool(my_tool)]) 

186 ``` 

187 

188 or with a custom prepare method: 

189 

190 ```python {noqa="I001"} 

191 from typing import Union 

192 

193 from pydantic_ai import Agent, RunContext, Tool 

194 from pydantic_ai.tools import ToolDefinition 

195 

196 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: 

197 return f'{ctx.deps} {x} {y}' 

198 

199 async def prep_my_tool( 

200 ctx: RunContext[int], tool_def: ToolDefinition 

201 ) -> Union[ToolDefinition, None]: 

202 # only register the tool if `deps == 42` 

203 if ctx.deps == 42: 

204 return tool_def 

205 

206 agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)]) 

207 ``` 

208 

209 

210 Args: 

211 function: The Python function to call as the tool. 

212 takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument, 

213 this is inferred if unset. 

214 max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`. 

215 name: Name of the tool, inferred from the function if `None`. 

216 description: Description of the tool, inferred from the function if `None`. 

217 prepare: custom method to prepare the tool definition for each step, return `None` to omit this 

218 tool from a given step. This is useful if you want to customise a tool at call time, 

219 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. 

220 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. 

221 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. 

222 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. 

223 """ 

224 if takes_ctx is None: 

225 takes_ctx = _pydantic.takes_ctx(function) 

226 

227 f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions) 

228 self.function = function 

229 self.takes_ctx = takes_ctx 

230 self.max_retries = max_retries 

231 self.name = name or function.__name__ 

232 self.description = description or f['description'] 

233 self.prepare = prepare 

234 self.docstring_format = docstring_format 

235 self.require_parameter_descriptions = require_parameter_descriptions 

236 self._is_async = inspect.iscoroutinefunction(self.function) 

237 self._single_arg_name = f['single_arg_name'] 

238 self._positional_fields = f['positional_fields'] 

239 self._var_positional_field = f['var_positional_field'] 

240 self._validator = f['validator'] 

241 self._parameters_json_schema = f['json_schema'] 

242 

243 async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: 

244 """Get the tool definition. 

245 

246 By default, this method creates a tool definition, then either returns it, or calls `self.prepare` 

247 if it's set. 

248 

249 Returns: 

250 return a `ToolDefinition` or `None` if the tools should not be registered for this run. 

251 """ 

252 tool_def = ToolDefinition( 

253 name=self.name, 

254 description=self.description, 

255 parameters_json_schema=self._parameters_json_schema, 

256 ) 

257 if self.prepare is not None: 

258 return await self.prepare(ctx, tool_def) 

259 else: 

260 return tool_def 

261 

262 async def run( 

263 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] 

264 ) -> _messages.ModelRequestPart: 

265 """Run the tool function asynchronously.""" 

266 try: 

267 if isinstance(message.args, str): 

268 args_dict = self._validator.validate_json(message.args) 

269 else: 

270 args_dict = self._validator.validate_python(message.args) 

271 except ValidationError as e: 

272 return self._on_error(e, message) 

273 

274 args, kwargs = self._call_args(args_dict, message, run_context) 

275 try: 

276 if self._is_async: 

277 function = cast(Callable[[Any], Awaitable[str]], self.function) 

278 response_content = await function(*args, **kwargs) 

279 else: 

280 function = cast(Callable[[Any], str], self.function) 

281 response_content = await _utils.run_in_executor(function, *args, **kwargs) 

282 except ModelRetry as e: 

283 return self._on_error(e, message) 

284 

285 self.current_retry = 0 

286 return _messages.ToolReturnPart( 

287 tool_name=message.tool_name, 

288 content=response_content, 

289 tool_call_id=message.tool_call_id, 

290 ) 

291 

292 def _call_args( 

293 self, 

294 args_dict: dict[str, Any], 

295 message: _messages.ToolCallPart, 

296 run_context: RunContext[AgentDepsT], 

297 ) -> tuple[list[Any], dict[str, Any]]: 

298 if self._single_arg_name: 

299 args_dict = {self._single_arg_name: args_dict} 

300 

301 ctx = dataclasses.replace(run_context, retry=self.current_retry, tool_name=message.tool_name) 

302 args = [ctx] if self.takes_ctx else [] 

303 for positional_field in self._positional_fields: 303 ↛ 304line 303 didn't jump to line 304 because the loop on line 303 never started

304 args.append(args_dict.pop(positional_field)) 

305 if self._var_positional_field: 

306 args.extend(args_dict.pop(self._var_positional_field)) 

307 

308 return args, args_dict 

309 

310 def _on_error( 

311 self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart 

312 ) -> _messages.RetryPromptPart: 

313 self.current_retry += 1 

314 if self.max_retries is None or self.current_retry > self.max_retries: 

315 raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc 

316 else: 

317 if isinstance(exc, ValidationError): 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true

318 content = exc.errors(include_url=False) 

319 else: 

320 content = exc.message 

321 return _messages.RetryPromptPart( 

322 tool_name=call_message.tool_name, 

323 content=content, 

324 tool_call_id=call_message.tool_call_id, 

325 ) 

326 

327 

328ObjectJsonSchema: TypeAlias = dict[str, Any] 

329"""Type representing JSON schema of an object, e.g. where `"type": "object"`. 

330 

331This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition]. 

332 

333With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` 

334""" 

335 

336 

337@dataclass 

338class ToolDefinition: 

339 """Definition of a tool passed to a model. 

340 

341 This is used for both function tools result tools. 

342 """ 

343 

344 name: str 

345 """The name of the tool.""" 

346 

347 description: str 

348 """The description of the tool.""" 

349 

350 parameters_json_schema: ObjectJsonSchema 

351 """The JSON schema for the tool's parameters.""" 

352 

353 outer_typed_dict_key: str | None = None 

354 """The key in the outer [TypedDict] that wraps a result tool. 

355 

356 This will only be set for result tools which don't have an `object` JSON schema. 

357 """