Coverage for pydantic_ai_slim/pydantic_ai/tools.py: 95.77%

163 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import dataclasses 

4import inspect 

5import json 

6from collections.abc import Awaitable, Sequence 

7from dataclasses import dataclass, field 

8from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast 

9 

10from opentelemetry.trace import Tracer 

11from pydantic import ValidationError 

12from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue 

13from pydantic_core import SchemaValidator, core_schema 

14from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar 

15 

16from . import _pydantic, _utils, messages as _messages, models 

17from .exceptions import ModelRetry, UnexpectedModelBehavior 

18 

19if TYPE_CHECKING: 

20 from .result import Usage 

21 

22__all__ = ( 

23 'AgentDepsT', 

24 'DocstringFormat', 

25 'RunContext', 

26 'SystemPromptFunc', 

27 'ToolFuncContext', 

28 'ToolFuncPlain', 

29 'ToolFuncEither', 

30 'ToolParams', 

31 'ToolPrepareFunc', 

32 'Tool', 

33 'ObjectJsonSchema', 

34 'ToolDefinition', 

35) 

36 

37AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) 

38"""Type variable for agent dependencies.""" 

39 

40 

41@dataclasses.dataclass 

42class RunContext(Generic[AgentDepsT]): 

43 """Information about the current call.""" 

44 

45 deps: AgentDepsT 

46 """Dependencies for the agent.""" 

47 model: models.Model 

48 """The model used in this run.""" 

49 usage: Usage 

50 """LLM usage associated with the run.""" 

51 prompt: str | Sequence[_messages.UserContent] 

52 """The original user prompt passed to the run.""" 

53 messages: list[_messages.ModelMessage] = field(default_factory=list) 

54 """Messages exchanged in the conversation so far.""" 

55 tool_call_id: str | None = None 

56 """The ID of the tool call.""" 

57 tool_name: str | None = None 

58 """Name of the tool being called.""" 

59 retry: int = 0 

60 """Number of retries so far.""" 

61 run_step: int = 0 

62 """The current step in the run.""" 

63 

64 def replace_with( 

65 self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET 

66 ) -> RunContext[AgentDepsT]: 

67 # Create a new `RunContext` a new `retry` value and `tool_name`. 

68 kwargs = {} 

69 if retry is not None: 

70 kwargs['retry'] = retry 

71 if tool_name is not _utils.UNSET: 71 ↛ 73line 71 didn't jump to line 73 because the condition on line 71 was always true

72 kwargs['tool_name'] = tool_name 

73 return dataclasses.replace(self, **kwargs) 

74 

75 

76ToolParams = ParamSpec('ToolParams', default=...) 

77"""Retrieval function param spec.""" 

78 

79SystemPromptFunc = Union[ 

80 Callable[[RunContext[AgentDepsT]], str], 

81 Callable[[RunContext[AgentDepsT]], Awaitable[str]], 

82 Callable[[], str], 

83 Callable[[], Awaitable[str]], 

84] 

85"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async. 

86 

87Usage `SystemPromptFunc[AgentDepsT]`. 

88""" 

89 

90ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any] 

91"""A tool function that takes `RunContext` as the first argument. 

92 

93Usage `ToolContextFunc[AgentDepsT, ToolParams]`. 

94""" 

95ToolFuncPlain = Callable[ToolParams, Any] 

96"""A tool function that does not take `RunContext` as the first argument. 

97 

98Usage `ToolPlainFunc[ToolParams]`. 

99""" 

100ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]] 

101"""Either kind of tool function. 

102 

103This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and 

104[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain]. 

105 

106Usage `ToolFuncEither[AgentDepsT, ToolParams]`. 

107""" 

108ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]' 

109"""Definition of a function that can prepare a tool definition at call time. 

110 

111See [tool docs](../tools.md#tool-prepare) for more information. 

112 

113Example — here `only_if_42` is valid as a `ToolPrepareFunc`: 

114 

115```python {noqa="I001"} 

116from typing import Union 

117 

118from pydantic_ai import RunContext, Tool 

119from pydantic_ai.tools import ToolDefinition 

120 

121async def only_if_42( 

122 ctx: RunContext[int], tool_def: ToolDefinition 

123) -> Union[ToolDefinition, None]: 

124 if ctx.deps == 42: 

125 return tool_def 

126 

127def hitchhiker(ctx: RunContext[int], answer: str) -> str: 

128 return f'{ctx.deps} {answer}' 

129 

130hitchhiker = Tool(hitchhiker, prepare=only_if_42) 

131``` 

132 

133Usage `ToolPrepareFunc[AgentDepsT]`. 

134""" 

135 

136DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto'] 

137"""Supported docstring formats. 

138 

139* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings. 

140* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings. 

141* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings. 

142* `'auto'` — Automatically infer the format based on the structure of the docstring. 

143""" 

144 

145A = TypeVar('A') 

146 

147 

148class GenerateToolJsonSchema(GenerateJsonSchema): 

149 def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue: 

150 s = super().typed_dict_schema(schema) 

151 total = schema.get('total') 

152 if total is not None: 152 ↛ 154line 152 didn't jump to line 154 because the condition on line 152 was always true

153 s['additionalProperties'] = not total 

154 return s 

155 

156 def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue: 

157 # Remove largely-useless property titles 

158 s = super()._named_required_fields_schema(named_required_fields) 

159 for p in s.get('properties', {}): 

160 s['properties'][p].pop('title', None) 

161 return s 

162 

163 

164@dataclass(init=False) 

165class Tool(Generic[AgentDepsT]): 

166 """A tool function for an agent.""" 

167 

168 function: ToolFuncEither[AgentDepsT] 

169 takes_ctx: bool 

170 max_retries: int | None 

171 name: str 

172 description: str 

173 prepare: ToolPrepareFunc[AgentDepsT] | None 

174 docstring_format: DocstringFormat 

175 require_parameter_descriptions: bool 

176 _is_async: bool = field(init=False) 

177 _single_arg_name: str | None = field(init=False) 

178 _positional_fields: list[str] = field(init=False) 

179 _var_positional_field: str | None = field(init=False) 

180 _validator: SchemaValidator = field(init=False, repr=False) 

181 _parameters_json_schema: ObjectJsonSchema = field(init=False) 

182 

183 # TODO: Move this state off the Tool class, which is otherwise stateless. 

184 # This should be tracked inside a specific agent run, not the tool. 

185 current_retry: int = field(default=0, init=False) 

186 

187 def __init__( 

188 self, 

189 function: ToolFuncEither[AgentDepsT], 

190 *, 

191 takes_ctx: bool | None = None, 

192 max_retries: int | None = None, 

193 name: str | None = None, 

194 description: str | None = None, 

195 prepare: ToolPrepareFunc[AgentDepsT] | None = None, 

196 docstring_format: DocstringFormat = 'auto', 

197 require_parameter_descriptions: bool = False, 

198 schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, 

199 ): 

200 """Create a new tool instance. 

201 

202 Example usage: 

203 

204 ```python {noqa="I001"} 

205 from pydantic_ai import Agent, RunContext, Tool 

206 

207 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: 

208 return f'{ctx.deps} {x} {y}' 

209 

210 agent = Agent('test', tools=[Tool(my_tool)]) 

211 ``` 

212 

213 or with a custom prepare method: 

214 

215 ```python {noqa="I001"} 

216 from typing import Union 

217 

218 from pydantic_ai import Agent, RunContext, Tool 

219 from pydantic_ai.tools import ToolDefinition 

220 

221 async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: 

222 return f'{ctx.deps} {x} {y}' 

223 

224 async def prep_my_tool( 

225 ctx: RunContext[int], tool_def: ToolDefinition 

226 ) -> Union[ToolDefinition, None]: 

227 # only register the tool if `deps == 42` 

228 if ctx.deps == 42: 

229 return tool_def 

230 

231 agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)]) 

232 ``` 

233 

234 

235 Args: 

236 function: The Python function to call as the tool. 

237 takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument, 

238 this is inferred if unset. 

239 max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`. 

240 name: Name of the tool, inferred from the function if `None`. 

241 description: Description of the tool, inferred from the function if `None`. 

242 prepare: custom method to prepare the tool definition for each step, return `None` to omit this 

243 tool from a given step. This is useful if you want to customise a tool at call time, 

244 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. 

245 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. 

246 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. 

247 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. 

248 schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`. 

249 """ 

250 if takes_ctx is None: 

251 takes_ctx = _pydantic.takes_ctx(function) 

252 

253 f = _pydantic.function_schema( 

254 function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator 

255 ) 

256 self.function = function 

257 self.takes_ctx = takes_ctx 

258 self.max_retries = max_retries 

259 self.name = name or function.__name__ 

260 self.description = description or f['description'] 

261 self.prepare = prepare 

262 self.docstring_format = docstring_format 

263 self.require_parameter_descriptions = require_parameter_descriptions 

264 self._is_async = inspect.iscoroutinefunction(self.function) 

265 self._single_arg_name = f['single_arg_name'] 

266 self._positional_fields = f['positional_fields'] 

267 self._var_positional_field = f['var_positional_field'] 

268 self._validator = f['validator'] 

269 self._parameters_json_schema = f['json_schema'] 

270 

271 async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: 

272 """Get the tool definition. 

273 

274 By default, this method creates a tool definition, then either returns it, or calls `self.prepare` 

275 if it's set. 

276 

277 Returns: 

278 return a `ToolDefinition` or `None` if the tools should not be registered for this run. 

279 """ 

280 tool_def = ToolDefinition( 

281 name=self.name, 

282 description=self.description, 

283 parameters_json_schema=self._parameters_json_schema, 

284 ) 

285 if self.prepare is not None: 

286 return await self.prepare(ctx, tool_def) 

287 else: 

288 return tool_def 

289 

290 async def run( 

291 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer 

292 ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: 

293 """Run the tool function asynchronously. 

294 

295 This method wraps `_run` in an OpenTelemetry span. 

296 

297 See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>. 

298 """ 

299 span_attributes = { 

300 'gen_ai.tool.name': self.name, 

301 # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai 

302 'gen_ai.tool.call.id': message.tool_call_id, 

303 'tool_arguments': message.args_as_json_str(), 

304 'logfire.msg': f'running tool: {self.name}', 

305 # add the JSON schema so these attributes are formatted nicely in Logfire 

306 'logfire.json_schema': json.dumps( 

307 { 

308 'type': 'object', 

309 'properties': { 

310 'tool_arguments': {'type': 'object'}, 

311 'gen_ai.tool.name': {}, 

312 'gen_ai.tool.call.id': {}, 

313 }, 

314 } 

315 ), 

316 } 

317 with tracer.start_as_current_span('running tool', attributes=span_attributes): 

318 return await self._run(message, run_context) 

319 

320 async def _run( 

321 self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] 

322 ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: 

323 try: 

324 if isinstance(message.args, str): 

325 args_dict = self._validator.validate_json(message.args) 

326 else: 

327 args_dict = self._validator.validate_python(message.args) 

328 except ValidationError as e: 

329 return self._on_error(e, message) 

330 

331 args, kwargs = self._call_args(args_dict, message, run_context) 

332 try: 

333 if self._is_async: 

334 function = cast(Callable[[Any], Awaitable[str]], self.function) 

335 response_content = await function(*args, **kwargs) 

336 else: 

337 function = cast(Callable[[Any], str], self.function) 

338 response_content = await _utils.run_in_executor(function, *args, **kwargs) 

339 except ModelRetry as e: 

340 return self._on_error(e, message) 

341 

342 self.current_retry = 0 

343 return _messages.ToolReturnPart( 

344 tool_name=message.tool_name, 

345 content=response_content, 

346 tool_call_id=message.tool_call_id, 

347 ) 

348 

349 def _call_args( 

350 self, 

351 args_dict: dict[str, Any], 

352 message: _messages.ToolCallPart, 

353 run_context: RunContext[AgentDepsT], 

354 ) -> tuple[list[Any], dict[str, Any]]: 

355 if self._single_arg_name: 

356 args_dict = {self._single_arg_name: args_dict} 

357 

358 ctx = dataclasses.replace( 

359 run_context, 

360 retry=self.current_retry, 

361 tool_name=message.tool_name, 

362 tool_call_id=message.tool_call_id, 

363 ) 

364 args = [ctx] if self.takes_ctx else [] 

365 for positional_field in self._positional_fields: 365 ↛ 366line 365 didn't jump to line 366 because the loop on line 365 never started

366 args.append(args_dict.pop(positional_field)) 

367 if self._var_positional_field: 

368 args.extend(args_dict.pop(self._var_positional_field)) 

369 

370 return args, args_dict 

371 

372 def _on_error( 

373 self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart 

374 ) -> _messages.RetryPromptPart: 

375 self.current_retry += 1 

376 if self.max_retries is None or self.current_retry > self.max_retries: 

377 raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc 

378 else: 

379 if isinstance(exc, ValidationError): 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true

380 content = exc.errors(include_url=False) 

381 else: 

382 content = exc.message 

383 return _messages.RetryPromptPart( 

384 tool_name=call_message.tool_name, 

385 content=content, 

386 tool_call_id=call_message.tool_call_id, 

387 ) 

388 

389 

390ObjectJsonSchema: TypeAlias = dict[str, Any] 

391"""Type representing JSON schema of an object, e.g. where `"type": "object"`. 

392 

393This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition]. 

394 

395With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` 

396""" 

397 

398 

399@dataclass 

400class ToolDefinition: 

401 """Definition of a tool passed to a model. 

402 

403 This is used for both function tools result tools. 

404 """ 

405 

406 name: str 

407 """The name of the tool.""" 

408 

409 description: str 

410 """The description of the tool.""" 

411 

412 parameters_json_schema: ObjectJsonSchema 

413 """The JSON schema for the tool's parameters.""" 

414 

415 outer_typed_dict_key: str | None = None 

416 """The key in the outer [TypedDict] that wraps a result tool. 

417 

418 This will only be set for result tools which don't have an `object` JSON schema. 

419 """