Coverage for pydantic_ai_slim/pydantic_ai/tools.py: 95.73%
142 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import dataclasses
4import inspect
5from collections.abc import Awaitable
6from dataclasses import dataclass, field
7from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
9from pydantic import ValidationError
10from pydantic_core import SchemaValidator
11from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
13from . import _pydantic, _utils, messages as _messages, models
14from .exceptions import ModelRetry, UnexpectedModelBehavior
16if TYPE_CHECKING:
17 from .result import Usage
19__all__ = (
20 'AgentDepsT',
21 'DocstringFormat',
22 'RunContext',
23 'SystemPromptFunc',
24 'ToolFuncContext',
25 'ToolFuncPlain',
26 'ToolFuncEither',
27 'ToolParams',
28 'ToolPrepareFunc',
29 'Tool',
30 'ObjectJsonSchema',
31 'ToolDefinition',
32)
34AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
35"""Type variable for agent dependencies."""
38@dataclasses.dataclass
39class RunContext(Generic[AgentDepsT]):
40 """Information about the current call."""
42 deps: AgentDepsT
43 """Dependencies for the agent."""
44 model: models.Model
45 """The model used in this run."""
46 usage: Usage
47 """LLM usage associated with the run."""
48 prompt: str
49 """The original user prompt passed to the run."""
50 messages: list[_messages.ModelMessage] = field(default_factory=list)
51 """Messages exchanged in the conversation so far."""
52 tool_name: str | None = None
53 """Name of the tool being called."""
54 retry: int = 0
55 """Number of retries so far."""
56 run_step: int = 0
57 """The current step in the run."""
59 def replace_with(
60 self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
61 ) -> RunContext[AgentDepsT]:
62 # Create a new `RunContext` a new `retry` value and `tool_name`.
63 kwargs = {}
64 if retry is not None:
65 kwargs['retry'] = retry
66 if tool_name is not _utils.UNSET: 66 ↛ 68line 66 didn't jump to line 68 because the condition on line 66 was always true
67 kwargs['tool_name'] = tool_name
68 return dataclasses.replace(self, **kwargs)
71ToolParams = ParamSpec('ToolParams', default=...)
72"""Retrieval function param spec."""
74SystemPromptFunc = Union[
75 Callable[[RunContext[AgentDepsT]], str],
76 Callable[[RunContext[AgentDepsT]], Awaitable[str]],
77 Callable[[], str],
78 Callable[[], Awaitable[str]],
79]
80"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
82Usage `SystemPromptFunc[AgentDepsT]`.
83"""
85ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any]
86"""A tool function that takes `RunContext` as the first argument.
88Usage `ToolContextFunc[AgentDepsT, ToolParams]`.
89"""
90ToolFuncPlain = Callable[ToolParams, Any]
91"""A tool function that does not take `RunContext` as the first argument.
93Usage `ToolPlainFunc[ToolParams]`.
94"""
95ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]]
96"""Either kind of tool function.
98This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
99[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
101Usage `ToolFuncEither[AgentDepsT, ToolParams]`.
102"""
103ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]'
104"""Definition of a function that can prepare a tool definition at call time.
106See [tool docs](../tools.md#tool-prepare) for more information.
108Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
110```python {noqa="I001"}
111from typing import Union
113from pydantic_ai import RunContext, Tool
114from pydantic_ai.tools import ToolDefinition
116async def only_if_42(
117 ctx: RunContext[int], tool_def: ToolDefinition
118) -> Union[ToolDefinition, None]:
119 if ctx.deps == 42:
120 return tool_def
122def hitchhiker(ctx: RunContext[int], answer: str) -> str:
123 return f'{ctx.deps} {answer}'
125hitchhiker = Tool(hitchhiker, prepare=only_if_42)
126```
128Usage `ToolPrepareFunc[AgentDepsT]`.
129"""
131DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
132"""Supported docstring formats.
134* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
135* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings.
136* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings.
137* `'auto'` — Automatically infer the format based on the structure of the docstring.
138"""
140A = TypeVar('A')
143@dataclass(init=False)
144class Tool(Generic[AgentDepsT]):
145 """A tool function for an agent."""
147 function: ToolFuncEither[AgentDepsT]
148 takes_ctx: bool
149 max_retries: int | None
150 name: str
151 description: str
152 prepare: ToolPrepareFunc[AgentDepsT] | None
153 docstring_format: DocstringFormat
154 require_parameter_descriptions: bool
155 _is_async: bool = field(init=False)
156 _single_arg_name: str | None = field(init=False)
157 _positional_fields: list[str] = field(init=False)
158 _var_positional_field: str | None = field(init=False)
159 _validator: SchemaValidator = field(init=False, repr=False)
160 _parameters_json_schema: ObjectJsonSchema = field(init=False)
161 current_retry: int = field(default=0, init=False)
163 def __init__(
164 self,
165 function: ToolFuncEither[AgentDepsT],
166 *,
167 takes_ctx: bool | None = None,
168 max_retries: int | None = None,
169 name: str | None = None,
170 description: str | None = None,
171 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
172 docstring_format: DocstringFormat = 'auto',
173 require_parameter_descriptions: bool = False,
174 ):
175 """Create a new tool instance.
177 Example usage:
179 ```python {noqa="I001"}
180 from pydantic_ai import Agent, RunContext, Tool
182 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
183 return f'{ctx.deps} {x} {y}'
185 agent = Agent('test', tools=[Tool(my_tool)])
186 ```
188 or with a custom prepare method:
190 ```python {noqa="I001"}
191 from typing import Union
193 from pydantic_ai import Agent, RunContext, Tool
194 from pydantic_ai.tools import ToolDefinition
196 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
197 return f'{ctx.deps} {x} {y}'
199 async def prep_my_tool(
200 ctx: RunContext[int], tool_def: ToolDefinition
201 ) -> Union[ToolDefinition, None]:
202 # only register the tool if `deps == 42`
203 if ctx.deps == 42:
204 return tool_def
206 agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)])
207 ```
210 Args:
211 function: The Python function to call as the tool.
212 takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument,
213 this is inferred if unset.
214 max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
215 name: Name of the tool, inferred from the function if `None`.
216 description: Description of the tool, inferred from the function if `None`.
217 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
218 tool from a given step. This is useful if you want to customise a tool at call time,
219 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
220 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
221 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
222 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
223 """
224 if takes_ctx is None:
225 takes_ctx = _pydantic.takes_ctx(function)
227 f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
228 self.function = function
229 self.takes_ctx = takes_ctx
230 self.max_retries = max_retries
231 self.name = name or function.__name__
232 self.description = description or f['description']
233 self.prepare = prepare
234 self.docstring_format = docstring_format
235 self.require_parameter_descriptions = require_parameter_descriptions
236 self._is_async = inspect.iscoroutinefunction(self.function)
237 self._single_arg_name = f['single_arg_name']
238 self._positional_fields = f['positional_fields']
239 self._var_positional_field = f['var_positional_field']
240 self._validator = f['validator']
241 self._parameters_json_schema = f['json_schema']
243 async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
244 """Get the tool definition.
246 By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
247 if it's set.
249 Returns:
250 return a `ToolDefinition` or `None` if the tools should not be registered for this run.
251 """
252 tool_def = ToolDefinition(
253 name=self.name,
254 description=self.description,
255 parameters_json_schema=self._parameters_json_schema,
256 )
257 if self.prepare is not None:
258 return await self.prepare(ctx, tool_def)
259 else:
260 return tool_def
262 async def run(
263 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
264 ) -> _messages.ModelRequestPart:
265 """Run the tool function asynchronously."""
266 try:
267 if isinstance(message.args, str):
268 args_dict = self._validator.validate_json(message.args)
269 else:
270 args_dict = self._validator.validate_python(message.args)
271 except ValidationError as e:
272 return self._on_error(e, message)
274 args, kwargs = self._call_args(args_dict, message, run_context)
275 try:
276 if self._is_async:
277 function = cast(Callable[[Any], Awaitable[str]], self.function)
278 response_content = await function(*args, **kwargs)
279 else:
280 function = cast(Callable[[Any], str], self.function)
281 response_content = await _utils.run_in_executor(function, *args, **kwargs)
282 except ModelRetry as e:
283 return self._on_error(e, message)
285 self.current_retry = 0
286 return _messages.ToolReturnPart(
287 tool_name=message.tool_name,
288 content=response_content,
289 tool_call_id=message.tool_call_id,
290 )
292 def _call_args(
293 self,
294 args_dict: dict[str, Any],
295 message: _messages.ToolCallPart,
296 run_context: RunContext[AgentDepsT],
297 ) -> tuple[list[Any], dict[str, Any]]:
298 if self._single_arg_name:
299 args_dict = {self._single_arg_name: args_dict}
301 ctx = dataclasses.replace(run_context, retry=self.current_retry, tool_name=message.tool_name)
302 args = [ctx] if self.takes_ctx else []
303 for positional_field in self._positional_fields: 303 ↛ 304line 303 didn't jump to line 304 because the loop on line 303 never started
304 args.append(args_dict.pop(positional_field))
305 if self._var_positional_field:
306 args.extend(args_dict.pop(self._var_positional_field))
308 return args, args_dict
310 def _on_error(
311 self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
312 ) -> _messages.RetryPromptPart:
313 self.current_retry += 1
314 if self.max_retries is None or self.current_retry > self.max_retries:
315 raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
316 else:
317 if isinstance(exc, ValidationError): 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true
318 content = exc.errors(include_url=False)
319 else:
320 content = exc.message
321 return _messages.RetryPromptPart(
322 tool_name=call_message.tool_name,
323 content=content,
324 tool_call_id=call_message.tool_call_id,
325 )
328ObjectJsonSchema: TypeAlias = dict[str, Any]
329"""Type representing JSON schema of an object, e.g. where `"type": "object"`.
331This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition].
333With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
334"""
337@dataclass
338class ToolDefinition:
339 """Definition of a tool passed to a model.
341 This is used for both function tools result tools.
342 """
344 name: str
345 """The name of the tool."""
347 description: str
348 """The description of the tool."""
350 parameters_json_schema: ObjectJsonSchema
351 """The JSON schema for the tool's parameters."""
353 outer_typed_dict_key: str | None = None
354 """The key in the outer [TypedDict] that wraps a result tool.
356 This will only be set for result tools which don't have an `object` JSON schema.
357 """