Coverage for pydantic_ai_slim/pydantic_ai/agent.py: 97.29%
362 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import dataclasses
4import inspect
5from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
6from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
7from copy import deepcopy
8from types import FrameType
9from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
11from opentelemetry.trace import NoOpTracer, use_span
12from pydantic.json_schema import GenerateJsonSchema
13from typing_extensions import TypeGuard, TypeVar, deprecated
15from pydantic_graph import End, Graph, GraphRun, GraphRunContext
16from pydantic_graph._utils import get_event_loop
18from . import (
19 _agent_graph,
20 _result,
21 _system_prompt,
22 _utils,
23 exceptions,
24 messages as _messages,
25 models,
26 result,
27 usage as _usage,
28)
29from .models.instrumented import InstrumentationSettings, InstrumentedModel
30from .result import FinalResult, ResultDataT, StreamedRunResult
31from .settings import ModelSettings, merge_model_settings
32from .tools import (
33 AgentDepsT,
34 DocstringFormat,
35 GenerateToolJsonSchema,
36 RunContext,
37 Tool,
38 ToolFuncContext,
39 ToolFuncEither,
40 ToolFuncPlain,
41 ToolParams,
42 ToolPrepareFunc,
43)
45# Re-exporting like this improves auto-import behavior in PyCharm
46capture_run_messages = _agent_graph.capture_run_messages
47EndStrategy = _agent_graph.EndStrategy
48CallToolsNode = _agent_graph.CallToolsNode
49ModelRequestNode = _agent_graph.ModelRequestNode
50UserPromptNode = _agent_graph.UserPromptNode
52if TYPE_CHECKING:
53 from pydantic_ai.mcp import MCPServer
55__all__ = (
56 'Agent',
57 'AgentRun',
58 'AgentRunResult',
59 'capture_run_messages',
60 'EndStrategy',
61 'CallToolsNode',
62 'ModelRequestNode',
63 'UserPromptNode',
64 'InstrumentationSettings',
65)
68T = TypeVar('T')
69S = TypeVar('S')
70NoneType = type(None)
71RunResultDataT = TypeVar('RunResultDataT')
72"""Type variable for the result data of a run where `result_type` was customized on the run call."""
75@final
76@dataclasses.dataclass(init=False)
77class Agent(Generic[AgentDepsT, ResultDataT]):
78 """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
80 Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
81 and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT].
83 By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
85 Minimal usage example:
87 ```python
88 from pydantic_ai import Agent
90 agent = Agent('openai:gpt-4o')
91 result = agent.run_sync('What is the capital of France?')
92 print(result.data)
93 #> Paris
94 ```
95 """
97 # we use dataclass fields in order to conveniently know what attributes are available
98 model: models.Model | models.KnownModelName | None
99 """The default model configured for this agent."""
101 name: str | None
102 """The name of the agent, used for logging.
104 If `None`, we try to infer the agent name from the call frame when the agent is first run.
105 """
106 end_strategy: EndStrategy
107 """Strategy for handling tool calls when a final result is found."""
109 model_settings: ModelSettings | None
110 """Optional model request settings to use for this agents's runs, by default.
112 Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
113 be merged with this value, with the runtime argument taking priority.
114 """
116 result_type: type[ResultDataT] = dataclasses.field(repr=False)
117 """
118 The type of the result data, used to validate the result data, defaults to `str`.
119 """
121 instrument: InstrumentationSettings | bool | None
122 """Options to automatically instrument with OpenTelemetry."""
124 _instrument_default: ClassVar[InstrumentationSettings | bool] = False
126 _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
127 _result_tool_name: str = dataclasses.field(repr=False)
128 _result_tool_description: str | None = dataclasses.field(repr=False)
129 _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
130 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
131 _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
132 _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
133 _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
134 repr=False
135 )
136 _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
137 _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
138 _default_retries: int = dataclasses.field(repr=False)
139 _max_result_retries: int = dataclasses.field(repr=False)
140 _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
141 _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
143 def __init__(
144 self,
145 model: models.Model | models.KnownModelName | None = None,
146 *,
147 result_type: type[ResultDataT] = str,
148 system_prompt: str | Sequence[str] = (),
149 deps_type: type[AgentDepsT] = NoneType,
150 name: str | None = None,
151 model_settings: ModelSettings | None = None,
152 retries: int = 1,
153 result_tool_name: str = 'final_result',
154 result_tool_description: str | None = None,
155 result_retries: int | None = None,
156 tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
157 mcp_servers: Sequence[MCPServer] = (),
158 defer_model_check: bool = False,
159 end_strategy: EndStrategy = 'early',
160 instrument: InstrumentationSettings | bool | None = None,
161 ):
162 """Create an agent.
164 Args:
165 model: The default model to use for this agent, if not provide,
166 you must provide the model when calling it.
167 result_type: The type of the result data, used to validate the result data, defaults to `str`.
168 system_prompt: Static system prompts to use for this agent, you can also register system
169 prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
170 deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
171 parameterize the agent, and therefore get the best out of static type checking.
172 If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
173 or add a type hint `: Agent[None, <return type>]`.
174 name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
175 when the agent is first run.
176 model_settings: Optional model request settings to use for this agent's runs, by default.
177 retries: The default number of retries to allow before raising an error.
178 result_tool_name: The name of the tool to use for the final result.
179 result_tool_description: The description of the final result tool.
180 result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
181 tools: Tools to register with the agent, you can also register tools via the decorators
182 [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
183 mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
184 for each server you want the agent to connect to.
185 defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
186 it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
187 which checks for the necessary environment variables. Set this to `false`
188 to defer the evaluation until the first run. Useful if you want to
189 [override the model][pydantic_ai.Agent.override] for testing.
190 end_strategy: Strategy for handling tool calls that are requested alongside a final result.
191 See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
192 instrument: Set to True to automatically instrument with OpenTelemetry,
193 which will use Logfire if it's configured.
194 Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize.
195 If this isn't set, then the last value set by
196 [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
197 will be used, which defaults to False.
198 See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
199 """
200 if model is None or defer_model_check:
201 self.model = model
202 else:
203 self.model = models.infer_model(model)
205 self.end_strategy = end_strategy
206 self.name = name
207 self.model_settings = model_settings
208 self.result_type = result_type
209 self.instrument = instrument
211 self._deps_type = deps_type
213 self._result_tool_name = result_tool_name
214 self._result_tool_description = result_tool_description
215 self._result_schema: _result.ResultSchema[ResultDataT] | None = _result.ResultSchema[result_type].build(
216 result_type, result_tool_name, result_tool_description
217 )
218 self._result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = []
220 self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
221 self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []
222 self._system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = {}
224 self._function_tools: dict[str, Tool[AgentDepsT]] = {}
226 self._default_retries = retries
227 self._max_result_retries = result_retries if result_retries is not None else retries
228 self._mcp_servers = mcp_servers
229 for tool in tools:
230 if isinstance(tool, Tool):
231 self._register_tool(tool)
232 else:
233 self._register_tool(Tool(tool))
235 @staticmethod
236 def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
237 """Set the instrumentation options for all agents where `instrument` is not set."""
238 Agent._instrument_default = instrument
240 @overload
241 async def run(
242 self,
243 user_prompt: str | Sequence[_messages.UserContent],
244 *,
245 result_type: None = None,
246 message_history: list[_messages.ModelMessage] | None = None,
247 model: models.Model | models.KnownModelName | None = None,
248 deps: AgentDepsT = None,
249 model_settings: ModelSettings | None = None,
250 usage_limits: _usage.UsageLimits | None = None,
251 usage: _usage.Usage | None = None,
252 infer_name: bool = True,
253 ) -> AgentRunResult[ResultDataT]: ...
255 @overload
256 async def run(
257 self,
258 user_prompt: str | Sequence[_messages.UserContent],
259 *,
260 result_type: type[RunResultDataT],
261 message_history: list[_messages.ModelMessage] | None = None,
262 model: models.Model | models.KnownModelName | None = None,
263 deps: AgentDepsT = None,
264 model_settings: ModelSettings | None = None,
265 usage_limits: _usage.UsageLimits | None = None,
266 usage: _usage.Usage | None = None,
267 infer_name: bool = True,
268 ) -> AgentRunResult[RunResultDataT]: ...
270 async def run(
271 self,
272 user_prompt: str | Sequence[_messages.UserContent],
273 *,
274 result_type: type[RunResultDataT] | None = None,
275 message_history: list[_messages.ModelMessage] | None = None,
276 model: models.Model | models.KnownModelName | None = None,
277 deps: AgentDepsT = None,
278 model_settings: ModelSettings | None = None,
279 usage_limits: _usage.UsageLimits | None = None,
280 usage: _usage.Usage | None = None,
281 infer_name: bool = True,
282 ) -> AgentRunResult[Any]:
283 """Run the agent with a user prompt in async mode.
285 This method builds an internal agent graph (using system prompts, tools and result schemas) and then
286 runs the graph to completion. The result of the run is returned.
288 Example:
289 ```python
290 from pydantic_ai import Agent
292 agent = Agent('openai:gpt-4o')
294 async def main():
295 agent_run = await agent.run('What is the capital of France?')
296 print(agent_run.data)
297 #> Paris
298 ```
300 Args:
301 user_prompt: User input to start/continue the conversation.
302 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
303 result validators since result validators would expect an argument that matches the agent's result type.
304 message_history: History of the conversation so far.
305 model: Optional model to use for this run, required if `model` was not set when creating the agent.
306 deps: Optional dependencies to use for this run.
307 model_settings: Optional settings to use for this model's request.
308 usage_limits: Optional limits on model request count or token usage.
309 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
310 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
312 Returns:
313 The result of the run.
314 """
315 if infer_name and self.name is None:
316 self._infer_name(inspect.currentframe())
317 async with self.iter(
318 user_prompt=user_prompt,
319 result_type=result_type,
320 message_history=message_history,
321 model=model,
322 deps=deps,
323 model_settings=model_settings,
324 usage_limits=usage_limits,
325 usage=usage,
326 ) as agent_run:
327 async for _ in agent_run:
328 pass
330 assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
331 return final_result
333 @asynccontextmanager
334 async def iter(
335 self,
336 user_prompt: str | Sequence[_messages.UserContent],
337 *,
338 result_type: type[RunResultDataT] | None = None,
339 message_history: list[_messages.ModelMessage] | None = None,
340 model: models.Model | models.KnownModelName | None = None,
341 deps: AgentDepsT = None,
342 model_settings: ModelSettings | None = None,
343 usage_limits: _usage.UsageLimits | None = None,
344 usage: _usage.Usage | None = None,
345 infer_name: bool = True,
346 ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
347 """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
349 This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an
350 `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are
351 executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the
352 stream of events coming from the execution of tools.
354 The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics,
355 and the final result of the run once it has completed.
357 For more details, see the documentation of `AgentRun`.
359 Example:
360 ```python
361 from pydantic_ai import Agent
363 agent = Agent('openai:gpt-4o')
365 async def main():
366 nodes = []
367 async with agent.iter('What is the capital of France?') as agent_run:
368 async for node in agent_run:
369 nodes.append(node)
370 print(nodes)
371 '''
372 [
373 ModelRequestNode(
374 request=ModelRequest(
375 parts=[
376 UserPromptPart(
377 content='What is the capital of France?',
378 timestamp=datetime.datetime(...),
379 part_kind='user-prompt',
380 )
381 ],
382 kind='request',
383 )
384 ),
385 CallToolsNode(
386 model_response=ModelResponse(
387 parts=[TextPart(content='Paris', part_kind='text')],
388 model_name='gpt-4o',
389 timestamp=datetime.datetime(...),
390 kind='response',
391 )
392 ),
393 End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
394 ]
395 '''
396 print(agent_run.result.data)
397 #> Paris
398 ```
400 Args:
401 user_prompt: User input to start/continue the conversation.
402 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
403 result validators since result validators would expect an argument that matches the agent's result type.
404 message_history: History of the conversation so far.
405 model: Optional model to use for this run, required if `model` was not set when creating the agent.
406 deps: Optional dependencies to use for this run.
407 model_settings: Optional settings to use for this model's request.
408 usage_limits: Optional limits on model request count or token usage.
409 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
410 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
412 Returns:
413 The result of the run.
414 """
415 if infer_name and self.name is None:
416 self._infer_name(inspect.currentframe())
417 model_used = self._get_model(model)
418 del model
420 deps = self._get_deps(deps)
421 new_message_index = len(message_history) if message_history else 0
422 result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
424 # Build the graph
425 graph = self._build_graph(result_type)
427 # Build the initial state
428 state = _agent_graph.GraphAgentState(
429 message_history=message_history[:] if message_history else [],
430 usage=usage or _usage.Usage(),
431 retries=0,
432 run_step=0,
433 )
435 # We consider it a user error if a user tries to restrict the result type while having a result validator that
436 # may change the result type from the restricted type to something else. Therefore, we consider the following
437 # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
438 result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
440 # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
441 # runs. Requires some changes to `Tool` to make them copyable though.
442 for v in self._function_tools.values():
443 v.current_retry = 0
445 model_settings = merge_model_settings(self.model_settings, model_settings)
446 usage_limits = usage_limits or _usage.UsageLimits()
448 if isinstance(model_used, InstrumentedModel):
449 tracer = model_used.settings.tracer
450 else:
451 tracer = NoOpTracer()
452 agent_name = self.name or 'agent'
453 run_span = tracer.start_span(
454 'agent run',
455 attributes={
456 'model_name': model_used.model_name if model_used else 'no-model',
457 'agent_name': agent_name,
458 'logfire.msg': f'{agent_name} run',
459 },
460 )
462 graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
463 user_deps=deps,
464 prompt=user_prompt,
465 new_message_index=new_message_index,
466 model=model_used,
467 model_settings=model_settings,
468 usage_limits=usage_limits,
469 max_result_retries=self._max_result_retries,
470 end_strategy=self.end_strategy,
471 result_schema=result_schema,
472 result_tools=self._result_schema.tool_defs() if self._result_schema else [],
473 result_validators=result_validators,
474 function_tools=self._function_tools,
475 mcp_servers=self._mcp_servers,
476 run_span=run_span,
477 tracer=tracer,
478 )
479 start_node = _agent_graph.UserPromptNode[AgentDepsT](
480 user_prompt=user_prompt,
481 system_prompts=self._system_prompts,
482 system_prompt_functions=self._system_prompt_functions,
483 system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
484 )
486 async with graph.iter(
487 start_node,
488 state=state,
489 deps=graph_deps,
490 span=use_span(run_span, end_on_exit=True),
491 infer_name=False,
492 ) as graph_run:
493 yield AgentRun(graph_run)
495 @overload
496 def run_sync(
497 self,
498 user_prompt: str | Sequence[_messages.UserContent],
499 *,
500 message_history: list[_messages.ModelMessage] | None = None,
501 model: models.Model | models.KnownModelName | None = None,
502 deps: AgentDepsT = None,
503 model_settings: ModelSettings | None = None,
504 usage_limits: _usage.UsageLimits | None = None,
505 usage: _usage.Usage | None = None,
506 infer_name: bool = True,
507 ) -> AgentRunResult[ResultDataT]: ...
509 @overload
510 def run_sync(
511 self,
512 user_prompt: str | Sequence[_messages.UserContent],
513 *,
514 result_type: type[RunResultDataT] | None,
515 message_history: list[_messages.ModelMessage] | None = None,
516 model: models.Model | models.KnownModelName | None = None,
517 deps: AgentDepsT = None,
518 model_settings: ModelSettings | None = None,
519 usage_limits: _usage.UsageLimits | None = None,
520 usage: _usage.Usage | None = None,
521 infer_name: bool = True,
522 ) -> AgentRunResult[RunResultDataT]: ...
524 def run_sync(
525 self,
526 user_prompt: str | Sequence[_messages.UserContent],
527 *,
528 result_type: type[RunResultDataT] | None = None,
529 message_history: list[_messages.ModelMessage] | None = None,
530 model: models.Model | models.KnownModelName | None = None,
531 deps: AgentDepsT = None,
532 model_settings: ModelSettings | None = None,
533 usage_limits: _usage.UsageLimits | None = None,
534 usage: _usage.Usage | None = None,
535 infer_name: bool = True,
536 ) -> AgentRunResult[Any]:
537 """Synchronously run the agent with a user prompt.
539 This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
540 You therefore can't use this method inside async code or if there's an active event loop.
542 Example:
543 ```python
544 from pydantic_ai import Agent
546 agent = Agent('openai:gpt-4o')
548 result_sync = agent.run_sync('What is the capital of Italy?')
549 print(result_sync.data)
550 #> Rome
551 ```
553 Args:
554 user_prompt: User input to start/continue the conversation.
555 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
556 result validators since result validators would expect an argument that matches the agent's result type.
557 message_history: History of the conversation so far.
558 model: Optional model to use for this run, required if `model` was not set when creating the agent.
559 deps: Optional dependencies to use for this run.
560 model_settings: Optional settings to use for this model's request.
561 usage_limits: Optional limits on model request count or token usage.
562 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
563 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
565 Returns:
566 The result of the run.
567 """
568 if infer_name and self.name is None:
569 self._infer_name(inspect.currentframe())
570 return get_event_loop().run_until_complete(
571 self.run(
572 user_prompt,
573 result_type=result_type,
574 message_history=message_history,
575 model=model,
576 deps=deps,
577 model_settings=model_settings,
578 usage_limits=usage_limits,
579 usage=usage,
580 infer_name=False,
581 )
582 )
584 @overload
585 def run_stream(
586 self,
587 user_prompt: str | Sequence[_messages.UserContent],
588 *,
589 result_type: None = None,
590 message_history: list[_messages.ModelMessage] | None = None,
591 model: models.Model | models.KnownModelName | None = None,
592 deps: AgentDepsT = None,
593 model_settings: ModelSettings | None = None,
594 usage_limits: _usage.UsageLimits | None = None,
595 usage: _usage.Usage | None = None,
596 infer_name: bool = True,
597 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ...
599 @overload
600 def run_stream(
601 self,
602 user_prompt: str | Sequence[_messages.UserContent],
603 *,
604 result_type: type[RunResultDataT],
605 message_history: list[_messages.ModelMessage] | None = None,
606 model: models.Model | models.KnownModelName | None = None,
607 deps: AgentDepsT = None,
608 model_settings: ModelSettings | None = None,
609 usage_limits: _usage.UsageLimits | None = None,
610 usage: _usage.Usage | None = None,
611 infer_name: bool = True,
612 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
614 @asynccontextmanager
615 async def run_stream( # noqa C901
616 self,
617 user_prompt: str | Sequence[_messages.UserContent],
618 *,
619 result_type: type[RunResultDataT] | None = None,
620 message_history: list[_messages.ModelMessage] | None = None,
621 model: models.Model | models.KnownModelName | None = None,
622 deps: AgentDepsT = None,
623 model_settings: ModelSettings | None = None,
624 usage_limits: _usage.UsageLimits | None = None,
625 usage: _usage.Usage | None = None,
626 infer_name: bool = True,
627 ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
628 """Run the agent with a user prompt in async mode, returning a streamed response.
630 Example:
631 ```python
632 from pydantic_ai import Agent
634 agent = Agent('openai:gpt-4o')
636 async def main():
637 async with agent.run_stream('What is the capital of the UK?') as response:
638 print(await response.get_data())
639 #> London
640 ```
642 Args:
643 user_prompt: User input to start/continue the conversation.
644 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
645 result validators since result validators would expect an argument that matches the agent's result type.
646 message_history: History of the conversation so far.
647 model: Optional model to use for this run, required if `model` was not set when creating the agent.
648 deps: Optional dependencies to use for this run.
649 model_settings: Optional settings to use for this model's request.
650 usage_limits: Optional limits on model request count or token usage.
651 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
652 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
654 Returns:
655 The result of the run.
656 """
657 # TODO: We need to deprecate this now that we have the `iter` method.
658 # Before that, though, we should add an event for when we reach the final result of the stream.
659 if infer_name and self.name is None:
660 # f_back because `asynccontextmanager` adds one frame
661 if frame := inspect.currentframe(): # pragma: no branch
662 self._infer_name(frame.f_back)
664 yielded = False
665 async with self.iter(
666 user_prompt,
667 result_type=result_type,
668 message_history=message_history,
669 model=model,
670 deps=deps,
671 model_settings=model_settings,
672 usage_limits=usage_limits,
673 usage=usage,
674 infer_name=False,
675 ) as agent_run:
676 first_node = agent_run.next_node # start with the first node
677 assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
678 node = first_node
679 while True:
680 if self.is_model_request_node(node):
681 graph_ctx = agent_run.ctx
682 async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
684 async def stream_to_final(
685 s: models.StreamedResponse,
686 ) -> FinalResult[models.StreamedResponse] | None:
687 result_schema = graph_ctx.deps.result_schema
688 async for maybe_part_event in streamed_response:
689 if isinstance(maybe_part_event, _messages.PartStartEvent):
690 new_part = maybe_part_event.part
691 if isinstance(new_part, _messages.TextPart):
692 if _agent_graph.allow_text_result(result_schema):
693 return FinalResult(s, None, None)
694 elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
695 for call, _ in result_schema.find_tool([new_part]):
696 return FinalResult(s, call.tool_name, call.tool_call_id)
697 return None
699 final_result_details = await stream_to_final(streamed_response)
700 if final_result_details is not None:
701 if yielded: 701 ↛ 702line 701 didn't jump to line 702 because the condition on line 701 was never true
702 raise exceptions.AgentRunError('Agent run produced final results')
703 yielded = True
705 messages = graph_ctx.state.message_history.copy()
707 async def on_complete() -> None:
708 """Called when the stream has completed.
710 The model response will have been added to messages by now
711 by `StreamedRunResult._marked_completed`.
712 """
713 last_message = messages[-1]
714 assert isinstance(last_message, _messages.ModelResponse)
715 tool_calls = [
716 part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
717 ]
719 parts: list[_messages.ModelRequestPart] = []
720 async for _event in _agent_graph.process_function_tools(
721 tool_calls,
722 final_result_details.tool_name,
723 final_result_details.tool_call_id,
724 graph_ctx,
725 parts,
726 ):
727 pass
728 # TODO: Should we do something here related to the retry count?
729 # Maybe we should move the incrementing of the retry count to where we actually make a request?
730 # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
731 # ctx.state.increment_retries(ctx.deps.max_result_retries)
732 if parts:
733 messages.append(_messages.ModelRequest(parts))
735 yield StreamedRunResult(
736 messages,
737 graph_ctx.deps.new_message_index,
738 graph_ctx.deps.usage_limits,
739 streamed_response,
740 graph_ctx.deps.result_schema,
741 _agent_graph.build_run_context(graph_ctx),
742 graph_ctx.deps.result_validators,
743 final_result_details.tool_name,
744 on_complete,
745 )
746 break
747 next_node = await agent_run.next(node)
748 if not isinstance(next_node, _agent_graph.AgentNode): 748 ↛ 749line 748 didn't jump to line 749 because the condition on line 748 was never true
749 raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here')
750 node = cast(_agent_graph.AgentNode[Any, Any], next_node)
752 if not yielded: 752 ↛ 753line 752 didn't jump to line 753 because the condition on line 752 was never true
753 raise exceptions.AgentRunError('Agent run finished without producing a final result')
755 @contextmanager
756 def override(
757 self,
758 *,
759 deps: AgentDepsT | _utils.Unset = _utils.UNSET,
760 model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
761 ) -> Iterator[None]:
762 """Context manager to temporarily override agent dependencies and model.
764 This is particularly useful when testing.
765 You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
767 Args:
768 deps: The dependencies to use instead of the dependencies passed to the agent run.
769 model: The model to use instead of the model passed to the agent run.
770 """
771 if _utils.is_set(deps):
772 override_deps_before = self._override_deps
773 self._override_deps = _utils.Some(deps)
774 else:
775 override_deps_before = _utils.UNSET
777 # noinspection PyTypeChecker
778 if _utils.is_set(model):
779 override_model_before = self._override_model
780 # noinspection PyTypeChecker
781 self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
782 else:
783 override_model_before = _utils.UNSET
785 try:
786 yield
787 finally:
788 if _utils.is_set(override_deps_before):
789 self._override_deps = override_deps_before
790 if _utils.is_set(override_model_before):
791 self._override_model = override_model_before
793 @overload
794 def system_prompt(
795 self, func: Callable[[RunContext[AgentDepsT]], str], /
796 ) -> Callable[[RunContext[AgentDepsT]], str]: ...
798 @overload
799 def system_prompt(
800 self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
801 ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
803 @overload
804 def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
806 @overload
807 def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
809 @overload
810 def system_prompt(
811 self, /, *, dynamic: bool = False
812 ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
814 def system_prompt(
815 self,
816 func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
817 /,
818 *,
819 dynamic: bool = False,
820 ) -> (
821 Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
822 | _system_prompt.SystemPromptFunc[AgentDepsT]
823 ):
824 """Decorator to register a system prompt function.
826 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
827 Can decorate a sync or async functions.
829 The decorator can be used either bare (`agent.system_prompt`) or as a function call
830 (`agent.system_prompt(...)`), see the examples below.
832 Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
833 the type of the function, see `tests/typed_agent.py` for tests.
835 Args:
836 func: The function to decorate
837 dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
838 see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
840 Example:
841 ```python
842 from pydantic_ai import Agent, RunContext
844 agent = Agent('test', deps_type=str)
846 @agent.system_prompt
847 def simple_system_prompt() -> str:
848 return 'foobar'
850 @agent.system_prompt(dynamic=True)
851 async def async_system_prompt(ctx: RunContext[str]) -> str:
852 return f'{ctx.deps} is the best'
853 ```
854 """
855 if func is None:
857 def decorator(
858 func_: _system_prompt.SystemPromptFunc[AgentDepsT],
859 ) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
860 runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
861 self._system_prompt_functions.append(runner)
862 if dynamic: 862 ↛ 864line 862 didn't jump to line 864 because the condition on line 862 was always true
863 self._system_prompt_dynamic_functions[func_.__qualname__] = runner
864 return func_
866 return decorator
867 else:
868 assert not dynamic, "dynamic can't be True in this case"
869 self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
870 return func
872 @overload
873 def result_validator(
874 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], /
875 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ...
877 @overload
878 def result_validator(
879 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], /
880 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ...
882 @overload
883 def result_validator(
884 self, func: Callable[[ResultDataT], ResultDataT], /
885 ) -> Callable[[ResultDataT], ResultDataT]: ...
887 @overload
888 def result_validator(
889 self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], /
890 ) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ...
892 def result_validator(
893 self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], /
894 ) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]:
895 """Decorator to register a result validator function.
897 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
898 Can decorate a sync or async functions.
900 Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
901 the type of the function, see `tests/typed_agent.py` for tests.
903 Example:
904 ```python
905 from pydantic_ai import Agent, ModelRetry, RunContext
907 agent = Agent('test', deps_type=str)
909 @agent.result_validator
910 def result_validator_simple(data: str) -> str:
911 if 'wrong' in data:
912 raise ModelRetry('wrong response')
913 return data
915 @agent.result_validator
916 async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
917 if ctx.deps in data:
918 raise ModelRetry('wrong response')
919 return data
921 result = agent.run_sync('foobar', deps='spam')
922 print(result.data)
923 #> success (no tool calls)
924 ```
925 """
926 self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func))
927 return func
929 @overload
930 def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
932 @overload
933 def tool(
934 self,
935 /,
936 *,
937 name: str | None = None,
938 retries: int | None = None,
939 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
940 docstring_format: DocstringFormat = 'auto',
941 require_parameter_descriptions: bool = False,
942 schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
943 ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
945 def tool(
946 self,
947 func: ToolFuncContext[AgentDepsT, ToolParams] | None = None,
948 /,
949 *,
950 name: str | None = None,
951 retries: int | None = None,
952 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
953 docstring_format: DocstringFormat = 'auto',
954 require_parameter_descriptions: bool = False,
955 schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
956 ) -> Any:
957 """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
959 Can decorate a sync or async functions.
961 The docstring is inspected to extract both the tool description and description of each parameter,
962 [learn more](../tools.md#function-tools-and-schema).
964 We can't add overloads for every possible signature of tool, since the return type is a recursive union
965 so the signature of functions decorated with `@agent.tool` is obscured.
967 Example:
968 ```python
969 from pydantic_ai import Agent, RunContext
971 agent = Agent('test', deps_type=int)
973 @agent.tool
974 def foobar(ctx: RunContext[int], x: int) -> int:
975 return ctx.deps + x
977 @agent.tool(retries=2)
978 async def spam(ctx: RunContext[str], y: float) -> float:
979 return ctx.deps + y
981 result = agent.run_sync('foobar', deps=1)
982 print(result.data)
983 #> {"foobar":1,"spam":1.0}
984 ```
986 Args:
987 func: The tool function to register.
988 name: The name of the tool, defaults to the function name.
989 retries: The number of retries to allow for this tool, defaults to the agent's default retries,
990 which defaults to 1.
991 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
992 tool from a given step. This is useful if you want to customise a tool at call time,
993 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
994 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
995 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
996 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
997 schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
998 """
999 if func is None:
1001 def tool_decorator(
1002 func_: ToolFuncContext[AgentDepsT, ToolParams],
1003 ) -> ToolFuncContext[AgentDepsT, ToolParams]:
1004 # noinspection PyTypeChecker
1005 self._register_function(
1006 func_,
1007 True,
1008 name,
1009 retries,
1010 prepare,
1011 docstring_format,
1012 require_parameter_descriptions,
1013 schema_generator,
1014 )
1015 return func_
1017 return tool_decorator
1018 else:
1019 # noinspection PyTypeChecker
1020 self._register_function(
1021 func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1022 )
1023 return func
1025 @overload
1026 def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
1028 @overload
1029 def tool_plain(
1030 self,
1031 /,
1032 *,
1033 name: str | None = None,
1034 retries: int | None = None,
1035 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
1036 docstring_format: DocstringFormat = 'auto',
1037 require_parameter_descriptions: bool = False,
1038 schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1039 ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
1041 def tool_plain(
1042 self,
1043 func: ToolFuncPlain[ToolParams] | None = None,
1044 /,
1045 *,
1046 name: str | None = None,
1047 retries: int | None = None,
1048 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
1049 docstring_format: DocstringFormat = 'auto',
1050 require_parameter_descriptions: bool = False,
1051 schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1052 ) -> Any:
1053 """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
1055 Can decorate a sync or async functions.
1057 The docstring is inspected to extract both the tool description and description of each parameter,
1058 [learn more](../tools.md#function-tools-and-schema).
1060 We can't add overloads for every possible signature of tool, since the return type is a recursive union
1061 so the signature of functions decorated with `@agent.tool` is obscured.
1063 Example:
1064 ```python
1065 from pydantic_ai import Agent, RunContext
1067 agent = Agent('test')
1069 @agent.tool
1070 def foobar(ctx: RunContext[int]) -> int:
1071 return 123
1073 @agent.tool(retries=2)
1074 async def spam(ctx: RunContext[str]) -> float:
1075 return 3.14
1077 result = agent.run_sync('foobar', deps=1)
1078 print(result.data)
1079 #> {"foobar":123,"spam":3.14}
1080 ```
1082 Args:
1083 func: The tool function to register.
1084 name: The name of the tool, defaults to the function name.
1085 retries: The number of retries to allow for this tool, defaults to the agent's default retries,
1086 which defaults to 1.
1087 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
1088 tool from a given step. This is useful if you want to customise a tool at call time,
1089 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
1090 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
1091 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
1092 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
1093 schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1094 """
1095 if func is None:
1097 def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
1098 # noinspection PyTypeChecker
1099 self._register_function(
1100 func_,
1101 False,
1102 name,
1103 retries,
1104 prepare,
1105 docstring_format,
1106 require_parameter_descriptions,
1107 schema_generator,
1108 )
1109 return func_
1111 return tool_decorator
1112 else:
1113 self._register_function(
1114 func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1115 )
1116 return func
1118 def _register_function(
1119 self,
1120 func: ToolFuncEither[AgentDepsT, ToolParams],
1121 takes_ctx: bool,
1122 name: str | None,
1123 retries: int | None,
1124 prepare: ToolPrepareFunc[AgentDepsT] | None,
1125 docstring_format: DocstringFormat,
1126 require_parameter_descriptions: bool,
1127 schema_generator: type[GenerateJsonSchema],
1128 ) -> None:
1129 """Private utility to register a function as a tool."""
1130 retries_ = retries if retries is not None else self._default_retries
1131 tool = Tool[AgentDepsT](
1132 func,
1133 takes_ctx=takes_ctx,
1134 name=name,
1135 max_retries=retries_,
1136 prepare=prepare,
1137 docstring_format=docstring_format,
1138 require_parameter_descriptions=require_parameter_descriptions,
1139 schema_generator=schema_generator,
1140 )
1141 self._register_tool(tool)
1143 def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
1144 """Private utility to register a tool instance."""
1145 if tool.max_retries is None:
1146 # noinspection PyTypeChecker
1147 tool = dataclasses.replace(tool, max_retries=self._default_retries)
1149 if tool.name in self._function_tools:
1150 raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
1152 if self._result_schema and tool.name in self._result_schema.tools:
1153 raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
1155 self._function_tools[tool.name] = tool
1157 def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
1158 """Create a model configured for this agent.
1160 Args:
1161 model: model to use for this run, required if `model` was not set when creating the agent.
1163 Returns:
1164 The model used
1165 """
1166 model_: models.Model
1167 if some_model := self._override_model:
1168 # we don't want `override()` to cover up errors from the model not being defined, hence this check
1169 if model is None and self.model is None:
1170 raise exceptions.UserError(
1171 '`model` must be set either when creating the agent or when calling it. '
1172 '(Even when `override(model=...)` is customizing the model that will actually be called)'
1173 )
1174 model_ = some_model.value
1175 elif model is not None:
1176 model_ = models.infer_model(model)
1177 elif self.model is not None:
1178 # noinspection PyTypeChecker
1179 model_ = self.model = models.infer_model(self.model)
1180 else:
1181 raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1183 instrument = self.instrument
1184 if instrument is None:
1185 instrument = self._instrument_default
1187 if instrument and not isinstance(model_, InstrumentedModel):
1188 if instrument is True:
1189 instrument = InstrumentationSettings()
1191 model_ = InstrumentedModel(model_, instrument)
1193 return model_
1195 def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
1196 """Get deps for a run.
1198 If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
1200 We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
1201 """
1202 if some_deps := self._override_deps:
1203 return some_deps.value
1204 else:
1205 return deps
1207 def _infer_name(self, function_frame: FrameType | None) -> None:
1208 """Infer the agent name from the call frame.
1210 Usage should be `self._infer_name(inspect.currentframe())`.
1211 """
1212 assert self.name is None, 'Name already set'
1213 if function_frame is not None: # pragma: no branch
1214 if parent_frame := function_frame.f_back: # pragma: no branch
1215 for name, item in parent_frame.f_locals.items():
1216 if item is self:
1217 self.name = name
1218 return
1219 if parent_frame.f_locals != parent_frame.f_globals: 1219 ↛ exitline 1219 didn't return from function '_infer_name' because the condition on line 1219 was always true
1220 # if we couldn't find the agent in locals and globals are a different dict, try globals
1221 for name, item in parent_frame.f_globals.items():
1222 if item is self:
1223 self.name = name
1224 return
1226 @property
1227 @deprecated(
1228 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
1229 )
1230 def last_run_messages(self) -> list[_messages.ModelMessage]:
1231 raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1233 def _build_graph(
1234 self, result_type: type[RunResultDataT] | None
1235 ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]]:
1236 return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)
1238 def _prepare_result_schema(
1239 self, result_type: type[RunResultDataT] | None
1240 ) -> _result.ResultSchema[RunResultDataT] | None:
1241 if result_type is not None:
1242 if self._result_validators:
1243 raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
1244 return _result.ResultSchema[result_type].build(
1245 result_type, self._result_tool_name, self._result_tool_description
1246 )
1247 else:
1248 return self._result_schema # pyright: ignore[reportReturnType]
1250 @staticmethod
1251 def is_model_request_node(
1252 node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1253 ) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]:
1254 """Check if the node is a `ModelRequestNode`, narrowing the type if it is.
1256 This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1257 """
1258 return isinstance(node, _agent_graph.ModelRequestNode)
1260 @staticmethod
1261 def is_call_tools_node(
1262 node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1263 ) -> TypeGuard[_agent_graph.CallToolsNode[T, S]]:
1264 """Check if the node is a `CallToolsNode`, narrowing the type if it is.
1266 This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1267 """
1268 return isinstance(node, _agent_graph.CallToolsNode)
1270 @staticmethod
1271 def is_user_prompt_node(
1272 node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1273 ) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]:
1274 """Check if the node is a `UserPromptNode`, narrowing the type if it is.
1276 This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1277 """
1278 return isinstance(node, _agent_graph.UserPromptNode)
1280 @staticmethod
1281 def is_end_node(
1282 node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1283 ) -> TypeGuard[End[result.FinalResult[S]]]:
1284 """Check if the node is a `End`, narrowing the type if it is.
1286 This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1287 """
1288 return isinstance(node, End)
1290 @asynccontextmanager
1291 async def run_mcp_servers(self) -> AsyncIterator[None]:
1292 """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1294 Returns: a context manager to start and shutdown the servers.
1295 """
1296 exit_stack = AsyncExitStack()
1297 try:
1298 for mcp_server in self._mcp_servers:
1299 await exit_stack.enter_async_context(mcp_server)
1300 yield
1301 finally:
1302 await exit_stack.aclose()
1305@dataclasses.dataclass(repr=False)
1306class AgentRun(Generic[AgentDepsT, ResultDataT]):
1307 """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
1309 You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
1311 Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
1312 [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
1313 becomes available.
1315 Example:
1316 ```python
1317 from pydantic_ai import Agent
1319 agent = Agent('openai:gpt-4o')
1321 async def main():
1322 nodes = []
1323 # Iterate through the run, recording each node along the way:
1324 async with agent.iter('What is the capital of France?') as agent_run:
1325 async for node in agent_run:
1326 nodes.append(node)
1327 print(nodes)
1328 '''
1329 [
1330 ModelRequestNode(
1331 request=ModelRequest(
1332 parts=[
1333 UserPromptPart(
1334 content='What is the capital of France?',
1335 timestamp=datetime.datetime(...),
1336 part_kind='user-prompt',
1337 )
1338 ],
1339 kind='request',
1340 )
1341 ),
1342 CallToolsNode(
1343 model_response=ModelResponse(
1344 parts=[TextPart(content='Paris', part_kind='text')],
1345 model_name='gpt-4o',
1346 timestamp=datetime.datetime(...),
1347 kind='response',
1348 )
1349 ),
1350 End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1351 ]
1352 '''
1353 print(agent_run.result.data)
1354 #> Paris
1355 ```
1357 You can also manually drive the iteration using the [`next`][pydantic_ai.agent.AgentRun.next] method for
1358 more granular control.
1359 """
1361 _graph_run: GraphRun[
1362 _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]
1363 ]
1365 @property
1366 def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
1367 """The current context of the agent run."""
1368 return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]](
1369 self._graph_run.state, self._graph_run.deps
1370 )
1372 @property
1373 def next_node(
1374 self,
1375 ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1376 """The next node that will be run in the agent graph.
1378 This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
1379 """
1380 next_node = self._graph_run.next_node
1381 if isinstance(next_node, End):
1382 return next_node
1383 if _agent_graph.is_agent_node(next_node):
1384 return next_node
1385 raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
1387 @property
1388 def result(self) -> AgentRunResult[ResultDataT] | None:
1389 """The final result of the run if it has ended, otherwise `None`.
1391 Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
1392 with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult].
1393 """
1394 graph_run_result = self._graph_run.result
1395 if graph_run_result is None: 1395 ↛ 1396line 1395 didn't jump to line 1396 because the condition on line 1395 was never true
1396 return None
1397 return AgentRunResult(
1398 graph_run_result.output.data,
1399 graph_run_result.output.tool_name,
1400 graph_run_result.state,
1401 self._graph_run.deps.new_message_index,
1402 )
1404 def __aiter__(
1405 self,
1406 ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]]:
1407 """Provide async-iteration over the nodes in the agent run."""
1408 return self
1410 async def __anext__(
1411 self,
1412 ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1413 """Advance to the next node automatically based on the last returned node."""
1414 next_node = await self._graph_run.__anext__()
1415 if _agent_graph.is_agent_node(next_node):
1416 return next_node
1417 assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
1418 return next_node
1420 async def next(
1421 self,
1422 node: _agent_graph.AgentNode[AgentDepsT, ResultDataT],
1423 ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1424 """Manually drive the agent run by passing in the node you want to run next.
1426 This lets you inspect or mutate the node before continuing execution, or skip certain nodes
1427 under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End]
1428 node.
1430 Example:
1431 ```python
1432 from pydantic_ai import Agent
1433 from pydantic_graph import End
1435 agent = Agent('openai:gpt-4o')
1437 async def main():
1438 async with agent.iter('What is the capital of France?') as agent_run:
1439 next_node = agent_run.next_node # start with the first node
1440 nodes = [next_node]
1441 while not isinstance(next_node, End):
1442 next_node = await agent_run.next(next_node)
1443 nodes.append(next_node)
1444 # Once `next_node` is an End, we've finished:
1445 print(nodes)
1446 '''
1447 [
1448 UserPromptNode(
1449 user_prompt='What is the capital of France?',
1450 system_prompts=(),
1451 system_prompt_functions=[],
1452 system_prompt_dynamic_functions={},
1453 ),
1454 ModelRequestNode(
1455 request=ModelRequest(
1456 parts=[
1457 UserPromptPart(
1458 content='What is the capital of France?',
1459 timestamp=datetime.datetime(...),
1460 part_kind='user-prompt',
1461 )
1462 ],
1463 kind='request',
1464 )
1465 ),
1466 CallToolsNode(
1467 model_response=ModelResponse(
1468 parts=[TextPart(content='Paris', part_kind='text')],
1469 model_name='gpt-4o',
1470 timestamp=datetime.datetime(...),
1471 kind='response',
1472 )
1473 ),
1474 End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1475 ]
1476 '''
1477 print('Final result:', agent_run.result.data)
1478 #> Final result: Paris
1479 ```
1481 Args:
1482 node: The node to run next in the graph.
1484 Returns:
1485 The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if
1486 the run has completed.
1487 """
1488 # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
1489 # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
1490 next_node = await self._graph_run.next(node)
1491 if _agent_graph.is_agent_node(next_node):
1492 return next_node
1493 assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
1494 return next_node
1496 def usage(self) -> _usage.Usage:
1497 """Get usage statistics for the run so far, including token usage, model requests, and so on."""
1498 return self._graph_run.state.usage
1500 def __repr__(self) -> str:
1501 result = self._graph_run.result
1502 result_repr = '<run not finished>' if result is None else repr(result.output)
1503 return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>'
1506@dataclasses.dataclass
1507class AgentRunResult(Generic[ResultDataT]):
1508 """The final result of an agent run."""
1510 data: ResultDataT # TODO: rename this to output. I'm putting this off for now mostly to reduce the size of the diff
1512 _result_tool_name: str | None = dataclasses.field(repr=False)
1513 _state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
1514 _new_message_index: int = dataclasses.field(repr=False)
1516 def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
1517 """Set return content for the result tool.
1519 Useful if you want to continue the conversation and want to set the response to the result tool call.
1520 """
1521 if not self._result_tool_name:
1522 raise ValueError('Cannot set result tool return content when the return type is `str`.')
1523 messages = deepcopy(self._state.message_history)
1524 last_message = messages[-1]
1525 for part in last_message.parts:
1526 if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name:
1527 part.content = return_content
1528 return messages
1529 raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.')
1531 def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
1532 """Return the history of _messages.
1534 Args:
1535 result_tool_return_content: The return content of the tool call to set in the last message.
1536 This provides a convenient way to modify the content of the result tool call if you want to continue
1537 the conversation and want to set the response to the result tool call. If `None`, the last message will
1538 not be modified.
1540 Returns:
1541 List of messages.
1542 """
1543 if result_tool_return_content is not None:
1544 return self._set_result_tool_return(result_tool_return_content)
1545 else:
1546 return self._state.message_history
1548 def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
1549 """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes.
1551 Args:
1552 result_tool_return_content: The return content of the tool call to set in the last message.
1553 This provides a convenient way to modify the content of the result tool call if you want to continue
1554 the conversation and want to set the response to the result tool call. If `None`, the last message will
1555 not be modified.
1557 Returns:
1558 JSON bytes representing the messages.
1559 """
1560 return _messages.ModelMessagesTypeAdapter.dump_json(
1561 self.all_messages(result_tool_return_content=result_tool_return_content)
1562 )
1564 def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
1565 """Return new messages associated with this run.
1567 Messages from older runs are excluded.
1569 Args:
1570 result_tool_return_content: The return content of the tool call to set in the last message.
1571 This provides a convenient way to modify the content of the result tool call if you want to continue
1572 the conversation and want to set the response to the result tool call. If `None`, the last message will
1573 not be modified.
1575 Returns:
1576 List of new messages.
1577 """
1578 return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
1580 def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
1581 """Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes.
1583 Args:
1584 result_tool_return_content: The return content of the tool call to set in the last message.
1585 This provides a convenient way to modify the content of the result tool call if you want to continue
1586 the conversation and want to set the response to the result tool call. If `None`, the last message will
1587 not be modified.
1589 Returns:
1590 JSON bytes representing the new messages.
1591 """
1592 return _messages.ModelMessagesTypeAdapter.dump_json(
1593 self.new_messages(result_tool_return_content=result_tool_return_content)
1594 )
1596 def usage(self) -> _usage.Usage:
1597 """Return the usage of the whole run."""
1598 return self._state.usage