Coverage for pydantic_ai_slim/pydantic_ai/_agent_graph.py: 98.57%
386 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import dataclasses
5from abc import ABC
6from collections.abc import AsyncIterator, Iterator, Sequence
7from contextlib import asynccontextmanager, contextmanager
8from contextvars import ContextVar
9from dataclasses import field
10from typing import Any, Generic, Literal, Union, cast
12import logfire_api
13from typing_extensions import TypeVar, assert_never
15from pydantic_graph import BaseNode, Graph, GraphRunContext
16from pydantic_graph.nodes import End, NodeRunEndT
18from . import (
19 _result,
20 _system_prompt,
21 exceptions,
22 messages as _messages,
23 models,
24 result,
25 usage as _usage,
26)
27from .result import ResultDataT
28from .settings import ModelSettings, merge_model_settings
29from .tools import (
30 RunContext,
31 Tool,
32 ToolDefinition,
33)
35_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
37# while waiting for https://github.com/pydantic/logfire/issues/745
38try:
39 import logfire._internal.stack_info
40except ImportError:
41 pass
42else:
43 from pathlib import Path
45 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
47T = TypeVar('T')
48NoneType = type(None)
49EndStrategy = Literal['early', 'exhaustive']
50"""The strategy for handling multiple tool calls when a final result is found.
52- `'early'`: Stop processing other tool calls once a final result is found
53- `'exhaustive'`: Process all tool calls even after finding a final result
54"""
55DepsT = TypeVar('DepsT')
56ResultT = TypeVar('ResultT')
59@dataclasses.dataclass
60class MarkFinalResult(Generic[ResultDataT]):
61 """Marker class to indicate that the result is the final result.
63 This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
65 It also avoids problems in the case where the result type is itself `None`, but is set.
66 """
68 data: ResultDataT
69 """The final result data."""
70 tool_name: str | None
71 """Name of the final result tool, None if the result is a string."""
74@dataclasses.dataclass
75class GraphAgentState:
76 """State kept across the execution of the agent graph."""
78 message_history: list[_messages.ModelMessage]
79 usage: _usage.Usage
80 retries: int
81 run_step: int
83 def increment_retries(self, max_result_retries: int) -> None:
84 self.retries += 1
85 if self.retries > max_result_retries:
86 raise exceptions.UnexpectedModelBehavior(
87 f'Exceeded maximum retries ({max_result_retries}) for result validation'
88 )
91@dataclasses.dataclass
92class GraphAgentDeps(Generic[DepsT, ResultDataT]):
93 """Dependencies/config passed to the agent graph."""
95 user_deps: DepsT
97 prompt: str
98 new_message_index: int
100 model: models.Model
101 model_settings: ModelSettings | None
102 usage_limits: _usage.UsageLimits
103 max_result_retries: int
104 end_strategy: EndStrategy
106 result_schema: _result.ResultSchema[ResultDataT] | None
107 result_tools: list[ToolDefinition]
108 result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
110 function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
112 run_span: logfire_api.LogfireSpan
115@dataclasses.dataclass
116class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
117 user_prompt: str
119 system_prompts: tuple[str, ...]
120 system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
121 system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
123 async def _get_first_message(
124 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
125 ) -> _messages.ModelRequest:
126 run_context = _build_run_context(ctx)
127 history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
128 ctx.state.message_history = history
129 run_context.messages = history
131 # TODO: We need to make it so that function_tools are not shared between runs
132 # See comment on the current_retry field of `Tool` for more details.
133 for tool in ctx.deps.function_tools.values():
134 tool.current_retry = 0
135 return next_message
137 async def _prepare_messages(
138 self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
139 ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
140 try:
141 ctx_messages = get_captured_run_messages()
142 except LookupError:
143 messages: list[_messages.ModelMessage] = []
144 else:
145 if ctx_messages.used:
146 messages = []
147 else:
148 messages = ctx_messages.messages
149 ctx_messages.used = True
151 if message_history:
152 # Shallow copy messages
153 messages.extend(message_history)
154 # Reevaluate any dynamic system prompt parts
155 await self._reevaluate_dynamic_prompts(messages, run_context)
156 return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
157 else:
158 parts = await self._sys_parts(run_context)
159 parts.append(_messages.UserPromptPart(user_prompt))
160 return messages, _messages.ModelRequest(parts)
162 async def _reevaluate_dynamic_prompts(
163 self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
164 ) -> None:
165 """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
166 # Only proceed if there's at least one dynamic runner.
167 if self.system_prompt_dynamic_functions:
168 for msg in messages:
169 if isinstance(msg, _messages.ModelRequest):
170 for i, part in enumerate(msg.parts):
171 if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
172 # Look up the runner by its ref
173 if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref): 173 ↛ 170line 173 didn't jump to line 170 because the condition on line 173 was always true
174 updated_part_content = await runner.run(run_context)
175 msg.parts[i] = _messages.SystemPromptPart(
176 updated_part_content, dynamic_ref=part.dynamic_ref
177 )
179 async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
180 """Build the initial messages for the conversation."""
181 messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts]
182 for sys_prompt_runner in self.system_prompt_functions:
183 prompt = await sys_prompt_runner.run(run_context)
184 if sys_prompt_runner.dynamic:
185 messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
186 else:
187 messages.append(_messages.SystemPromptPart(prompt))
188 return messages
191@dataclasses.dataclass
192class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
193 async def run(
194 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
195 ) -> ModelRequestNode[DepsT, NodeRunEndT]:
196 return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
199@dataclasses.dataclass
200class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
201 async def run(
202 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
203 ) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
204 return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
207async def _prepare_model(
208 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
209) -> models.AgentModel:
210 """Build tools and create an agent model."""
211 function_tool_defs: list[ToolDefinition] = []
213 run_context = _build_run_context(ctx)
215 async def add_tool(tool: Tool[DepsT]) -> None:
216 ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
217 if tool_def := await tool.prepare_tool_def(ctx):
218 function_tool_defs.append(tool_def)
220 await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
222 result_schema = ctx.deps.result_schema
223 return await run_context.model.agent_model(
224 function_tools=function_tool_defs,
225 allow_text_result=_allow_text_result(result_schema),
226 result_tools=result_schema.tool_defs() if result_schema is not None else [],
227 )
230@dataclasses.dataclass
231class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
232 """Make a request to the model using the last message in state.message_history."""
234 request: _messages.ModelRequest
236 async def run(
237 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
238 ) -> HandleResponseNode[DepsT, NodeRunEndT]:
239 ctx.state.message_history.append(self.request)
241 # Check usage
242 if ctx.deps.usage_limits: 242 ↛ 246line 242 didn't jump to line 246 because the condition on line 242 was always true
243 ctx.deps.usage_limits.check_before_request(ctx.state.usage)
245 # Increment run_step
246 ctx.state.run_step += 1
248 with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
249 agent_model = await _prepare_model(ctx)
251 # Actually make the model request
252 model_settings = merge_model_settings(ctx.deps.model_settings, None)
253 with _logfire.span('model request') as span:
254 model_response, request_usage = await agent_model.request(ctx.state.message_history, model_settings)
255 span.set_attribute('response', model_response)
256 span.set_attribute('usage', request_usage)
258 # Update usage
259 ctx.state.usage.incr(request_usage, requests=1)
260 if ctx.deps.usage_limits: 260 ↛ 264line 260 didn't jump to line 264 because the condition on line 260 was always true
261 ctx.deps.usage_limits.check_tokens(ctx.state.usage)
263 # Append the model response to state.message_history
264 ctx.state.message_history.append(model_response)
265 return HandleResponseNode(model_response)
268@dataclasses.dataclass
269class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
270 """Process e response from a model, decide whether to end the run or make a new request."""
272 model_response: _messages.ModelResponse
274 async def run(
275 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
276 ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007
277 with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
278 texts: list[str] = []
279 tool_calls: list[_messages.ToolCallPart] = []
280 for part in self.model_response.parts:
281 if isinstance(part, _messages.TextPart):
282 # ignore empty content for text parts, see #437
283 if part.content:
284 texts.append(part.content)
285 elif isinstance(part, _messages.ToolCallPart):
286 tool_calls.append(part)
287 else:
288 assert_never(part)
290 # At the moment, we prioritize at least executing tool calls if they are present.
291 # In the future, we'd consider making this configurable at the agent or run level.
292 # This accounts for cases like anthropic returns that might contain a text response
293 # and a tool call response, where the text response just indicates the tool call will happen.
294 if tool_calls:
295 return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
296 elif texts:
297 return await self._handle_text_response(ctx, texts, handle_span)
298 else:
299 raise exceptions.UnexpectedModelBehavior('Received empty model response')
301 async def _handle_tool_calls_response(
302 self,
303 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
304 tool_calls: list[_messages.ToolCallPart],
305 handle_span: logfire_api.LogfireSpan,
306 ):
307 result_schema = ctx.deps.result_schema
309 # first look for the result tool call
310 final_result: MarkFinalResult[NodeRunEndT] | None = None
311 parts: list[_messages.ModelRequestPart] = []
312 if result_schema is not None:
313 if match := result_schema.find_tool(tool_calls):
314 call, result_tool = match
315 try:
316 result_data = result_tool.validate(call)
317 result_data = await _validate_result(result_data, ctx, call)
318 except _result.ToolRetryError as e:
319 # TODO: Should only increment retry stuff once per node execution, not for each tool call
320 # Also, should increment the tool-specific retry count rather than the run retry count
321 ctx.state.increment_retries(ctx.deps.max_result_retries)
322 parts.append(e.tool_retry)
323 else:
324 final_result = MarkFinalResult(result_data, call.tool_name)
326 # Then build the other request parts based on end strategy
327 tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
329 if final_result:
330 handle_span.set_attribute('result', final_result.data)
331 handle_span.message = 'handle model response -> final result'
332 return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
333 else:
334 if tool_responses:
335 handle_span.set_attribute('tool_responses', tool_responses)
336 tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
337 handle_span.message = f'handle model response -> {tool_responses_str}'
338 parts.extend(tool_responses)
339 return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
341 async def _handle_text_response(
342 self,
343 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
344 texts: list[str],
345 handle_span: logfire_api.LogfireSpan,
346 ):
347 result_schema = ctx.deps.result_schema
349 text = '\n\n'.join(texts)
350 if _allow_text_result(result_schema):
351 result_data_input = cast(NodeRunEndT, text)
352 try:
353 result_data = await _validate_result(result_data_input, ctx, None)
354 except _result.ToolRetryError as e:
355 ctx.state.increment_retries(ctx.deps.max_result_retries)
356 return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
357 else:
358 handle_span.set_attribute('result', result_data)
359 handle_span.message = 'handle model response -> final result'
360 return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
361 else:
362 ctx.state.increment_retries(ctx.deps.max_result_retries)
363 return ModelRequestNode[DepsT, NodeRunEndT](
364 _messages.ModelRequest(
365 parts=[
366 _messages.RetryPromptPart(
367 content='Plain text responses are not permitted, please call one of the functions instead.',
368 )
369 ]
370 )
371 )
374@dataclasses.dataclass
375class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
376 """Make a request to the model using the last message in state.message_history (or a specified request)."""
378 request: _messages.ModelRequest
379 _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
380 field(default=None, repr=False)
381 )
383 async def run(
384 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
385 ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
386 if self._result is not None: 386 ↛ 389line 386 didn't jump to line 389 because the condition on line 386 was always true
387 return self._result
389 async with self.run_to_result(ctx) as final_node:
390 return final_node
392 @asynccontextmanager
393 async def run_to_result(
394 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
395 ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
396 result_schema = ctx.deps.result_schema
398 ctx.state.message_history.append(self.request)
400 # Check usage
401 if ctx.deps.usage_limits: 401 ↛ 405line 401 didn't jump to line 405 because the condition on line 401 was always true
402 ctx.deps.usage_limits.check_before_request(ctx.state.usage)
404 # Increment run_step
405 ctx.state.run_step += 1
407 with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
408 agent_model = await _prepare_model(ctx)
410 # Actually make the model request
411 model_settings = merge_model_settings(ctx.deps.model_settings, None)
412 with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
413 async with agent_model.request_stream(ctx.state.message_history, model_settings) as streamed_response:
414 ctx.state.usage.requests += 1
415 model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
416 # We want to end the "model request" span here, but we can't exit the context manager
417 # in the traditional way
418 model_req_span.__exit__(None, None, None)
420 with _logfire.span('handle model response') as handle_span:
421 received_text = False
423 async for maybe_part_event in streamed_response:
424 if isinstance(maybe_part_event, _messages.PartStartEvent):
425 new_part = maybe_part_event.part
426 if isinstance(new_part, _messages.TextPart):
427 received_text = True
428 if _allow_text_result(result_schema):
429 handle_span.message = 'handle model response -> final result'
430 streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
431 self._result = End(streamed_run_result)
432 yield self._result
433 return
434 elif isinstance(new_part, _messages.ToolCallPart):
435 if result_schema is not None and (match := result_schema.find_tool([new_part])):
436 call, _ = match
437 handle_span.message = 'handle model response -> final result'
438 streamed_run_result = _build_streamed_run_result(
439 streamed_response, call.tool_name, ctx
440 )
441 self._result = End(streamed_run_result)
442 yield self._result
443 return
444 else:
445 assert_never(new_part)
447 tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
448 parts: list[_messages.ModelRequestPart] = []
449 model_response = streamed_response.get()
450 if not model_response.parts:
451 raise exceptions.UnexpectedModelBehavior('Received empty model response')
452 ctx.state.message_history.append(model_response)
454 run_context = _build_run_context(ctx)
455 for p in model_response.parts:
456 if isinstance(p, _messages.ToolCallPart):
457 if tool := ctx.deps.function_tools.get(p.tool_name):
458 tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
459 else:
460 parts.append(_unknown_tool(p.tool_name, ctx))
462 if received_text and not tasks and not parts:
463 # Can only get here if self._allow_text_result returns `False` for the provided result_schema
464 ctx.state.increment_retries(ctx.deps.max_result_retries)
465 self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
466 _messages.ModelRequest(
467 parts=[
468 _messages.RetryPromptPart(
469 content='Plain text responses are not permitted, please call one of the functions instead.',
470 )
471 ]
472 )
473 )
474 yield self._result
475 return
477 with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
478 task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
479 parts.extend(task_results)
481 next_request = _messages.ModelRequest(parts=parts)
482 if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
483 try:
484 ctx.state.increment_retries(ctx.deps.max_result_retries)
485 except:
486 # TODO: This is janky, so I think we should probably change it, but how?
487 ctx.state.message_history.append(next_request)
488 raise
490 handle_span.set_attribute('tool_responses', parts)
491 tool_responses_str = ' '.join(r.part_kind for r in parts)
492 handle_span.message = f'handle model response -> {tool_responses_str}'
493 # the model_response should have been fully streamed by now, we can add its usage
494 streamed_response_usage = streamed_response.usage()
495 run_context.usage.incr(streamed_response_usage)
496 ctx.deps.usage_limits.check_tokens(run_context.usage)
497 self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
498 yield self._result
499 return
502@dataclasses.dataclass
503class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
504 """Produce the final result of the run."""
506 data: MarkFinalResult[NodeRunEndT]
507 """The final result data."""
508 extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
510 async def run(
511 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
512 ) -> End[MarkFinalResult[NodeRunEndT]]:
513 run_span = ctx.deps.run_span
514 usage = ctx.state.usage
515 messages = ctx.state.message_history
517 # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
518 if self.extra_parts:
519 messages.append(_messages.ModelRequest(parts=self.extra_parts))
521 # TODO: Set this attribute somewhere
522 # handle_span = self.handle_model_response_span
523 # handle_span.set_attribute('final_data', self.data)
524 run_span.set_attribute('usage', usage)
525 run_span.set_attribute('all_messages', messages)
527 # End the run with self.data
528 return End(self.data)
531def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
532 return RunContext[DepsT](
533 deps=ctx.deps.user_deps,
534 model=ctx.deps.model,
535 usage=ctx.state.usage,
536 prompt=ctx.deps.prompt,
537 messages=ctx.state.message_history,
538 run_step=ctx.state.run_step,
539 )
542def _build_streamed_run_result(
543 result_stream: models.StreamedResponse,
544 result_tool_name: str | None,
545 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
546) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
547 new_message_index = ctx.deps.new_message_index
548 result_schema = ctx.deps.result_schema
549 run_span = ctx.deps.run_span
550 usage_limits = ctx.deps.usage_limits
551 messages = ctx.state.message_history
552 run_context = _build_run_context(ctx)
554 async def on_complete():
555 """Called when the stream has completed.
557 The model response will have been added to messages by now
558 by `StreamedRunResult._marked_completed`.
559 """
560 last_message = messages[-1]
561 assert isinstance(last_message, _messages.ModelResponse)
562 tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
563 parts = await _process_function_tools(
564 tool_calls,
565 result_tool_name,
566 ctx,
567 )
568 # TODO: Should we do something here related to the retry count?
569 # Maybe we should move the incrementing of the retry count to where we actually make a request?
570 # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
571 # ctx.state.increment_retries(ctx.deps.max_result_retries)
572 if parts:
573 messages.append(_messages.ModelRequest(parts))
574 run_span.set_attribute('all_messages', messages)
576 return result.StreamedRunResult[DepsT, NodeRunEndT](
577 messages,
578 new_message_index,
579 usage_limits,
580 result_stream,
581 result_schema,
582 run_context,
583 ctx.deps.result_validators,
584 result_tool_name,
585 on_complete,
586 )
589async def _process_function_tools(
590 tool_calls: list[_messages.ToolCallPart],
591 result_tool_name: str | None,
592 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
593) -> list[_messages.ModelRequestPart]:
594 """Process function (non-result) tool calls in parallel.
596 Also add stub return parts for any other tools that need it.
597 """
598 parts: list[_messages.ModelRequestPart] = []
599 tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
601 stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
602 result_schema = ctx.deps.result_schema
604 # we rely on the fact that if we found a result, it's the first result tool in the last
605 found_used_result_tool = False
606 run_context = _build_run_context(ctx)
608 for call in tool_calls:
609 if call.tool_name == result_tool_name and not found_used_result_tool:
610 found_used_result_tool = True
611 parts.append(
612 _messages.ToolReturnPart(
613 tool_name=call.tool_name,
614 content='Final result processed.',
615 tool_call_id=call.tool_call_id,
616 )
617 )
618 elif tool := ctx.deps.function_tools.get(call.tool_name):
619 if stub_function_tools:
620 parts.append(
621 _messages.ToolReturnPart(
622 tool_name=call.tool_name,
623 content='Tool not executed - a final result was already processed.',
624 tool_call_id=call.tool_call_id,
625 )
626 )
627 else:
628 tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
629 elif result_schema is not None and call.tool_name in result_schema.tools:
630 # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
631 # validation, we don't add another part here
632 if result_tool_name is not None:
633 parts.append(
634 _messages.ToolReturnPart(
635 tool_name=call.tool_name,
636 content='Result tool not used - a final result was already processed.',
637 tool_call_id=call.tool_call_id,
638 )
639 )
640 else:
641 parts.append(_unknown_tool(call.tool_name, ctx))
643 # Run all tool tasks in parallel
644 if tasks:
645 with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
646 task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks)
647 for result in task_results:
648 if isinstance(result, _messages.ToolReturnPart):
649 parts.append(result)
650 elif isinstance(result, _messages.RetryPromptPart):
651 parts.append(result)
652 else:
653 assert_never(result)
654 return parts
657def _unknown_tool(
658 tool_name: str,
659 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
660) -> _messages.RetryPromptPart:
661 ctx.state.increment_retries(ctx.deps.max_result_retries)
662 tool_names = list(ctx.deps.function_tools.keys())
663 if result_schema := ctx.deps.result_schema:
664 tool_names.extend(result_schema.tool_names())
666 if tool_names:
667 msg = f'Available tools: {", ".join(tool_names)}'
668 else:
669 msg = 'No tools available.'
671 return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
674async def _validate_result(
675 result_data: T,
676 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
677 tool_call: _messages.ToolCallPart | None,
678) -> T:
679 for validator in ctx.deps.result_validators:
680 run_context = _build_run_context(ctx)
681 result_data = await validator.validate(result_data, tool_call, run_context)
682 return result_data
685def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
686 return result_schema is None or result_schema.allow_text_result
689@dataclasses.dataclass
690class _RunMessages:
691 messages: list[_messages.ModelMessage]
692 used: bool = False
695_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
698@contextmanager
699def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
700 """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
702 Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
704 Examples:
705 ```python
706 from pydantic_ai import Agent, capture_run_messages
708 agent = Agent('test')
710 with capture_run_messages() as messages:
711 try:
712 result = agent.run_sync('foobar')
713 except Exception:
714 print(messages)
715 raise
716 ```
718 !!! note
719 If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
720 `messages` will represent the messages exchanged during the first call only.
721 """
722 try:
723 yield _messages_ctx_var.get().messages
724 except LookupError:
725 messages: list[_messages.ModelMessage] = []
726 token = _messages_ctx_var.set(_RunMessages(messages))
727 try:
728 yield messages
729 finally:
730 _messages_ctx_var.reset(token)
733def get_captured_run_messages() -> _RunMessages:
734 return _messages_ctx_var.get()
737def build_agent_graph(
738 name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
739) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]:
740 # We'll define the known node classes:
741 nodes = (
742 UserPromptNode[DepsT],
743 ModelRequestNode[DepsT],
744 HandleResponseNode[DepsT],
745 FinalResultNode[DepsT, ResultT],
746 )
747 graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]](
748 nodes=nodes,
749 name=name or 'Agent',
750 state_type=GraphAgentState,
751 run_end_type=MarkFinalResult[result_type],
752 auto_instrument=False,
753 )
754 return graph
757def build_agent_stream_graph(
758 name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
759) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
760 nodes = [
761 StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
762 StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
763 ]
764 graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
765 nodes=nodes,
766 name=name or 'Agent',
767 state_type=GraphAgentState,
768 run_end_type=result.StreamedRunResult[DepsT, result_type],
769 )
770 return graph