Coverage for pydantic_ai_slim/pydantic_ai/_agent_graph.py: 98.57%

386 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-30 19:21 +0000

1from __future__ import annotations as _annotations 

2 

3import asyncio 

4import dataclasses 

5from abc import ABC 

6from collections.abc import AsyncIterator, Iterator, Sequence 

7from contextlib import asynccontextmanager, contextmanager 

8from contextvars import ContextVar 

9from dataclasses import field 

10from typing import Any, Generic, Literal, Union, cast 

11 

12import logfire_api 

13from typing_extensions import TypeVar, assert_never 

14 

15from pydantic_graph import BaseNode, Graph, GraphRunContext 

16from pydantic_graph.nodes import End, NodeRunEndT 

17 

18from . import ( 

19 _result, 

20 _system_prompt, 

21 exceptions, 

22 messages as _messages, 

23 models, 

24 result, 

25 usage as _usage, 

26) 

27from .result import ResultDataT 

28from .settings import ModelSettings, merge_model_settings 

29from .tools import ( 

30 RunContext, 

31 Tool, 

32 ToolDefinition, 

33) 

34 

35_logfire = logfire_api.Logfire(otel_scope='pydantic-ai') 

36 

37# while waiting for https://github.com/pydantic/logfire/issues/745 

38try: 

39 import logfire._internal.stack_info 

40except ImportError: 

41 pass 

42else: 

43 from pathlib import Path 

44 

45 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) 

46 

47T = TypeVar('T') 

48NoneType = type(None) 

49EndStrategy = Literal['early', 'exhaustive'] 

50"""The strategy for handling multiple tool calls when a final result is found. 

51 

52- `'early'`: Stop processing other tool calls once a final result is found 

53- `'exhaustive'`: Process all tool calls even after finding a final result 

54""" 

55DepsT = TypeVar('DepsT') 

56ResultT = TypeVar('ResultT') 

57 

58 

59@dataclasses.dataclass 

60class MarkFinalResult(Generic[ResultDataT]): 

61 """Marker class to indicate that the result is the final result. 

62 

63 This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. 

64 

65 It also avoids problems in the case where the result type is itself `None`, but is set. 

66 """ 

67 

68 data: ResultDataT 

69 """The final result data.""" 

70 tool_name: str | None 

71 """Name of the final result tool, None if the result is a string.""" 

72 

73 

74@dataclasses.dataclass 

75class GraphAgentState: 

76 """State kept across the execution of the agent graph.""" 

77 

78 message_history: list[_messages.ModelMessage] 

79 usage: _usage.Usage 

80 retries: int 

81 run_step: int 

82 

83 def increment_retries(self, max_result_retries: int) -> None: 

84 self.retries += 1 

85 if self.retries > max_result_retries: 

86 raise exceptions.UnexpectedModelBehavior( 

87 f'Exceeded maximum retries ({max_result_retries}) for result validation' 

88 ) 

89 

90 

91@dataclasses.dataclass 

92class GraphAgentDeps(Generic[DepsT, ResultDataT]): 

93 """Dependencies/config passed to the agent graph.""" 

94 

95 user_deps: DepsT 

96 

97 prompt: str 

98 new_message_index: int 

99 

100 model: models.Model 

101 model_settings: ModelSettings | None 

102 usage_limits: _usage.UsageLimits 

103 max_result_retries: int 

104 end_strategy: EndStrategy 

105 

106 result_schema: _result.ResultSchema[ResultDataT] | None 

107 result_tools: list[ToolDefinition] 

108 result_validators: list[_result.ResultValidator[DepsT, ResultDataT]] 

109 

110 function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) 

111 

112 run_span: logfire_api.LogfireSpan 

113 

114 

115@dataclasses.dataclass 

116class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): 

117 user_prompt: str 

118 

119 system_prompts: tuple[str, ...] 

120 system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] 

121 system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] 

122 

123 async def _get_first_message( 

124 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] 

125 ) -> _messages.ModelRequest: 

126 run_context = _build_run_context(ctx) 

127 history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context) 

128 ctx.state.message_history = history 

129 run_context.messages = history 

130 

131 # TODO: We need to make it so that function_tools are not shared between runs 

132 # See comment on the current_retry field of `Tool` for more details. 

133 for tool in ctx.deps.function_tools.values(): 

134 tool.current_retry = 0 

135 return next_message 

136 

137 async def _prepare_messages( 

138 self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT] 

139 ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]: 

140 try: 

141 ctx_messages = get_captured_run_messages() 

142 except LookupError: 

143 messages: list[_messages.ModelMessage] = [] 

144 else: 

145 if ctx_messages.used: 

146 messages = [] 

147 else: 

148 messages = ctx_messages.messages 

149 ctx_messages.used = True 

150 

151 if message_history: 

152 # Shallow copy messages 

153 messages.extend(message_history) 

154 # Reevaluate any dynamic system prompt parts 

155 await self._reevaluate_dynamic_prompts(messages, run_context) 

156 return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)]) 

157 else: 

158 parts = await self._sys_parts(run_context) 

159 parts.append(_messages.UserPromptPart(user_prompt)) 

160 return messages, _messages.ModelRequest(parts) 

161 

162 async def _reevaluate_dynamic_prompts( 

163 self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT] 

164 ) -> None: 

165 """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function.""" 

166 # Only proceed if there's at least one dynamic runner. 

167 if self.system_prompt_dynamic_functions: 

168 for msg in messages: 

169 if isinstance(msg, _messages.ModelRequest): 

170 for i, part in enumerate(msg.parts): 

171 if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref: 

172 # Look up the runner by its ref 

173 if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref): 173 ↛ 170line 173 didn't jump to line 170 because the condition on line 173 was always true

174 updated_part_content = await runner.run(run_context) 

175 msg.parts[i] = _messages.SystemPromptPart( 

176 updated_part_content, dynamic_ref=part.dynamic_ref 

177 ) 

178 

179 async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]: 

180 """Build the initial messages for the conversation.""" 

181 messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts] 

182 for sys_prompt_runner in self.system_prompt_functions: 

183 prompt = await sys_prompt_runner.run(run_context) 

184 if sys_prompt_runner.dynamic: 

185 messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__)) 

186 else: 

187 messages.append(_messages.SystemPromptPart(prompt)) 

188 return messages 

189 

190 

191@dataclasses.dataclass 

192class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): 

193 async def run( 

194 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] 

195 ) -> ModelRequestNode[DepsT, NodeRunEndT]: 

196 return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) 

197 

198 

199@dataclasses.dataclass 

200class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): 

201 async def run( 

202 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] 

203 ) -> StreamModelRequestNode[DepsT, NodeRunEndT]: 

204 return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) 

205 

206 

207async def _prepare_model( 

208 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

209) -> models.AgentModel: 

210 """Build tools and create an agent model.""" 

211 function_tool_defs: list[ToolDefinition] = [] 

212 

213 run_context = _build_run_context(ctx) 

214 

215 async def add_tool(tool: Tool[DepsT]) -> None: 

216 ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) 

217 if tool_def := await tool.prepare_tool_def(ctx): 

218 function_tool_defs.append(tool_def) 

219 

220 await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values())) 

221 

222 result_schema = ctx.deps.result_schema 

223 return await run_context.model.agent_model( 

224 function_tools=function_tool_defs, 

225 allow_text_result=_allow_text_result(result_schema), 

226 result_tools=result_schema.tool_defs() if result_schema is not None else [], 

227 ) 

228 

229 

230@dataclasses.dataclass 

231class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): 

232 """Make a request to the model using the last message in state.message_history.""" 

233 

234 request: _messages.ModelRequest 

235 

236 async def run( 

237 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

238 ) -> HandleResponseNode[DepsT, NodeRunEndT]: 

239 ctx.state.message_history.append(self.request) 

240 

241 # Check usage 

242 if ctx.deps.usage_limits: 242 ↛ 246line 242 didn't jump to line 246 because the condition on line 242 was always true

243 ctx.deps.usage_limits.check_before_request(ctx.state.usage) 

244 

245 # Increment run_step 

246 ctx.state.run_step += 1 

247 

248 with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step): 

249 agent_model = await _prepare_model(ctx) 

250 

251 # Actually make the model request 

252 model_settings = merge_model_settings(ctx.deps.model_settings, None) 

253 with _logfire.span('model request') as span: 

254 model_response, request_usage = await agent_model.request(ctx.state.message_history, model_settings) 

255 span.set_attribute('response', model_response) 

256 span.set_attribute('usage', request_usage) 

257 

258 # Update usage 

259 ctx.state.usage.incr(request_usage, requests=1) 

260 if ctx.deps.usage_limits: 260 ↛ 264line 260 didn't jump to line 264 because the condition on line 260 was always true

261 ctx.deps.usage_limits.check_tokens(ctx.state.usage) 

262 

263 # Append the model response to state.message_history 

264 ctx.state.message_history.append(model_response) 

265 return HandleResponseNode(model_response) 

266 

267 

268@dataclasses.dataclass 

269class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): 

270 """Process e response from a model, decide whether to end the run or make a new request.""" 

271 

272 model_response: _messages.ModelResponse 

273 

274 async def run( 

275 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

276 ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007 

277 with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span: 

278 texts: list[str] = [] 

279 tool_calls: list[_messages.ToolCallPart] = [] 

280 for part in self.model_response.parts: 

281 if isinstance(part, _messages.TextPart): 

282 # ignore empty content for text parts, see #437 

283 if part.content: 

284 texts.append(part.content) 

285 elif isinstance(part, _messages.ToolCallPart): 

286 tool_calls.append(part) 

287 else: 

288 assert_never(part) 

289 

290 # At the moment, we prioritize at least executing tool calls if they are present. 

291 # In the future, we'd consider making this configurable at the agent or run level. 

292 # This accounts for cases like anthropic returns that might contain a text response 

293 # and a tool call response, where the text response just indicates the tool call will happen. 

294 if tool_calls: 

295 return await self._handle_tool_calls_response(ctx, tool_calls, handle_span) 

296 elif texts: 

297 return await self._handle_text_response(ctx, texts, handle_span) 

298 else: 

299 raise exceptions.UnexpectedModelBehavior('Received empty model response') 

300 

301 async def _handle_tool_calls_response( 

302 self, 

303 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

304 tool_calls: list[_messages.ToolCallPart], 

305 handle_span: logfire_api.LogfireSpan, 

306 ): 

307 result_schema = ctx.deps.result_schema 

308 

309 # first look for the result tool call 

310 final_result: MarkFinalResult[NodeRunEndT] | None = None 

311 parts: list[_messages.ModelRequestPart] = [] 

312 if result_schema is not None: 

313 if match := result_schema.find_tool(tool_calls): 

314 call, result_tool = match 

315 try: 

316 result_data = result_tool.validate(call) 

317 result_data = await _validate_result(result_data, ctx, call) 

318 except _result.ToolRetryError as e: 

319 # TODO: Should only increment retry stuff once per node execution, not for each tool call 

320 # Also, should increment the tool-specific retry count rather than the run retry count 

321 ctx.state.increment_retries(ctx.deps.max_result_retries) 

322 parts.append(e.tool_retry) 

323 else: 

324 final_result = MarkFinalResult(result_data, call.tool_name) 

325 

326 # Then build the other request parts based on end strategy 

327 tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx) 

328 

329 if final_result: 

330 handle_span.set_attribute('result', final_result.data) 

331 handle_span.message = 'handle model response -> final result' 

332 return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses) 

333 else: 

334 if tool_responses: 

335 handle_span.set_attribute('tool_responses', tool_responses) 

336 tool_responses_str = ' '.join(r.part_kind for r in tool_responses) 

337 handle_span.message = f'handle model response -> {tool_responses_str}' 

338 parts.extend(tool_responses) 

339 return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) 

340 

341 async def _handle_text_response( 

342 self, 

343 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

344 texts: list[str], 

345 handle_span: logfire_api.LogfireSpan, 

346 ): 

347 result_schema = ctx.deps.result_schema 

348 

349 text = '\n\n'.join(texts) 

350 if _allow_text_result(result_schema): 

351 result_data_input = cast(NodeRunEndT, text) 

352 try: 

353 result_data = await _validate_result(result_data_input, ctx, None) 

354 except _result.ToolRetryError as e: 

355 ctx.state.increment_retries(ctx.deps.max_result_retries) 

356 return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) 

357 else: 

358 handle_span.set_attribute('result', result_data) 

359 handle_span.message = 'handle model response -> final result' 

360 return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None)) 

361 else: 

362 ctx.state.increment_retries(ctx.deps.max_result_retries) 

363 return ModelRequestNode[DepsT, NodeRunEndT]( 

364 _messages.ModelRequest( 

365 parts=[ 

366 _messages.RetryPromptPart( 

367 content='Plain text responses are not permitted, please call one of the functions instead.', 

368 ) 

369 ] 

370 ) 

371 ) 

372 

373 

374@dataclasses.dataclass 

375class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): 

376 """Make a request to the model using the last message in state.message_history (or a specified request).""" 

377 

378 request: _messages.ModelRequest 

379 _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = ( 

380 field(default=None, repr=False) 

381 ) 

382 

383 async def run( 

384 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

385 ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007 

386 if self._result is not None: 386 ↛ 389line 386 didn't jump to line 389 because the condition on line 386 was always true

387 return self._result 

388 

389 async with self.run_to_result(ctx) as final_node: 

390 return final_node 

391 

392 @asynccontextmanager 

393 async def run_to_result( 

394 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

395 ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: 

396 result_schema = ctx.deps.result_schema 

397 

398 ctx.state.message_history.append(self.request) 

399 

400 # Check usage 

401 if ctx.deps.usage_limits: 401 ↛ 405line 401 didn't jump to line 405 because the condition on line 401 was always true

402 ctx.deps.usage_limits.check_before_request(ctx.state.usage) 

403 

404 # Increment run_step 

405 ctx.state.run_step += 1 

406 

407 with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step): 

408 agent_model = await _prepare_model(ctx) 

409 

410 # Actually make the model request 

411 model_settings = merge_model_settings(ctx.deps.model_settings, None) 

412 with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span: 

413 async with agent_model.request_stream(ctx.state.message_history, model_settings) as streamed_response: 

414 ctx.state.usage.requests += 1 

415 model_req_span.set_attribute('response_type', streamed_response.__class__.__name__) 

416 # We want to end the "model request" span here, but we can't exit the context manager 

417 # in the traditional way 

418 model_req_span.__exit__(None, None, None) 

419 

420 with _logfire.span('handle model response') as handle_span: 

421 received_text = False 

422 

423 async for maybe_part_event in streamed_response: 

424 if isinstance(maybe_part_event, _messages.PartStartEvent): 

425 new_part = maybe_part_event.part 

426 if isinstance(new_part, _messages.TextPart): 

427 received_text = True 

428 if _allow_text_result(result_schema): 

429 handle_span.message = 'handle model response -> final result' 

430 streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx) 

431 self._result = End(streamed_run_result) 

432 yield self._result 

433 return 

434 elif isinstance(new_part, _messages.ToolCallPart): 

435 if result_schema is not None and (match := result_schema.find_tool([new_part])): 

436 call, _ = match 

437 handle_span.message = 'handle model response -> final result' 

438 streamed_run_result = _build_streamed_run_result( 

439 streamed_response, call.tool_name, ctx 

440 ) 

441 self._result = End(streamed_run_result) 

442 yield self._result 

443 return 

444 else: 

445 assert_never(new_part) 

446 

447 tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] 

448 parts: list[_messages.ModelRequestPart] = [] 

449 model_response = streamed_response.get() 

450 if not model_response.parts: 

451 raise exceptions.UnexpectedModelBehavior('Received empty model response') 

452 ctx.state.message_history.append(model_response) 

453 

454 run_context = _build_run_context(ctx) 

455 for p in model_response.parts: 

456 if isinstance(p, _messages.ToolCallPart): 

457 if tool := ctx.deps.function_tools.get(p.tool_name): 

458 tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) 

459 else: 

460 parts.append(_unknown_tool(p.tool_name, ctx)) 

461 

462 if received_text and not tasks and not parts: 

463 # Can only get here if self._allow_text_result returns `False` for the provided result_schema 

464 ctx.state.increment_retries(ctx.deps.max_result_retries) 

465 self._result = StreamModelRequestNode[DepsT, NodeRunEndT]( 

466 _messages.ModelRequest( 

467 parts=[ 

468 _messages.RetryPromptPart( 

469 content='Plain text responses are not permitted, please call one of the functions instead.', 

470 ) 

471 ] 

472 ) 

473 ) 

474 yield self._result 

475 return 

476 

477 with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): 

478 task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) 

479 parts.extend(task_results) 

480 

481 next_request = _messages.ModelRequest(parts=parts) 

482 if any(isinstance(part, _messages.RetryPromptPart) for part in parts): 

483 try: 

484 ctx.state.increment_retries(ctx.deps.max_result_retries) 

485 except: 

486 # TODO: This is janky, so I think we should probably change it, but how? 

487 ctx.state.message_history.append(next_request) 

488 raise 

489 

490 handle_span.set_attribute('tool_responses', parts) 

491 tool_responses_str = ' '.join(r.part_kind for r in parts) 

492 handle_span.message = f'handle model response -> {tool_responses_str}' 

493 # the model_response should have been fully streamed by now, we can add its usage 

494 streamed_response_usage = streamed_response.usage() 

495 run_context.usage.incr(streamed_response_usage) 

496 ctx.deps.usage_limits.check_tokens(run_context.usage) 

497 self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request) 

498 yield self._result 

499 return 

500 

501 

502@dataclasses.dataclass 

503class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]): 

504 """Produce the final result of the run.""" 

505 

506 data: MarkFinalResult[NodeRunEndT] 

507 """The final result data.""" 

508 extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list) 

509 

510 async def run( 

511 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] 

512 ) -> End[MarkFinalResult[NodeRunEndT]]: 

513 run_span = ctx.deps.run_span 

514 usage = ctx.state.usage 

515 messages = ctx.state.message_history 

516 

517 # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries 

518 if self.extra_parts: 

519 messages.append(_messages.ModelRequest(parts=self.extra_parts)) 

520 

521 # TODO: Set this attribute somewhere 

522 # handle_span = self.handle_model_response_span 

523 # handle_span.set_attribute('final_data', self.data) 

524 run_span.set_attribute('usage', usage) 

525 run_span.set_attribute('all_messages', messages) 

526 

527 # End the run with self.data 

528 return End(self.data) 

529 

530 

531def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: 

532 return RunContext[DepsT]( 

533 deps=ctx.deps.user_deps, 

534 model=ctx.deps.model, 

535 usage=ctx.state.usage, 

536 prompt=ctx.deps.prompt, 

537 messages=ctx.state.message_history, 

538 run_step=ctx.state.run_step, 

539 ) 

540 

541 

542def _build_streamed_run_result( 

543 result_stream: models.StreamedResponse, 

544 result_tool_name: str | None, 

545 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

546) -> result.StreamedRunResult[DepsT, NodeRunEndT]: 

547 new_message_index = ctx.deps.new_message_index 

548 result_schema = ctx.deps.result_schema 

549 run_span = ctx.deps.run_span 

550 usage_limits = ctx.deps.usage_limits 

551 messages = ctx.state.message_history 

552 run_context = _build_run_context(ctx) 

553 

554 async def on_complete(): 

555 """Called when the stream has completed. 

556 

557 The model response will have been added to messages by now 

558 by `StreamedRunResult._marked_completed`. 

559 """ 

560 last_message = messages[-1] 

561 assert isinstance(last_message, _messages.ModelResponse) 

562 tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)] 

563 parts = await _process_function_tools( 

564 tool_calls, 

565 result_tool_name, 

566 ctx, 

567 ) 

568 # TODO: Should we do something here related to the retry count? 

569 # Maybe we should move the incrementing of the retry count to where we actually make a request? 

570 # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): 

571 # ctx.state.increment_retries(ctx.deps.max_result_retries) 

572 if parts: 

573 messages.append(_messages.ModelRequest(parts)) 

574 run_span.set_attribute('all_messages', messages) 

575 

576 return result.StreamedRunResult[DepsT, NodeRunEndT]( 

577 messages, 

578 new_message_index, 

579 usage_limits, 

580 result_stream, 

581 result_schema, 

582 run_context, 

583 ctx.deps.result_validators, 

584 result_tool_name, 

585 on_complete, 

586 ) 

587 

588 

589async def _process_function_tools( 

590 tool_calls: list[_messages.ToolCallPart], 

591 result_tool_name: str | None, 

592 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

593) -> list[_messages.ModelRequestPart]: 

594 """Process function (non-result) tool calls in parallel. 

595 

596 Also add stub return parts for any other tools that need it. 

597 """ 

598 parts: list[_messages.ModelRequestPart] = [] 

599 tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = [] 

600 

601 stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early' 

602 result_schema = ctx.deps.result_schema 

603 

604 # we rely on the fact that if we found a result, it's the first result tool in the last 

605 found_used_result_tool = False 

606 run_context = _build_run_context(ctx) 

607 

608 for call in tool_calls: 

609 if call.tool_name == result_tool_name and not found_used_result_tool: 

610 found_used_result_tool = True 

611 parts.append( 

612 _messages.ToolReturnPart( 

613 tool_name=call.tool_name, 

614 content='Final result processed.', 

615 tool_call_id=call.tool_call_id, 

616 ) 

617 ) 

618 elif tool := ctx.deps.function_tools.get(call.tool_name): 

619 if stub_function_tools: 

620 parts.append( 

621 _messages.ToolReturnPart( 

622 tool_name=call.tool_name, 

623 content='Tool not executed - a final result was already processed.', 

624 tool_call_id=call.tool_call_id, 

625 ) 

626 ) 

627 else: 

628 tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name)) 

629 elif result_schema is not None and call.tool_name in result_schema.tools: 

630 # if tool_name is in _result_schema, it means we found a result tool but an error occurred in 

631 # validation, we don't add another part here 

632 if result_tool_name is not None: 

633 parts.append( 

634 _messages.ToolReturnPart( 

635 tool_name=call.tool_name, 

636 content='Result tool not used - a final result was already processed.', 

637 tool_call_id=call.tool_call_id, 

638 ) 

639 ) 

640 else: 

641 parts.append(_unknown_tool(call.tool_name, ctx)) 

642 

643 # Run all tool tasks in parallel 

644 if tasks: 

645 with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): 

646 task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks) 

647 for result in task_results: 

648 if isinstance(result, _messages.ToolReturnPart): 

649 parts.append(result) 

650 elif isinstance(result, _messages.RetryPromptPart): 

651 parts.append(result) 

652 else: 

653 assert_never(result) 

654 return parts 

655 

656 

657def _unknown_tool( 

658 tool_name: str, 

659 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], 

660) -> _messages.RetryPromptPart: 

661 ctx.state.increment_retries(ctx.deps.max_result_retries) 

662 tool_names = list(ctx.deps.function_tools.keys()) 

663 if result_schema := ctx.deps.result_schema: 

664 tool_names.extend(result_schema.tool_names()) 

665 

666 if tool_names: 

667 msg = f'Available tools: {", ".join(tool_names)}' 

668 else: 

669 msg = 'No tools available.' 

670 

671 return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}') 

672 

673 

674async def _validate_result( 

675 result_data: T, 

676 ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], 

677 tool_call: _messages.ToolCallPart | None, 

678) -> T: 

679 for validator in ctx.deps.result_validators: 

680 run_context = _build_run_context(ctx) 

681 result_data = await validator.validate(result_data, tool_call, run_context) 

682 return result_data 

683 

684 

685def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: 

686 return result_schema is None or result_schema.allow_text_result 

687 

688 

689@dataclasses.dataclass 

690class _RunMessages: 

691 messages: list[_messages.ModelMessage] 

692 used: bool = False 

693 

694 

695_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var') 

696 

697 

698@contextmanager 

699def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: 

700 """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call. 

701 

702 Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information. 

703 

704 Examples: 

705 ```python 

706 from pydantic_ai import Agent, capture_run_messages 

707 

708 agent = Agent('test') 

709 

710 with capture_run_messages() as messages: 

711 try: 

712 result = agent.run_sync('foobar') 

713 except Exception: 

714 print(messages) 

715 raise 

716 ``` 

717 

718 !!! note 

719 If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context, 

720 `messages` will represent the messages exchanged during the first call only. 

721 """ 

722 try: 

723 yield _messages_ctx_var.get().messages 

724 except LookupError: 

725 messages: list[_messages.ModelMessage] = [] 

726 token = _messages_ctx_var.set(_RunMessages(messages)) 

727 try: 

728 yield messages 

729 finally: 

730 _messages_ctx_var.reset(token) 

731 

732 

733def get_captured_run_messages() -> _RunMessages: 

734 return _messages_ctx_var.get() 

735 

736 

737def build_agent_graph( 

738 name: str | None, deps_type: type[DepsT], result_type: type[ResultT] 

739) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]: 

740 # We'll define the known node classes: 

741 nodes = ( 

742 UserPromptNode[DepsT], 

743 ModelRequestNode[DepsT], 

744 HandleResponseNode[DepsT], 

745 FinalResultNode[DepsT, ResultT], 

746 ) 

747 graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]( 

748 nodes=nodes, 

749 name=name or 'Agent', 

750 state_type=GraphAgentState, 

751 run_end_type=MarkFinalResult[result_type], 

752 auto_instrument=False, 

753 ) 

754 return graph 

755 

756 

757def build_agent_stream_graph( 

758 name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None 

759) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]: 

760 nodes = [ 

761 StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], 

762 StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], 

763 ] 

764 graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]( 

765 nodes=nodes, 

766 name=name or 'Agent', 

767 state_type=GraphAgentState, 

768 run_end_type=result.StreamedRunResult[DepsT, result_type], 

769 ) 

770 return graph