Coverage for pydantic_ai_slim/pydantic_ai/agent.py: 99.36%
458 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import dataclasses
5import inspect
6from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
8from contextvars import ContextVar
9from types import FrameType
10from typing import Any, Callable, Generic, Literal, cast, final, overload
12import logfire_api
13from typing_extensions import TypeVar, assert_never, deprecated
15from . import (
16 _result,
17 _system_prompt,
18 _utils,
19 exceptions,
20 messages as _messages,
21 models,
22 result,
23 usage as _usage,
24)
25from .result import ResultDataT
26from .settings import ModelSettings, merge_model_settings
27from .tools import (
28 AgentDepsT,
29 DocstringFormat,
30 RunContext,
31 Tool,
32 ToolDefinition,
33 ToolFuncContext,
34 ToolFuncEither,
35 ToolFuncPlain,
36 ToolParams,
37 ToolPrepareFunc,
38)
40__all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
42_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
44# while waiting for https://github.com/pydantic/logfire/issues/745
45try:
46 import logfire._internal.stack_info
47except ImportError:
48 pass
49else:
50 from pathlib import Path
52 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
54T = TypeVar('T')
55"""An invariant TypeVar."""
56NoneType = type(None)
57EndStrategy = Literal['early', 'exhaustive']
58"""The strategy for handling multiple tool calls when a final result is found.
60- `'early'`: Stop processing other tool calls once a final result is found
61- `'exhaustive'`: Process all tool calls even after finding a final result
62"""
63RunResultData = TypeVar('RunResultData')
64"""Type variable for the result data of a run where `result_type` was customized on the run call."""
67@final
68@dataclasses.dataclass(init=False)
69class Agent(Generic[AgentDepsT, ResultDataT]):
70 """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
72 Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
73 and the result data type they return, [`ResultDataT`][pydantic_ai.result.ResultDataT].
75 By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
77 Minimal usage example:
79 ```python
80 from pydantic_ai import Agent
82 agent = Agent('openai:gpt-4o')
83 result = agent.run_sync('What is the capital of France?')
84 print(result.data)
85 #> Paris
86 ```
87 """
89 # we use dataclass fields in order to conveniently know what attributes are available
90 model: models.Model | models.KnownModelName | None
91 """The default model configured for this agent."""
93 name: str | None
94 """The name of the agent, used for logging.
96 If `None`, we try to infer the agent name from the call frame when the agent is first run.
97 """
98 end_strategy: EndStrategy
99 """Strategy for handling tool calls when a final result is found."""
101 model_settings: ModelSettings | None
102 """Optional model request settings to use for this agents's runs, by default.
104 Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
105 be merged with this value, with the runtime argument taking priority.
106 """
107 _result_tool_name: str = dataclasses.field(repr=False)
108 _result_tool_description: str | None = dataclasses.field(repr=False)
109 _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
110 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
111 _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
112 _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
113 _default_retries: int = dataclasses.field(repr=False)
114 _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
115 _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
116 repr=False
117 )
118 _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
119 _max_result_retries: int = dataclasses.field(repr=False)
120 _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
121 _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
123 def __init__(
124 self,
125 model: models.Model | models.KnownModelName | None = None,
126 *,
127 result_type: type[ResultDataT] = str,
128 system_prompt: str | Sequence[str] = (),
129 deps_type: type[AgentDepsT] = NoneType,
130 name: str | None = None,
131 model_settings: ModelSettings | None = None,
132 retries: int = 1,
133 result_tool_name: str = 'final_result',
134 result_tool_description: str | None = None,
135 result_retries: int | None = None,
136 tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
137 defer_model_check: bool = False,
138 end_strategy: EndStrategy = 'early',
139 ):
140 """Create an agent.
142 Args:
143 model: The default model to use for this agent, if not provide,
144 you must provide the model when calling it.
145 result_type: The type of the result data, used to validate the result data, defaults to `str`.
146 system_prompt: Static system prompts to use for this agent, you can also register system
147 prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
148 deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
149 parameterize the agent, and therefore get the best out of static type checking.
150 If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
151 or add a type hint `: Agent[None, <return type>]`.
152 name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
153 when the agent is first run.
154 model_settings: Optional model request settings to use for this agent's runs, by default.
155 retries: The default number of retries to allow before raising an error.
156 result_tool_name: The name of the tool to use for the final result.
157 result_tool_description: The description of the final result tool.
158 result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
159 tools: Tools to register with the agent, you can also register tools via the decorators
160 [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
161 defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
162 it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
163 which checks for the necessary environment variables. Set this to `false`
164 to defer the evaluation until the first run. Useful if you want to
165 [override the model][pydantic_ai.Agent.override] for testing.
166 end_strategy: Strategy for handling tool calls that are requested alongside a final result.
167 See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
168 """
169 if model is None or defer_model_check:
170 self.model = model
171 else:
172 self.model = models.infer_model(model)
174 self.end_strategy = end_strategy
175 self.name = name
176 self.model_settings = model_settings
177 self._result_tool_name = result_tool_name
178 self._result_tool_description = result_tool_description
179 self._result_schema = _result.ResultSchema[result_type].build(
180 result_type, result_tool_name, result_tool_description
181 )
183 self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
184 self._function_tools = {}
185 self._default_retries = retries
186 for tool in tools:
187 if isinstance(tool, Tool):
188 self._register_tool(tool)
189 else:
190 self._register_tool(Tool(tool))
191 self._deps_type = deps_type
192 self._system_prompt_functions = []
193 self._system_prompt_dynamic_functions = {}
194 self._max_result_retries = result_retries if result_retries is not None else retries
195 self._result_validators = []
197 @overload
198 async def run(
199 self,
200 user_prompt: str,
201 *,
202 result_type: None = None,
203 message_history: list[_messages.ModelMessage] | None = None,
204 model: models.Model | models.KnownModelName | None = None,
205 deps: AgentDepsT = None,
206 model_settings: ModelSettings | None = None,
207 usage_limits: _usage.UsageLimits | None = None,
208 usage: _usage.Usage | None = None,
209 infer_name: bool = True,
210 ) -> result.RunResult[ResultDataT]: ...
212 @overload
213 async def run(
214 self,
215 user_prompt: str,
216 *,
217 result_type: type[RunResultData],
218 message_history: list[_messages.ModelMessage] | None = None,
219 model: models.Model | models.KnownModelName | None = None,
220 deps: AgentDepsT = None,
221 model_settings: ModelSettings | None = None,
222 usage_limits: _usage.UsageLimits | None = None,
223 usage: _usage.Usage | None = None,
224 infer_name: bool = True,
225 ) -> result.RunResult[RunResultData]: ...
227 async def run(
228 self,
229 user_prompt: str,
230 *,
231 message_history: list[_messages.ModelMessage] | None = None,
232 model: models.Model | models.KnownModelName | None = None,
233 deps: AgentDepsT = None,
234 model_settings: ModelSettings | None = None,
235 usage_limits: _usage.UsageLimits | None = None,
236 usage: _usage.Usage | None = None,
237 result_type: type[RunResultData] | None = None,
238 infer_name: bool = True,
239 ) -> result.RunResult[Any]:
240 """Run the agent with a user prompt in async mode.
242 Example:
243 ```python
244 from pydantic_ai import Agent
246 agent = Agent('openai:gpt-4o')
248 async def main():
249 result = await agent.run('What is the capital of France?')
250 print(result.data)
251 #> Paris
252 ```
254 Args:
255 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
256 result validators since result validators would expect an argument that matches the agent's result type.
257 user_prompt: User input to start/continue the conversation.
258 message_history: History of the conversation so far.
259 model: Optional model to use for this run, required if `model` was not set when creating the agent.
260 deps: Optional dependencies to use for this run.
261 model_settings: Optional settings to use for this model's request.
262 usage_limits: Optional limits on model request count or token usage.
263 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
264 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
266 Returns:
267 The result of the run.
268 """
269 if infer_name and self.name is None:
270 self._infer_name(inspect.currentframe())
271 model_used = await self._get_model(model)
273 deps = self._get_deps(deps)
274 new_message_index = len(message_history) if message_history else 0
275 result_schema = self._prepare_result_schema(result_type)
277 with _logfire.span(
278 '{agent_name} run {prompt=}',
279 prompt=user_prompt,
280 agent=self,
281 model_name=model_used.name(),
282 agent_name=self.name or 'agent',
283 ) as run_span:
284 run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
285 messages = await self._prepare_messages(user_prompt, message_history, run_context)
286 run_context.messages = messages
288 for tool in self._function_tools.values():
289 tool.current_retry = 0
291 model_settings = merge_model_settings(self.model_settings, model_settings)
292 usage_limits = usage_limits or _usage.UsageLimits()
294 while True:
295 usage_limits.check_before_request(run_context.usage)
297 run_context.run_step += 1
298 with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
299 agent_model = await self._prepare_model(run_context, result_schema)
301 with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
302 model_response, request_usage = await agent_model.request(messages, model_settings)
303 model_req_span.set_attribute('response', model_response)
304 model_req_span.set_attribute('usage', request_usage)
306 messages.append(model_response)
307 run_context.usage.incr(request_usage, requests=1)
308 usage_limits.check_tokens(run_context.usage)
310 with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
311 final_result, tool_responses = await self._handle_model_response(
312 model_response, run_context, result_schema
313 )
315 if tool_responses:
316 # Add parts to the conversation as a new message
317 messages.append(_messages.ModelRequest(tool_responses))
319 # Check if we got a final result
320 if final_result is not None:
321 result_data = final_result.data
322 result_tool_name = final_result.tool_name
323 run_span.set_attribute('all_messages', messages)
324 run_span.set_attribute('usage', run_context.usage)
325 handle_span.set_attribute('result', result_data)
326 handle_span.message = 'handle model response -> final result'
327 return result.RunResult(
328 messages, new_message_index, result_data, result_tool_name, run_context.usage
329 )
330 else:
331 # continue the conversation
332 handle_span.set_attribute('tool_responses', tool_responses)
333 tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
334 handle_span.message = f'handle model response -> {tool_responses_str}'
336 @overload
337 def run_sync(
338 self,
339 user_prompt: str,
340 *,
341 message_history: list[_messages.ModelMessage] | None = None,
342 model: models.Model | models.KnownModelName | None = None,
343 deps: AgentDepsT = None,
344 model_settings: ModelSettings | None = None,
345 usage_limits: _usage.UsageLimits | None = None,
346 usage: _usage.Usage | None = None,
347 infer_name: bool = True,
348 ) -> result.RunResult[ResultDataT]: ...
350 @overload
351 def run_sync(
352 self,
353 user_prompt: str,
354 *,
355 result_type: type[RunResultData] | None,
356 message_history: list[_messages.ModelMessage] | None = None,
357 model: models.Model | models.KnownModelName | None = None,
358 deps: AgentDepsT = None,
359 model_settings: ModelSettings | None = None,
360 usage_limits: _usage.UsageLimits | None = None,
361 usage: _usage.Usage | None = None,
362 infer_name: bool = True,
363 ) -> result.RunResult[RunResultData]: ...
365 def run_sync(
366 self,
367 user_prompt: str,
368 *,
369 result_type: type[RunResultData] | None = None,
370 message_history: list[_messages.ModelMessage] | None = None,
371 model: models.Model | models.KnownModelName | None = None,
372 deps: AgentDepsT = None,
373 model_settings: ModelSettings | None = None,
374 usage_limits: _usage.UsageLimits | None = None,
375 usage: _usage.Usage | None = None,
376 infer_name: bool = True,
377 ) -> result.RunResult[Any]:
378 """Run the agent with a user prompt synchronously.
380 This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
381 You therefore can't use this method inside async code or if there's an active event loop.
383 Example:
384 ```python
385 from pydantic_ai import Agent
387 agent = Agent('openai:gpt-4o')
389 result_sync = agent.run_sync('What is the capital of Italy?')
390 print(result_sync.data)
391 #> Rome
392 ```
394 Args:
395 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
396 result validators since result validators would expect an argument that matches the agent's result type.
397 user_prompt: User input to start/continue the conversation.
398 message_history: History of the conversation so far.
399 model: Optional model to use for this run, required if `model` was not set when creating the agent.
400 deps: Optional dependencies to use for this run.
401 model_settings: Optional settings to use for this model's request.
402 usage_limits: Optional limits on model request count or token usage.
403 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
404 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
406 Returns:
407 The result of the run.
408 """
409 if infer_name and self.name is None:
410 self._infer_name(inspect.currentframe())
411 return asyncio.get_event_loop().run_until_complete(
412 self.run(
413 user_prompt,
414 result_type=result_type,
415 message_history=message_history,
416 model=model,
417 deps=deps,
418 model_settings=model_settings,
419 usage_limits=usage_limits,
420 usage=usage,
421 infer_name=False,
422 )
423 )
425 @overload
426 def run_stream(
427 self,
428 user_prompt: str,
429 *,
430 result_type: None = None,
431 message_history: list[_messages.ModelMessage] | None = None,
432 model: models.Model | models.KnownModelName | None = None,
433 deps: AgentDepsT = None,
434 model_settings: ModelSettings | None = None,
435 usage_limits: _usage.UsageLimits | None = None,
436 usage: _usage.Usage | None = None,
437 infer_name: bool = True,
438 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, ResultDataT]]: ...
440 @overload
441 def run_stream(
442 self,
443 user_prompt: str,
444 *,
445 result_type: type[RunResultData],
446 message_history: list[_messages.ModelMessage] | None = None,
447 model: models.Model | models.KnownModelName | None = None,
448 deps: AgentDepsT = None,
449 model_settings: ModelSettings | None = None,
450 usage_limits: _usage.UsageLimits | None = None,
451 usage: _usage.Usage | None = None,
452 infer_name: bool = True,
453 ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultData]]: ...
455 @asynccontextmanager
456 async def run_stream(
457 self,
458 user_prompt: str,
459 *,
460 result_type: type[RunResultData] | None = None,
461 message_history: list[_messages.ModelMessage] | None = None,
462 model: models.Model | models.KnownModelName | None = None,
463 deps: AgentDepsT = None,
464 model_settings: ModelSettings | None = None,
465 usage_limits: _usage.UsageLimits | None = None,
466 usage: _usage.Usage | None = None,
467 infer_name: bool = True,
468 ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
469 """Run the agent with a user prompt in async mode, returning a streamed response.
471 Example:
472 ```python
473 from pydantic_ai import Agent
475 agent = Agent('openai:gpt-4o')
477 async def main():
478 async with agent.run_stream('What is the capital of the UK?') as response:
479 print(await response.get_data())
480 #> London
481 ```
483 Args:
484 result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
485 result validators since result validators would expect an argument that matches the agent's result type.
486 user_prompt: User input to start/continue the conversation.
487 message_history: History of the conversation so far.
488 model: Optional model to use for this run, required if `model` was not set when creating the agent.
489 deps: Optional dependencies to use for this run.
490 model_settings: Optional settings to use for this model's request.
491 usage_limits: Optional limits on model request count or token usage.
492 usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
493 infer_name: Whether to try to infer the agent name from the call frame if it's not set.
495 Returns:
496 The result of the run.
497 """
498 if infer_name and self.name is None:
499 # f_back because `asynccontextmanager` adds one frame
500 if frame := inspect.currentframe(): # pragma: no branch
501 self._infer_name(frame.f_back)
502 model_used = await self._get_model(model)
504 deps = self._get_deps(deps)
505 new_message_index = len(message_history) if message_history else 0
506 result_schema = self._prepare_result_schema(result_type)
508 with _logfire.span(
509 '{agent_name} run stream {prompt=}',
510 prompt=user_prompt,
511 agent=self,
512 model_name=model_used.name(),
513 agent_name=self.name or 'agent',
514 ) as run_span:
515 run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
516 messages = await self._prepare_messages(user_prompt, message_history, run_context)
517 run_context.messages = messages
519 for tool in self._function_tools.values():
520 tool.current_retry = 0
522 model_settings = merge_model_settings(self.model_settings, model_settings)
523 usage_limits = usage_limits or _usage.UsageLimits()
525 while True:
526 run_context.run_step += 1
527 usage_limits.check_before_request(run_context.usage)
529 with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
530 agent_model = await self._prepare_model(run_context, result_schema)
532 with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
533 async with agent_model.request_stream(messages, model_settings) as model_response:
534 run_context.usage.requests += 1
535 model_req_span.set_attribute('response_type', model_response.__class__.__name__)
536 # We want to end the "model request" span here, but we can't exit the context manager
537 # in the traditional way
538 model_req_span.__exit__(None, None, None)
540 with _logfire.span('handle model response') as handle_span:
541 maybe_final_result = await self._handle_streamed_response(
542 model_response, run_context, result_schema
543 )
545 # Check if we got a final result
546 if isinstance(maybe_final_result, _MarkFinalResult):
547 result_stream = maybe_final_result.data
548 result_tool_name = maybe_final_result.tool_name
549 handle_span.message = 'handle model response -> final result'
551 async def on_complete():
552 """Called when the stream has completed.
554 The model response will have been added to messages by now
555 by `StreamedRunResult._marked_completed`.
556 """
557 last_message = messages[-1]
558 assert isinstance(last_message, _messages.ModelResponse)
559 tool_calls = [
560 part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
561 ]
562 parts = await self._process_function_tools(
563 tool_calls, result_tool_name, run_context, result_schema
564 )
565 if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
566 self._incr_result_retry(run_context)
567 if parts:
568 messages.append(_messages.ModelRequest(parts))
569 run_span.set_attribute('all_messages', messages)
571 # The following is not guaranteed to be true, but we consider it a user error if
572 # there are result validators that might convert the result data from an overridden
573 # `result_type` to a type that is not valid as such.
574 result_validators = cast(
575 list[_result.ResultValidator[AgentDepsT, RunResultData]], self._result_validators
576 )
578 yield result.StreamedRunResult(
579 messages,
580 new_message_index,
581 usage_limits,
582 result_stream,
583 result_schema,
584 run_context,
585 result_validators,
586 result_tool_name,
587 on_complete,
588 )
589 return
590 else:
591 # continue the conversation
592 model_response_msg, tool_responses = maybe_final_result
593 # if we got a model response add that to messages
594 messages.append(model_response_msg)
595 if tool_responses: 595 ↛ 599line 595 didn't jump to line 599 because the condition on line 595 was always true
596 # if we got one or more tool response parts, add a model request message
597 messages.append(_messages.ModelRequest(tool_responses))
599 handle_span.set_attribute('tool_responses', tool_responses)
600 tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
601 handle_span.message = f'handle model response -> {tool_responses_str}'
602 # the model_response should have been fully streamed by now, we can add its usage
603 model_response_usage = model_response.usage()
604 run_context.usage.incr(model_response_usage)
605 usage_limits.check_tokens(run_context.usage)
607 @contextmanager
608 def override(
609 self,
610 *,
611 deps: AgentDepsT | _utils.Unset = _utils.UNSET,
612 model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
613 ) -> Iterator[None]:
614 """Context manager to temporarily override agent dependencies and model.
616 This is particularly useful when testing.
617 You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
619 Args:
620 deps: The dependencies to use instead of the dependencies passed to the agent run.
621 model: The model to use instead of the model passed to the agent run.
622 """
623 if _utils.is_set(deps):
624 override_deps_before = self._override_deps
625 self._override_deps = _utils.Some(deps)
626 else:
627 override_deps_before = _utils.UNSET
629 # noinspection PyTypeChecker
630 if _utils.is_set(model):
631 override_model_before = self._override_model
632 # noinspection PyTypeChecker
633 self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
634 else:
635 override_model_before = _utils.UNSET
637 try:
638 yield
639 finally:
640 if _utils.is_set(override_deps_before):
641 self._override_deps = override_deps_before
642 if _utils.is_set(override_model_before):
643 self._override_model = override_model_before
645 @overload
646 def system_prompt(
647 self, func: Callable[[RunContext[AgentDepsT]], str], /
648 ) -> Callable[[RunContext[AgentDepsT]], str]: ...
650 @overload
651 def system_prompt(
652 self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
653 ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
655 @overload
656 def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
658 @overload
659 def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
661 @overload
662 def system_prompt(
663 self, /, *, dynamic: bool = False
664 ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
666 def system_prompt(
667 self,
668 func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
669 /,
670 *,
671 dynamic: bool = False,
672 ) -> (
673 Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
674 | _system_prompt.SystemPromptFunc[AgentDepsT]
675 ):
676 """Decorator to register a system prompt function.
678 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
679 Can decorate a sync or async functions.
681 The decorator can be used either bare (`agent.system_prompt`) or as a function call
682 (`agent.system_prompt(...)`), see the examples below.
684 Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
685 the type of the function, see `tests/typed_agent.py` for tests.
687 Args:
688 func: The function to decorate
689 dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
690 see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
692 Example:
693 ```python
694 from pydantic_ai import Agent, RunContext
696 agent = Agent('test', deps_type=str)
698 @agent.system_prompt
699 def simple_system_prompt() -> str:
700 return 'foobar'
702 @agent.system_prompt(dynamic=True)
703 async def async_system_prompt(ctx: RunContext[str]) -> str:
704 return f'{ctx.deps} is the best'
705 ```
706 """
707 if func is None:
709 def decorator(
710 func_: _system_prompt.SystemPromptFunc[AgentDepsT],
711 ) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
712 runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
713 self._system_prompt_functions.append(runner)
714 if dynamic: 714 ↛ 716line 714 didn't jump to line 716 because the condition on line 714 was always true
715 self._system_prompt_dynamic_functions[func_.__qualname__] = runner
716 return func_
718 return decorator
719 else:
720 assert not dynamic, "dynamic can't be True in this case"
721 self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
722 return func
724 @overload
725 def result_validator(
726 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT], /
727 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], ResultDataT]: ...
729 @overload
730 def result_validator(
731 self, func: Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]], /
732 ) -> Callable[[RunContext[AgentDepsT], ResultDataT], Awaitable[ResultDataT]]: ...
734 @overload
735 def result_validator(
736 self, func: Callable[[ResultDataT], ResultDataT], /
737 ) -> Callable[[ResultDataT], ResultDataT]: ...
739 @overload
740 def result_validator(
741 self, func: Callable[[ResultDataT], Awaitable[ResultDataT]], /
742 ) -> Callable[[ResultDataT], Awaitable[ResultDataT]]: ...
744 def result_validator(
745 self, func: _result.ResultValidatorFunc[AgentDepsT, ResultDataT], /
746 ) -> _result.ResultValidatorFunc[AgentDepsT, ResultDataT]:
747 """Decorator to register a result validator function.
749 Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
750 Can decorate a sync or async functions.
752 Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
753 the type of the function, see `tests/typed_agent.py` for tests.
755 Example:
756 ```python
757 from pydantic_ai import Agent, ModelRetry, RunContext
759 agent = Agent('test', deps_type=str)
761 @agent.result_validator
762 def result_validator_simple(data: str) -> str:
763 if 'wrong' in data:
764 raise ModelRetry('wrong response')
765 return data
767 @agent.result_validator
768 async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
769 if ctx.deps in data:
770 raise ModelRetry('wrong response')
771 return data
773 result = agent.run_sync('foobar', deps='spam')
774 print(result.data)
775 #> success (no tool calls)
776 ```
777 """
778 self._result_validators.append(_result.ResultValidator[AgentDepsT, Any](func))
779 return func
781 @overload
782 def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
784 @overload
785 def tool(
786 self,
787 /,
788 *,
789 retries: int | None = None,
790 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
791 docstring_format: DocstringFormat = 'auto',
792 require_parameter_descriptions: bool = False,
793 ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
795 def tool(
796 self,
797 func: ToolFuncContext[AgentDepsT, ToolParams] | None = None,
798 /,
799 *,
800 retries: int | None = None,
801 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
802 docstring_format: DocstringFormat = 'auto',
803 require_parameter_descriptions: bool = False,
804 ) -> Any:
805 """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
807 Can decorate a sync or async functions.
809 The docstring is inspected to extract both the tool description and description of each parameter,
810 [learn more](../tools.md#function-tools-and-schema).
812 We can't add overloads for every possible signature of tool, since the return type is a recursive union
813 so the signature of functions decorated with `@agent.tool` is obscured.
815 Example:
816 ```python
817 from pydantic_ai import Agent, RunContext
819 agent = Agent('test', deps_type=int)
821 @agent.tool
822 def foobar(ctx: RunContext[int], x: int) -> int:
823 return ctx.deps + x
825 @agent.tool(retries=2)
826 async def spam(ctx: RunContext[str], y: float) -> float:
827 return ctx.deps + y
829 result = agent.run_sync('foobar', deps=1)
830 print(result.data)
831 #> {"foobar":1,"spam":1.0}
832 ```
834 Args:
835 func: The tool function to register.
836 retries: The number of retries to allow for this tool, defaults to the agent's default retries,
837 which defaults to 1.
838 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
839 tool from a given step. This is useful if you want to customise a tool at call time,
840 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
841 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
842 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
843 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
844 """
845 if func is None:
847 def tool_decorator(
848 func_: ToolFuncContext[AgentDepsT, ToolParams],
849 ) -> ToolFuncContext[AgentDepsT, ToolParams]:
850 # noinspection PyTypeChecker
851 self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
852 return func_
854 return tool_decorator
855 else:
856 # noinspection PyTypeChecker
857 self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
858 return func
860 @overload
861 def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
863 @overload
864 def tool_plain(
865 self,
866 /,
867 *,
868 retries: int | None = None,
869 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
870 docstring_format: DocstringFormat = 'auto',
871 require_parameter_descriptions: bool = False,
872 ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
874 def tool_plain(
875 self,
876 func: ToolFuncPlain[ToolParams] | None = None,
877 /,
878 *,
879 retries: int | None = None,
880 prepare: ToolPrepareFunc[AgentDepsT] | None = None,
881 docstring_format: DocstringFormat = 'auto',
882 require_parameter_descriptions: bool = False,
883 ) -> Any:
884 """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
886 Can decorate a sync or async functions.
888 The docstring is inspected to extract both the tool description and description of each parameter,
889 [learn more](../tools.md#function-tools-and-schema).
891 We can't add overloads for every possible signature of tool, since the return type is a recursive union
892 so the signature of functions decorated with `@agent.tool` is obscured.
894 Example:
895 ```python
896 from pydantic_ai import Agent, RunContext
898 agent = Agent('test')
900 @agent.tool
901 def foobar(ctx: RunContext[int]) -> int:
902 return 123
904 @agent.tool(retries=2)
905 async def spam(ctx: RunContext[str]) -> float:
906 return 3.14
908 result = agent.run_sync('foobar', deps=1)
909 print(result.data)
910 #> {"foobar":123,"spam":3.14}
911 ```
913 Args:
914 func: The tool function to register.
915 retries: The number of retries to allow for this tool, defaults to the agent's default retries,
916 which defaults to 1.
917 prepare: custom method to prepare the tool definition for each step, return `None` to omit this
918 tool from a given step. This is useful if you want to customise a tool at call time,
919 or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
920 docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
921 Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
922 require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
923 """
924 if func is None:
926 def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
927 # noinspection PyTypeChecker
928 self._register_function(
929 func_, False, retries, prepare, docstring_format, require_parameter_descriptions
930 )
931 return func_
933 return tool_decorator
934 else:
935 self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
936 return func
938 def _register_function(
939 self,
940 func: ToolFuncEither[AgentDepsT, ToolParams],
941 takes_ctx: bool,
942 retries: int | None,
943 prepare: ToolPrepareFunc[AgentDepsT] | None,
944 docstring_format: DocstringFormat,
945 require_parameter_descriptions: bool,
946 ) -> None:
947 """Private utility to register a function as a tool."""
948 retries_ = retries if retries is not None else self._default_retries
949 tool = Tool[AgentDepsT](
950 func,
951 takes_ctx=takes_ctx,
952 max_retries=retries_,
953 prepare=prepare,
954 docstring_format=docstring_format,
955 require_parameter_descriptions=require_parameter_descriptions,
956 )
957 self._register_tool(tool)
959 def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
960 """Private utility to register a tool instance."""
961 if tool.max_retries is None:
962 # noinspection PyTypeChecker
963 tool = dataclasses.replace(tool, max_retries=self._default_retries)
965 if tool.name in self._function_tools:
966 raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
968 if self._result_schema and tool.name in self._result_schema.tools:
969 raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
971 self._function_tools[tool.name] = tool
973 async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
974 """Create a model configured for this agent.
976 Args:
977 model: model to use for this run, required if `model` was not set when creating the agent.
979 Returns:
980 The model used
981 """
982 model_: models.Model
983 if some_model := self._override_model:
984 # we don't want `override()` to cover up errors from the model not being defined, hence this check
985 if model is None and self.model is None:
986 raise exceptions.UserError(
987 '`model` must be set either when creating the agent or when calling it. '
988 '(Even when `override(model=...)` is customizing the model that will actually be called)'
989 )
990 model_ = some_model.value
991 elif model is not None:
992 model_ = models.infer_model(model)
993 elif self.model is not None:
994 # noinspection PyTypeChecker
995 model_ = self.model = models.infer_model(self.model)
996 else:
997 raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
999 return model_
1001 async def _prepare_model(
1002 self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
1003 ) -> models.AgentModel:
1004 """Build tools and create an agent model."""
1005 function_tools: list[ToolDefinition] = []
1007 async def add_tool(tool: Tool[AgentDepsT]) -> None:
1008 ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
1009 if tool_def := await tool.prepare_tool_def(ctx):
1010 function_tools.append(tool_def)
1012 await asyncio.gather(*map(add_tool, self._function_tools.values()))
1014 return await run_context.model.agent_model(
1015 function_tools=function_tools,
1016 allow_text_result=self._allow_text_result(result_schema),
1017 result_tools=result_schema.tool_defs() if result_schema is not None else [],
1018 )
1020 async def _reevaluate_dynamic_prompts(
1021 self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDepsT]
1022 ) -> None:
1023 """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
1024 # Only proceed if there's at least one dynamic runner.
1025 if self._system_prompt_dynamic_functions:
1026 for msg in messages:
1027 if isinstance(msg, _messages.ModelRequest):
1028 for i, part in enumerate(msg.parts):
1029 if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
1030 # Look up the runner by its ref
1031 if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref): 1031 ↛ 1028line 1031 didn't jump to line 1028 because the condition on line 1031 was always true
1032 updated_part_content = await runner.run(run_context)
1033 msg.parts[i] = _messages.SystemPromptPart(
1034 updated_part_content, dynamic_ref=part.dynamic_ref
1035 )
1037 def _prepare_result_schema(
1038 self, result_type: type[RunResultData] | None
1039 ) -> _result.ResultSchema[RunResultData] | None:
1040 if result_type is not None:
1041 if self._result_validators:
1042 raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
1043 return _result.ResultSchema[result_type].build(
1044 result_type, self._result_tool_name, self._result_tool_description
1045 )
1046 else:
1047 return self._result_schema # pyright: ignore[reportReturnType]
1049 async def _prepare_messages(
1050 self,
1051 user_prompt: str,
1052 message_history: list[_messages.ModelMessage] | None,
1053 run_context: RunContext[AgentDepsT],
1054 ) -> list[_messages.ModelMessage]:
1055 try:
1056 ctx_messages = _messages_ctx_var.get()
1057 except LookupError:
1058 messages: list[_messages.ModelMessage] = []
1059 else:
1060 if ctx_messages.used:
1061 messages = []
1062 else:
1063 messages = ctx_messages.messages
1064 ctx_messages.used = True
1066 if message_history:
1067 # Shallow copy messages
1068 messages.extend(message_history)
1069 # Reevaluate any dynamic system prompt parts
1070 await self._reevaluate_dynamic_prompts(messages, run_context)
1071 messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
1072 else:
1073 parts = await self._sys_parts(run_context)
1074 parts.append(_messages.UserPromptPart(user_prompt))
1075 messages.append(_messages.ModelRequest(parts))
1077 return messages
1079 async def _handle_model_response(
1080 self,
1081 model_response: _messages.ModelResponse,
1082 run_context: RunContext[AgentDepsT],
1083 result_schema: _result.ResultSchema[RunResultData] | None,
1084 ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1085 """Process a non-streamed response from the model.
1087 Returns:
1088 A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
1089 """
1090 texts: list[str] = []
1091 tool_calls: list[_messages.ToolCallPart] = []
1092 for part in model_response.parts:
1093 if isinstance(part, _messages.TextPart):
1094 # ignore empty content for text parts, see #437
1095 if part.content:
1096 texts.append(part.content)
1097 else:
1098 tool_calls.append(part)
1100 # At the moment, we prioritize at least executing tool calls if they are present.
1101 # In the future, we'd consider making this configurable at the agent or run level.
1102 # This accounts for cases like anthropic returns that might contain a text response
1103 # and a tool call response, where the text response just indicates the tool call will happen.
1104 if tool_calls:
1105 return await self._handle_structured_response(tool_calls, run_context, result_schema)
1106 elif texts:
1107 text = '\n\n'.join(texts)
1108 return await self._handle_text_response(text, run_context, result_schema)
1109 else:
1110 raise exceptions.UnexpectedModelBehavior('Received empty model response')
1112 async def _handle_text_response(
1113 self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
1114 ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1115 """Handle a plain text response from the model for non-streaming responses."""
1116 if self._allow_text_result(result_schema):
1117 result_data_input = cast(RunResultData, text)
1118 try:
1119 result_data = await self._validate_result(result_data_input, run_context, None)
1120 except _result.ToolRetryError as e:
1121 self._incr_result_retry(run_context)
1122 return None, [e.tool_retry]
1123 else:
1124 return _MarkFinalResult(result_data, None), []
1125 else:
1126 self._incr_result_retry(run_context)
1127 response = _messages.RetryPromptPart(
1128 content='Plain text responses are not permitted, please call one of the functions instead.',
1129 )
1130 return None, [response]
1132 async def _handle_structured_response(
1133 self,
1134 tool_calls: list[_messages.ToolCallPart],
1135 run_context: RunContext[AgentDepsT],
1136 result_schema: _result.ResultSchema[RunResultData] | None,
1137 ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1138 """Handle a structured response containing tool calls from the model for non-streaming responses."""
1139 assert tool_calls, 'Expected at least one tool call'
1141 # first look for the result tool call
1142 final_result: _MarkFinalResult[RunResultData] | None = None
1144 parts: list[_messages.ModelRequestPart] = []
1145 if result_schema is not None:
1146 if match := result_schema.find_tool(tool_calls):
1147 call, result_tool = match
1148 try:
1149 result_data = result_tool.validate(call)
1150 result_data = await self._validate_result(result_data, run_context, call)
1151 except _result.ToolRetryError as e:
1152 parts.append(e.tool_retry)
1153 else:
1154 final_result = _MarkFinalResult(result_data, call.tool_name)
1156 # Then build the other request parts based on end strategy
1157 parts += await self._process_function_tools(
1158 tool_calls, final_result and final_result.tool_name, run_context, result_schema
1159 )
1161 if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1162 self._incr_result_retry(run_context)
1164 return final_result, parts
1166 async def _process_function_tools(
1167 self,
1168 tool_calls: list[_messages.ToolCallPart],
1169 result_tool_name: str | None,
1170 run_context: RunContext[AgentDepsT],
1171 result_schema: _result.ResultSchema[RunResultData] | None,
1172 ) -> list[_messages.ModelRequestPart]:
1173 """Process function (non-result) tool calls in parallel.
1175 Also add stub return parts for any other tools that need it.
1176 """
1177 parts: list[_messages.ModelRequestPart] = []
1178 tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1180 stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
1182 # we rely on the fact that if we found a result, it's the first result tool in the last
1183 found_used_result_tool = False
1184 for call in tool_calls:
1185 if call.tool_name == result_tool_name and not found_used_result_tool:
1186 found_used_result_tool = True
1187 parts.append(
1188 _messages.ToolReturnPart(
1189 tool_name=call.tool_name,
1190 content='Final result processed.',
1191 tool_call_id=call.tool_call_id,
1192 )
1193 )
1194 elif tool := self._function_tools.get(call.tool_name):
1195 if stub_function_tools:
1196 parts.append(
1197 _messages.ToolReturnPart(
1198 tool_name=call.tool_name,
1199 content='Tool not executed - a final result was already processed.',
1200 tool_call_id=call.tool_call_id,
1201 )
1202 )
1203 else:
1204 tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1205 elif result_schema is not None and call.tool_name in result_schema.tools:
1206 # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
1207 # validation, we don't add another part here
1208 if result_tool_name is not None:
1209 parts.append(
1210 _messages.ToolReturnPart(
1211 tool_name=call.tool_name,
1212 content='Result tool not used - a final result was already processed.',
1213 tool_call_id=call.tool_call_id,
1214 )
1215 )
1216 else:
1217 parts.append(self._unknown_tool(call.tool_name, result_schema))
1219 # Run all tool tasks in parallel
1220 if tasks:
1221 with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1222 task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1223 parts.extend(task_results)
1224 return parts
1226 async def _handle_streamed_response(
1227 self,
1228 streamed_response: models.StreamedResponse,
1229 run_context: RunContext[AgentDepsT],
1230 result_schema: _result.ResultSchema[RunResultData] | None,
1231 ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
1232 """Process a streamed response from the model.
1234 Returns:
1235 Either a final result or a tuple of the model response and the tool responses for the next request.
1236 If a final result is returned, the conversation should end.
1237 """
1238 received_text = False
1240 async for maybe_part_event in streamed_response:
1241 if isinstance(maybe_part_event, _messages.PartStartEvent):
1242 new_part = maybe_part_event.part
1243 if isinstance(new_part, _messages.TextPart):
1244 received_text = True
1245 if self._allow_text_result(result_schema):
1246 return _MarkFinalResult(streamed_response, None)
1247 elif isinstance(new_part, _messages.ToolCallPart):
1248 if result_schema is not None and (match := result_schema.find_tool([new_part])):
1249 call, _ = match
1250 return _MarkFinalResult(streamed_response, call.tool_name)
1251 else:
1252 assert_never(new_part)
1254 tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1255 parts: list[_messages.ModelRequestPart] = []
1256 model_response = streamed_response.get()
1257 if not model_response.parts:
1258 raise exceptions.UnexpectedModelBehavior('Received empty model response')
1259 for p in model_response.parts:
1260 if isinstance(p, _messages.ToolCallPart):
1261 if tool := self._function_tools.get(p.tool_name):
1262 tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
1263 else:
1264 parts.append(self._unknown_tool(p.tool_name, result_schema))
1266 if received_text and not tasks and not parts:
1267 # Can only get here if self._allow_text_result returns `False` for the provided result_schema
1268 self._incr_result_retry(run_context)
1269 model_response = _messages.RetryPromptPart(
1270 content='Plain text responses are not permitted, please call one of the functions instead.',
1271 )
1272 return streamed_response.get(), [model_response]
1274 with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1275 task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1276 parts.extend(task_results)
1278 if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1279 self._incr_result_retry(run_context)
1281 return model_response, parts
1283 async def _validate_result(
1284 self,
1285 result_data: RunResultData,
1286 run_context: RunContext[AgentDepsT],
1287 tool_call: _messages.ToolCallPart | None,
1288 ) -> RunResultData:
1289 if self._result_validators:
1290 agent_result_data = cast(ResultDataT, result_data)
1291 for validator in self._result_validators:
1292 agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1293 return cast(RunResultData, agent_result_data)
1294 else:
1295 return result_data
1297 def _incr_result_retry(self, run_context: RunContext[AgentDepsT]) -> None:
1298 run_context.retry += 1
1299 if run_context.retry > self._max_result_retries:
1300 raise exceptions.UnexpectedModelBehavior(
1301 f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
1302 )
1304 async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_messages.ModelRequestPart]:
1305 """Build the initial messages for the conversation."""
1306 messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1307 for sys_prompt_runner in self._system_prompt_functions:
1308 prompt = await sys_prompt_runner.run(run_context)
1309 if sys_prompt_runner.dynamic:
1310 messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
1311 else:
1312 messages.append(_messages.SystemPromptPart(prompt))
1313 return messages
1315 def _unknown_tool(
1316 self,
1317 tool_name: str,
1318 result_schema: _result.ResultSchema[RunResultData] | None,
1319 ) -> _messages.RetryPromptPart:
1320 names = list(self._function_tools.keys())
1321 if result_schema:
1322 names.extend(result_schema.tool_names())
1323 if names:
1324 msg = f'Available tools: {", ".join(names)}'
1325 else:
1326 msg = 'No tools available.'
1327 return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
1329 def _get_deps(self: Agent[T, Any], deps: T) -> T:
1330 """Get deps for a run.
1332 If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
1334 We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
1335 """
1336 if some_deps := self._override_deps:
1337 return some_deps.value
1338 else:
1339 return deps
1341 def _infer_name(self, function_frame: FrameType | None) -> None:
1342 """Infer the agent name from the call frame.
1344 Usage should be `self._infer_name(inspect.currentframe())`.
1345 """
1346 assert self.name is None, 'Name already set'
1347 if function_frame is not None: # pragma: no branch
1348 if parent_frame := function_frame.f_back: # pragma: no branch
1349 for name, item in parent_frame.f_locals.items():
1350 if item is self:
1351 self.name = name
1352 return
1353 if parent_frame.f_locals != parent_frame.f_globals: 1353 ↛ exitline 1353 didn't return from function '_infer_name' because the condition on line 1353 was always true
1354 # if we couldn't find the agent in locals and globals are a different dict, try globals
1355 for name, item in parent_frame.f_globals.items():
1356 if item is self:
1357 self.name = name
1358 return
1360 @staticmethod
1361 def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
1362 return result_schema is None or result_schema.allow_text_result
1364 @property
1365 @deprecated(
1366 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
1367 )
1368 def last_run_messages(self) -> list[_messages.ModelMessage]:
1369 raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1372@dataclasses.dataclass
1373class _RunMessages:
1374 messages: list[_messages.ModelMessage]
1375 used: bool = False
1378_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
1381@contextmanager
1382def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1383 """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
1385 Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
1387 Examples:
1388 ```python
1389 from pydantic_ai import Agent, capture_run_messages
1391 agent = Agent('test')
1393 with capture_run_messages() as messages:
1394 try:
1395 result = agent.run_sync('foobar')
1396 except Exception:
1397 print(messages)
1398 raise
1399 ```
1401 !!! note
1402 If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
1403 `messages` will represent the messages exchanged during the first call only.
1404 """
1405 try:
1406 yield _messages_ctx_var.get().messages
1407 except LookupError:
1408 messages: list[_messages.ModelMessage] = []
1409 token = _messages_ctx_var.set(_RunMessages(messages))
1410 try:
1411 yield messages
1412 finally:
1413 _messages_ctx_var.reset(token)
1416@dataclasses.dataclass
1417class _MarkFinalResult(Generic[ResultDataT]):
1418 """Marker class to indicate that the result is the final result.
1420 This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
1422 It also avoids problems in the case where the result type is itself `None`, but is set.
1423 """
1425 data: ResultDataT
1426 """The final result data."""
1427 tool_name: str | None
1428 """Name of the final result tool, None if the result is a string."""