Coverage for pydantic_ai_slim/pydantic_ai/tools.py: 95.73%

142 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-30 19:21 +0000

1from __future__ import annotations as _annotations 

2 

3import dataclasses 

4import inspect 

5from collections.abc import Awaitable 

6from dataclasses import dataclass, field 

7from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast 

8 

9from pydantic import ValidationError 

10from pydantic_core import SchemaValidator 

11from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar 

12 

13from . import _pydantic, _utils, messages as _messages, models 

14from .exceptions import ModelRetry, UnexpectedModelBehavior 

15 

16if TYPE_CHECKING: 

17 from .result import Usage 

18 

19__all__ = ( 

20 'AgentDepsT', 

21 'DocstringFormat', 

22 'RunContext', 

23 'SystemPromptFunc', 

24 'ToolFuncContext', 

25 'ToolFuncPlain', 

26 'ToolFuncEither', 

27 'ToolParams', 

28 'ToolPrepareFunc', 

29 'Tool', 

30 'ObjectJsonSchema', 

31 'ToolDefinition', 

32) 

33 

34AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) 

35"""Type variable for agent dependencies.""" 

36 

37 

38@dataclasses.dataclass 

39class RunContext(Generic[AgentDepsT]): 

40 """Information about the current call.""" 

41 

42 deps: AgentDepsT 

43 """Dependencies for the agent.""" 

44 model: models.Model 

45 """The model used in this run.""" 

46 usage: Usage 

47 """LLM usage associated with the run.""" 

48 prompt: str 

49 """The original user prompt passed to the run.""" 

50 messages: list[_messages.ModelMessage] = field(default_factory=list) 

51 """Messages exchanged in the conversation so far.""" 

52 tool_name: str | None = None 

53 """Name of the tool being called.""" 

54 retry: int = 0 

55 """Number of retries so far.""" 

56 run_step: int = 0 

57 """The current step in the run.""" 

58 

59 def replace_with( 

60 self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET 

61 ) -> RunContext[AgentDepsT]: 

62 # Create a new `RunContext` a new `retry` value and `tool_name`. 

63 kwargs = {} 

64 if retry is not None: 

65 kwargs['retry'] = retry 

66 if tool_name is not _utils.UNSET: 66 ↛ 68line 66 didn't jump to line 68 because the condition on line 66 was always true

67 kwargs['tool_name'] = tool_name 

68 return dataclasses.replace(self, **kwargs) 

69 

70 

71ToolParams = ParamSpec('ToolParams', default=...) 

72"""Retrieval function param spec.""" 

73 

74SystemPromptFunc = Union[ 

75 Callable[[RunContext[AgentDepsT]], str], 

76 Callable[[RunContext[AgentDepsT]], Awaitable[str]], 

77 Callable[[], str], 

78 Callable[[], Awaitable[str]], 

79] 

80"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async. 

81 

82Usage `SystemPromptFunc[AgentDepsT]`. 

83""" 

84 

85ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any] 

86"""A tool function that takes `RunContext` as the first argument. 

87 

88Usage `ToolContextFunc[AgentDepsT, ToolParams]`. 

89""" 

90ToolFuncPlain = Callable[ToolParams, Any] 

91"""A tool function that does not take `RunContext` as the first argument. 

92 

93Usage `ToolPlainFunc[ToolParams]`. 

94""" 

95ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]] 

96"""Either kind of tool function. 

97 

98This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and 

99[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain]. 

100 

101Usage `ToolFuncEither[AgentDepsT, ToolParams]`. 

102""" 

103ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]' 

104"""Definition of a function that can prepare a tool definition at call time. 

105 

106See [tool docs](../tools.md#tool-prepare) for more information. 

107 

108Example — here `only_if_42` is valid as a `ToolPrepareFunc`: 

109 

110```python {noqa="I001"} 

111from typing import Union 

112 

113from pydantic_ai import RunContext, Tool 

114from pydantic_ai.tools import ToolDefinition 

115 

116async def only_if_42( 

117 ctx: RunContext[int], tool_def: ToolDefinition 

118) -> Union[ToolDefinition, None]: 

119 if ctx.deps == 42: 

120 return tool_def 

121 

122def hitchhiker(ctx: RunContext[int], answer: str) -> str: 

123 return f'{ctx.deps} {answer}' 

124 

125hitchhiker = Tool(hitchhiker, prepare=only_if_42) 

126``` 

127 

128Usage `ToolPrepareFunc[AgentDepsT]`. 

129""" 

130 

131DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto'] 

132"""Supported docstring formats. 

133 

134* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings. 

135* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings. 

136* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings. 

137* `'auto'` — Automatically infer the format based on the structure of the docstring. 

138""" 

139 

140A = TypeVar('A') 

141 

142 

143@dataclass(init=False) 

144class Tool(Generic[AgentDepsT]): 

145 """A tool function for an agent.""" 

146 

147 function: ToolFuncEither[AgentDepsT] 

148 takes_ctx: bool 

149 max_retries: int | None 

150 name: str 

151 description: str 

152 prepare: ToolPrepareFunc[AgentDepsT] | None 

153 docstring_format: DocstringFormat 

154 require_parameter_descriptions: bool 

155 _is_async: bool = field(init=False) 

156 _single_arg_name: str | None = field(init=False) 

157 _positional_fields: list[str] = field(init=False) 

158 _var_positional_field: str | None = field(init=False) 

159 _validator: SchemaValidator = field(init=False, repr=False) 

160 _parameters_json_schema: ObjectJsonSchema = field(init=False) 

161 

162 # TODO: Move this state off the Tool class, which is otherwise stateless. 

163 # This should be tracked inside a specific agent run, not the tool. 

164 current_retry: int = field(default=0, init=False) 

165 

166 def __init__( 

167 self, 

168 function: ToolFuncEither[AgentDepsT], 

169 *, 

170 takes_ctx: bool | None = None, 

171 max_retries: int | None = None, 

172 name: str | None = None, 

173 description: str | None = None, 

174 prepare: ToolPrepareFunc[AgentDepsT] | None = None, 

175 docstring_format: DocstringFormat = 'auto', 

176 require_parameter_descriptions: bool = False, 

177 ): 

178 """Create a new tool instance. 

179 

180 Example usage: 

181 

182 ```python {noqa="I001"} 

183 from pydantic_ai import Agent, RunContext, Tool 

184 

185 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: 

186 return f'{ctx.deps} {x} {y}' 

187 

188 agent = Agent('test', tools=[Tool(my_tool)]) 

189 ``` 

190 

191 or with a custom prepare method: 

192 

193 ```python {noqa="I001"} 

194 from typing import Union 

195 

196 from pydantic_ai import Agent, RunContext, Tool 

197 from pydantic_ai.tools import ToolDefinition 

198 

199 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: 

200 return f'{ctx.deps} {x} {y}' 

201 

202 async def prep_my_tool( 

203 ctx: RunContext[int], tool_def: ToolDefinition 

204 ) -> Union[ToolDefinition, None]: 

205 # only register the tool if `deps == 42` 

206 if ctx.deps == 42: 

207 return tool_def 

208 

209 agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)]) 

210 ``` 

211 

212 

213 Args: 

214 function: The Python function to call as the tool. 

215 takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument, 

216 this is inferred if unset. 

217 max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`. 

218 name: Name of the tool, inferred from the function if `None`. 

219 description: Description of the tool, inferred from the function if `None`. 

220 prepare: custom method to prepare the tool definition for each step, return `None` to omit this 

221 tool from a given step. This is useful if you want to customise a tool at call time, 

222 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. 

223 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. 

224 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. 

225 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. 

226 """ 

227 if takes_ctx is None: 

228 takes_ctx = _pydantic.takes_ctx(function) 

229 

230 f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions) 

231 self.function = function 

232 self.takes_ctx = takes_ctx 

233 self.max_retries = max_retries 

234 self.name = name or function.__name__ 

235 self.description = description or f['description'] 

236 self.prepare = prepare 

237 self.docstring_format = docstring_format 

238 self.require_parameter_descriptions = require_parameter_descriptions 

239 self._is_async = inspect.iscoroutinefunction(self.function) 

240 self._single_arg_name = f['single_arg_name'] 

241 self._positional_fields = f['positional_fields'] 

242 self._var_positional_field = f['var_positional_field'] 

243 self._validator = f['validator'] 

244 self._parameters_json_schema = f['json_schema'] 

245 

246 async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: 

247 """Get the tool definition. 

248 

249 By default, this method creates a tool definition, then either returns it, or calls `self.prepare` 

250 if it's set. 

251 

252 Returns: 

253 return a `ToolDefinition` or `None` if the tools should not be registered for this run. 

254 """ 

255 tool_def = ToolDefinition( 

256 name=self.name, 

257 description=self.description, 

258 parameters_json_schema=self._parameters_json_schema, 

259 ) 

260 if self.prepare is not None: 

261 return await self.prepare(ctx, tool_def) 

262 else: 

263 return tool_def 

264 

265 async def run( 

266 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] 

267 ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: 

268 """Run the tool function asynchronously.""" 

269 try: 

270 if isinstance(message.args, str): 

271 args_dict = self._validator.validate_json(message.args) 

272 else: 

273 args_dict = self._validator.validate_python(message.args) 

274 except ValidationError as e: 

275 return self._on_error(e, message) 

276 

277 args, kwargs = self._call_args(args_dict, message, run_context) 

278 try: 

279 if self._is_async: 

280 function = cast(Callable[[Any], Awaitable[str]], self.function) 

281 response_content = await function(*args, **kwargs) 

282 else: 

283 function = cast(Callable[[Any], str], self.function) 

284 response_content = await _utils.run_in_executor(function, *args, **kwargs) 

285 except ModelRetry as e: 

286 return self._on_error(e, message) 

287 

288 self.current_retry = 0 

289 return _messages.ToolReturnPart( 

290 tool_name=message.tool_name, 

291 content=response_content, 

292 tool_call_id=message.tool_call_id, 

293 ) 

294 

295 def _call_args( 

296 self, 

297 args_dict: dict[str, Any], 

298 message: _messages.ToolCallPart, 

299 run_context: RunContext[AgentDepsT], 

300 ) -> tuple[list[Any], dict[str, Any]]: 

301 if self._single_arg_name: 

302 args_dict = {self._single_arg_name: args_dict} 

303 

304 ctx = dataclasses.replace(run_context, retry=self.current_retry, tool_name=message.tool_name) 

305 args = [ctx] if self.takes_ctx else [] 

306 for positional_field in self._positional_fields: 306 ↛ 307line 306 didn't jump to line 307 because the loop on line 306 never started

307 args.append(args_dict.pop(positional_field)) 

308 if self._var_positional_field: 

309 args.extend(args_dict.pop(self._var_positional_field)) 

310 

311 return args, args_dict 

312 

313 def _on_error( 

314 self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart 

315 ) -> _messages.RetryPromptPart: 

316 self.current_retry += 1 

317 if self.max_retries is None or self.current_retry > self.max_retries: 

318 raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc 

319 else: 

320 if isinstance(exc, ValidationError): 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true

321 content = exc.errors(include_url=False) 

322 else: 

323 content = exc.message 

324 return _messages.RetryPromptPart( 

325 tool_name=call_message.tool_name, 

326 content=content, 

327 tool_call_id=call_message.tool_call_id, 

328 ) 

329 

330 

331ObjectJsonSchema: TypeAlias = dict[str, Any] 

332"""Type representing JSON schema of an object, e.g. where `"type": "object"`. 

333 

334This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition]. 

335 

336With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` 

337""" 

338 

339 

340@dataclass 

341class ToolDefinition: 

342 """Definition of a tool passed to a model. 

343 

344 This is used for both function tools result tools. 

345 """ 

346 

347 name: str 

348 """The name of the tool.""" 

349 

350 description: str 

351 """The description of the tool.""" 

352 

353 parameters_json_schema: ObjectJsonSchema 

354 """The JSON schema for the tool's parameters.""" 

355 

356 outer_typed_dict_key: str | None = None 

357 """The key in the outer [TypedDict] that wraps a result tool. 

358 

359 This will only be set for result tools which don't have an `object` JSON schema. 

360 """