Coverage for pydantic_ai_slim/pydantic_ai/agent.py: 99.01%

232 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-30 19:21 +0000

1from __future__ import annotations as _annotations 

2 

3import asyncio 

4import dataclasses 

5import inspect 

6from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence 

7from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager 

8from types import FrameType 

9from typing import Any, Callable, Generic, cast, final, overload 

10 

11import logfire_api 

12from typing_extensions import TypeVar, deprecated 

13 

14from pydantic_graph import Graph, GraphRunContext, HistoryStep 

15from pydantic_graph.nodes import End 

16 

17from . import ( 

18 _agent_graph, 

19 _result, 

20 _system_prompt, 

21 _utils, 

22 exceptions, 

23 messages as _messages, 

24 models, 

25 result, 

26 usage as _usage, 

27) 

28from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export 

29from .result import ResultDataT 

30from .settings import ModelSettings, merge_model_settings 

31from .tools import ( 

32 AgentDepsT, 

33 DocstringFormat, 

34 RunContext, 

35 Tool, 

36 ToolFuncContext, 

37 ToolFuncEither, 

38 ToolFuncPlain, 

39 ToolParams, 

40 ToolPrepareFunc, 

41) 

42 

43__all__ = 'Agent', 'capture_run_messages', 'EndStrategy' 

44 

45_logfire = logfire_api.Logfire(otel_scope='pydantic-ai') 

46 

47# while waiting for https://github.com/pydantic/logfire/issues/745 

48try: 

49 import logfire._internal.stack_info 

50except ImportError: 

51 pass 

52else: 

53 from pathlib import Path 

54 

55 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) 

56 

57T = TypeVar('T') 

58NoneType = type(None) 

59RunResultDataT = TypeVar('RunResultDataT') 

60"""Type variable for the result data of a run where `result_type` was customized on the run call.""" 

61 

62 

63@final 

64@dataclasses.dataclass(init=False) 

65class Agent(Generic[AgentDepsT, ResultDataT]): 

66 """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. 

67 

68 Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] 

69 and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT]. 

70 

71 By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. 

72 

73 Minimal usage example: 

74 

75 ```python 

76 from pydantic_ai import Agent 

77 

78 agent = Agent('openai:gpt-4o') 

79 result = agent.run_sync('What is the capital of France?') 

80 print(result.data) 

81 #> Paris 

82 ``` 

83 """ 

84 

85 # we use dataclass fields in order to conveniently know what attributes are available 

86 model: models.Model | models.KnownModelName | None 

87 """The default model configured for this agent.""" 

88 

89 name: str | None 

90 """The name of the agent, used for logging. 

91 

92 If `None`, we try to infer the agent name from the call frame when the agent is first run. 

93 """ 

94 end_strategy: EndStrategy 

95 """Strategy for handling tool calls when a final result is found.""" 

96 

97 model_settings: ModelSettings | None 

98 """Optional model request settings to use for this agents's runs, by default. 

99 

100 Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will 

101 be merged with this value, with the runtime argument taking priority. 

102 """ 

103 

104 result_type: type[ResultDataT] = dataclasses.field(repr=False) 

105 """ 

106 The type of the result data, used to validate the result data, defaults to `str`. 

107 """ 

108 

109 _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) 

110 _result_tool_name: str = dataclasses.field(repr=False) 

111 _result_tool_description: str | None = dataclasses.field(repr=False) 

112 _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False) 

113 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False) 

114 _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) 

115 _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) 

116 _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( 

117 repr=False 

118 ) 

119 _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False) 

120 _default_retries: int = dataclasses.field(repr=False) 

121 _max_result_retries: int = dataclasses.field(repr=False) 

122 _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False) 

123 _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False) 

124 

125 def __init__( 

126 self, 

127 model: models.Model | models.KnownModelName | None = None, 

128 *, 

129 result_type: type[ResultDataT] = str, 

130 system_prompt: str | Sequence[str] = (), 

131 deps_type: type[AgentDepsT] = NoneType, 

132 name: str | None = None, 

133 model_settings: ModelSettings | None = None, 

134 retries: int = 1, 

135 result_tool_name: str = 'final_result', 

136 result_tool_description: str | None = None, 

137 result_retries: int | None = None, 

138 tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), 

139 defer_model_check: bool = False, 

140 end_strategy: EndStrategy = 'early', 

141 ): 

142 """Create an agent. 

143 

144 Args: 

145 model: The default model to use for this agent, if not provide, 

146 you must provide the model when calling it. 

147 result_type: The type of the result data, used to validate the result data, defaults to `str`. 

148 system_prompt: Static system prompts to use for this agent, you can also register system 

149 prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. 

150 deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully 

151 parameterize the agent, and therefore get the best out of static type checking. 

152 If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright 

153 or add a type hint `: Agent[None, <return type>]`. 

154 name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame 

155 when the agent is first run. 

156 model_settings: Optional model request settings to use for this agent's runs, by default. 

157 retries: The default number of retries to allow before raising an error. 

158 result_tool_name: The name of the tool to use for the final result. 

159 result_tool_description: The description of the final result tool. 

160 result_retries: The maximum number of retries to allow for result validation, defaults to `retries`. 

161 tools: Tools to register with the agent, you can also register tools via the decorators 

162 [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. 

163 defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, 

164 it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, 

165 which checks for the necessary environment variables. Set this to `false` 

166 to defer the evaluation until the first run. Useful if you want to 

167 [override the model][pydantic_ai.Agent.override] for testing. 

168 end_strategy: Strategy for handling tool calls that are requested alongside a final result. 

169 See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. 

170 """ 

171 if model is None or defer_model_check: 

172 self.model = model 

173 else: 

174 self.model = models.infer_model(model) 

175 

176 self.end_strategy = end_strategy 

177 self.name = name 

178 self.model_settings = model_settings 

179 self.result_type = result_type 

180 

181 self._deps_type = deps_type 

182 

183 self._result_tool_name = result_tool_name 

184 self._result_tool_description = result_tool_description 

185 self._result_schema: _result.ResultSchema[ResultDataT] | None = _result.ResultSchema[result_type].build( 

186 result_type, result_tool_name, result_tool_description 

187 ) 

188 self._result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = [] 

189 

190 self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) 

191 self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = [] 

192 self._system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = {} 

193 

194 self._function_tools: dict[str, Tool[AgentDepsT]] = {} 

195 

196 self._default_retries = retries 

197 self._max_result_retries = result_retries if result_retries is not None else retries 

198 for tool in tools: 

199 if isinstance(tool, Tool): 

200 self._register_tool(tool) 

201 else: 

202 self._register_tool(Tool(tool)) 

203 

204 @overload 

205 async def run( 

206 self, 

207 user_prompt: str, 

208 *, 

209 result_type: None = None, 

210 message_history: list[_messages.ModelMessage] | None = None, 

211 model: models.Model | models.KnownModelName | None = None, 

212 deps: AgentDepsT = None, 

213 model_settings: ModelSettings | None = None, 

214 usage_limits: _usage.UsageLimits | None = None, 

215 usage: _usage.Usage | None = None, 

216 infer_name: bool = True, 

217 ) -> result.RunResult[ResultDataT]: ... 

218 

219 @overload 

220 async def run( 

221 self, 

222 user_prompt: str, 

223 *, 

224 result_type: type[RunResultDataT], 

225 message_history: list[_messages.ModelMessage] | None = None, 

226 model: models.Model | models.KnownModelName | None = None, 

227 deps: AgentDepsT = None, 

228 model_settings: ModelSettings | None = None, 

229 usage_limits: _usage.UsageLimits | None = None, 

230 usage: _usage.Usage | None = None, 

231 infer_name: bool = True, 

232 ) -> result.RunResult[RunResultDataT]: ... 

233 

234 async def run( 

235 self, 

236 user_prompt: str, 

237 *, 

238 message_history: list[_messages.ModelMessage] | None = None, 

239 model: models.Model | models.KnownModelName | None = None, 

240 deps: AgentDepsT = None, 

241 model_settings: ModelSettings | None = None, 

242 usage_limits: _usage.UsageLimits | None = None, 

243 usage: _usage.Usage | None = None, 

244 result_type: type[RunResultDataT] | None = None, 

245 infer_name: bool = True, 

246 ) -> result.RunResult[Any]: 

247 """Run the agent with a user prompt in async mode. 

248 

249 Example: 

250 ```python 

251 from pydantic_ai import Agent 

252 

253 agent = Agent('openai:gpt-4o') 

254 

255 async def main(): 

256 result = await agent.run('What is the capital of France?') 

257 print(result.data) 

258 #> Paris 

259 ``` 

260 

261 Args: 

262 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no 

263 result validators since result validators would expect an argument that matches the agent's result type. 

264 user_prompt: User input to start/continue the conversation. 

265 message_history: History of the conversation so far. 

266 model: Optional model to use for this run, required if `model` was not set when creating the agent. 

267 deps: Optional dependencies to use for this run. 

268 model_settings: Optional settings to use for this model's request. 

269 usage_limits: Optional limits on model request count or token usage. 

270 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. 

271 infer_name: Whether to try to infer the agent name from the call frame if it's not set. 

272 

273 Returns: 

274 The result of the run. 

275 """ 

276 if infer_name and self.name is None: 

277 self._infer_name(inspect.currentframe()) 

278 model_used = await self._get_model(model) 

279 

280 deps = self._get_deps(deps) 

281 new_message_index = len(message_history) if message_history else 0 

282 result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type) 

283 

284 # Build the graph 

285 graph = _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type) 

286 

287 # Build the initial state 

288 state = _agent_graph.GraphAgentState( 

289 message_history=message_history[:] if message_history else [], 

290 usage=usage or _usage.Usage(), 

291 retries=0, 

292 run_step=0, 

293 ) 

294 

295 # We consider it a user error if a user tries to restrict the result type while having a result validator that 

296 # may change the result type from the restricted type to something else. Therefore, we consider the following 

297 # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. 

298 result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators) 

299 

300 # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent 

301 # runs. Requires some changes to `Tool` to make them copyable though. 

302 for v in self._function_tools.values(): 

303 v.current_retry = 0 

304 

305 model_settings = merge_model_settings(self.model_settings, model_settings) 

306 usage_limits = usage_limits or _usage.UsageLimits() 

307 

308 with _logfire.span( 

309 '{agent_name} run {prompt=}', 

310 prompt=user_prompt, 

311 agent=self, 

312 model_name=model_used.name() if model_used else 'no-model', 

313 agent_name=self.name or 'agent', 

314 ) as run_span: 

315 # Build the deps object for the graph 

316 graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( 

317 user_deps=deps, 

318 prompt=user_prompt, 

319 new_message_index=new_message_index, 

320 model=model_used, 

321 model_settings=model_settings, 

322 usage_limits=usage_limits, 

323 max_result_retries=self._max_result_retries, 

324 end_strategy=self.end_strategy, 

325 result_schema=result_schema, 

326 result_tools=self._result_schema.tool_defs() if self._result_schema else [], 

327 result_validators=result_validators, 

328 function_tools=self._function_tools, 

329 run_span=run_span, 

330 ) 

331 

332 start_node = _agent_graph.UserPromptNode[AgentDepsT]( 

333 user_prompt=user_prompt, 

334 system_prompts=self._system_prompts, 

335 system_prompt_functions=self._system_prompt_functions, 

336 system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, 

337 ) 

338 

339 # Actually run 

340 end_result, _ = await graph.run( 

341 start_node, 

342 state=state, 

343 deps=graph_deps, 

344 infer_name=False, 

345 ) 

346 

347 # Build final run result 

348 # We don't do any advanced checking if the data is actually from a final result or not 

349 return result.RunResult( 

350 state.message_history, 

351 new_message_index, 

352 end_result.data, 

353 end_result.tool_name, 

354 state.usage, 

355 ) 

356 

357 @overload 

358 def run_sync( 

359 self, 

360 user_prompt: str, 

361 *, 

362 message_history: list[_messages.ModelMessage] | None = None, 

363 model: models.Model | models.KnownModelName | None = None, 

364 deps: AgentDepsT = None, 

365 model_settings: ModelSettings | None = None, 

366 usage_limits: _usage.UsageLimits | None = None, 

367 usage: _usage.Usage | None = None, 

368 infer_name: bool = True, 

369 ) -> result.RunResult[ResultDataT]: ... 

370 

371 @overload 

372 def run_sync( 

373 self, 

374 user_prompt: str, 

375 *, 

376 result_type: type[RunResultDataT] | None, 

377 message_history: list[_messages.ModelMessage] | None = None, 

378 model: models.Model | models.KnownModelName | None = None, 

379 deps: AgentDepsT = None, 

380 model_settings: ModelSettings | None = None, 

381 usage_limits: _usage.UsageLimits | None = None, 

382 usage: _usage.Usage | None = None, 

383 infer_name: bool = True, 

384 ) -> result.RunResult[RunResultDataT]: ... 

385 

386 def run_sync( 

387 self, 

388 user_prompt: str, 

389 *, 

390 result_type: type[RunResultDataT] | None = None, 

391 message_history: list[_messages.ModelMessage] | None = None, 

392 model: models.Model | models.KnownModelName | None = None, 

393 deps: AgentDepsT = None, 

394 model_settings: ModelSettings | None = None, 

395 usage_limits: _usage.UsageLimits | None = None, 

396 usage: _usage.Usage | None = None, 

397 infer_name: bool = True, 

398 ) -> result.RunResult[Any]: 

399 """Run the agent with a user prompt synchronously. 

400 

401 This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. 

402 You therefore can't use this method inside async code or if there's an active event loop. 

403 

404 Example: 

405 ```python 

406 from pydantic_ai import Agent 

407 

408 agent = Agent('openai:gpt-4o') 

409 

410 result_sync = agent.run_sync('What is the capital of Italy?') 

411 print(result_sync.data) 

412 #> Rome 

413 ``` 

414 

415 Args: 

416 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no 

417 result validators since result validators would expect an argument that matches the agent's result type. 

418 user_prompt: User input to start/continue the conversation. 

419 message_history: History of the conversation so far. 

420 model: Optional model to use for this run, required if `model` was not set when creating the agent. 

421 deps: Optional dependencies to use for this run. 

422 model_settings: Optional settings to use for this model's request. 

423 usage_limits: Optional limits on model request count or token usage. 

424 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. 

425 infer_name: Whether to try to infer the agent name from the call frame if it's not set. 

426 

427 Returns: 

428 The result of the run. 

429 """ 

430 if infer_name and self.name is None: 

431 self._infer_name(inspect.currentframe()) 

432 return asyncio.get_event_loop().run_until_complete( 

433 self.run( 

434 user_prompt, 

435 result_type=result_type, 

436 message_history=message_history, 

437 model=model, 

438 deps=deps, 

439 model_settings=model_settings, 

440 usage_limits=usage_limits, 

441 usage=usage, 

442 infer_name=False, 

443 ) 

444 ) 

445 

446 @overload 

447 def run_stream( 

448 self, 

449 user_prompt: str, 

450 *, 

451 result_type: None = None, 

452 message_history: list[_messages.ModelMessage] | None = None, 

453 model: models.Model | models.KnownModelName | None = None, 

454 deps: AgentDepsT = None, 

455 model_settings: ModelSettings | None = None, 

456 usage_limits: _usage.UsageLimits | None = None, 

457 usage: _usage.Usage | None = None, 

458 infer_name: bool = True, 

459 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ... 

460 

461 @overload 

462 def run_stream( 

463 self, 

464 user_prompt: str, 

465 *, 

466 result_type: type[RunResultDataT], 

467 message_history: list[_messages.ModelMessage] | None = None, 

468 model: models.Model | models.KnownModelName | None = None, 

469 deps: AgentDepsT = None, 

470 model_settings: ModelSettings | None = None, 

471 usage_limits: _usage.UsageLimits | None = None, 

472 usage: _usage.Usage | None = None, 

473 infer_name: bool = True, 

474 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ... 

475 

476 @asynccontextmanager 

477 async def run_stream( 

478 self, 

479 user_prompt: str, 

480 *, 

481 result_type: type[RunResultDataT] | None = None, 

482 message_history: list[_messages.ModelMessage] | None = None, 

483 model: models.Model | models.KnownModelName | None = None, 

484 deps: AgentDepsT = None, 

485 model_settings: ModelSettings | None = None, 

486 usage_limits: _usage.UsageLimits | None = None, 

487 usage: _usage.Usage | None = None, 

488 infer_name: bool = True, 

489 ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: 

490 """Run the agent with a user prompt in async mode, returning a streamed response. 

491 

492 Example: 

493 ```python 

494 from pydantic_ai import Agent 

495 

496 agent = Agent('openai:gpt-4o') 

497 

498 async def main(): 

499 async with agent.run_stream('What is the capital of the UK?') as response: 

500 print(await response.get_data()) 

501 #> London 

502 ``` 

503 

504 Args: 

505 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no 

506 result validators since result validators would expect an argument that matches the agent's result type. 

507 user_prompt: User input to start/continue the conversation. 

508 message_history: History of the conversation so far. 

509 model: Optional model to use for this run, required if `model` was not set when creating the agent. 

510 deps: Optional dependencies to use for this run. 

511 model_settings: Optional settings to use for this model's request. 

512 usage_limits: Optional limits on model request count or token usage. 

513 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. 

514 infer_name: Whether to try to infer the agent name from the call frame if it's not set. 

515 

516 Returns: 

517 The result of the run. 

518 """ 

519 if infer_name and self.name is None: 

520 # f_back because `asynccontextmanager` adds one frame 

521 if frame := inspect.currentframe(): # pragma: no branch 

522 self._infer_name(frame.f_back) 

523 model_used = await self._get_model(model) 

524 

525 deps = self._get_deps(deps) 

526 new_message_index = len(message_history) if message_history else 0 

527 result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type) 

528 

529 # Build the graph 

530 graph = self._build_stream_graph(result_type) 

531 

532 # Build the initial state 

533 graph_state = _agent_graph.GraphAgentState( 

534 message_history=message_history[:] if message_history else [], 

535 usage=usage or _usage.Usage(), 

536 retries=0, 

537 run_step=0, 

538 ) 

539 

540 # We consider it a user error if a user tries to restrict the result type while having a result validator that 

541 # may change the result type from the restricted type to something else. Therefore, we consider the following 

542 # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. 

543 result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators) 

544 

545 # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent 

546 # runs. Requires some changes to `Tool` to make them copyable though. 

547 for v in self._function_tools.values(): 

548 v.current_retry = 0 

549 

550 model_settings = merge_model_settings(self.model_settings, model_settings) 

551 usage_limits = usage_limits or _usage.UsageLimits() 

552 

553 with _logfire.span( 

554 '{agent_name} run stream {prompt=}', 

555 prompt=user_prompt, 

556 agent=self, 

557 model_name=model_used.name(), 

558 agent_name=self.name or 'agent', 

559 ) as run_span: 

560 # Build the deps object for the graph 

561 graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( 

562 user_deps=deps, 

563 prompt=user_prompt, 

564 new_message_index=new_message_index, 

565 model=model_used, 

566 model_settings=model_settings, 

567 usage_limits=usage_limits, 

568 max_result_retries=self._max_result_retries, 

569 end_strategy=self.end_strategy, 

570 result_schema=result_schema, 

571 result_tools=self._result_schema.tool_defs() if self._result_schema else [], 

572 result_validators=result_validators, 

573 function_tools=self._function_tools, 

574 run_span=run_span, 

575 ) 

576 

577 start_node = _agent_graph.StreamUserPromptNode[AgentDepsT]( 

578 user_prompt=user_prompt, 

579 system_prompts=self._system_prompts, 

580 system_prompt_functions=self._system_prompt_functions, 

581 system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, 

582 ) 

583 

584 # Actually run 

585 node = start_node 

586 history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = [] 

587 while True: 

588 if isinstance(node, _agent_graph.StreamModelRequestNode): 

589 node = cast( 

590 _agent_graph.StreamModelRequestNode[ 

591 AgentDepsT, result.StreamedRunResult[AgentDepsT, RunResultDataT] 

592 ], 

593 node, 

594 ) 

595 async with node.run_to_result(GraphRunContext(graph_state, graph_deps)) as r: 

596 if isinstance(r, End): 

597 yield r.data 

598 break 

599 assert not isinstance(node, End) # the previous line should be hit first 

600 node = await graph.next( 

601 node, 

602 history, 

603 state=graph_state, 

604 deps=graph_deps, 

605 infer_name=False, 

606 ) 

607 

608 @contextmanager 

609 def override( 

610 self, 

611 *, 

612 deps: AgentDepsT | _utils.Unset = _utils.UNSET, 

613 model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET, 

614 ) -> Iterator[None]: 

615 """Context manager to temporarily override agent dependencies and model. 

616 

617 This is particularly useful when testing. 

618 You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures). 

619 

620 Args: 

621 deps: The dependencies to use instead of the dependencies passed to the agent run. 

622 model: The model to use instead of the model passed to the agent run. 

623 """ 

624 if _utils.is_set(deps): 

625 override_deps_before = self._override_deps 

626 self._override_deps = _utils.Some(deps) 

627 else: 

628 override_deps_before = _utils.UNSET 

629 

630 # noinspection PyTypeChecker 

631 if _utils.is_set(model): 

632 override_model_before = self._override_model 

633 # noinspection PyTypeChecker 

634 self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType] 

635 else: 

636 override_model_before = _utils.UNSET 

637 

638 try: 

639 yield 

640 finally: 

641 if _utils.is_set(override_deps_before): 

642 self._override_deps = override_deps_before 

643 if _utils.is_set(override_model_before): 

644 self._override_model = override_model_before 

645 

646 @overload 

647 def system_prompt( 

648 self, func: Callable[[RunContext[AgentDepsT]], str], / 

649 ) -> Callable[[RunContext[AgentDepsT]], str]: ... 

650 

651 @overload 

652 def system_prompt( 

653 self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], / 

654 ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ... 

655 

656 @overload 

657 def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ... 

658 

659 @overload 

660 def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ... 

661 

662 @overload 

663 def system_prompt( 

664 self, /, *, dynamic: bool = False 

665 ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ... 

666 

667 def system_prompt( 

668 self, 

669 func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, 

670 /, 

671 *, 

672 dynamic: bool = False, 

673 ) -> ( 

674 Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] 

675 | _system_prompt.SystemPromptFunc[AgentDepsT] 

676 ): 

677 """Decorator to register a system prompt function. 

678 

679 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. 

680 Can decorate a sync or async functions. 

681 

682 The decorator can be used either bare (`agent.system_prompt`) or as a function call 

683 (`agent.system_prompt(...)`), see the examples below. 

684 

685 Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure 

686 the type of the function, see `tests/typed_agent.py` for tests. 

687 

688 Args: 

689 func: The function to decorate 

690 dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided, 

691 see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref] 

692 

693 Example: 

694 ```python 

695 from pydantic_ai import Agent, RunContext 

696 

697 agent = Agent('test', deps_type=str) 

698 

699 @agent.system_prompt 

700 def simple_system_prompt() -> str: 

701 return 'foobar' 

702 

703 @agent.system_prompt(dynamic=True) 

704 async def async_system_prompt(ctx: RunContext[str]) -> str: 

705 return f'{ctx.deps} is the best' 

706 ``` 

707 """ 

708 if func is None: 

709 

710 def decorator( 

711 func_: _system_prompt.SystemPromptFunc[AgentDepsT], 

712 ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: 

713 runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic) 

714 self._system_prompt_functions.append(runner) 

715 if dynamic: 715 ↛ 717line 715 didn't jump to line 717 because the condition on line 715 was always true

716 self._system_prompt_dynamic_functions[func_.__qualname__] = runner 

717 return func_ 

718 

719 return decorator 

720 else: 

721 assert not dynamic, "dynamic can't be True in this case" 

722 self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic)) 

723 return func 

724 

725 @overload 

726 def result_validator( 

727 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], / 

728 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ... 

729 

730 @overload 

731 def result_validator( 

732 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], / 

733 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ... 

734 

735 @overload 

736 def result_validator( 

737 self, func: Callable[[ResultDataT], ResultDataT], / 

738 ) -> Callable[[ResultDataT], ResultDataT]: ... 

739 

740 @overload 

741 def result_validator( 

742 self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], / 

743 ) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ... 

744 

745 def result_validator( 

746 self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], / 

747 ) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]: 

748 """Decorator to register a result validator function. 

749 

750 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. 

751 Can decorate a sync or async functions. 

752 

753 Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure 

754 the type of the function, see `tests/typed_agent.py` for tests. 

755 

756 Example: 

757 ```python 

758 from pydantic_ai import Agent, ModelRetry, RunContext 

759 

760 agent = Agent('test', deps_type=str) 

761 

762 @agent.result_validator 

763 def result_validator_simple(data: str) -> str: 

764 if 'wrong' in data: 

765 raise ModelRetry('wrong response') 

766 return data 

767 

768 @agent.result_validator 

769 async def result_validator_deps(ctx: RunContext[str], data: str) -> str: 

770 if ctx.deps in data: 

771 raise ModelRetry('wrong response') 

772 return data 

773 

774 result = agent.run_sync('foobar', deps='spam') 

775 print(result.data) 

776 #> success (no tool calls) 

777 ``` 

778 """ 

779 self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func)) 

780 return func 

781 

782 @overload 

783 def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ... 

784 

785 @overload 

786 def tool( 

787 self, 

788 /, 

789 *, 

790 retries: int | None = None, 

791 prepare: ToolPrepareFunc[AgentDepsT] | None = None, 

792 docstring_format: DocstringFormat = 'auto', 

793 require_parameter_descriptions: bool = False, 

794 ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ... 

795 

796 def tool( 

797 self, 

798 func: ToolFuncContext[AgentDepsT, ToolParams] | None = None, 

799 /, 

800 *, 

801 retries: int | None = None, 

802 prepare: ToolPrepareFunc[AgentDepsT] | None = None, 

803 docstring_format: DocstringFormat = 'auto', 

804 require_parameter_descriptions: bool = False, 

805 ) -> Any: 

806 """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. 

807 

808 Can decorate a sync or async functions. 

809 

810 The docstring is inspected to extract both the tool description and description of each parameter, 

811 [learn more](../tools.md#function-tools-and-schema). 

812 

813 We can't add overloads for every possible signature of tool, since the return type is a recursive union 

814 so the signature of functions decorated with `@agent.tool` is obscured. 

815 

816 Example: 

817 ```python 

818 from pydantic_ai import Agent, RunContext 

819 

820 agent = Agent('test', deps_type=int) 

821 

822 @agent.tool 

823 def foobar(ctx: RunContext[int], x: int) -> int: 

824 return ctx.deps + x 

825 

826 @agent.tool(retries=2) 

827 async def spam(ctx: RunContext[str], y: float) -> float: 

828 return ctx.deps + y 

829 

830 result = agent.run_sync('foobar', deps=1) 

831 print(result.data) 

832 #> {"foobar":1,"spam":1.0} 

833 ``` 

834 

835 Args: 

836 func: The tool function to register. 

837 retries: The number of retries to allow for this tool, defaults to the agent's default retries, 

838 which defaults to 1. 

839 prepare: custom method to prepare the tool definition for each step, return `None` to omit this 

840 tool from a given step. This is useful if you want to customise a tool at call time, 

841 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. 

842 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. 

843 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. 

844 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. 

845 """ 

846 if func is None: 

847 

848 def tool_decorator( 

849 func_: ToolFuncContext[AgentDepsT, ToolParams], 

850 ) -> ToolFuncContext[AgentDepsT, ToolParams]: 

851 # noinspection PyTypeChecker 

852 self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions) 

853 return func_ 

854 

855 return tool_decorator 

856 else: 

857 # noinspection PyTypeChecker 

858 self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions) 

859 return func 

860 

861 @overload 

862 def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ... 

863 

864 @overload 

865 def tool_plain( 

866 self, 

867 /, 

868 *, 

869 retries: int | None = None, 

870 prepare: ToolPrepareFunc[AgentDepsT] | None = None, 

871 docstring_format: DocstringFormat = 'auto', 

872 require_parameter_descriptions: bool = False, 

873 ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ... 

874 

875 def tool_plain( 

876 self, 

877 func: ToolFuncPlain[ToolParams] | None = None, 

878 /, 

879 *, 

880 retries: int | None = None, 

881 prepare: ToolPrepareFunc[AgentDepsT] | None = None, 

882 docstring_format: DocstringFormat = 'auto', 

883 require_parameter_descriptions: bool = False, 

884 ) -> Any: 

885 """Decorator to register a tool function which DOES NOT take `RunContext` as an argument. 

886 

887 Can decorate a sync or async functions. 

888 

889 The docstring is inspected to extract both the tool description and description of each parameter, 

890 [learn more](../tools.md#function-tools-and-schema). 

891 

892 We can't add overloads for every possible signature of tool, since the return type is a recursive union 

893 so the signature of functions decorated with `@agent.tool` is obscured. 

894 

895 Example: 

896 ```python 

897 from pydantic_ai import Agent, RunContext 

898 

899 agent = Agent('test') 

900 

901 @agent.tool 

902 def foobar(ctx: RunContext[int]) -> int: 

903 return 123 

904 

905 @agent.tool(retries=2) 

906 async def spam(ctx: RunContext[str]) -> float: 

907 return 3.14 

908 

909 result = agent.run_sync('foobar', deps=1) 

910 print(result.data) 

911 #> {"foobar":123,"spam":3.14} 

912 ``` 

913 

914 Args: 

915 func: The tool function to register. 

916 retries: The number of retries to allow for this tool, defaults to the agent's default retries, 

917 which defaults to 1. 

918 prepare: custom method to prepare the tool definition for each step, return `None` to omit this 

919 tool from a given step. This is useful if you want to customise a tool at call time, 

920 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. 

921 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. 

922 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. 

923 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. 

924 """ 

925 if func is None: 

926 

927 def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: 

928 # noinspection PyTypeChecker 

929 self._register_function( 

930 func_, False, retries, prepare, docstring_format, require_parameter_descriptions 

931 ) 

932 return func_ 

933 

934 return tool_decorator 

935 else: 

936 self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions) 

937 return func 

938 

939 def _register_function( 

940 self, 

941 func: ToolFuncEither[AgentDepsT, ToolParams], 

942 takes_ctx: bool, 

943 retries: int | None, 

944 prepare: ToolPrepareFunc[AgentDepsT] | None, 

945 docstring_format: DocstringFormat, 

946 require_parameter_descriptions: bool, 

947 ) -> None: 

948 """Private utility to register a function as a tool.""" 

949 retries_ = retries if retries is not None else self._default_retries 

950 tool = Tool[AgentDepsT]( 

951 func, 

952 takes_ctx=takes_ctx, 

953 max_retries=retries_, 

954 prepare=prepare, 

955 docstring_format=docstring_format, 

956 require_parameter_descriptions=require_parameter_descriptions, 

957 ) 

958 self._register_tool(tool) 

959 

960 def _register_tool(self, tool: Tool[AgentDepsT]) -> None: 

961 """Private utility to register a tool instance.""" 

962 if tool.max_retries is None: 

963 # noinspection PyTypeChecker 

964 tool = dataclasses.replace(tool, max_retries=self._default_retries) 

965 

966 if tool.name in self._function_tools: 

967 raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') 

968 

969 if self._result_schema and tool.name in self._result_schema.tools: 

970 raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}') 

971 

972 self._function_tools[tool.name] = tool 

973 

974 async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model: 

975 """Create a model configured for this agent. 

976 

977 Args: 

978 model: model to use for this run, required if `model` was not set when creating the agent. 

979 

980 Returns: 

981 The model used 

982 """ 

983 model_: models.Model 

984 if some_model := self._override_model: 

985 # we don't want `override()` to cover up errors from the model not being defined, hence this check 

986 if model is None and self.model is None: 

987 raise exceptions.UserError( 

988 '`model` must be set either when creating the agent or when calling it. ' 

989 '(Even when `override(model=...)` is customizing the model that will actually be called)' 

990 ) 

991 model_ = some_model.value 

992 elif model is not None: 

993 model_ = models.infer_model(model) 

994 elif self.model is not None: 

995 # noinspection PyTypeChecker 

996 model_ = self.model = models.infer_model(self.model) 

997 else: 

998 raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.') 

999 

1000 return model_ 

1001 

1002 def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T: 

1003 """Get deps for a run. 

1004 

1005 If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call. 

1006 

1007 We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. 

1008 """ 

1009 if some_deps := self._override_deps: 

1010 return some_deps.value 

1011 else: 

1012 return deps 

1013 

1014 def _infer_name(self, function_frame: FrameType | None) -> None: 

1015 """Infer the agent name from the call frame. 

1016 

1017 Usage should be `self._infer_name(inspect.currentframe())`. 

1018 """ 

1019 assert self.name is None, 'Name already set' 

1020 if function_frame is not None: # pragma: no branch 

1021 if parent_frame := function_frame.f_back: # pragma: no branch 

1022 for name, item in parent_frame.f_locals.items(): 

1023 if item is self: 

1024 self.name = name 

1025 return 

1026 if parent_frame.f_locals != parent_frame.f_globals: 1026 ↛ exitline 1026 didn't return from function '_infer_name' because the condition on line 1026 was always true

1027 # if we couldn't find the agent in locals and globals are a different dict, try globals 

1028 for name, item in parent_frame.f_globals.items(): 

1029 if item is self: 

1030 self.name = name 

1031 return 

1032 

1033 @property 

1034 @deprecated( 

1035 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None 

1036 ) 

1037 def last_run_messages(self) -> list[_messages.ModelMessage]: 

1038 raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') 

1039 

1040 def _build_graph( 

1041 self, result_type: type[RunResultDataT] | None 

1042 ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]: 

1043 return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type) 

1044 

1045 def _build_stream_graph( 

1046 self, result_type: type[RunResultDataT] | None 

1047 ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]: 

1048 return _agent_graph.build_agent_stream_graph(self.name, self._deps_type, result_type or self.result_type) 

1049 

1050 def _prepare_result_schema( 

1051 self, result_type: type[RunResultDataT] | None 

1052 ) -> _result.ResultSchema[RunResultDataT] | None: 

1053 if result_type is not None: 

1054 if self._result_validators: 

1055 raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators') 

1056 return _result.ResultSchema[result_type].build( 

1057 result_type, self._result_tool_name, self._result_tool_description 

1058 ) 

1059 else: 

1060 return self._result_schema # pyright: ignore[reportReturnType]