Coverage for pydantic_ai_slim/pydantic_ai/_agent_graph.py: 98.20%
343 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import dataclasses
5import json
6from collections.abc import AsyncIterator, Iterator, Sequence
7from contextlib import asynccontextmanager, contextmanager
8from contextvars import ContextVar
9from dataclasses import field
10from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
12from opentelemetry.trace import Span, Tracer
13from typing_extensions import TypeGuard, TypeVar, assert_never
15from pydantic_graph import BaseNode, Graph, GraphRunContext
16from pydantic_graph.nodes import End, NodeRunEndT
18from . import (
19 _result,
20 _system_prompt,
21 exceptions,
22 messages as _messages,
23 models,
24 result,
25 usage as _usage,
26)
27from .models.instrumented import InstrumentedModel
28from .result import ResultDataT
29from .settings import ModelSettings, merge_model_settings
30from .tools import RunContext, Tool, ToolDefinition
32if TYPE_CHECKING:
33 from .mcp import MCPServer
35__all__ = (
36 'GraphAgentState',
37 'GraphAgentDeps',
38 'UserPromptNode',
39 'ModelRequestNode',
40 'CallToolsNode',
41 'build_run_context',
42 'capture_run_messages',
43)
46T = TypeVar('T')
47S = TypeVar('S')
48NoneType = type(None)
49EndStrategy = Literal['early', 'exhaustive']
50"""The strategy for handling multiple tool calls when a final result is found.
52- `'early'`: Stop processing other tool calls once a final result is found
53- `'exhaustive'`: Process all tool calls even after finding a final result
54"""
55DepsT = TypeVar('DepsT')
56ResultT = TypeVar('ResultT')
59@dataclasses.dataclass
60class GraphAgentState:
61 """State kept across the execution of the agent graph."""
63 message_history: list[_messages.ModelMessage]
64 usage: _usage.Usage
65 retries: int
66 run_step: int
68 def increment_retries(self, max_result_retries: int) -> None:
69 self.retries += 1
70 if self.retries > max_result_retries:
71 raise exceptions.UnexpectedModelBehavior(
72 f'Exceeded maximum retries ({max_result_retries}) for result validation'
73 )
76@dataclasses.dataclass
77class GraphAgentDeps(Generic[DepsT, ResultDataT]):
78 """Dependencies/config passed to the agent graph."""
80 user_deps: DepsT
82 prompt: str | Sequence[_messages.UserContent]
83 new_message_index: int
85 model: models.Model
86 model_settings: ModelSettings | None
87 usage_limits: _usage.UsageLimits
88 max_result_retries: int
89 end_strategy: EndStrategy
91 result_schema: _result.ResultSchema[ResultDataT] | None
92 result_tools: list[ToolDefinition]
93 result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
95 function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
96 mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
98 run_span: Span
99 tracer: Tracer
102class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
103 """The base class for all agent nodes.
105 Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere
106 """
109def is_agent_node(
110 node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]],
111) -> TypeGuard[AgentNode[T, S]]:
112 """Check if the provided node is an instance of `AgentNode`.
114 Usage:
116 if is_agent_node(node):
117 # `node` is an AgentNode
118 ...
120 This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`.
121 """
122 return isinstance(node, AgentNode)
125@dataclasses.dataclass
126class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
127 user_prompt: str | Sequence[_messages.UserContent]
129 system_prompts: tuple[str, ...]
130 system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
131 system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
133 async def run(
134 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
135 ) -> ModelRequestNode[DepsT, NodeRunEndT]:
136 return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
138 async def _get_first_message(
139 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
140 ) -> _messages.ModelRequest:
141 run_context = build_run_context(ctx)
142 history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
143 ctx.state.message_history = history
144 run_context.messages = history
146 # TODO: We need to make it so that function_tools are not shared between runs
147 # See comment on the current_retry field of `Tool` for more details.
148 for tool in ctx.deps.function_tools.values():
149 tool.current_retry = 0
150 return next_message
152 async def _prepare_messages(
153 self,
154 user_prompt: str | Sequence[_messages.UserContent],
155 message_history: list[_messages.ModelMessage] | None,
156 run_context: RunContext[DepsT],
157 ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
158 try:
159 ctx_messages = get_captured_run_messages()
160 except LookupError:
161 messages: list[_messages.ModelMessage] = []
162 else:
163 if ctx_messages.used:
164 messages = []
165 else:
166 messages = ctx_messages.messages
167 ctx_messages.used = True
169 if message_history:
170 # Shallow copy messages
171 messages.extend(message_history)
172 # Reevaluate any dynamic system prompt parts
173 await self._reevaluate_dynamic_prompts(messages, run_context)
174 return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
175 else:
176 parts = await self._sys_parts(run_context)
177 parts.append(_messages.UserPromptPart(user_prompt))
178 return messages, _messages.ModelRequest(parts)
180 async def _reevaluate_dynamic_prompts(
181 self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
182 ) -> None:
183 """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
184 # Only proceed if there's at least one dynamic runner.
185 if self.system_prompt_dynamic_functions:
186 for msg in messages:
187 if isinstance(msg, _messages.ModelRequest):
188 for i, part in enumerate(msg.parts):
189 if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
190 # Look up the runner by its ref
191 if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref): 191 ↛ 188line 191 didn't jump to line 188 because the condition on line 191 was always true
192 updated_part_content = await runner.run(run_context)
193 msg.parts[i] = _messages.SystemPromptPart(
194 updated_part_content, dynamic_ref=part.dynamic_ref
195 )
197 async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
198 """Build the initial messages for the conversation."""
199 messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts]
200 for sys_prompt_runner in self.system_prompt_functions:
201 prompt = await sys_prompt_runner.run(run_context)
202 if sys_prompt_runner.dynamic:
203 messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
204 else:
205 messages.append(_messages.SystemPromptPart(prompt))
206 return messages
209async def _prepare_request_parameters(
210 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
211) -> models.ModelRequestParameters:
212 """Build tools and create an agent model."""
213 function_tool_defs: list[ToolDefinition] = []
215 run_context = build_run_context(ctx)
217 async def add_tool(tool: Tool[DepsT]) -> None:
218 ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
219 if tool_def := await tool.prepare_tool_def(ctx):
220 function_tool_defs.append(tool_def)
222 async def add_mcp_server_tools(server: MCPServer) -> None:
223 if not server.is_running:
224 raise exceptions.UserError(f'MCP server is not running: {server}')
225 tool_defs = await server.list_tools()
226 # TODO(Marcelo): We should check if the tool names are unique. If not, we should raise an error.
227 function_tool_defs.extend(tool_defs)
229 await asyncio.gather(
230 *map(add_tool, ctx.deps.function_tools.values()),
231 *map(add_mcp_server_tools, ctx.deps.mcp_servers),
232 )
234 result_schema = ctx.deps.result_schema
235 return models.ModelRequestParameters(
236 function_tools=function_tool_defs,
237 allow_text_result=allow_text_result(result_schema),
238 result_tools=result_schema.tool_defs() if result_schema is not None else [],
239 )
242@dataclasses.dataclass
243class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
244 """Make a request to the model using the last message in state.message_history."""
246 request: _messages.ModelRequest
248 _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
249 _did_stream: bool = field(default=False, repr=False)
251 async def run(
252 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
253 ) -> CallToolsNode[DepsT, NodeRunEndT]:
254 if self._result is not None:
255 return self._result
257 if self._did_stream: 257 ↛ 260line 257 didn't jump to line 260 because the condition on line 257 was never true
258 # `self._result` gets set when exiting the `stream` contextmanager, so hitting this
259 # means that the stream was started but not finished before `run()` was called
260 raise exceptions.AgentRunError('You must finish streaming before calling run()')
262 return await self._make_request(ctx)
264 @asynccontextmanager
265 async def stream(
266 self,
267 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
268 ) -> AsyncIterator[result.AgentStream[DepsT, T]]:
269 async with self._stream(ctx) as streamed_response:
270 agent_stream = result.AgentStream[DepsT, T](
271 streamed_response,
272 ctx.deps.result_schema,
273 ctx.deps.result_validators,
274 build_run_context(ctx),
275 ctx.deps.usage_limits,
276 )
277 yield agent_stream
278 # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
279 # otherwise usage won't be properly counted:
280 async for _ in agent_stream:
281 pass
283 @asynccontextmanager
284 async def _stream(
285 self,
286 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
287 ) -> AsyncIterator[models.StreamedResponse]:
288 assert not self._did_stream, 'stream() should only be called once per node'
290 model_settings, model_request_parameters = await self._prepare_request(ctx)
291 async with ctx.deps.model.request_stream(
292 ctx.state.message_history, model_settings, model_request_parameters
293 ) as streamed_response:
294 self._did_stream = True
295 ctx.state.usage.incr(_usage.Usage(), requests=1)
296 yield streamed_response
297 # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
298 # otherwise usage won't be properly counted:
299 async for _ in streamed_response:
300 pass
301 model_response = streamed_response.get()
302 request_usage = streamed_response.usage()
304 self._finish_handling(ctx, model_response, request_usage)
305 assert self._result is not None # this should be set by the previous line
307 async def _make_request(
308 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
309 ) -> CallToolsNode[DepsT, NodeRunEndT]:
310 if self._result is not None: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 return self._result
313 model_settings, model_request_parameters = await self._prepare_request(ctx)
314 model_response, request_usage = await ctx.deps.model.request(
315 ctx.state.message_history, model_settings, model_request_parameters
316 )
317 ctx.state.usage.incr(_usage.Usage(), requests=1)
319 return self._finish_handling(ctx, model_response, request_usage)
321 async def _prepare_request(
322 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
323 ) -> tuple[ModelSettings | None, models.ModelRequestParameters]:
324 ctx.state.message_history.append(self.request)
326 # Check usage
327 if ctx.deps.usage_limits: 327 ↛ 331line 327 didn't jump to line 331 because the condition on line 327 was always true
328 ctx.deps.usage_limits.check_before_request(ctx.state.usage)
330 # Increment run_step
331 ctx.state.run_step += 1
333 model_settings = merge_model_settings(ctx.deps.model_settings, None)
334 model_request_parameters = await _prepare_request_parameters(ctx)
335 return model_settings, model_request_parameters
337 def _finish_handling(
338 self,
339 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
340 response: _messages.ModelResponse,
341 usage: _usage.Usage,
342 ) -> CallToolsNode[DepsT, NodeRunEndT]:
343 # Update usage
344 ctx.state.usage.incr(usage, requests=0)
345 if ctx.deps.usage_limits: 345 ↛ 349line 345 didn't jump to line 349 because the condition on line 345 was always true
346 ctx.deps.usage_limits.check_tokens(ctx.state.usage)
348 # Append the model response to state.message_history
349 ctx.state.message_history.append(response)
351 # Set the `_result` attribute since we can't use `return` in an async iterator
352 self._result = CallToolsNode(response)
354 return self._result
357@dataclasses.dataclass
358class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
359 """Process a model response, and decide whether to end the run or make a new request."""
361 model_response: _messages.ModelResponse
363 _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
364 _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
365 default=None, repr=False
366 )
367 _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
369 async def run(
370 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
371 ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
372 async with self.stream(ctx):
373 pass
375 assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends'
376 return next_node
378 @asynccontextmanager
379 async def stream(
380 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
381 ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
382 """Process the model response and yield events for the start and end of each function tool call."""
383 stream = self._run_stream(ctx)
384 yield stream
386 # Run the stream to completion if it was not finished:
387 async for _event in stream:
388 pass
390 async def _run_stream(
391 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
392 ) -> AsyncIterator[_messages.HandleResponseEvent]:
393 if self._events_iterator is None:
394 # Ensure that the stream is only run once
396 async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
397 texts: list[str] = []
398 tool_calls: list[_messages.ToolCallPart] = []
399 for part in self.model_response.parts:
400 if isinstance(part, _messages.TextPart):
401 # ignore empty content for text parts, see #437
402 if part.content:
403 texts.append(part.content)
404 elif isinstance(part, _messages.ToolCallPart):
405 tool_calls.append(part)
406 else:
407 assert_never(part)
409 # At the moment, we prioritize at least executing tool calls if they are present.
410 # In the future, we'd consider making this configurable at the agent or run level.
411 # This accounts for cases like anthropic returns that might contain a text response
412 # and a tool call response, where the text response just indicates the tool call will happen.
413 if tool_calls:
414 async for event in self._handle_tool_calls(ctx, tool_calls):
415 yield event
416 elif texts:
417 # No events are emitted during the handling of text responses, so we don't need to yield anything
418 self._next_node = await self._handle_text_response(ctx, texts)
419 else:
420 raise exceptions.UnexpectedModelBehavior('Received empty model response')
422 self._events_iterator = _run_stream()
424 async for event in self._events_iterator:
425 yield event
427 async def _handle_tool_calls(
428 self,
429 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
430 tool_calls: list[_messages.ToolCallPart],
431 ) -> AsyncIterator[_messages.HandleResponseEvent]:
432 result_schema = ctx.deps.result_schema
434 # first look for the result tool call
435 final_result: result.FinalResult[NodeRunEndT] | None = None
436 parts: list[_messages.ModelRequestPart] = []
437 if result_schema is not None:
438 for call, result_tool in result_schema.find_tool(tool_calls):
439 try:
440 result_data = result_tool.validate(call)
441 result_data = await _validate_result(result_data, ctx, call)
442 except _result.ToolRetryError as e:
443 # TODO: Should only increment retry stuff once per node execution, not for each tool call
444 # Also, should increment the tool-specific retry count rather than the run retry count
445 ctx.state.increment_retries(ctx.deps.max_result_retries)
446 parts.append(e.tool_retry)
447 else:
448 final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
449 break
451 # Then build the other request parts based on end strategy
452 tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
453 async for event in process_function_tools(
454 tool_calls,
455 final_result and final_result.tool_name,
456 final_result and final_result.tool_call_id,
457 ctx,
458 tool_responses,
459 ):
460 yield event
462 if final_result:
463 self._next_node = self._handle_final_result(ctx, final_result, tool_responses)
464 else:
465 if tool_responses:
466 parts.extend(tool_responses)
467 self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
469 def _handle_final_result(
470 self,
471 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
472 final_result: result.FinalResult[NodeRunEndT],
473 tool_responses: list[_messages.ModelRequestPart],
474 ) -> End[result.FinalResult[NodeRunEndT]]:
475 run_span = ctx.deps.run_span
476 usage = ctx.state.usage
477 messages = ctx.state.message_history
479 # For backwards compatibility, append a new ModelRequest using the tool returns and retries
480 if tool_responses:
481 messages.append(_messages.ModelRequest(parts=tool_responses))
483 run_span.set_attributes(
484 {
485 **usage.opentelemetry_attributes(),
486 'all_messages_events': json.dumps(
487 [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
488 ),
489 'final_result': final_result.data
490 if isinstance(final_result.data, str)
491 else json.dumps(InstrumentedModel.serialize_any(final_result.data)),
492 }
493 )
494 run_span.set_attributes(
495 {
496 'logfire.json_schema': json.dumps(
497 {
498 'type': 'object',
499 'properties': {
500 'all_messages_events': {'type': 'array'},
501 'final_result': {'type': 'object'},
502 },
503 }
504 ),
505 }
506 )
508 # End the run with self.data
509 return End(final_result)
511 async def _handle_text_response(
512 self,
513 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
514 texts: list[str],
515 ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
516 result_schema = ctx.deps.result_schema
518 text = '\n\n'.join(texts)
519 if allow_text_result(result_schema):
520 result_data_input = cast(NodeRunEndT, text)
521 try:
522 result_data = await _validate_result(result_data_input, ctx, None)
523 except _result.ToolRetryError as e:
524 ctx.state.increment_retries(ctx.deps.max_result_retries)
525 return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
526 else:
527 # The following cast is safe because we know `str` is an allowed result type
528 return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
529 else:
530 ctx.state.increment_retries(ctx.deps.max_result_retries)
531 return ModelRequestNode[DepsT, NodeRunEndT](
532 _messages.ModelRequest(
533 parts=[
534 _messages.RetryPromptPart(
535 content='Plain text responses are not permitted, please call one of the functions instead.',
536 )
537 ]
538 )
539 )
542def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
543 """Build a `RunContext` object from the current agent graph run context."""
544 return RunContext[DepsT](
545 deps=ctx.deps.user_deps,
546 model=ctx.deps.model,
547 usage=ctx.state.usage,
548 prompt=ctx.deps.prompt,
549 messages=ctx.state.message_history,
550 run_step=ctx.state.run_step,
551 )
554async def process_function_tools(
555 tool_calls: list[_messages.ToolCallPart],
556 result_tool_name: str | None,
557 result_tool_call_id: str | None,
558 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
559 output_parts: list[_messages.ModelRequestPart],
560) -> AsyncIterator[_messages.HandleResponseEvent]:
561 """Process function (i.e., non-result) tool calls in parallel.
563 Also add stub return parts for any other tools that need it.
565 Because async iterators can't have return values, we use `output_parts` as an output argument.
566 """
567 stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
568 result_schema = ctx.deps.result_schema
570 # we rely on the fact that if we found a result, it's the first result tool in the last
571 found_used_result_tool = False
572 run_context = build_run_context(ctx)
574 calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
575 call_index_to_event_id: dict[int, str] = {}
576 for call in tool_calls:
577 if (
578 call.tool_name == result_tool_name
579 and call.tool_call_id == result_tool_call_id
580 and not found_used_result_tool
581 ):
582 found_used_result_tool = True
583 output_parts.append(
584 _messages.ToolReturnPart(
585 tool_name=call.tool_name,
586 content='Final result processed.',
587 tool_call_id=call.tool_call_id,
588 )
589 )
590 elif tool := ctx.deps.function_tools.get(call.tool_name):
591 if stub_function_tools:
592 output_parts.append(
593 _messages.ToolReturnPart(
594 tool_name=call.tool_name,
595 content='Tool not executed - a final result was already processed.',
596 tool_call_id=call.tool_call_id,
597 )
598 )
599 else:
600 event = _messages.FunctionToolCallEvent(call)
601 yield event
602 call_index_to_event_id[len(calls_to_run)] = event.call_id
603 calls_to_run.append((tool, call))
604 elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
605 if stub_function_tools:
606 # TODO(Marcelo): We should add coverage for this part of the code.
607 output_parts.append( # pragma: no cover
608 _messages.ToolReturnPart(
609 tool_name=call.tool_name,
610 content='Tool not executed - a final result was already processed.',
611 tool_call_id=call.tool_call_id,
612 )
613 )
614 else:
615 event = _messages.FunctionToolCallEvent(call)
616 yield event
617 call_index_to_event_id[len(calls_to_run)] = event.call_id
618 calls_to_run.append((mcp_tool, call))
619 elif result_schema is not None and call.tool_name in result_schema.tools:
620 # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
621 # validation, we don't add another part here
622 if result_tool_name is not None:
623 if found_used_result_tool:
624 content = 'Result tool not used - a final result was already processed.'
625 else:
626 # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
627 content = 'Result tool not used - result failed validation.'
628 part = _messages.ToolReturnPart(
629 tool_name=call.tool_name,
630 content=content,
631 tool_call_id=call.tool_call_id,
632 )
633 output_parts.append(part)
634 else:
635 output_parts.append(_unknown_tool(call.tool_name, ctx))
637 if not calls_to_run:
638 return
640 # Run all tool tasks in parallel
641 results_by_index: dict[int, _messages.ModelRequestPart] = {}
642 with ctx.deps.tracer.start_as_current_span(
643 'running tools',
644 attributes={
645 'tools': [call.tool_name for _, call in calls_to_run],
646 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
647 },
648 ):
649 tasks = [
650 asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name)
651 for tool, call in calls_to_run
652 ]
653 pending = tasks
654 while pending:
655 done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
656 for task in done:
657 index = tasks.index(task)
658 result = task.result()
659 yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
660 if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
661 results_by_index[index] = result
662 else:
663 assert_never(result)
665 # We append the results at the end, rather than as they are received, to retain a consistent ordering
666 # This is mostly just to simplify testing
667 for k in sorted(results_by_index):
668 output_parts.append(results_by_index[k])
671async def _tool_from_mcp_server(
672 tool_name: str,
673 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
674) -> Tool[DepsT] | None:
675 """Call each MCP server to find the tool with the given name.
677 Args:
678 tool_name: The name of the tool to find.
679 ctx: The current run context.
681 Returns:
682 The tool with the given name, or `None` if no tool with the given name is found.
683 """
685 async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
686 # There's no normal situation where the server will not be running at this point, we check just in case
687 # some weird edge case occurs.
688 if not server.is_running: # pragma: no cover
689 raise exceptions.UserError(f'MCP server is not running: {server}')
690 result = await server.call_tool(tool_name, args)
691 return result
693 for server in ctx.deps.mcp_servers:
694 tools = await server.list_tools()
695 if tool_name in {tool.name for tool in tools}: 695 ↛ 693line 695 didn't jump to line 693 because the condition on line 695 was always true
696 return Tool(name=tool_name, function=run_tool, takes_ctx=True)
697 return None
700def _unknown_tool(
701 tool_name: str,
702 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
703) -> _messages.RetryPromptPart:
704 ctx.state.increment_retries(ctx.deps.max_result_retries)
705 tool_names = list(ctx.deps.function_tools.keys())
706 if result_schema := ctx.deps.result_schema:
707 tool_names.extend(result_schema.tool_names())
709 if tool_names:
710 msg = f'Available tools: {", ".join(tool_names)}'
711 else:
712 msg = 'No tools available.'
714 return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
717async def _validate_result(
718 result_data: T,
719 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
720 tool_call: _messages.ToolCallPart | None,
721) -> T:
722 for validator in ctx.deps.result_validators:
723 run_context = build_run_context(ctx)
724 result_data = await validator.validate(result_data, tool_call, run_context)
725 return result_data
728def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
729 """Check if the result schema allows text results."""
730 return result_schema is None or result_schema.allow_text_result
733@dataclasses.dataclass
734class _RunMessages:
735 messages: list[_messages.ModelMessage]
736 used: bool = False
739_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
742@contextmanager
743def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
744 """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
746 Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
748 Examples:
749 ```python
750 from pydantic_ai import Agent, capture_run_messages
752 agent = Agent('test')
754 with capture_run_messages() as messages:
755 try:
756 result = agent.run_sync('foobar')
757 except Exception:
758 print(messages)
759 raise
760 ```
762 !!! note
763 If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
764 `messages` will represent the messages exchanged during the first call only.
765 """
766 try:
767 yield _messages_ctx_var.get().messages
768 except LookupError:
769 messages: list[_messages.ModelMessage] = []
770 token = _messages_ctx_var.set(_RunMessages(messages))
771 try:
772 yield messages
773 finally:
774 _messages_ctx_var.reset(token)
777def get_captured_run_messages() -> _RunMessages:
778 return _messages_ctx_var.get()
781def build_agent_graph(
782 name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
783) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
784 """Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
785 nodes = (
786 UserPromptNode[DepsT],
787 ModelRequestNode[DepsT],
788 CallToolsNode[DepsT],
789 )
790 graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
791 nodes=nodes,
792 name=name or 'Agent',
793 state_type=GraphAgentState,
794 run_end_type=result.FinalResult[result_type],
795 auto_instrument=False,
796 )
797 return graph