Coverage for pydantic_ai_slim/pydantic_ai/tools.py: 95.77%
163 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import dataclasses
4import inspect
5import json
6from collections.abc import Awaitable, Sequence
7from dataclasses import dataclass, field
8from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
10from opentelemetry.trace import Tracer
11from pydantic import ValidationError
12from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
13from pydantic_core import SchemaValidator, core_schema
14from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
16from . import _pydantic, _utils, messages as _messages, models
17from .exceptions import ModelRetry, UnexpectedModelBehavior
19if TYPE_CHECKING:
20 from .result import Usage
22__all__ = (
23 'AgentDepsT',
24 'DocstringFormat',
25 'RunContext',
26 'SystemPromptFunc',
27 'ToolFuncContext',
28 'ToolFuncPlain',
29 'ToolFuncEither',
30 'ToolParams',
31 'ToolPrepareFunc',
32 'Tool',
33 'ObjectJsonSchema',
34 'ToolDefinition',
35)
37AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
38"""Type variable for agent dependencies."""
41@dataclasses.dataclass
42class RunContext(Generic[AgentDepsT]):
43 """Information about the current call."""
45 deps: AgentDepsT
46 """Dependencies for the agent."""
47 model: models.Model
48 """The model used in this run."""
49 usage: Usage
50 """LLM usage associated with the run."""
51 prompt: str | Sequence[_messages.UserContent]
52 """The original user prompt passed to the run."""
53 messages: list[_messages.ModelMessage] = field(default_factory=list)
54 """Messages exchanged in the conversation so far."""
55 tool_call_id: str | None = None
56 """The ID of the tool call."""
57 tool_name: str | None = None
58 """Name of the tool being called."""
59 retry: int = 0
60 """Number of retries so far."""
61 run_step: int = 0
62 """The current step in the run."""
64 def replace_with(
65 self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
66 ) -> RunContext[AgentDepsT]:
67 # Create a new `RunContext` a new `retry` value and `tool_name`.
68 kwargs = {}
69 if retry is not None:
70 kwargs['retry'] = retry
71 if tool_name is not _utils.UNSET: 71 ↛ 73line 71 didn't jump to line 73 because the condition on line 71 was always true
72 kwargs['tool_name'] = tool_name
73 return dataclasses.replace(self, **kwargs)
76ToolParams = ParamSpec('ToolParams', default=...)
77"""Retrieval function param spec."""
79SystemPromptFunc = Union[
80 Callable[[RunContext[AgentDepsT]], str],
81 Callable[[RunContext[AgentDepsT]], Awaitable[str]],
82 Callable[[], str],
83 Callable[[], Awaitable[str]],
84]
85"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
87Usage `SystemPromptFunc[AgentDepsT]`.
88"""
90ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any]
91"""A tool function that takes `RunContext` as the first argument.
93Usage `ToolContextFunc[AgentDepsT, ToolParams]`.
94"""
95ToolFuncPlain = Callable[ToolParams, Any]
96"""A tool function that does not take `RunContext` as the first argument.
98Usage `ToolPlainFunc[ToolParams]`.
99"""
100ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]]
101"""Either kind of tool function.
103This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
104[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
106Usage `ToolFuncEither[AgentDepsT, ToolParams]`.
107"""
108ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]'
109"""Definition of a function that can prepare a tool definition at call time.
111See [tool docs](../tools.md#tool-prepare) for more information.
113Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
115```python {noqa="I001"}
116from typing import Union
118from pydantic_ai import RunContext, Tool
119from pydantic_ai.tools import ToolDefinition
121async def only_if_42(
122 ctx: RunContext[int], tool_def: ToolDefinition
123) -> Union[ToolDefinition, None]:
124 if ctx.deps == 42:
125 return tool_def
127def hitchhiker(ctx: RunContext[int], answer: str) -> str:
128 return f'{ctx.deps} {answer}'
130hitchhiker = Tool(hitchhiker, prepare=only_if_42)
131```
133Usage `ToolPrepareFunc[AgentDepsT]`.
134"""
136DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
137"""Supported docstring formats.
139* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
140* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings.
141* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings.
142* `'auto'` — Automatically infer the format based on the structure of the docstring.
143"""
145A = TypeVar('A')
148class GenerateToolJsonSchema(GenerateJsonSchema):
149 def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
150 s = super().typed_dict_schema(schema)
151 total = schema.get('total')
152 if total is not None: 152 ↛ 154line 152 didn't jump to line 154 because the condition on line 152 was always true
153 s['additionalProperties'] = not total
154 return s
156 def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
157 # Remove largely-useless property titles
158 s = super()._named_required_fields_schema(named_required_fields)
159 for p in s.get('properties', {}):
160 s['properties'][p].pop('title', None)
161 return s
164@dataclass(init=False)
165class Tool(Generic[AgentDepsT]):
166 """A tool function for an agent."""
168 function: ToolFuncEither[AgentDepsT]
169 takes_ctx: bool
170 max_retries: int | None
171 name: str
172 description: str
173 prepare: ToolPrepareFunc[AgentDepsT] | None
174 docstring_format: DocstringFormat
175 require_parameter_descriptions: bool
176 _is_async: bool = field(init=False)
177 _single_arg_name: str | None = field(init=False)
178 _positional_fields: list[str] = field(init=False)
179 _var_positional_field: str | None = field(init=False)
180 _validator: SchemaValidator = field(init=False, repr=False)
181 _parameters_json_schema: ObjectJsonSchema = field(init=False)
183 # TODO: Move this state off the Tool class, which is otherwise stateless.
184 # This should be tracked inside a specific agent run, not the tool.
185 current_retry: int = field(default=0, init=False)
187 def __init__(
188 self,
189 function: ToolFuncEither[AgentDepsT],
190 *,
191 takes_ctx: bool | None = None,
192 max_retries: int | None = None,
193 name: str | None = None,
194 description: str | None = None,
195 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
196 docstring_format: DocstringFormat = 'auto',
197 require_parameter_descriptions: bool = False,
198 schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
199 ):
200 """Create a new tool instance.
202 Example usage:
204 ```python {noqa="I001"}
205 from pydantic_ai import Agent, RunContext, Tool
207 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
208 return f'{ctx.deps} {x} {y}'
210 agent = Agent('test', tools=[Tool(my_tool)])
211 ```
213 or with a custom prepare method:
215 ```python {noqa="I001"}
216 from typing import Union
218 from pydantic_ai import Agent, RunContext, Tool
219 from pydantic_ai.tools import ToolDefinition
221 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
222 return f'{ctx.deps} {x} {y}'
224 async def prep_my_tool(
225 ctx: RunContext[int], tool_def: ToolDefinition
226 ) -> Union[ToolDefinition, None]:
227 # only register the tool if `deps == 42`
228 if ctx.deps == 42:
229 return tool_def
231 agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)])
232 ```
235 Args:
236 function: The Python function to call as the tool.
237 takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument,
238 this is inferred if unset.
239 max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
240 name: Name of the tool, inferred from the function if `None`.
241 description: Description of the tool, inferred from the function if `None`.
242 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
243 tool from a given step. This is useful if you want to customise a tool at call time,
244 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
245 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
246 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
247 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
248 schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
249 """
250 if takes_ctx is None:
251 takes_ctx = _pydantic.takes_ctx(function)
253 f = _pydantic.function_schema(
254 function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator
255 )
256 self.function = function
257 self.takes_ctx = takes_ctx
258 self.max_retries = max_retries
259 self.name = name or function.__name__
260 self.description = description or f['description']
261 self.prepare = prepare
262 self.docstring_format = docstring_format
263 self.require_parameter_descriptions = require_parameter_descriptions
264 self._is_async = inspect.iscoroutinefunction(self.function)
265 self._single_arg_name = f['single_arg_name']
266 self._positional_fields = f['positional_fields']
267 self._var_positional_field = f['var_positional_field']
268 self._validator = f['validator']
269 self._parameters_json_schema = f['json_schema']
271 async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
272 """Get the tool definition.
274 By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
275 if it's set.
277 Returns:
278 return a `ToolDefinition` or `None` if the tools should not be registered for this run.
279 """
280 tool_def = ToolDefinition(
281 name=self.name,
282 description=self.description,
283 parameters_json_schema=self._parameters_json_schema,
284 )
285 if self.prepare is not None:
286 return await self.prepare(ctx, tool_def)
287 else:
288 return tool_def
290 async def run(
291 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer
292 ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
293 """Run the tool function asynchronously.
295 This method wraps `_run` in an OpenTelemetry span.
297 See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>.
298 """
299 span_attributes = {
300 'gen_ai.tool.name': self.name,
301 # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
302 'gen_ai.tool.call.id': message.tool_call_id,
303 'tool_arguments': message.args_as_json_str(),
304 'logfire.msg': f'running tool: {self.name}',
305 # add the JSON schema so these attributes are formatted nicely in Logfire
306 'logfire.json_schema': json.dumps(
307 {
308 'type': 'object',
309 'properties': {
310 'tool_arguments': {'type': 'object'},
311 'gen_ai.tool.name': {},
312 'gen_ai.tool.call.id': {},
313 },
314 }
315 ),
316 }
317 with tracer.start_as_current_span('running tool', attributes=span_attributes):
318 return await self._run(message, run_context)
320 async def _run(
321 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
322 ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
323 try:
324 if isinstance(message.args, str):
325 args_dict = self._validator.validate_json(message.args)
326 else:
327 args_dict = self._validator.validate_python(message.args)
328 except ValidationError as e:
329 return self._on_error(e, message)
331 args, kwargs = self._call_args(args_dict, message, run_context)
332 try:
333 if self._is_async:
334 function = cast(Callable[[Any], Awaitable[str]], self.function)
335 response_content = await function(*args, **kwargs)
336 else:
337 function = cast(Callable[[Any], str], self.function)
338 response_content = await _utils.run_in_executor(function, *args, **kwargs)
339 except ModelRetry as e:
340 return self._on_error(e, message)
342 self.current_retry = 0
343 return _messages.ToolReturnPart(
344 tool_name=message.tool_name,
345 content=response_content,
346 tool_call_id=message.tool_call_id,
347 )
349 def _call_args(
350 self,
351 args_dict: dict[str, Any],
352 message: _messages.ToolCallPart,
353 run_context: RunContext[AgentDepsT],
354 ) -> tuple[list[Any], dict[str, Any]]:
355 if self._single_arg_name:
356 args_dict = {self._single_arg_name: args_dict}
358 ctx = dataclasses.replace(
359 run_context,
360 retry=self.current_retry,
361 tool_name=message.tool_name,
362 tool_call_id=message.tool_call_id,
363 )
364 args = [ctx] if self.takes_ctx else []
365 for positional_field in self._positional_fields: 365 ↛ 366line 365 didn't jump to line 366 because the loop on line 365 never started
366 args.append(args_dict.pop(positional_field))
367 if self._var_positional_field:
368 args.extend(args_dict.pop(self._var_positional_field))
370 return args, args_dict
372 def _on_error(
373 self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
374 ) -> _messages.RetryPromptPart:
375 self.current_retry += 1
376 if self.max_retries is None or self.current_retry > self.max_retries:
377 raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
378 else:
379 if isinstance(exc, ValidationError): 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true
380 content = exc.errors(include_url=False)
381 else:
382 content = exc.message
383 return _messages.RetryPromptPart(
384 tool_name=call_message.tool_name,
385 content=content,
386 tool_call_id=call_message.tool_call_id,
387 )
390ObjectJsonSchema: TypeAlias = dict[str, Any]
391"""Type representing JSON schema of an object, e.g. where `"type": "object"`.
393This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition].
395With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
396"""
399@dataclass
400class ToolDefinition:
401 """Definition of a tool passed to a model.
403 This is used for both function tools result tools.
404 """
406 name: str
407 """The name of the tool."""
409 description: str
410 """The description of the tool."""
412 parameters_json_schema: ObjectJsonSchema
413 """The JSON schema for the tool's parameters."""
415 outer_typed_dict_key: str | None = None
416 """The key in the outer [TypedDict] that wraps a result tool.
418 This will only be set for result tools which don't have an `object` JSON schema.
419 """