Coverage for pydantic_ai_slim/pydantic_ai/tools.py: 95.73%
142 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
1from __future__ import annotations as _annotations
3import dataclasses
4import inspect
5from collections.abc import Awaitable
6from dataclasses import dataclass, field
7from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
9from pydantic import ValidationError
10from pydantic_core import SchemaValidator
11from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
13from . import _pydantic, _utils, messages as _messages, models
14from .exceptions import ModelRetry, UnexpectedModelBehavior
16if TYPE_CHECKING:
17 from .result import Usage
19__all__ = (
20 'AgentDepsT',
21 'DocstringFormat',
22 'RunContext',
23 'SystemPromptFunc',
24 'ToolFuncContext',
25 'ToolFuncPlain',
26 'ToolFuncEither',
27 'ToolParams',
28 'ToolPrepareFunc',
29 'Tool',
30 'ObjectJsonSchema',
31 'ToolDefinition',
32)
34AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
35"""Type variable for agent dependencies."""
38@dataclasses.dataclass
39class RunContext(Generic[AgentDepsT]):
40 """Information about the current call."""
42 deps: AgentDepsT
43 """Dependencies for the agent."""
44 model: models.Model
45 """The model used in this run."""
46 usage: Usage
47 """LLM usage associated with the run."""
48 prompt: str
49 """The original user prompt passed to the run."""
50 messages: list[_messages.ModelMessage] = field(default_factory=list)
51 """Messages exchanged in the conversation so far."""
52 tool_name: str | None = None
53 """Name of the tool being called."""
54 retry: int = 0
55 """Number of retries so far."""
56 run_step: int = 0
57 """The current step in the run."""
59 def replace_with(
60 self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
61 ) -> RunContext[AgentDepsT]:
62 # Create a new `RunContext` a new `retry` value and `tool_name`.
63 kwargs = {}
64 if retry is not None:
65 kwargs['retry'] = retry
66 if tool_name is not _utils.UNSET: 66 ↛ 68line 66 didn't jump to line 68 because the condition on line 66 was always true
67 kwargs['tool_name'] = tool_name
68 return dataclasses.replace(self, **kwargs)
71ToolParams = ParamSpec('ToolParams', default=...)
72"""Retrieval function param spec."""
74SystemPromptFunc = Union[
75 Callable[[RunContext[AgentDepsT]], str],
76 Callable[[RunContext[AgentDepsT]], Awaitable[str]],
77 Callable[[], str],
78 Callable[[], Awaitable[str]],
79]
80"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
82Usage `SystemPromptFunc[AgentDepsT]`.
83"""
85ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any]
86"""A tool function that takes `RunContext` as the first argument.
88Usage `ToolContextFunc[AgentDepsT, ToolParams]`.
89"""
90ToolFuncPlain = Callable[ToolParams, Any]
91"""A tool function that does not take `RunContext` as the first argument.
93Usage `ToolPlainFunc[ToolParams]`.
94"""
95ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]]
96"""Either kind of tool function.
98This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
99[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
101Usage `ToolFuncEither[AgentDepsT, ToolParams]`.
102"""
103ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]'
104"""Definition of a function that can prepare a tool definition at call time.
106See [tool docs](../tools.md#tool-prepare) for more information.
108Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
110```python {noqa="I001"}
111from typing import Union
113from pydantic_ai import RunContext, Tool
114from pydantic_ai.tools import ToolDefinition
116async def only_if_42(
117 ctx: RunContext[int], tool_def: ToolDefinition
118) -> Union[ToolDefinition, None]:
119 if ctx.deps == 42:
120 return tool_def
122def hitchhiker(ctx: RunContext[int], answer: str) -> str:
123 return f'{ctx.deps} {answer}'
125hitchhiker = Tool(hitchhiker, prepare=only_if_42)
126```
128Usage `ToolPrepareFunc[AgentDepsT]`.
129"""
131DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
132"""Supported docstring formats.
134* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
135* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings.
136* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings.
137* `'auto'` — Automatically infer the format based on the structure of the docstring.
138"""
140A = TypeVar('A')
143@dataclass(init=False)
144class Tool(Generic[AgentDepsT]):
145 """A tool function for an agent."""
147 function: ToolFuncEither[AgentDepsT]
148 takes_ctx: bool
149 max_retries: int | None
150 name: str
151 description: str
152 prepare: ToolPrepareFunc[AgentDepsT] | None
153 docstring_format: DocstringFormat
154 require_parameter_descriptions: bool
155 _is_async: bool = field(init=False)
156 _single_arg_name: str | None = field(init=False)
157 _positional_fields: list[str] = field(init=False)
158 _var_positional_field: str | None = field(init=False)
159 _validator: SchemaValidator = field(init=False, repr=False)
160 _parameters_json_schema: ObjectJsonSchema = field(init=False)
162 # TODO: Move this state off the Tool class, which is otherwise stateless.
163 # This should be tracked inside a specific agent run, not the tool.
164 current_retry: int = field(default=0, init=False)
166 def __init__(
167 self,
168 function: ToolFuncEither[AgentDepsT],
169 *,
170 takes_ctx: bool | None = None,
171 max_retries: int | None = None,
172 name: str | None = None,
173 description: str | None = None,
174 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
175 docstring_format: DocstringFormat = 'auto',
176 require_parameter_descriptions: bool = False,
177 ):
178 """Create a new tool instance.
180 Example usage:
182 ```python {noqa="I001"}
183 from pydantic_ai import Agent, RunContext, Tool
185 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
186 return f'{ctx.deps} {x} {y}'
188 agent = Agent('test', tools=[Tool(my_tool)])
189 ```
191 or with a custom prepare method:
193 ```python {noqa="I001"}
194 from typing import Union
196 from pydantic_ai import Agent, RunContext, Tool
197 from pydantic_ai.tools import ToolDefinition
199 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
200 return f'{ctx.deps} {x} {y}'
202 async def prep_my_tool(
203 ctx: RunContext[int], tool_def: ToolDefinition
204 ) -> Union[ToolDefinition, None]:
205 # only register the tool if `deps == 42`
206 if ctx.deps == 42:
207 return tool_def
209 agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)])
210 ```
213 Args:
214 function: The Python function to call as the tool.
215 takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument,
216 this is inferred if unset.
217 max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
218 name: Name of the tool, inferred from the function if `None`.
219 description: Description of the tool, inferred from the function if `None`.
220 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
221 tool from a given step. This is useful if you want to customise a tool at call time,
222 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
223 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
224 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
225 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
226 """
227 if takes_ctx is None:
228 takes_ctx = _pydantic.takes_ctx(function)
230 f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
231 self.function = function
232 self.takes_ctx = takes_ctx
233 self.max_retries = max_retries
234 self.name = name or function.__name__
235 self.description = description or f['description']
236 self.prepare = prepare
237 self.docstring_format = docstring_format
238 self.require_parameter_descriptions = require_parameter_descriptions
239 self._is_async = inspect.iscoroutinefunction(self.function)
240 self._single_arg_name = f['single_arg_name']
241 self._positional_fields = f['positional_fields']
242 self._var_positional_field = f['var_positional_field']
243 self._validator = f['validator']
244 self._parameters_json_schema = f['json_schema']
246 async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
247 """Get the tool definition.
249 By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
250 if it's set.
252 Returns:
253 return a `ToolDefinition` or `None` if the tools should not be registered for this run.
254 """
255 tool_def = ToolDefinition(
256 name=self.name,
257 description=self.description,
258 parameters_json_schema=self._parameters_json_schema,
259 )
260 if self.prepare is not None:
261 return await self.prepare(ctx, tool_def)
262 else:
263 return tool_def
265 async def run(
266 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
267 ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
268 """Run the tool function asynchronously."""
269 try:
270 if isinstance(message.args, str):
271 args_dict = self._validator.validate_json(message.args)
272 else:
273 args_dict = self._validator.validate_python(message.args)
274 except ValidationError as e:
275 return self._on_error(e, message)
277 args, kwargs = self._call_args(args_dict, message, run_context)
278 try:
279 if self._is_async:
280 function = cast(Callable[[Any], Awaitable[str]], self.function)
281 response_content = await function(*args, **kwargs)
282 else:
283 function = cast(Callable[[Any], str], self.function)
284 response_content = await _utils.run_in_executor(function, *args, **kwargs)
285 except ModelRetry as e:
286 return self._on_error(e, message)
288 self.current_retry = 0
289 return _messages.ToolReturnPart(
290 tool_name=message.tool_name,
291 content=response_content,
292 tool_call_id=message.tool_call_id,
293 )
295 def _call_args(
296 self,
297 args_dict: dict[str, Any],
298 message: _messages.ToolCallPart,
299 run_context: RunContext[AgentDepsT],
300 ) -> tuple[list[Any], dict[str, Any]]:
301 if self._single_arg_name:
302 args_dict = {self._single_arg_name: args_dict}
304 ctx = dataclasses.replace(run_context, retry=self.current_retry, tool_name=message.tool_name)
305 args = [ctx] if self.takes_ctx else []
306 for positional_field in self._positional_fields: 306 ↛ 307line 306 didn't jump to line 307 because the loop on line 306 never started
307 args.append(args_dict.pop(positional_field))
308 if self._var_positional_field:
309 args.extend(args_dict.pop(self._var_positional_field))
311 return args, args_dict
313 def _on_error(
314 self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
315 ) -> _messages.RetryPromptPart:
316 self.current_retry += 1
317 if self.max_retries is None or self.current_retry > self.max_retries:
318 raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
319 else:
320 if isinstance(exc, ValidationError): 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true
321 content = exc.errors(include_url=False)
322 else:
323 content = exc.message
324 return _messages.RetryPromptPart(
325 tool_name=call_message.tool_name,
326 content=content,
327 tool_call_id=call_message.tool_call_id,
328 )
331ObjectJsonSchema: TypeAlias = dict[str, Any]
332"""Type representing JSON schema of an object, e.g. where `"type": "object"`.
334This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition].
336With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
337"""
340@dataclass
341class ToolDefinition:
342 """Definition of a tool passed to a model.
344 This is used for both function tools result tools.
345 """
347 name: str
348 """The name of the tool."""
350 description: str
351 """The description of the tool."""
353 parameters_json_schema: ObjectJsonSchema
354 """The JSON schema for the tool's parameters."""
356 outer_typed_dict_key: str | None = None
357 """The key in the outer [TypedDict] that wraps a result tool.
359 This will only be set for result tools which don't have an `object` JSON schema.
360 """