Coverage for pydantic_ai_slim/pydantic_ai/agent.py: 99.01%
232 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import dataclasses
5import inspect
6from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
8from types import FrameType
9from typing import Any, Callable, Generic, cast, final, overload
11import logfire_api
12from typing_extensions import TypeVar, deprecated
14from pydantic_graph import Graph, GraphRunContext, HistoryStep
15from pydantic_graph.nodes import End
17from . import (
18 _agent_graph,
19 _result,
20 _system_prompt,
21 _utils,
22 exceptions,
23 messages as _messages,
24 models,
25 result,
26 usage as _usage,
27)
28from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
29from .result import ResultDataT
30from .settings import ModelSettings, merge_model_settings
31from .tools import (
32 AgentDepsT,
33 DocstringFormat,
34 RunContext,
35 Tool,
36 ToolFuncContext,
37 ToolFuncEither,
38 ToolFuncPlain,
39 ToolParams,
40 ToolPrepareFunc,
41)
43__all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
45_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
47# while waiting for https://github.com/pydantic/logfire/issues/745
48try:
49 import logfire._internal.stack_info
50except ImportError:
51 pass
52else:
53 from pathlib import Path
55 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
57T = TypeVar('T')
58NoneType = type(None)
59RunResultDataT = TypeVar('RunResultDataT')
60"""Type variable for the result data of a run where `result_type` was customized on the run call."""
63@final
64@dataclasses.dataclass(init=False)
65class Agent(Generic[AgentDepsT, ResultDataT]):
66 """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
68 Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
69 and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT].
71 By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
73 Minimal usage example:
75 ```python
76 from pydantic_ai import Agent
78 agent = Agent('openai:gpt-4o')
79 result = agent.run_sync('What is the capital of France?')
80 print(result.data)
81 #> Paris
82 ```
83 """
85 # we use dataclass fields in order to conveniently know what attributes are available
86 model: models.Model | models.KnownModelName | None
87 """The default model configured for this agent."""
89 name: str | None
90 """The name of the agent, used for logging.
92 If `None`, we try to infer the agent name from the call frame when the agent is first run.
93 """
94 end_strategy: EndStrategy
95 """Strategy for handling tool calls when a final result is found."""
97 model_settings: ModelSettings | None
98 """Optional model request settings to use for this agents's runs, by default.
100 Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
101 be merged with this value, with the runtime argument taking priority.
102 """
104 result_type: type[ResultDataT] = dataclasses.field(repr=False)
105 """
106 The type of the result data, used to validate the result data, defaults to `str`.
107 """
109 _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
110 _result_tool_name: str = dataclasses.field(repr=False)
111 _result_tool_description: str | None = dataclasses.field(repr=False)
112 _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
113 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
114 _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
115 _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
116 _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
117 repr=False
118 )
119 _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
120 _default_retries: int = dataclasses.field(repr=False)
121 _max_result_retries: int = dataclasses.field(repr=False)
122 _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
123 _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
125 def __init__(
126 self,
127 model: models.Model | models.KnownModelName | None = None,
128 *,
129 result_type: type[ResultDataT] = str,
130 system_prompt: str | Sequence[str] = (),
131 deps_type: type[AgentDepsT] = NoneType,
132 name: str | None = None,
133 model_settings: ModelSettings | None = None,
134 retries: int = 1,
135 result_tool_name: str = 'final_result',
136 result_tool_description: str | None = None,
137 result_retries: int | None = None,
138 tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
139 defer_model_check: bool = False,
140 end_strategy: EndStrategy = 'early',
141 ):
142 """Create an agent.
144 Args:
145 model: The default model to use for this agent, if not provide,
146 you must provide the model when calling it.
147 result_type: The type of the result data, used to validate the result data, defaults to `str`.
148 system_prompt: Static system prompts to use for this agent, you can also register system
149 prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
150 deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
151 parameterize the agent, and therefore get the best out of static type checking.
152 If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
153 or add a type hint `: Agent[None, <return type>]`.
154 name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
155 when the agent is first run.
156 model_settings: Optional model request settings to use for this agent's runs, by default.
157 retries: The default number of retries to allow before raising an error.
158 result_tool_name: The name of the tool to use for the final result.
159 result_tool_description: The description of the final result tool.
160 result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
161 tools: Tools to register with the agent, you can also register tools via the decorators
162 [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
163 defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
164 it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
165 which checks for the necessary environment variables. Set this to `false`
166 to defer the evaluation until the first run. Useful if you want to
167 [override the model][pydantic_ai.Agent.override] for testing.
168 end_strategy: Strategy for handling tool calls that are requested alongside a final result.
169 See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
170 """
171 if model is None or defer_model_check:
172 self.model = model
173 else:
174 self.model = models.infer_model(model)
176 self.end_strategy = end_strategy
177 self.name = name
178 self.model_settings = model_settings
179 self.result_type = result_type
181 self._deps_type = deps_type
183 self._result_tool_name = result_tool_name
184 self._result_tool_description = result_tool_description
185 self._result_schema: _result.ResultSchema[ResultDataT] | None = _result.ResultSchema[result_type].build(
186 result_type, result_tool_name, result_tool_description
187 )
188 self._result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = []
190 self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
191 self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []
192 self._system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = {}
194 self._function_tools: dict[str, Tool[AgentDepsT]] = {}
196 self._default_retries = retries
197 self._max_result_retries = result_retries if result_retries is not None else retries
198 for tool in tools:
199 if isinstance(tool, Tool):
200 self._register_tool(tool)
201 else:
202 self._register_tool(Tool(tool))
204 @overload
205 async def run(
206 self,
207 user_prompt: str,
208 *,
209 result_type: None = None,
210 message_history: list[_messages.ModelMessage] | None = None,
211 model: models.Model | models.KnownModelName | None = None,
212 deps: AgentDepsT = None,
213 model_settings: ModelSettings | None = None,
214 usage_limits: _usage.UsageLimits | None = None,
215 usage: _usage.Usage | None = None,
216 infer_name: bool = True,
217 ) -> result.RunResult[ResultDataT]: ...
219 @overload
220 async def run(
221 self,
222 user_prompt: str,
223 *,
224 result_type: type[RunResultDataT],
225 message_history: list[_messages.ModelMessage] | None = None,
226 model: models.Model | models.KnownModelName | None = None,
227 deps: AgentDepsT = None,
228 model_settings: ModelSettings | None = None,
229 usage_limits: _usage.UsageLimits | None = None,
230 usage: _usage.Usage | None = None,
231 infer_name: bool = True,
232 ) -> result.RunResult[RunResultDataT]: ...
234 async def run(
235 self,
236 user_prompt: str,
237 *,
238 message_history: list[_messages.ModelMessage] | None = None,
239 model: models.Model | models.KnownModelName | None = None,
240 deps: AgentDepsT = None,
241 model_settings: ModelSettings | None = None,
242 usage_limits: _usage.UsageLimits | None = None,
243 usage: _usage.Usage | None = None,
244 result_type: type[RunResultDataT] | None = None,
245 infer_name: bool = True,
246 ) -> result.RunResult[Any]:
247 """Run the agent with a user prompt in async mode.
249 Example:
250 ```python
251 from pydantic_ai import Agent
253 agent = Agent('openai:gpt-4o')
255 async def main():
256 result = await agent.run('What is the capital of France?')
257 print(result.data)
258 #> Paris
259 ```
261 Args:
262 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
263 result validators since result validators would expect an argument that matches the agent's result type.
264 user_prompt: User input to start/continue the conversation.
265 message_history: History of the conversation so far.
266 model: Optional model to use for this run, required if `model` was not set when creating the agent.
267 deps: Optional dependencies to use for this run.
268 model_settings: Optional settings to use for this model's request.
269 usage_limits: Optional limits on model request count or token usage.
270 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
271 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
273 Returns:
274 The result of the run.
275 """
276 if infer_name and self.name is None:
277 self._infer_name(inspect.currentframe())
278 model_used = await self._get_model(model)
280 deps = self._get_deps(deps)
281 new_message_index = len(message_history) if message_history else 0
282 result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
284 # Build the graph
285 graph = _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)
287 # Build the initial state
288 state = _agent_graph.GraphAgentState(
289 message_history=message_history[:] if message_history else [],
290 usage=usage or _usage.Usage(),
291 retries=0,
292 run_step=0,
293 )
295 # We consider it a user error if a user tries to restrict the result type while having a result validator that
296 # may change the result type from the restricted type to something else. Therefore, we consider the following
297 # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
298 result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
300 # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
301 # runs. Requires some changes to `Tool` to make them copyable though.
302 for v in self._function_tools.values():
303 v.current_retry = 0
305 model_settings = merge_model_settings(self.model_settings, model_settings)
306 usage_limits = usage_limits or _usage.UsageLimits()
308 with _logfire.span(
309 '{agent_name} run {prompt=}',
310 prompt=user_prompt,
311 agent=self,
312 model_name=model_used.name() if model_used else 'no-model',
313 agent_name=self.name or 'agent',
314 ) as run_span:
315 # Build the deps object for the graph
316 graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
317 user_deps=deps,
318 prompt=user_prompt,
319 new_message_index=new_message_index,
320 model=model_used,
321 model_settings=model_settings,
322 usage_limits=usage_limits,
323 max_result_retries=self._max_result_retries,
324 end_strategy=self.end_strategy,
325 result_schema=result_schema,
326 result_tools=self._result_schema.tool_defs() if self._result_schema else [],
327 result_validators=result_validators,
328 function_tools=self._function_tools,
329 run_span=run_span,
330 )
332 start_node = _agent_graph.UserPromptNode[AgentDepsT](
333 user_prompt=user_prompt,
334 system_prompts=self._system_prompts,
335 system_prompt_functions=self._system_prompt_functions,
336 system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
337 )
339 # Actually run
340 end_result, _ = await graph.run(
341 start_node,
342 state=state,
343 deps=graph_deps,
344 infer_name=False,
345 )
347 # Build final run result
348 # We don't do any advanced checking if the data is actually from a final result or not
349 return result.RunResult(
350 state.message_history,
351 new_message_index,
352 end_result.data,
353 end_result.tool_name,
354 state.usage,
355 )
357 @overload
358 def run_sync(
359 self,
360 user_prompt: str,
361 *,
362 message_history: list[_messages.ModelMessage] | None = None,
363 model: models.Model | models.KnownModelName | None = None,
364 deps: AgentDepsT = None,
365 model_settings: ModelSettings | None = None,
366 usage_limits: _usage.UsageLimits | None = None,
367 usage: _usage.Usage | None = None,
368 infer_name: bool = True,
369 ) -> result.RunResult[ResultDataT]: ...
371 @overload
372 def run_sync(
373 self,
374 user_prompt: str,
375 *,
376 result_type: type[RunResultDataT] | None,
377 message_history: list[_messages.ModelMessage] | None = None,
378 model: models.Model | models.KnownModelName | None = None,
379 deps: AgentDepsT = None,
380 model_settings: ModelSettings | None = None,
381 usage_limits: _usage.UsageLimits | None = None,
382 usage: _usage.Usage | None = None,
383 infer_name: bool = True,
384 ) -> result.RunResult[RunResultDataT]: ...
386 def run_sync(
387 self,
388 user_prompt: str,
389 *,
390 result_type: type[RunResultDataT] | None = None,
391 message_history: list[_messages.ModelMessage] | None = None,
392 model: models.Model | models.KnownModelName | None = None,
393 deps: AgentDepsT = None,
394 model_settings: ModelSettings | None = None,
395 usage_limits: _usage.UsageLimits | None = None,
396 usage: _usage.Usage | None = None,
397 infer_name: bool = True,
398 ) -> result.RunResult[Any]:
399 """Run the agent with a user prompt synchronously.
401 This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
402 You therefore can't use this method inside async code or if there's an active event loop.
404 Example:
405 ```python
406 from pydantic_ai import Agent
408 agent = Agent('openai:gpt-4o')
410 result_sync = agent.run_sync('What is the capital of Italy?')
411 print(result_sync.data)
412 #> Rome
413 ```
415 Args:
416 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
417 result validators since result validators would expect an argument that matches the agent's result type.
418 user_prompt: User input to start/continue the conversation.
419 message_history: History of the conversation so far.
420 model: Optional model to use for this run, required if `model` was not set when creating the agent.
421 deps: Optional dependencies to use for this run.
422 model_settings: Optional settings to use for this model's request.
423 usage_limits: Optional limits on model request count or token usage.
424 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
425 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
427 Returns:
428 The result of the run.
429 """
430 if infer_name and self.name is None:
431 self._infer_name(inspect.currentframe())
432 return asyncio.get_event_loop().run_until_complete(
433 self.run(
434 user_prompt,
435 result_type=result_type,
436 message_history=message_history,
437 model=model,
438 deps=deps,
439 model_settings=model_settings,
440 usage_limits=usage_limits,
441 usage=usage,
442 infer_name=False,
443 )
444 )
446 @overload
447 def run_stream(
448 self,
449 user_prompt: str,
450 *,
451 result_type: None = None,
452 message_history: list[_messages.ModelMessage] | None = None,
453 model: models.Model | models.KnownModelName | None = None,
454 deps: AgentDepsT = None,
455 model_settings: ModelSettings | None = None,
456 usage_limits: _usage.UsageLimits | None = None,
457 usage: _usage.Usage | None = None,
458 infer_name: bool = True,
459 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ...
461 @overload
462 def run_stream(
463 self,
464 user_prompt: str,
465 *,
466 result_type: type[RunResultDataT],
467 message_history: list[_messages.ModelMessage] | None = None,
468 model: models.Model | models.KnownModelName | None = None,
469 deps: AgentDepsT = None,
470 model_settings: ModelSettings | None = None,
471 usage_limits: _usage.UsageLimits | None = None,
472 usage: _usage.Usage | None = None,
473 infer_name: bool = True,
474 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
476 @asynccontextmanager
477 async def run_stream(
478 self,
479 user_prompt: str,
480 *,
481 result_type: type[RunResultDataT] | None = None,
482 message_history: list[_messages.ModelMessage] | None = None,
483 model: models.Model | models.KnownModelName | None = None,
484 deps: AgentDepsT = None,
485 model_settings: ModelSettings | None = None,
486 usage_limits: _usage.UsageLimits | None = None,
487 usage: _usage.Usage | None = None,
488 infer_name: bool = True,
489 ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
490 """Run the agent with a user prompt in async mode, returning a streamed response.
492 Example:
493 ```python
494 from pydantic_ai import Agent
496 agent = Agent('openai:gpt-4o')
498 async def main():
499 async with agent.run_stream('What is the capital of the UK?') as response:
500 print(await response.get_data())
501 #> London
502 ```
504 Args:
505 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
506 result validators since result validators would expect an argument that matches the agent's result type.
507 user_prompt: User input to start/continue the conversation.
508 message_history: History of the conversation so far.
509 model: Optional model to use for this run, required if `model` was not set when creating the agent.
510 deps: Optional dependencies to use for this run.
511 model_settings: Optional settings to use for this model's request.
512 usage_limits: Optional limits on model request count or token usage.
513 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
514 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
516 Returns:
517 The result of the run.
518 """
519 if infer_name and self.name is None:
520 # f_back because `asynccontextmanager` adds one frame
521 if frame := inspect.currentframe(): # pragma: no branch
522 self._infer_name(frame.f_back)
523 model_used = await self._get_model(model)
525 deps = self._get_deps(deps)
526 new_message_index = len(message_history) if message_history else 0
527 result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
529 # Build the graph
530 graph = self._build_stream_graph(result_type)
532 # Build the initial state
533 graph_state = _agent_graph.GraphAgentState(
534 message_history=message_history[:] if message_history else [],
535 usage=usage or _usage.Usage(),
536 retries=0,
537 run_step=0,
538 )
540 # We consider it a user error if a user tries to restrict the result type while having a result validator that
541 # may change the result type from the restricted type to something else. Therefore, we consider the following
542 # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
543 result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
545 # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
546 # runs. Requires some changes to `Tool` to make them copyable though.
547 for v in self._function_tools.values():
548 v.current_retry = 0
550 model_settings = merge_model_settings(self.model_settings, model_settings)
551 usage_limits = usage_limits or _usage.UsageLimits()
553 with _logfire.span(
554 '{agent_name} run stream {prompt=}',
555 prompt=user_prompt,
556 agent=self,
557 model_name=model_used.name(),
558 agent_name=self.name or 'agent',
559 ) as run_span:
560 # Build the deps object for the graph
561 graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
562 user_deps=deps,
563 prompt=user_prompt,
564 new_message_index=new_message_index,
565 model=model_used,
566 model_settings=model_settings,
567 usage_limits=usage_limits,
568 max_result_retries=self._max_result_retries,
569 end_strategy=self.end_strategy,
570 result_schema=result_schema,
571 result_tools=self._result_schema.tool_defs() if self._result_schema else [],
572 result_validators=result_validators,
573 function_tools=self._function_tools,
574 run_span=run_span,
575 )
577 start_node = _agent_graph.StreamUserPromptNode[AgentDepsT](
578 user_prompt=user_prompt,
579 system_prompts=self._system_prompts,
580 system_prompt_functions=self._system_prompt_functions,
581 system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
582 )
584 # Actually run
585 node = start_node
586 history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = []
587 while True:
588 if isinstance(node, _agent_graph.StreamModelRequestNode):
589 node = cast(
590 _agent_graph.StreamModelRequestNode[
591 AgentDepsT, result.StreamedRunResult[AgentDepsT, RunResultDataT]
592 ],
593 node,
594 )
595 async with node.run_to_result(GraphRunContext(graph_state, graph_deps)) as r:
596 if isinstance(r, End):
597 yield r.data
598 break
599 assert not isinstance(node, End) # the previous line should be hit first
600 node = await graph.next(
601 node,
602 history,
603 state=graph_state,
604 deps=graph_deps,
605 infer_name=False,
606 )
608 @contextmanager
609 def override(
610 self,
611 *,
612 deps: AgentDepsT | _utils.Unset = _utils.UNSET,
613 model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
614 ) -> Iterator[None]:
615 """Context manager to temporarily override agent dependencies and model.
617 This is particularly useful when testing.
618 You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
620 Args:
621 deps: The dependencies to use instead of the dependencies passed to the agent run.
622 model: The model to use instead of the model passed to the agent run.
623 """
624 if _utils.is_set(deps):
625 override_deps_before = self._override_deps
626 self._override_deps = _utils.Some(deps)
627 else:
628 override_deps_before = _utils.UNSET
630 # noinspection PyTypeChecker
631 if _utils.is_set(model):
632 override_model_before = self._override_model
633 # noinspection PyTypeChecker
634 self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
635 else:
636 override_model_before = _utils.UNSET
638 try:
639 yield
640 finally:
641 if _utils.is_set(override_deps_before):
642 self._override_deps = override_deps_before
643 if _utils.is_set(override_model_before):
644 self._override_model = override_model_before
646 @overload
647 def system_prompt(
648 self, func: Callable[[RunContext[AgentDepsT]], str], /
649 ) -> Callable[[RunContext[AgentDepsT]], str]: ...
651 @overload
652 def system_prompt(
653 self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
654 ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
656 @overload
657 def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
659 @overload
660 def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
662 @overload
663 def system_prompt(
664 self, /, *, dynamic: bool = False
665 ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
667 def system_prompt(
668 self,
669 func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
670 /,
671 *,
672 dynamic: bool = False,
673 ) -> (
674 Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
675 | _system_prompt.SystemPromptFunc[AgentDepsT]
676 ):
677 """Decorator to register a system prompt function.
679 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
680 Can decorate a sync or async functions.
682 The decorator can be used either bare (`agent.system_prompt`) or as a function call
683 (`agent.system_prompt(...)`), see the examples below.
685 Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
686 the type of the function, see `tests/typed_agent.py` for tests.
688 Args:
689 func: The function to decorate
690 dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
691 see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
693 Example:
694 ```python
695 from pydantic_ai import Agent, RunContext
697 agent = Agent('test', deps_type=str)
699 @agent.system_prompt
700 def simple_system_prompt() -> str:
701 return 'foobar'
703 @agent.system_prompt(dynamic=True)
704 async def async_system_prompt(ctx: RunContext[str]) -> str:
705 return f'{ctx.deps} is the best'
706 ```
707 """
708 if func is None:
710 def decorator(
711 func_: _system_prompt.SystemPromptFunc[AgentDepsT],
712 ) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
713 runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
714 self._system_prompt_functions.append(runner)
715 if dynamic: 715 ↛ 717line 715 didn't jump to line 717 because the condition on line 715 was always true
716 self._system_prompt_dynamic_functions[func_.__qualname__] = runner
717 return func_
719 return decorator
720 else:
721 assert not dynamic, "dynamic can't be True in this case"
722 self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
723 return func
725 @overload
726 def result_validator(
727 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], /
728 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ...
730 @overload
731 def result_validator(
732 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], /
733 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ...
735 @overload
736 def result_validator(
737 self, func: Callable[[ResultDataT], ResultDataT], /
738 ) -> Callable[[ResultDataT], ResultDataT]: ...
740 @overload
741 def result_validator(
742 self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], /
743 ) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ...
745 def result_validator(
746 self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], /
747 ) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]:
748 """Decorator to register a result validator function.
750 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
751 Can decorate a sync or async functions.
753 Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
754 the type of the function, see `tests/typed_agent.py` for tests.
756 Example:
757 ```python
758 from pydantic_ai import Agent, ModelRetry, RunContext
760 agent = Agent('test', deps_type=str)
762 @agent.result_validator
763 def result_validator_simple(data: str) -> str:
764 if 'wrong' in data:
765 raise ModelRetry('wrong response')
766 return data
768 @agent.result_validator
769 async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
770 if ctx.deps in data:
771 raise ModelRetry('wrong response')
772 return data
774 result = agent.run_sync('foobar', deps='spam')
775 print(result.data)
776 #> success (no tool calls)
777 ```
778 """
779 self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func))
780 return func
782 @overload
783 def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
785 @overload
786 def tool(
787 self,
788 /,
789 *,
790 retries: int | None = None,
791 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
792 docstring_format: DocstringFormat = 'auto',
793 require_parameter_descriptions: bool = False,
794 ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
796 def tool(
797 self,
798 func: ToolFuncContext[AgentDepsT, ToolParams] | None = None,
799 /,
800 *,
801 retries: int | None = None,
802 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
803 docstring_format: DocstringFormat = 'auto',
804 require_parameter_descriptions: bool = False,
805 ) -> Any:
806 """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
808 Can decorate a sync or async functions.
810 The docstring is inspected to extract both the tool description and description of each parameter,
811 [learn more](../tools.md#function-tools-and-schema).
813 We can't add overloads for every possible signature of tool, since the return type is a recursive union
814 so the signature of functions decorated with `@agent.tool` is obscured.
816 Example:
817 ```python
818 from pydantic_ai import Agent, RunContext
820 agent = Agent('test', deps_type=int)
822 @agent.tool
823 def foobar(ctx: RunContext[int], x: int) -> int:
824 return ctx.deps + x
826 @agent.tool(retries=2)
827 async def spam(ctx: RunContext[str], y: float) -> float:
828 return ctx.deps + y
830 result = agent.run_sync('foobar', deps=1)
831 print(result.data)
832 #> {"foobar":1,"spam":1.0}
833 ```
835 Args:
836 func: The tool function to register.
837 retries: The number of retries to allow for this tool, defaults to the agent's default retries,
838 which defaults to 1.
839 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
840 tool from a given step. This is useful if you want to customise a tool at call time,
841 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
842 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
843 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
844 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
845 """
846 if func is None:
848 def tool_decorator(
849 func_: ToolFuncContext[AgentDepsT, ToolParams],
850 ) -> ToolFuncContext[AgentDepsT, ToolParams]:
851 # noinspection PyTypeChecker
852 self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
853 return func_
855 return tool_decorator
856 else:
857 # noinspection PyTypeChecker
858 self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
859 return func
861 @overload
862 def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
864 @overload
865 def tool_plain(
866 self,
867 /,
868 *,
869 retries: int | None = None,
870 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
871 docstring_format: DocstringFormat = 'auto',
872 require_parameter_descriptions: bool = False,
873 ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
875 def tool_plain(
876 self,
877 func: ToolFuncPlain[ToolParams] | None = None,
878 /,
879 *,
880 retries: int | None = None,
881 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
882 docstring_format: DocstringFormat = 'auto',
883 require_parameter_descriptions: bool = False,
884 ) -> Any:
885 """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
887 Can decorate a sync or async functions.
889 The docstring is inspected to extract both the tool description and description of each parameter,
890 [learn more](../tools.md#function-tools-and-schema).
892 We can't add overloads for every possible signature of tool, since the return type is a recursive union
893 so the signature of functions decorated with `@agent.tool` is obscured.
895 Example:
896 ```python
897 from pydantic_ai import Agent, RunContext
899 agent = Agent('test')
901 @agent.tool
902 def foobar(ctx: RunContext[int]) -> int:
903 return 123
905 @agent.tool(retries=2)
906 async def spam(ctx: RunContext[str]) -> float:
907 return 3.14
909 result = agent.run_sync('foobar', deps=1)
910 print(result.data)
911 #> {"foobar":123,"spam":3.14}
912 ```
914 Args:
915 func: The tool function to register.
916 retries: The number of retries to allow for this tool, defaults to the agent's default retries,
917 which defaults to 1.
918 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
919 tool from a given step. This is useful if you want to customise a tool at call time,
920 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
921 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
922 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
923 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
924 """
925 if func is None:
927 def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
928 # noinspection PyTypeChecker
929 self._register_function(
930 func_, False, retries, prepare, docstring_format, require_parameter_descriptions
931 )
932 return func_
934 return tool_decorator
935 else:
936 self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
937 return func
939 def _register_function(
940 self,
941 func: ToolFuncEither[AgentDepsT, ToolParams],
942 takes_ctx: bool,
943 retries: int | None,
944 prepare: ToolPrepareFunc[AgentDepsT] | None,
945 docstring_format: DocstringFormat,
946 require_parameter_descriptions: bool,
947 ) -> None:
948 """Private utility to register a function as a tool."""
949 retries_ = retries if retries is not None else self._default_retries
950 tool = Tool[AgentDepsT](
951 func,
952 takes_ctx=takes_ctx,
953 max_retries=retries_,
954 prepare=prepare,
955 docstring_format=docstring_format,
956 require_parameter_descriptions=require_parameter_descriptions,
957 )
958 self._register_tool(tool)
960 def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
961 """Private utility to register a tool instance."""
962 if tool.max_retries is None:
963 # noinspection PyTypeChecker
964 tool = dataclasses.replace(tool, max_retries=self._default_retries)
966 if tool.name in self._function_tools:
967 raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
969 if self._result_schema and tool.name in self._result_schema.tools:
970 raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
972 self._function_tools[tool.name] = tool
974 async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
975 """Create a model configured for this agent.
977 Args:
978 model: model to use for this run, required if `model` was not set when creating the agent.
980 Returns:
981 The model used
982 """
983 model_: models.Model
984 if some_model := self._override_model:
985 # we don't want `override()` to cover up errors from the model not being defined, hence this check
986 if model is None and self.model is None:
987 raise exceptions.UserError(
988 '`model` must be set either when creating the agent or when calling it. '
989 '(Even when `override(model=...)` is customizing the model that will actually be called)'
990 )
991 model_ = some_model.value
992 elif model is not None:
993 model_ = models.infer_model(model)
994 elif self.model is not None:
995 # noinspection PyTypeChecker
996 model_ = self.model = models.infer_model(self.model)
997 else:
998 raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1000 return model_
1002 def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
1003 """Get deps for a run.
1005 If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
1007 We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
1008 """
1009 if some_deps := self._override_deps:
1010 return some_deps.value
1011 else:
1012 return deps
1014 def _infer_name(self, function_frame: FrameType | None) -> None:
1015 """Infer the agent name from the call frame.
1017 Usage should be `self._infer_name(inspect.currentframe())`.
1018 """
1019 assert self.name is None, 'Name already set'
1020 if function_frame is not None: # pragma: no branch
1021 if parent_frame := function_frame.f_back: # pragma: no branch
1022 for name, item in parent_frame.f_locals.items():
1023 if item is self:
1024 self.name = name
1025 return
1026 if parent_frame.f_locals != parent_frame.f_globals: 1026 ↛ exitline 1026 didn't return from function '_infer_name' because the condition on line 1026 was always true
1027 # if we couldn't find the agent in locals and globals are a different dict, try globals
1028 for name, item in parent_frame.f_globals.items():
1029 if item is self:
1030 self.name = name
1031 return
1033 @property
1034 @deprecated(
1035 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
1036 )
1037 def last_run_messages(self) -> list[_messages.ModelMessage]:
1038 raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1040 def _build_graph(
1041 self, result_type: type[RunResultDataT] | None
1042 ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
1043 return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)
1045 def _build_stream_graph(
1046 self, result_type: type[RunResultDataT] | None
1047 ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
1048 return _agent_graph.build_agent_stream_graph(self.name, self._deps_type, result_type or self.result_type)
1050 def _prepare_result_schema(
1051 self, result_type: type[RunResultDataT] | None
1052 ) -> _result.ResultSchema[RunResultDataT] | None:
1053 if result_type is not None:
1054 if self._result_validators:
1055 raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
1056 return _result.ResultSchema[result_type].build(
1057 result_type, self._result_tool_name, self._result_tool_description
1058 )
1059 else:
1060 return self._result_schema # pyright: ignore[reportReturnType]