Coverage for pydantic_ai_slim/pydantic_ai/_agent_graph.py: 98.20%

343 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import asyncio 

4import dataclasses 

5import json 

6from collections.abc import AsyncIterator, Iterator, Sequence 

7from contextlib import asynccontextmanager, contextmanager 

8from contextvars import ContextVar 

9from dataclasses import field 

10from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast 

11 

12from opentelemetry.trace import Span, Tracer 

13from typing_extensions import TypeGuard, TypeVar, assert_never 

14 

15from pydantic_graph import BaseNode, Graph, GraphRunContext 

16from pydantic_graph.nodes import End, NodeRunEndT 

17 

18from . import ( 

19 _result, 

20 _system_prompt, 

21 exceptions, 

22 messages as _messages, 

23 models, 

24 result, 

25 usage as _usage, 

26) 

27from .models.instrumented import InstrumentedModel 

28from .result import ResultDataT 

29from .settings import ModelSettings, merge_model_settings 

30from .tools import RunContext, Tool, ToolDefinition 

31 

32if TYPE_CHECKING: 

33 from .mcp import MCPServer 

34 

35__all__ = ( 

36 'GraphAgentState', 

37 'GraphAgentDeps', 

38 'UserPromptNode', 

39 'ModelRequestNode', 

40 'CallToolsNode', 

41 'build_run_context', 

42 'capture_run_messages', 

43) 

44 

45 

46T = TypeVar('T') 

47S = TypeVar('S') 

48NoneType = type(None) 

49EndStrategy = Literal['early', 'exhaustive'] 

50"""The strategy for handling multiple tool calls when a final result is found. 

51 

52- `'early'`: Stop processing other tool calls once a final result is found 

53- `'exhaustive'`: Process all tool calls even after finding a final result 

54""" 

55DepsT = TypeVar('DepsT') 

56ResultT = TypeVar('ResultT') 

57 

58 

59@dataclasses.dataclass 

60class GraphAgentState: 

61 """State kept across the execution of the agent graph.""" 

62 

63 message_history: list[_messages.ModelMessage] 

64 usage: _usage.Usage 

65 retries: int 

66 run_step: int 

67 

68 def increment_retries(self, max_result_retries: int) -> None: 

69 self.retries += 1 

70 if self.retries > max_result_retries: 

71 raise exceptions.UnexpectedModelBehavior( 

72 f'Exceeded maximum retries ({max_result_retries}) for result validation' 

73 ) 

74 

75 

76@dataclasses.dataclass 

77class GraphAgentDeps(Generic[DepsT, ResultDataT]): 

78 """Dependencies/config passed to the agent graph.""" 

79 

80 user_deps: DepsT 

81 

82 prompt: str | Sequence[_messages.UserContent] 

83 new_message_index: int 

84 

85 model: models.Model 

86 model_settings: ModelSettings | None 

87 usage_limits: _usage.UsageLimits 

88 max_result_retries: int 

89 end_strategy: EndStrategy 

90 

91 result_schema: _result.ResultSchema[ResultDataT] | None 

92 result_tools: list[ToolDefinition] 

93 result_validators: list[_result.ResultValidator[DepsT, ResultDataT]] 

94 

95 function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) 

96 mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) 

97 

98 run_span: Span 

99 tracer: Tracer 

100 

101 

102class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): 

103 """The base class for all agent nodes. 

104 

105 Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere 

106 """ 

107 

108 

109def is_agent_node( 

110 node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]], 

111) -> TypeGuard[AgentNode[T, S]]: 

112 """Check if the provided node is an instance of `AgentNode`. 

113 

114 Usage: 

115 

116 if is_agent_node(node): 

117 # `node` is an AgentNode 

118 ... 

119 

120 This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`. 

121 """ 

122 return isinstance(node, AgentNode) 

123 

124 

125@dataclasses.dataclass 

126class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): 

127 user_prompt: str | Sequence[_messages.UserContent] 

128 

129 system_prompts: tuple[str, ...] 

130 system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] 

131 system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] 

132 

133 async def run( 

134 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

135 ) -> ModelRequestNode[DepsT, NodeRunEndT]: 

136 return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) 

137 

138 async def _get_first_message( 

139 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

140 ) -> _messages.ModelRequest: 

141 run_context = build_run_context(ctx) 

142 history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context) 

143 ctx.state.message_history = history 

144 run_context.messages = history 

145 

146 # TODO: We need to make it so that function_tools are not shared between runs 

147 # See comment on the current_retry field of `Tool` for more details. 

148 for tool in ctx.deps.function_tools.values(): 

149 tool.current_retry = 0 

150 return next_message 

151 

152 async def _prepare_messages( 

153 self, 

154 user_prompt: str | Sequence[_messages.UserContent], 

155 message_history: list[_messages.ModelMessage] | None, 

156 run_context: RunContext[DepsT], 

157 ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]: 

158 try: 

159 ctx_messages = get_captured_run_messages() 

160 except LookupError: 

161 messages: list[_messages.ModelMessage] = [] 

162 else: 

163 if ctx_messages.used: 

164 messages = [] 

165 else: 

166 messages = ctx_messages.messages 

167 ctx_messages.used = True 

168 

169 if message_history: 

170 # Shallow copy messages 

171 messages.extend(message_history) 

172 # Reevaluate any dynamic system prompt parts 

173 await self._reevaluate_dynamic_prompts(messages, run_context) 

174 return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)]) 

175 else: 

176 parts = await self._sys_parts(run_context) 

177 parts.append(_messages.UserPromptPart(user_prompt)) 

178 return messages, _messages.ModelRequest(parts) 

179 

180 async def _reevaluate_dynamic_prompts( 

181 self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT] 

182 ) -> None: 

183 """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function.""" 

184 # Only proceed if there's at least one dynamic runner. 

185 if self.system_prompt_dynamic_functions: 

186 for msg in messages: 

187 if isinstance(msg, _messages.ModelRequest): 

188 for i, part in enumerate(msg.parts): 

189 if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref: 

190 # Look up the runner by its ref 

191 if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref): 191 ↛ 188line 191 didn't jump to line 188 because the condition on line 191 was always true

192 updated_part_content = await runner.run(run_context) 

193 msg.parts[i] = _messages.SystemPromptPart( 

194 updated_part_content, dynamic_ref=part.dynamic_ref 

195 ) 

196 

197 async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]: 

198 """Build the initial messages for the conversation.""" 

199 messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts] 

200 for sys_prompt_runner in self.system_prompt_functions: 

201 prompt = await sys_prompt_runner.run(run_context) 

202 if sys_prompt_runner.dynamic: 

203 messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__)) 

204 else: 

205 messages.append(_messages.SystemPromptPart(prompt)) 

206 return messages 

207 

208 

209async def _prepare_request_parameters( 

210 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

211) -> models.ModelRequestParameters: 

212 """Build tools and create an agent model.""" 

213 function_tool_defs: list[ToolDefinition] = [] 

214 

215 run_context = build_run_context(ctx) 

216 

217 async def add_tool(tool: Tool[DepsT]) -> None: 

218 ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) 

219 if tool_def := await tool.prepare_tool_def(ctx): 

220 function_tool_defs.append(tool_def) 

221 

222 async def add_mcp_server_tools(server: MCPServer) -> None: 

223 if not server.is_running: 

224 raise exceptions.UserError(f'MCP server is not running: {server}') 

225 tool_defs = await server.list_tools() 

226 # TODO(Marcelo): We should check if the tool names are unique. If not, we should raise an error. 

227 function_tool_defs.extend(tool_defs) 

228 

229 await asyncio.gather( 

230 *map(add_tool, ctx.deps.function_tools.values()), 

231 *map(add_mcp_server_tools, ctx.deps.mcp_servers), 

232 ) 

233 

234 result_schema = ctx.deps.result_schema 

235 return models.ModelRequestParameters( 

236 function_tools=function_tool_defs, 

237 allow_text_result=allow_text_result(result_schema), 

238 result_tools=result_schema.tool_defs() if result_schema is not None else [], 

239 ) 

240 

241 

242@dataclasses.dataclass 

243class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): 

244 """Make a request to the model using the last message in state.message_history.""" 

245 

246 request: _messages.ModelRequest 

247 

248 _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False) 

249 _did_stream: bool = field(default=False, repr=False) 

250 

251 async def run( 

252 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

253 ) -> CallToolsNode[DepsT, NodeRunEndT]: 

254 if self._result is not None: 

255 return self._result 

256 

257 if self._did_stream: 257 ↛ 260line 257 didn't jump to line 260 because the condition on line 257 was never true

258 # `self._result` gets set when exiting the `stream` contextmanager, so hitting this 

259 # means that the stream was started but not finished before `run()` was called 

260 raise exceptions.AgentRunError('You must finish streaming before calling run()') 

261 

262 return await self._make_request(ctx) 

263 

264 @asynccontextmanager 

265 async def stream( 

266 self, 

267 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], 

268 ) -> AsyncIterator[result.AgentStream[DepsT, T]]: 

269 async with self._stream(ctx) as streamed_response: 

270 agent_stream = result.AgentStream[DepsT, T]( 

271 streamed_response, 

272 ctx.deps.result_schema, 

273 ctx.deps.result_validators, 

274 build_run_context(ctx), 

275 ctx.deps.usage_limits, 

276 ) 

277 yield agent_stream 

278 # In case the user didn't manually consume the full stream, ensure it is fully consumed here, 

279 # otherwise usage won't be properly counted: 

280 async for _ in agent_stream: 

281 pass 

282 

283 @asynccontextmanager 

284 async def _stream( 

285 self, 

286 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], 

287 ) -> AsyncIterator[models.StreamedResponse]: 

288 assert not self._did_stream, 'stream() should only be called once per node' 

289 

290 model_settings, model_request_parameters = await self._prepare_request(ctx) 

291 async with ctx.deps.model.request_stream( 

292 ctx.state.message_history, model_settings, model_request_parameters 

293 ) as streamed_response: 

294 self._did_stream = True 

295 ctx.state.usage.incr(_usage.Usage(), requests=1) 

296 yield streamed_response 

297 # In case the user didn't manually consume the full stream, ensure it is fully consumed here, 

298 # otherwise usage won't be properly counted: 

299 async for _ in streamed_response: 

300 pass 

301 model_response = streamed_response.get() 

302 request_usage = streamed_response.usage() 

303 

304 self._finish_handling(ctx, model_response, request_usage) 

305 assert self._result is not None # this should be set by the previous line 

306 

307 async def _make_request( 

308 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

309 ) -> CallToolsNode[DepsT, NodeRunEndT]: 

310 if self._result is not None: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 return self._result 

312 

313 model_settings, model_request_parameters = await self._prepare_request(ctx) 

314 model_response, request_usage = await ctx.deps.model.request( 

315 ctx.state.message_history, model_settings, model_request_parameters 

316 ) 

317 ctx.state.usage.incr(_usage.Usage(), requests=1) 

318 

319 return self._finish_handling(ctx, model_response, request_usage) 

320 

321 async def _prepare_request( 

322 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

323 ) -> tuple[ModelSettings | None, models.ModelRequestParameters]: 

324 ctx.state.message_history.append(self.request) 

325 

326 # Check usage 

327 if ctx.deps.usage_limits: 327 ↛ 331line 327 didn't jump to line 331 because the condition on line 327 was always true

328 ctx.deps.usage_limits.check_before_request(ctx.state.usage) 

329 

330 # Increment run_step 

331 ctx.state.run_step += 1 

332 

333 model_settings = merge_model_settings(ctx.deps.model_settings, None) 

334 model_request_parameters = await _prepare_request_parameters(ctx) 

335 return model_settings, model_request_parameters 

336 

337 def _finish_handling( 

338 self, 

339 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

340 response: _messages.ModelResponse, 

341 usage: _usage.Usage, 

342 ) -> CallToolsNode[DepsT, NodeRunEndT]: 

343 # Update usage 

344 ctx.state.usage.incr(usage, requests=0) 

345 if ctx.deps.usage_limits: 345 ↛ 349line 345 didn't jump to line 349 because the condition on line 345 was always true

346 ctx.deps.usage_limits.check_tokens(ctx.state.usage) 

347 

348 # Append the model response to state.message_history 

349 ctx.state.message_history.append(response) 

350 

351 # Set the `_result` attribute since we can't use `return` in an async iterator 

352 self._result = CallToolsNode(response) 

353 

354 return self._result 

355 

356 

357@dataclasses.dataclass 

358class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): 

359 """Process a model response, and decide whether to end the run or make a new request.""" 

360 

361 model_response: _messages.ModelResponse 

362 

363 _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) 

364 _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( 

365 default=None, repr=False 

366 ) 

367 _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) 

368 

369 async def run( 

370 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

371 ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007 

372 async with self.stream(ctx): 

373 pass 

374 

375 assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends' 

376 return next_node 

377 

378 @asynccontextmanager 

379 async def stream( 

380 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

381 ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: 

382 """Process the model response and yield events for the start and end of each function tool call.""" 

383 stream = self._run_stream(ctx) 

384 yield stream 

385 

386 # Run the stream to completion if it was not finished: 

387 async for _event in stream: 

388 pass 

389 

390 async def _run_stream( 

391 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

392 ) -> AsyncIterator[_messages.HandleResponseEvent]: 

393 if self._events_iterator is None: 

394 # Ensure that the stream is only run once 

395 

396 async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: 

397 texts: list[str] = [] 

398 tool_calls: list[_messages.ToolCallPart] = [] 

399 for part in self.model_response.parts: 

400 if isinstance(part, _messages.TextPart): 

401 # ignore empty content for text parts, see #437 

402 if part.content: 

403 texts.append(part.content) 

404 elif isinstance(part, _messages.ToolCallPart): 

405 tool_calls.append(part) 

406 else: 

407 assert_never(part) 

408 

409 # At the moment, we prioritize at least executing tool calls if they are present. 

410 # In the future, we'd consider making this configurable at the agent or run level. 

411 # This accounts for cases like anthropic returns that might contain a text response 

412 # and a tool call response, where the text response just indicates the tool call will happen. 

413 if tool_calls: 

414 async for event in self._handle_tool_calls(ctx, tool_calls): 

415 yield event 

416 elif texts: 

417 # No events are emitted during the handling of text responses, so we don't need to yield anything 

418 self._next_node = await self._handle_text_response(ctx, texts) 

419 else: 

420 raise exceptions.UnexpectedModelBehavior('Received empty model response') 

421 

422 self._events_iterator = _run_stream() 

423 

424 async for event in self._events_iterator: 

425 yield event 

426 

427 async def _handle_tool_calls( 

428 self, 

429 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

430 tool_calls: list[_messages.ToolCallPart], 

431 ) -> AsyncIterator[_messages.HandleResponseEvent]: 

432 result_schema = ctx.deps.result_schema 

433 

434 # first look for the result tool call 

435 final_result: result.FinalResult[NodeRunEndT] | None = None 

436 parts: list[_messages.ModelRequestPart] = [] 

437 if result_schema is not None: 

438 for call, result_tool in result_schema.find_tool(tool_calls): 

439 try: 

440 result_data = result_tool.validate(call) 

441 result_data = await _validate_result(result_data, ctx, call) 

442 except _result.ToolRetryError as e: 

443 # TODO: Should only increment retry stuff once per node execution, not for each tool call 

444 # Also, should increment the tool-specific retry count rather than the run retry count 

445 ctx.state.increment_retries(ctx.deps.max_result_retries) 

446 parts.append(e.tool_retry) 

447 else: 

448 final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) 

449 break 

450 

451 # Then build the other request parts based on end strategy 

452 tool_responses: list[_messages.ModelRequestPart] = self._tool_responses 

453 async for event in process_function_tools( 

454 tool_calls, 

455 final_result and final_result.tool_name, 

456 final_result and final_result.tool_call_id, 

457 ctx, 

458 tool_responses, 

459 ): 

460 yield event 

461 

462 if final_result: 

463 self._next_node = self._handle_final_result(ctx, final_result, tool_responses) 

464 else: 

465 if tool_responses: 

466 parts.extend(tool_responses) 

467 self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) 

468 

469 def _handle_final_result( 

470 self, 

471 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

472 final_result: result.FinalResult[NodeRunEndT], 

473 tool_responses: list[_messages.ModelRequestPart], 

474 ) -> End[result.FinalResult[NodeRunEndT]]: 

475 run_span = ctx.deps.run_span 

476 usage = ctx.state.usage 

477 messages = ctx.state.message_history 

478 

479 # For backwards compatibility, append a new ModelRequest using the tool returns and retries 

480 if tool_responses: 

481 messages.append(_messages.ModelRequest(parts=tool_responses)) 

482 

483 run_span.set_attributes( 

484 { 

485 **usage.opentelemetry_attributes(), 

486 'all_messages_events': json.dumps( 

487 [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)] 

488 ), 

489 'final_result': final_result.data 

490 if isinstance(final_result.data, str) 

491 else json.dumps(InstrumentedModel.serialize_any(final_result.data)), 

492 } 

493 ) 

494 run_span.set_attributes( 

495 { 

496 'logfire.json_schema': json.dumps( 

497 { 

498 'type': 'object', 

499 'properties': { 

500 'all_messages_events': {'type': 'array'}, 

501 'final_result': {'type': 'object'}, 

502 }, 

503 } 

504 ), 

505 } 

506 ) 

507 

508 # End the run with self.data 

509 return End(final_result) 

510 

511 async def _handle_text_response( 

512 self, 

513 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

514 texts: list[str], 

515 ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: 

516 result_schema = ctx.deps.result_schema 

517 

518 text = '\n\n'.join(texts) 

519 if allow_text_result(result_schema): 

520 result_data_input = cast(NodeRunEndT, text) 

521 try: 

522 result_data = await _validate_result(result_data_input, ctx, None) 

523 except _result.ToolRetryError as e: 

524 ctx.state.increment_retries(ctx.deps.max_result_retries) 

525 return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) 

526 else: 

527 # The following cast is safe because we know `str` is an allowed result type 

528 return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) 

529 else: 

530 ctx.state.increment_retries(ctx.deps.max_result_retries) 

531 return ModelRequestNode[DepsT, NodeRunEndT]( 

532 _messages.ModelRequest( 

533 parts=[ 

534 _messages.RetryPromptPart( 

535 content='Plain text responses are not permitted, please call one of the functions instead.', 

536 ) 

537 ] 

538 ) 

539 ) 

540 

541 

542def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: 

543 """Build a `RunContext` object from the current agent graph run context.""" 

544 return RunContext[DepsT]( 

545 deps=ctx.deps.user_deps, 

546 model=ctx.deps.model, 

547 usage=ctx.state.usage, 

548 prompt=ctx.deps.prompt, 

549 messages=ctx.state.message_history, 

550 run_step=ctx.state.run_step, 

551 ) 

552 

553 

554async def process_function_tools( 

555 tool_calls: list[_messages.ToolCallPart], 

556 result_tool_name: str | None, 

557 result_tool_call_id: str | None, 

558 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

559 output_parts: list[_messages.ModelRequestPart], 

560) -> AsyncIterator[_messages.HandleResponseEvent]: 

561 """Process function (i.e., non-result) tool calls in parallel. 

562 

563 Also add stub return parts for any other tools that need it. 

564 

565 Because async iterators can't have return values, we use `output_parts` as an output argument. 

566 """ 

567 stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early' 

568 result_schema = ctx.deps.result_schema 

569 

570 # we rely on the fact that if we found a result, it's the first result tool in the last 

571 found_used_result_tool = False 

572 run_context = build_run_context(ctx) 

573 

574 calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] 

575 call_index_to_event_id: dict[int, str] = {} 

576 for call in tool_calls: 

577 if ( 

578 call.tool_name == result_tool_name 

579 and call.tool_call_id == result_tool_call_id 

580 and not found_used_result_tool 

581 ): 

582 found_used_result_tool = True 

583 output_parts.append( 

584 _messages.ToolReturnPart( 

585 tool_name=call.tool_name, 

586 content='Final result processed.', 

587 tool_call_id=call.tool_call_id, 

588 ) 

589 ) 

590 elif tool := ctx.deps.function_tools.get(call.tool_name): 

591 if stub_function_tools: 

592 output_parts.append( 

593 _messages.ToolReturnPart( 

594 tool_name=call.tool_name, 

595 content='Tool not executed - a final result was already processed.', 

596 tool_call_id=call.tool_call_id, 

597 ) 

598 ) 

599 else: 

600 event = _messages.FunctionToolCallEvent(call) 

601 yield event 

602 call_index_to_event_id[len(calls_to_run)] = event.call_id 

603 calls_to_run.append((tool, call)) 

604 elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx): 

605 if stub_function_tools: 

606 # TODO(Marcelo): We should add coverage for this part of the code. 

607 output_parts.append( # pragma: no cover 

608 _messages.ToolReturnPart( 

609 tool_name=call.tool_name, 

610 content='Tool not executed - a final result was already processed.', 

611 tool_call_id=call.tool_call_id, 

612 ) 

613 ) 

614 else: 

615 event = _messages.FunctionToolCallEvent(call) 

616 yield event 

617 call_index_to_event_id[len(calls_to_run)] = event.call_id 

618 calls_to_run.append((mcp_tool, call)) 

619 elif result_schema is not None and call.tool_name in result_schema.tools: 

620 # if tool_name is in _result_schema, it means we found a result tool but an error occurred in 

621 # validation, we don't add another part here 

622 if result_tool_name is not None: 

623 if found_used_result_tool: 

624 content = 'Result tool not used - a final result was already processed.' 

625 else: 

626 # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part 

627 content = 'Result tool not used - result failed validation.' 

628 part = _messages.ToolReturnPart( 

629 tool_name=call.tool_name, 

630 content=content, 

631 tool_call_id=call.tool_call_id, 

632 ) 

633 output_parts.append(part) 

634 else: 

635 output_parts.append(_unknown_tool(call.tool_name, ctx)) 

636 

637 if not calls_to_run: 

638 return 

639 

640 # Run all tool tasks in parallel 

641 results_by_index: dict[int, _messages.ModelRequestPart] = {} 

642 with ctx.deps.tracer.start_as_current_span( 

643 'running tools', 

644 attributes={ 

645 'tools': [call.tool_name for _, call in calls_to_run], 

646 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', 

647 }, 

648 ): 

649 tasks = [ 

650 asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name) 

651 for tool, call in calls_to_run 

652 ] 

653 pending = tasks 

654 while pending: 

655 done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) 

656 for task in done: 

657 index = tasks.index(task) 

658 result = task.result() 

659 yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index]) 

660 if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)): 

661 results_by_index[index] = result 

662 else: 

663 assert_never(result) 

664 

665 # We append the results at the end, rather than as they are received, to retain a consistent ordering 

666 # This is mostly just to simplify testing 

667 for k in sorted(results_by_index): 

668 output_parts.append(results_by_index[k]) 

669 

670 

671async def _tool_from_mcp_server( 

672 tool_name: str, 

673 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

674) -> Tool[DepsT] | None: 

675 """Call each MCP server to find the tool with the given name. 

676 

677 Args: 

678 tool_name: The name of the tool to find. 

679 ctx: The current run context. 

680 

681 Returns: 

682 The tool with the given name, or `None` if no tool with the given name is found. 

683 """ 

684 

685 async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any: 

686 # There's no normal situation where the server will not be running at this point, we check just in case 

687 # some weird edge case occurs. 

688 if not server.is_running: # pragma: no cover 

689 raise exceptions.UserError(f'MCP server is not running: {server}') 

690 result = await server.call_tool(tool_name, args) 

691 return result 

692 

693 for server in ctx.deps.mcp_servers: 

694 tools = await server.list_tools() 

695 if tool_name in {tool.name for tool in tools}: 695 ↛ 693line 695 didn't jump to line 693 because the condition on line 695 was always true

696 return Tool(name=tool_name, function=run_tool, takes_ctx=True) 

697 return None 

698 

699 

700def _unknown_tool( 

701 tool_name: str, 

702 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

703) -> _messages.RetryPromptPart: 

704 ctx.state.increment_retries(ctx.deps.max_result_retries) 

705 tool_names = list(ctx.deps.function_tools.keys()) 

706 if result_schema := ctx.deps.result_schema: 

707 tool_names.extend(result_schema.tool_names()) 

708 

709 if tool_names: 

710 msg = f'Available tools: {", ".join(tool_names)}' 

711 else: 

712 msg = 'No tools available.' 

713 

714 return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}') 

715 

716 

717async def _validate_result( 

718 result_data: T, 

719 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], 

720 tool_call: _messages.ToolCallPart | None, 

721) -> T: 

722 for validator in ctx.deps.result_validators: 

723 run_context = build_run_context(ctx) 

724 result_data = await validator.validate(result_data, tool_call, run_context) 

725 return result_data 

726 

727 

728def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: 

729 """Check if the result schema allows text results.""" 

730 return result_schema is None or result_schema.allow_text_result 

731 

732 

733@dataclasses.dataclass 

734class _RunMessages: 

735 messages: list[_messages.ModelMessage] 

736 used: bool = False 

737 

738 

739_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var') 

740 

741 

742@contextmanager 

743def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: 

744 """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call. 

745 

746 Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information. 

747 

748 Examples: 

749 ```python 

750 from pydantic_ai import Agent, capture_run_messages 

751 

752 agent = Agent('test') 

753 

754 with capture_run_messages() as messages: 

755 try: 

756 result = agent.run_sync('foobar') 

757 except Exception: 

758 print(messages) 

759 raise 

760 ``` 

761 

762 !!! note 

763 If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context, 

764 `messages` will represent the messages exchanged during the first call only. 

765 """ 

766 try: 

767 yield _messages_ctx_var.get().messages 

768 except LookupError: 

769 messages: list[_messages.ModelMessage] = [] 

770 token = _messages_ctx_var.set(_RunMessages(messages)) 

771 try: 

772 yield messages 

773 finally: 

774 _messages_ctx_var.reset(token) 

775 

776 

777def get_captured_run_messages() -> _RunMessages: 

778 return _messages_ctx_var.get() 

779 

780 

781def build_agent_graph( 

782 name: str | None, deps_type: type[DepsT], result_type: type[ResultT] 

783) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]: 

784 """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" 

785 nodes = ( 

786 UserPromptNode[DepsT], 

787 ModelRequestNode[DepsT], 

788 CallToolsNode[DepsT], 

789 ) 

790 graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]( 

791 nodes=nodes, 

792 name=name or 'Agent', 

793 state_type=GraphAgentState, 

794 run_end_type=result.FinalResult[result_type], 

795 auto_instrument=False, 

796 ) 

797 return graph