Coverage for tests/test_agent.py: 99.18%

465 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1import json 

2import re 

3import sys 

4from datetime import timezone 

5from typing import Any, Callable, Union 

6 

7import httpx 

8import pytest 

9from dirty_equals import IsJson 

10from inline_snapshot import snapshot 

11from pydantic import BaseModel, field_validator 

12from pydantic_core import to_json 

13 

14from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages 

15from pydantic_ai.messages import ( 

16 ModelMessage, 

17 ModelRequest, 

18 ModelResponse, 

19 ModelResponsePart, 

20 RetryPromptPart, 

21 SystemPromptPart, 

22 TextPart, 

23 ToolCallPart, 

24 ToolReturnPart, 

25 UserPromptPart, 

26) 

27from pydantic_ai.models import cached_async_http_client 

28from pydantic_ai.models.function import AgentInfo, FunctionModel 

29from pydantic_ai.models.test import TestModel 

30from pydantic_ai.result import RunResult, Usage 

31from pydantic_ai.tools import ToolDefinition 

32 

33from .conftest import IsNow, TestEnv 

34 

35pytestmark = pytest.mark.anyio 

36 

37 

38def test_result_tuple(): 

39 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

40 assert info.result_tools is not None 

41 args_json = '{"response": ["foo", "bar"]}' 

42 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

43 

44 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str]) 

45 

46 result = agent.run_sync('Hello') 

47 assert result.data == ('foo', 'bar') 

48 

49 

50class Foo(BaseModel): 

51 a: int 

52 b: str 

53 

54 

55def test_result_pydantic_model(): 

56 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

57 assert info.result_tools is not None 

58 args_json = '{"a": 1, "b": "foo"}' 

59 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

60 

61 agent = Agent(FunctionModel(return_model), result_type=Foo) 

62 

63 result = agent.run_sync('Hello') 

64 assert isinstance(result.data, Foo) 

65 assert result.data.model_dump() == {'a': 1, 'b': 'foo'} 

66 

67 

68def test_result_pydantic_model_retry(): 

69 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

70 assert info.result_tools is not None 

71 if len(messages) == 1: 

72 args_json = '{"a": "wrong", "b": "foo"}' 

73 else: 

74 args_json = '{"a": 42, "b": "foo"}' 

75 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

76 

77 agent = Agent(FunctionModel(return_model), result_type=Foo) 

78 

79 assert agent.name is None 

80 

81 result = agent.run_sync('Hello') 

82 assert agent.name == 'agent' 

83 assert isinstance(result.data, Foo) 

84 assert result.data.model_dump() == {'a': 42, 'b': 'foo'} 

85 assert result.all_messages() == snapshot( 

86 [ 

87 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

88 ModelResponse( 

89 parts=[ToolCallPart('final_result', '{"a": "wrong", "b": "foo"}')], 

90 model_name='function:return_model', 

91 timestamp=IsNow(tz=timezone.utc), 

92 ), 

93 ModelRequest( 

94 parts=[ 

95 RetryPromptPart( 

96 tool_name='final_result', 

97 content=[ 

98 { 

99 'type': 'int_parsing', 

100 'loc': ('a',), 

101 'msg': 'Input should be a valid integer, unable to parse string as an integer', 

102 'input': 'wrong', 

103 } 

104 ], 

105 timestamp=IsNow(tz=timezone.utc), 

106 ) 

107 ] 

108 ), 

109 ModelResponse( 

110 parts=[ToolCallPart('final_result', '{"a": 42, "b": "foo"}')], 

111 model_name='function:return_model', 

112 timestamp=IsNow(tz=timezone.utc), 

113 ), 

114 ModelRequest( 

115 parts=[ 

116 ToolReturnPart( 

117 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) 

118 ) 

119 ] 

120 ), 

121 ] 

122 ) 

123 assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') 

124 

125 

126def test_result_pydantic_model_validation_error(): 

127 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

128 assert info.result_tools is not None 

129 if len(messages) == 1: 

130 args_json = '{"a": 1, "b": "foo"}' 

131 else: 

132 args_json = '{"a": 1, "b": "bar"}' 

133 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

134 

135 class Bar(BaseModel): 

136 a: int 

137 b: str 

138 

139 @field_validator('b') 

140 def check_b(cls, v: str) -> str: 

141 if v == 'foo': 

142 raise ValueError('must not be foo') 

143 return v 

144 

145 agent = Agent(FunctionModel(return_model), result_type=Bar) 

146 

147 result = agent.run_sync('Hello') 

148 assert isinstance(result.data, Bar) 

149 assert result.data.model_dump() == snapshot({'a': 1, 'b': 'bar'}) 

150 messages_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result.all_messages()] 

151 assert messages_part_kinds == snapshot( 

152 [ 

153 ('request', ['user-prompt']), 

154 ('response', ['tool-call']), 

155 ('request', ['retry-prompt']), 

156 ('response', ['tool-call']), 

157 ('request', ['tool-return']), 

158 ] 

159 ) 

160 

161 user_retry = result.all_messages()[2] 

162 assert isinstance(user_retry, ModelRequest) 

163 retry_prompt = user_retry.parts[0] 

164 assert isinstance(retry_prompt, RetryPromptPart) 

165 assert retry_prompt.model_response() == snapshot("""\ 

1661 validation errors: [ 

167 { 

168 "type": "value_error", 

169 "loc": [ 

170 "b" 

171 ], 

172 "msg": "Value error, must not be foo", 

173 "input": "foo" 

174 } 

175] 

176 

177Fix the errors and try again.""") 

178 

179 

180def test_result_validator(): 

181 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

182 assert info.result_tools is not None 

183 if len(messages) == 1: 

184 args_json = '{"a": 41, "b": "foo"}' 

185 else: 

186 args_json = '{"a": 42, "b": "foo"}' 

187 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

188 

189 agent = Agent(FunctionModel(return_model), result_type=Foo) 

190 

191 @agent.result_validator 

192 def validate_result(ctx: RunContext[None], r: Foo) -> Foo: 

193 assert ctx.tool_name == 'final_result' 

194 if r.a == 42: 

195 return r 

196 else: 

197 raise ModelRetry('"a" should be 42') 

198 

199 result = agent.run_sync('Hello') 

200 assert isinstance(result.data, Foo) 

201 assert result.data.model_dump() == {'a': 42, 'b': 'foo'} 

202 assert result.all_messages() == snapshot( 

203 [ 

204 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

205 ModelResponse( 

206 parts=[ToolCallPart('final_result', '{"a": 41, "b": "foo"}')], 

207 model_name='function:return_model', 

208 timestamp=IsNow(tz=timezone.utc), 

209 ), 

210 ModelRequest( 

211 parts=[ 

212 RetryPromptPart( 

213 content='"a" should be 42', tool_name='final_result', timestamp=IsNow(tz=timezone.utc) 

214 ) 

215 ] 

216 ), 

217 ModelResponse( 

218 parts=[ToolCallPart('final_result', '{"a": 42, "b": "foo"}')], 

219 model_name='function:return_model', 

220 timestamp=IsNow(tz=timezone.utc), 

221 ), 

222 ModelRequest( 

223 parts=[ 

224 ToolReturnPart( 

225 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) 

226 ) 

227 ] 

228 ), 

229 ] 

230 ) 

231 

232 

233def test_plain_response_then_tuple(): 

234 call_index = 0 

235 

236 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

237 nonlocal call_index 

238 

239 assert info.result_tools is not None 

240 call_index += 1 

241 if call_index == 1: 

242 return ModelResponse(parts=[TextPart('hello')]) 

243 else: 

244 args_json = '{"response": ["foo", "bar"]}' 

245 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

246 

247 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str]) 

248 

249 result = agent.run_sync('Hello') 

250 assert result.data == ('foo', 'bar') 

251 assert call_index == 2 

252 assert result.all_messages() == snapshot( 

253 [ 

254 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

255 ModelResponse( 

256 parts=[TextPart(content='hello')], 

257 model_name='function:return_tuple', 

258 timestamp=IsNow(tz=timezone.utc), 

259 ), 

260 ModelRequest( 

261 parts=[ 

262 RetryPromptPart( 

263 content='Plain text responses are not permitted, please call one of the functions instead.', 

264 timestamp=IsNow(tz=timezone.utc), 

265 ) 

266 ] 

267 ), 

268 ModelResponse( 

269 parts=[ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}')], 

270 model_name='function:return_tuple', 

271 timestamp=IsNow(tz=timezone.utc), 

272 ), 

273 ModelRequest( 

274 parts=[ 

275 ToolReturnPart( 

276 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) 

277 ) 

278 ] 

279 ), 

280 ] 

281 ) 

282 assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] 

283 assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( 

284 ModelRequest( 

285 parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))] 

286 ) 

287 ) 

288 assert result.all_messages()[-1] == snapshot( 

289 ModelRequest( 

290 parts=[ 

291 ToolReturnPart( 

292 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) 

293 ) 

294 ] 

295 ) 

296 ) 

297 

298 

299def test_result_tool_return_content_str_return(): 

300 agent = Agent('test') 

301 

302 result = agent.run_sync('Hello') 

303 assert result.data == 'success (no tool calls)' 

304 

305 msg = re.escape('Cannot set result tool return content when the return type is `str`.') 

306 with pytest.raises(ValueError, match=msg): 

307 result.all_messages(result_tool_return_content='foobar') 

308 

309 

310def test_result_tool_return_content_no_tool(): 

311 agent = Agent('test', result_type=int) 

312 

313 result = agent.run_sync('Hello') 

314 assert result.data == 0 

315 result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] 

316 with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): 

317 result.all_messages(result_tool_return_content='foobar') 

318 

319 

320def test_response_tuple(): 

321 m = TestModel() 

322 

323 agent = Agent(m, result_type=tuple[str, str]) 

324 assert agent._result_schema.allow_text_result is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] 

325 

326 result = agent.run_sync('Hello') 

327 assert result.data == snapshot(('a', 'a')) 

328 

329 assert m.agent_model_function_tools == snapshot([]) 

330 assert m.agent_model_allow_text_result is False 

331 

332 assert m.agent_model_result_tools is not None 

333 assert len(m.agent_model_result_tools) == 1 

334 

335 assert m.agent_model_result_tools == snapshot( 

336 [ 

337 ToolDefinition( 

338 name='final_result', 

339 description='The final response which ends this conversation', 

340 parameters_json_schema={ 

341 'properties': { 

342 'response': { 

343 'maxItems': 2, 

344 'minItems': 2, 

345 'prefixItems': [{'type': 'string'}, {'type': 'string'}], 

346 'title': 'Response', 

347 'type': 'array', 

348 } 

349 }, 

350 'required': ['response'], 

351 'type': 'object', 

352 }, 

353 outer_typed_dict_key='response', 

354 ) 

355 ] 

356 ) 

357 

358 

359@pytest.mark.parametrize( 

360 'input_union_callable', 

361 [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str], 

362 ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'], 

363) 

364def test_response_union_allow_str(input_union_callable: Callable[[], Any]): 

365 try: 

366 union = input_union_callable() 

367 except TypeError: 

368 pytest.skip('Python version does not support `|` syntax for unions') 

369 

370 m = TestModel() 

371 agent: Agent[None, Union[str, Foo]] = Agent(m, result_type=union) 

372 

373 got_tool_call_name = 'unset' 

374 

375 @agent.result_validator 

376 def validate_result(ctx: RunContext[None], r: Any) -> Any: 

377 nonlocal got_tool_call_name 

378 got_tool_call_name = ctx.tool_name 

379 return r 

380 

381 assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] 

382 

383 result = agent.run_sync('Hello') 

384 assert result.data == snapshot('success (no tool calls)') 

385 assert got_tool_call_name == snapshot(None) 

386 

387 assert m.agent_model_function_tools == snapshot([]) 

388 assert m.agent_model_allow_text_result is True 

389 

390 assert m.agent_model_result_tools is not None 

391 assert len(m.agent_model_result_tools) == 1 

392 

393 assert m.agent_model_result_tools == snapshot( 

394 [ 

395 ToolDefinition( 

396 name='final_result', 

397 description='The final response which ends this conversation', 

398 parameters_json_schema={ 

399 'properties': { 

400 'a': {'title': 'A', 'type': 'integer'}, 

401 'b': {'title': 'B', 'type': 'string'}, 

402 }, 

403 'required': ['a', 'b'], 

404 'title': 'Foo', 

405 'type': 'object', 

406 }, 

407 ) 

408 ] 

409 ) 

410 

411 

412# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false 

413@pytest.mark.parametrize( 

414 'union_code', 

415 [ 

416 pytest.param('ResultType = Union[Foo, Bar]'), 

417 pytest.param('ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')), 

418 pytest.param( 

419 'ResultType: TypeAlias = Foo | Bar', 

420 marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='Python 3.10+'), 

421 ), 

422 pytest.param( 

423 'type ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+') 

424 ), 

425 ], 

426) 

427def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str): 

428 module_code = f''' 

429from pydantic import BaseModel 

430from typing import Union 

431from typing_extensions import TypeAlias 

432 

433class Foo(BaseModel): 

434 a: int 

435 b: str 

436 

437 

438class Bar(BaseModel): 

439 """This is a bar model.""" 

440 

441 b: str 

442 

443{union_code} 

444 ''' 

445 

446 mod = create_module(module_code) 

447 

448 m = TestModel() 

449 agent = Agent(m, result_type=mod.ResultType) 

450 got_tool_call_name = 'unset' 

451 

452 @agent.result_validator 

453 def validate_result(ctx: RunContext[None], r: Any) -> Any: 

454 nonlocal got_tool_call_name 

455 got_tool_call_name = ctx.tool_name 

456 return r 

457 

458 result = agent.run_sync('Hello') 

459 assert result.data == mod.Foo(a=0, b='a') 

460 assert got_tool_call_name == snapshot('final_result_Foo') 

461 

462 assert m.agent_model_function_tools == snapshot([]) 

463 assert m.agent_model_allow_text_result is False 

464 

465 assert m.agent_model_result_tools is not None 

466 assert len(m.agent_model_result_tools) == 2 

467 

468 assert m.agent_model_result_tools == snapshot( 

469 [ 

470 ToolDefinition( 

471 name='final_result_Foo', 

472 description='Foo: The final response which ends this conversation', 

473 parameters_json_schema={ 

474 'properties': { 

475 'a': {'title': 'A', 'type': 'integer'}, 

476 'b': {'title': 'B', 'type': 'string'}, 

477 }, 

478 'required': ['a', 'b'], 

479 'title': 'Foo', 

480 'type': 'object', 

481 }, 

482 ), 

483 ToolDefinition( 

484 name='final_result_Bar', 

485 description='This is a bar model.', 

486 parameters_json_schema={ 

487 'properties': {'b': {'title': 'B', 'type': 'string'}}, 

488 'required': ['b'], 

489 'title': 'Bar', 

490 'type': 'object', 

491 }, 

492 ), 

493 ] 

494 ) 

495 

496 result = agent.run_sync('Hello', model=TestModel(seed=1)) 

497 assert result.data == mod.Bar(b='b') 

498 assert got_tool_call_name == snapshot('final_result_Bar') 

499 

500 

501def test_run_with_history_new(): 

502 m = TestModel() 

503 

504 agent = Agent(m, system_prompt='Foobar') 

505 

506 @agent.tool_plain 

507 async def ret_a(x: str) -> str: 

508 return f'{x}-apple' 

509 

510 result1 = agent.run_sync('Hello') 

511 assert result1.new_messages() == snapshot( 

512 [ 

513 ModelRequest( 

514 parts=[ 

515 SystemPromptPart(content='Foobar'), 

516 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

517 ] 

518 ), 

519 ModelResponse( 

520 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], 

521 model_name='test', 

522 timestamp=IsNow(tz=timezone.utc), 

523 ), 

524 ModelRequest( 

525 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] 

526 ), 

527 ModelResponse( 

528 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

529 ), 

530 ] 

531 ) 

532 

533 # if we pass new_messages, system prompt is inserted before the message_history messages 

534 result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) 

535 assert result2 == snapshot( 

536 RunResult( 

537 _all_messages=[ 

538 ModelRequest( 

539 parts=[ 

540 SystemPromptPart(content='Foobar'), 

541 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

542 ] 

543 ), 

544 ModelResponse( 

545 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], 

546 model_name='test', 

547 timestamp=IsNow(tz=timezone.utc), 

548 ), 

549 ModelRequest( 

550 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] 

551 ), 

552 ModelResponse( 

553 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

554 ), 

555 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), 

556 ModelResponse( 

557 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

558 ), 

559 ], 

560 _new_message_index=4, 

561 data='{"ret_a":"a-apple"}', 

562 _result_tool_name=None, 

563 _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), 

564 ) 

565 ) 

566 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] 

567 assert new_msg_part_kinds == snapshot( 

568 [ 

569 ('request', ['system-prompt', 'user-prompt']), 

570 ('response', ['tool-call']), 

571 ('request', ['tool-return']), 

572 ('response', ['text']), 

573 ('request', ['user-prompt']), 

574 ('response', ['text']), 

575 ] 

576 ) 

577 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",') 

578 

579 # if we pass all_messages, system prompt is NOT inserted before the message_history messages, 

580 # so only one system prompt 

581 result3 = agent.run_sync('Hello again', message_history=result1.all_messages()) 

582 # same as result2 except for datetimes 

583 assert result3 == snapshot( 

584 RunResult( 

585 _all_messages=[ 

586 ModelRequest( 

587 parts=[ 

588 SystemPromptPart(content='Foobar'), 

589 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

590 ] 

591 ), 

592 ModelResponse( 

593 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], 

594 model_name='test', 

595 timestamp=IsNow(tz=timezone.utc), 

596 ), 

597 ModelRequest( 

598 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] 

599 ), 

600 ModelResponse( 

601 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

602 ), 

603 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), 

604 ModelResponse( 

605 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

606 ), 

607 ], 

608 _new_message_index=4, 

609 data='{"ret_a":"a-apple"}', 

610 _result_tool_name=None, 

611 _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), 

612 ) 

613 ) 

614 

615 

616def test_run_with_history_new_structured(): 

617 m = TestModel() 

618 

619 class Response(BaseModel): 

620 a: int 

621 

622 agent = Agent(m, system_prompt='Foobar', result_type=Response) 

623 

624 @agent.tool_plain 

625 async def ret_a(x: str) -> str: 

626 return f'{x}-apple' 

627 

628 result1 = agent.run_sync('Hello') 

629 assert result1.new_messages() == snapshot( 

630 [ 

631 ModelRequest( 

632 parts=[ 

633 SystemPromptPart(content='Foobar'), 

634 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

635 ] 

636 ), 

637 ModelResponse( 

638 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], 

639 model_name='test', 

640 timestamp=IsNow(tz=timezone.utc), 

641 ), 

642 ModelRequest( 

643 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] 

644 ), 

645 ModelResponse( 

646 parts=[ 

647 ToolCallPart( 

648 tool_name='final_result', 

649 args={'a': 0}, 

650 tool_call_id=None, 

651 ) 

652 ], 

653 model_name='test', 

654 timestamp=IsNow(tz=timezone.utc), 

655 ), 

656 ModelRequest( 

657 parts=[ 

658 ToolReturnPart( 

659 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) 

660 ) 

661 ] 

662 ), 

663 ] 

664 ) 

665 

666 result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) 

667 assert result2 == snapshot( 

668 RunResult( 

669 data=Response(a=0), 

670 _all_messages=[ 

671 ModelRequest( 

672 parts=[ 

673 SystemPromptPart(content='Foobar'), 

674 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

675 ], 

676 ), 

677 ModelResponse( 

678 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], 

679 model_name='test', 

680 timestamp=IsNow(tz=timezone.utc), 

681 ), 

682 ModelRequest( 

683 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))], 

684 ), 

685 ModelResponse( 

686 parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], 

687 model_name='test', 

688 timestamp=IsNow(tz=timezone.utc), 

689 ), 

690 ModelRequest( 

691 parts=[ 

692 ToolReturnPart( 

693 tool_name='final_result', 

694 content='Final result processed.', 

695 timestamp=IsNow(tz=timezone.utc), 

696 ), 

697 ], 

698 ), 

699 # second call, notice no repeated system prompt 

700 ModelRequest( 

701 parts=[ 

702 UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), 

703 ], 

704 ), 

705 ModelResponse( 

706 parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], 

707 model_name='test', 

708 timestamp=IsNow(tz=timezone.utc), 

709 ), 

710 ModelRequest( 

711 parts=[ 

712 ToolReturnPart( 

713 tool_name='final_result', 

714 content='Final result processed.', 

715 timestamp=IsNow(tz=timezone.utc), 

716 ), 

717 ] 

718 ), 

719 ], 

720 _new_message_index=5, 

721 _result_tool_name='final_result', 

722 _usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None), 

723 ) 

724 ) 

725 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] 

726 assert new_msg_part_kinds == snapshot( 

727 [ 

728 ('request', ['system-prompt', 'user-prompt']), 

729 ('response', ['tool-call']), 

730 ('request', ['tool-return']), 

731 ('response', ['tool-call']), 

732 ('request', ['tool-return']), 

733 ('request', ['user-prompt']), 

734 ('response', ['tool-call']), 

735 ('request', ['tool-return']), 

736 ] 

737 ) 

738 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",') 

739 

740 

741def test_empty_tool_calls(): 

742 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: 

743 return ModelResponse(parts=[]) 

744 

745 agent = Agent(FunctionModel(empty)) 

746 

747 with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): 

748 agent.run_sync('Hello') 

749 

750 

751def test_unknown_tool(): 

752 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: 

753 return ModelResponse(parts=[ToolCallPart('foobar', '{}')]) 

754 

755 agent = Agent(FunctionModel(empty)) 

756 

757 with capture_run_messages() as messages: 

758 with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): 

759 agent.run_sync('Hello') 

760 assert messages == snapshot( 

761 [ 

762 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

763 ModelResponse( 

764 parts=[ToolCallPart(tool_name='foobar', args='{}')], 

765 model_name='function:empty', 

766 timestamp=IsNow(tz=timezone.utc), 

767 ), 

768 ModelRequest( 

769 parts=[ 

770 RetryPromptPart( 

771 content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc) 

772 ) 

773 ] 

774 ), 

775 ModelResponse( 

776 parts=[ToolCallPart(tool_name='foobar', args='{}')], 

777 model_name='function:empty', 

778 timestamp=IsNow(tz=timezone.utc), 

779 ), 

780 ] 

781 ) 

782 

783 

784def test_unknown_tool_fix(): 

785 def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: 

786 if len(m) > 1: 

787 return ModelResponse(parts=[TextPart('success')]) 

788 else: 

789 return ModelResponse(parts=[ToolCallPart('foobar', '{}')]) 

790 

791 agent = Agent(FunctionModel(empty)) 

792 

793 result = agent.run_sync('Hello') 

794 assert result.data == 'success' 

795 assert result.all_messages() == snapshot( 

796 [ 

797 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

798 ModelResponse( 

799 parts=[ToolCallPart(tool_name='foobar', args='{}')], 

800 model_name='function:empty', 

801 timestamp=IsNow(tz=timezone.utc), 

802 ), 

803 ModelRequest( 

804 parts=[ 

805 RetryPromptPart( 

806 content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc) 

807 ) 

808 ] 

809 ), 

810 ModelResponse( 

811 parts=[TextPart(content='success')], 

812 model_name='function:empty', 

813 timestamp=IsNow(tz=timezone.utc), 

814 ), 

815 ] 

816 ) 

817 

818 

819def test_model_requests_blocked(env: TestEnv): 

820 env.set('GEMINI_API_KEY', 'foobar') 

821 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True) 

822 

823 with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): 

824 agent.run_sync('Hello') 

825 

826 

827def test_override_model(env: TestEnv): 

828 env.set('GEMINI_API_KEY', 'foobar') 

829 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True) 

830 

831 with agent.override(model='test'): 

832 result = agent.run_sync('Hello') 

833 assert result.data == snapshot((0, 'a')) 

834 

835 

836def test_override_model_no_model(): 

837 agent = Agent() 

838 

839 with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'): 

840 with agent.override(model='test'): 

841 agent.run_sync('Hello') 

842 

843 

844def test_run_sync_multiple(): 

845 agent = Agent('test') 

846 

847 @agent.tool_plain 

848 async def make_request() -> str: 

849 # raised a `RuntimeError: Event loop is closed` on repeat runs when we used `asyncio.run()` 

850 client = cached_async_http_client() 

851 # use this as I suspect it's about the fastest globally available endpoint 

852 try: 

853 response = await client.get('https://cloudflare.com/cdn-cgi/trace') 

854 except httpx.ConnectError: 

855 pytest.skip('offline') 

856 else: 

857 return str(response.status_code) 

858 

859 for _ in range(2): 

860 result = agent.run_sync('Hello') 

861 assert result.data == '{"make_request":"200"}' 

862 

863 

864async def test_agent_name(): 

865 my_agent = Agent('test') 

866 

867 assert my_agent.name is None 

868 

869 await my_agent.run('Hello', infer_name=False) 

870 assert my_agent.name is None 

871 

872 await my_agent.run('Hello') 

873 assert my_agent.name == 'my_agent' 

874 

875 

876async def test_agent_name_already_set(): 

877 my_agent = Agent('test', name='fig_tree') 

878 

879 assert my_agent.name == 'fig_tree' 

880 

881 await my_agent.run('Hello') 

882 assert my_agent.name == 'fig_tree' 

883 

884 

885async def test_agent_name_changes(): 

886 my_agent = Agent('test') 

887 

888 await my_agent.run('Hello') 

889 assert my_agent.name == 'my_agent' 

890 

891 new_agent = my_agent 

892 del my_agent 

893 

894 await new_agent.run('Hello') 

895 assert new_agent.name == 'my_agent' 

896 

897 

898def test_name_from_global(create_module: Callable[[str], Any]): 

899 module_code = """ 

900from pydantic_ai import Agent 

901 

902my_agent = Agent('test') 

903 

904def foo(): 

905 result = my_agent.run_sync('Hello') 

906 return result.data 

907""" 

908 

909 mod = create_module(module_code) 

910 

911 assert mod.my_agent.name is None 

912 assert mod.foo() == snapshot('success (no tool calls)') 

913 assert mod.my_agent.name == 'my_agent' 

914 

915 

916class TestMultipleToolCalls: 

917 """Tests for scenarios where multiple tool calls are made in a single response.""" 

918 

919 pytestmark = pytest.mark.usefixtures('set_event_loop') 

920 

921 class ResultType(BaseModel): 

922 """Result type used by all tests.""" 

923 

924 value: str 

925 

926 def test_early_strategy_stops_after_first_final_result(self): 

927 """Test that 'early' strategy stops processing regular tools after first final result.""" 

928 tool_called = [] 

929 

930 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

931 assert info.result_tools is not None 

932 return ModelResponse( 

933 parts=[ 

934 ToolCallPart('final_result', {'value': 'final'}), 

935 ToolCallPart('regular_tool', {'x': 1}), 

936 ToolCallPart('another_tool', {'y': 2}), 

937 ] 

938 ) 

939 

940 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early') 

941 

942 @agent.tool_plain 

943 def regular_tool(x: int) -> int: # pragma: no cover 

944 """A regular tool that should not be called.""" 

945 tool_called.append('regular_tool') 

946 return x 

947 

948 @agent.tool_plain 

949 def another_tool(y: int) -> int: # pragma: no cover 

950 """Another tool that should not be called.""" 

951 tool_called.append('another_tool') 

952 return y 

953 

954 result = agent.run_sync('test early strategy') 

955 messages = result.all_messages() 

956 

957 # Verify no tools were called after final result 

958 assert tool_called == [] 

959 

960 # Verify we got tool returns for all calls 

961 assert messages[-1].parts == snapshot( 

962 [ 

963 ToolReturnPart( 

964 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) 

965 ), 

966 ToolReturnPart( 

967 tool_name='regular_tool', 

968 content='Tool not executed - a final result was already processed.', 

969 timestamp=IsNow(tz=timezone.utc), 

970 ), 

971 ToolReturnPart( 

972 tool_name='another_tool', 

973 content='Tool not executed - a final result was already processed.', 

974 timestamp=IsNow(tz=timezone.utc), 

975 ), 

976 ] 

977 ) 

978 

979 def test_early_strategy_uses_first_final_result(self): 

980 """Test that 'early' strategy uses the first final result and ignores subsequent ones.""" 

981 

982 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

983 assert info.result_tools is not None 

984 return ModelResponse( 

985 parts=[ 

986 ToolCallPart('final_result', {'value': 'first'}), 

987 ToolCallPart('final_result', {'value': 'second'}), 

988 ] 

989 ) 

990 

991 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early') 

992 result = agent.run_sync('test multiple final results') 

993 

994 # Verify the result came from the first final tool 

995 assert result.data.value == 'first' 

996 

997 # Verify we got appropriate tool returns 

998 assert result.new_messages()[-1].parts == snapshot( 

999 [ 

1000 ToolReturnPart( 

1001 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) 

1002 ), 

1003 ToolReturnPart( 

1004 tool_name='final_result', 

1005 content='Result tool not used - a final result was already processed.', 

1006 timestamp=IsNow(tz=timezone.utc), 

1007 ), 

1008 ] 

1009 ) 

1010 

1011 def test_exhaustive_strategy_executes_all_tools(self): 

1012 """Test that 'exhaustive' strategy executes all tools while using first final result.""" 

1013 tool_called: list[str] = [] 

1014 

1015 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1016 assert info.result_tools is not None 

1017 return ModelResponse( 

1018 parts=[ 

1019 ToolCallPart('regular_tool', {'x': 42}), 

1020 ToolCallPart('final_result', {'value': 'first'}), 

1021 ToolCallPart('another_tool', {'y': 2}), 

1022 ToolCallPart('final_result', {'value': 'second'}), 

1023 ToolCallPart('unknown_tool', {'value': '???'}), 

1024 ] 

1025 ) 

1026 

1027 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='exhaustive') 

1028 

1029 @agent.tool_plain 

1030 def regular_tool(x: int) -> int: 

1031 """A regular tool that should be called.""" 

1032 tool_called.append('regular_tool') 

1033 return x 

1034 

1035 @agent.tool_plain 

1036 def another_tool(y: int) -> int: 

1037 """Another tool that should be called.""" 

1038 tool_called.append('another_tool') 

1039 return y 

1040 

1041 result = agent.run_sync('test exhaustive strategy') 

1042 

1043 # Verify the result came from the first final tool 

1044 assert result.data.value == 'first' 

1045 

1046 # Verify all regular tools were called 

1047 assert sorted(tool_called) == sorted(['regular_tool', 'another_tool']) 

1048 

1049 # Verify we got tool returns in the correct order 

1050 assert result.all_messages() == snapshot( 

1051 [ 

1052 ModelRequest( 

1053 parts=[UserPromptPart(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc))] 

1054 ), 

1055 ModelResponse( 

1056 parts=[ 

1057 ToolCallPart(tool_name='regular_tool', args={'x': 42}), 

1058 ToolCallPart(tool_name='final_result', args={'value': 'first'}), 

1059 ToolCallPart(tool_name='another_tool', args={'y': 2}), 

1060 ToolCallPart(tool_name='final_result', args={'value': 'second'}), 

1061 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}), 

1062 ], 

1063 model_name='function:return_model', 

1064 timestamp=IsNow(tz=timezone.utc), 

1065 ), 

1066 ModelRequest( 

1067 parts=[ 

1068 ToolReturnPart( 

1069 tool_name='final_result', 

1070 content='Final result processed.', 

1071 timestamp=IsNow(tz=timezone.utc), 

1072 ), 

1073 ToolReturnPart( 

1074 tool_name='final_result', 

1075 content='Result tool not used - a final result was already processed.', 

1076 timestamp=IsNow(tz=timezone.utc), 

1077 ), 

1078 RetryPromptPart( 

1079 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", 

1080 timestamp=IsNow(tz=timezone.utc), 

1081 ), 

1082 ToolReturnPart(tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc)), 

1083 ToolReturnPart(tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc)), 

1084 ] 

1085 ), 

1086 ] 

1087 ) 

1088 

1089 def test_early_strategy_with_final_result_in_middle(self): 

1090 """Test that 'early' strategy stops at first final result, regardless of position.""" 

1091 tool_called = [] 

1092 

1093 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1094 assert info.result_tools is not None 

1095 return ModelResponse( 

1096 parts=[ 

1097 ToolCallPart('regular_tool', {'x': 1}), 

1098 ToolCallPart('final_result', {'value': 'final'}), 

1099 ToolCallPart('another_tool', {'y': 2}), 

1100 ToolCallPart('unknown_tool', {'value': '???'}), 

1101 ] 

1102 ) 

1103 

1104 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early') 

1105 

1106 @agent.tool_plain 

1107 def regular_tool(x: int) -> int: # pragma: no cover 

1108 """A regular tool that should not be called.""" 

1109 tool_called.append('regular_tool') 

1110 return x 

1111 

1112 @agent.tool_plain 

1113 def another_tool(y: int) -> int: # pragma: no cover 

1114 """A tool that should not be called.""" 

1115 tool_called.append('another_tool') 

1116 return y 

1117 

1118 result = agent.run_sync('test early strategy with final result in middle') 

1119 

1120 # Verify no tools were called 

1121 assert tool_called == [] 

1122 

1123 # Verify we got appropriate tool returns 

1124 assert result.all_messages() == snapshot( 

1125 [ 

1126 ModelRequest( 

1127 parts=[ 

1128 UserPromptPart( 

1129 content='test early strategy with final result in middle', timestamp=IsNow(tz=timezone.utc) 

1130 ) 

1131 ] 

1132 ), 

1133 ModelResponse( 

1134 parts=[ 

1135 ToolCallPart(tool_name='regular_tool', args={'x': 1}), 

1136 ToolCallPart(tool_name='final_result', args={'value': 'final'}), 

1137 ToolCallPart(tool_name='another_tool', args={'y': 2}), 

1138 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}), 

1139 ], 

1140 model_name='function:return_model', 

1141 timestamp=IsNow(tz=timezone.utc), 

1142 ), 

1143 ModelRequest( 

1144 parts=[ 

1145 ToolReturnPart( 

1146 tool_name='regular_tool', 

1147 content='Tool not executed - a final result was already processed.', 

1148 timestamp=IsNow(tz=timezone.utc), 

1149 ), 

1150 ToolReturnPart( 

1151 tool_name='final_result', 

1152 content='Final result processed.', 

1153 timestamp=IsNow(tz=timezone.utc), 

1154 ), 

1155 ToolReturnPart( 

1156 tool_name='another_tool', 

1157 content='Tool not executed - a final result was already processed.', 

1158 timestamp=IsNow(tz=timezone.utc), 

1159 ), 

1160 RetryPromptPart( 

1161 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", 

1162 timestamp=IsNow(tz=timezone.utc), 

1163 ), 

1164 ] 

1165 ), 

1166 ] 

1167 ) 

1168 

1169 def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(self): 

1170 """Test that 'early' strategy does not apply to tool calls without final tool.""" 

1171 tool_called = [] 

1172 agent = Agent(TestModel(), result_type=self.ResultType, end_strategy='early') 

1173 

1174 @agent.tool_plain 

1175 def regular_tool(x: int) -> int: 

1176 """A regular tool that should be called.""" 

1177 tool_called.append('regular_tool') 

1178 return x 

1179 

1180 result = agent.run_sync('test early strategy with regular tool calls') 

1181 assert tool_called == ['regular_tool'] 

1182 

1183 tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)] 

1184 assert tool_returns == snapshot([]) 

1185 

1186 

1187async def test_model_settings_override() -> None: 

1188 def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1189 return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())]) 

1190 

1191 my_agent = Agent(FunctionModel(return_settings)) 

1192 assert (await my_agent.run('Hello')).data == IsJson(None) 

1193 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5}) 

1194 

1195 my_agent = Agent(FunctionModel(return_settings), model_settings={'temperature': 0.1}) 

1196 assert (await my_agent.run('Hello')).data == IsJson({'temperature': 0.1}) 

1197 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5}) 

1198 

1199 

1200async def test_empty_text_part(): 

1201 def return_empty_text(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1202 assert info.result_tools is not None 

1203 args_json = '{"response": ["foo", "bar"]}' 

1204 return ModelResponse(parts=[TextPart(''), ToolCallPart(info.result_tools[0].name, args_json)]) 

1205 

1206 agent = Agent(FunctionModel(return_empty_text), result_type=tuple[str, str]) 

1207 

1208 result = await agent.run('Hello') 

1209 assert result.data == ('foo', 'bar') 

1210 

1211 

1212def test_heterogeneous_responses_non_streaming() -> None: 

1213 """Indicates that tool calls are prioritized over text in heterogeneous responses.""" 

1214 

1215 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1216 assert info.result_tools is not None 

1217 parts: list[ModelResponsePart] = [] 

1218 if len(messages) == 1: 

1219 parts = [ 

1220 TextPart(content='foo'), 

1221 ToolCallPart('get_location', {'loc_name': 'London'}), 

1222 ] 

1223 else: 

1224 parts = [TextPart(content='final response')] 

1225 return ModelResponse(parts=parts) 

1226 

1227 agent = Agent(FunctionModel(return_model)) 

1228 

1229 @agent.tool_plain 

1230 async def get_location(loc_name: str) -> str: 

1231 if loc_name == 'London': 1231 ↛ 1234line 1231 didn't jump to line 1234 because the condition on line 1231 was always true

1232 return json.dumps({'lat': 51, 'lng': 0}) 

1233 else: 

1234 raise ModelRetry('Wrong location, please try again') 

1235 

1236 result = agent.run_sync('Hello') 

1237 assert result.data == 'final response' 

1238 assert result.all_messages() == snapshot( 

1239 [ 

1240 ModelRequest( 

1241 parts=[ 

1242 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

1243 ] 

1244 ), 

1245 ModelResponse( 

1246 parts=[ 

1247 TextPart(content='foo'), 

1248 ToolCallPart( 

1249 tool_name='get_location', 

1250 args={'loc_name': 'London'}, 

1251 ), 

1252 ], 

1253 model_name='function:return_model', 

1254 timestamp=IsNow(tz=timezone.utc), 

1255 ), 

1256 ModelRequest( 

1257 parts=[ 

1258 ToolReturnPart( 

1259 tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc) 

1260 ) 

1261 ] 

1262 ), 

1263 ModelResponse( 

1264 parts=[TextPart(content='final response')], 

1265 model_name='function:return_model', 

1266 timestamp=IsNow(tz=timezone.utc), 

1267 ), 

1268 ] 

1269 ) 

1270 

1271 

1272def test_last_run_messages() -> None: 

1273 agent = Agent('test') 

1274 

1275 with pytest.raises(AttributeError, match='The `last_run_messages` attribute has been removed,'): 

1276 agent.last_run_messages # pyright: ignore[reportDeprecated] 

1277 

1278 

1279def test_nested_capture_run_messages() -> None: 

1280 agent = Agent('test') 

1281 

1282 with capture_run_messages() as messages1: 

1283 assert messages1 == [] 

1284 with capture_run_messages() as messages2: 

1285 assert messages2 == [] 

1286 assert messages1 is messages2 

1287 result = agent.run_sync('Hello') 

1288 assert result.data == 'success (no tool calls)' 

1289 

1290 assert messages1 == snapshot( 

1291 [ 

1292 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

1293 ModelResponse( 

1294 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

1295 ), 

1296 ] 

1297 ) 

1298 assert messages1 == messages2 

1299 

1300 

1301def test_double_capture_run_messages() -> None: 

1302 agent = Agent('test') 

1303 

1304 with capture_run_messages() as messages: 

1305 assert messages == [] 

1306 result = agent.run_sync('Hello') 

1307 assert result.data == 'success (no tool calls)' 

1308 result2 = agent.run_sync('Hello 2') 

1309 assert result2.data == 'success (no tool calls)' 

1310 assert messages == snapshot( 

1311 [ 

1312 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

1313 ModelResponse( 

1314 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

1315 ), 

1316 ] 

1317 ) 

1318 

1319 

1320def test_dynamic_false_no_reevaluate(): 

1321 """When dynamic is false (default), the system prompt is not reevaluated 

1322 i.e: SystemPromptPart( 

1323 content="A", <--- Remains the same when `message_history` is passed. 

1324 part_kind='system-prompt') 

1325 """ 

1326 agent = Agent('test', system_prompt='Foobar') 

1327 

1328 dynamic_value = 'A' 

1329 

1330 @agent.system_prompt 

1331 async def func() -> str: 

1332 return dynamic_value 

1333 

1334 res = agent.run_sync('Hello') 

1335 

1336 assert res.all_messages() == snapshot( 

1337 [ 

1338 ModelRequest( 

1339 parts=[ 

1340 SystemPromptPart(content='Foobar', part_kind='system-prompt'), 

1341 SystemPromptPart(content=dynamic_value, part_kind='system-prompt'), 

1342 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1343 ], 

1344 kind='request', 

1345 ), 

1346 ModelResponse( 

1347 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1348 model_name='test', 

1349 timestamp=IsNow(tz=timezone.utc), 

1350 kind='response', 

1351 ), 

1352 ] 

1353 ) 

1354 

1355 dynamic_value = 'B' 

1356 

1357 res_two = agent.run_sync('World', message_history=res.all_messages()) 

1358 

1359 assert res_two.all_messages() == snapshot( 

1360 [ 

1361 ModelRequest( 

1362 parts=[ 

1363 SystemPromptPart(content='Foobar', part_kind='system-prompt'), 

1364 SystemPromptPart( 

1365 content='A', # Remains the same 

1366 part_kind='system-prompt', 

1367 ), 

1368 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1369 ], 

1370 kind='request', 

1371 ), 

1372 ModelResponse( 

1373 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1374 model_name='test', 

1375 timestamp=IsNow(tz=timezone.utc), 

1376 kind='response', 

1377 ), 

1378 ModelRequest( 

1379 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], 

1380 kind='request', 

1381 ), 

1382 ModelResponse( 

1383 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1384 model_name='test', 

1385 timestamp=IsNow(tz=timezone.utc), 

1386 kind='response', 

1387 ), 

1388 ] 

1389 ) 

1390 

1391 

1392def test_dynamic_true_reevaluate_system_prompt(): 

1393 """When dynamic is true, the system prompt is reevaluated 

1394 i.e: SystemPromptPart( 

1395 content="B", <--- Updated value 

1396 part_kind='system-prompt') 

1397 """ 

1398 agent = Agent('test', system_prompt='Foobar') 

1399 

1400 dynamic_value = 'A' 

1401 

1402 @agent.system_prompt(dynamic=True) 

1403 async def func(): 

1404 return dynamic_value 

1405 

1406 res = agent.run_sync('Hello') 

1407 

1408 assert res.all_messages() == snapshot( 

1409 [ 

1410 ModelRequest( 

1411 parts=[ 

1412 SystemPromptPart(content='Foobar', part_kind='system-prompt'), 

1413 SystemPromptPart( 

1414 content=dynamic_value, 

1415 part_kind='system-prompt', 

1416 dynamic_ref=func.__qualname__, 

1417 ), 

1418 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1419 ], 

1420 kind='request', 

1421 ), 

1422 ModelResponse( 

1423 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1424 model_name='test', 

1425 timestamp=IsNow(tz=timezone.utc), 

1426 kind='response', 

1427 ), 

1428 ] 

1429 ) 

1430 

1431 dynamic_value = 'B' 

1432 

1433 res_two = agent.run_sync('World', message_history=res.all_messages()) 

1434 

1435 assert res_two.all_messages() == snapshot( 

1436 [ 

1437 ModelRequest( 

1438 parts=[ 

1439 SystemPromptPart(content='Foobar', part_kind='system-prompt'), 

1440 SystemPromptPart( 

1441 content='B', 

1442 part_kind='system-prompt', 

1443 dynamic_ref=func.__qualname__, 

1444 ), 

1445 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1446 ], 

1447 kind='request', 

1448 ), 

1449 ModelResponse( 

1450 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1451 model_name='test', 

1452 timestamp=IsNow(tz=timezone.utc), 

1453 kind='response', 

1454 ), 

1455 ModelRequest( 

1456 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], 

1457 kind='request', 

1458 ), 

1459 ModelResponse( 

1460 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1461 model_name='test', 

1462 timestamp=IsNow(tz=timezone.utc), 

1463 kind='response', 

1464 ), 

1465 ] 

1466 ) 

1467 

1468 

1469def test_capture_run_messages_tool_agent() -> None: 

1470 agent_outer = Agent('test') 

1471 agent_inner = Agent(TestModel(custom_result_text='inner agent result')) 

1472 

1473 @agent_outer.tool_plain 

1474 async def foobar(x: str) -> str: 

1475 result_ = await agent_inner.run(x) 

1476 return result_.data 

1477 

1478 with capture_run_messages() as messages: 

1479 result = agent_outer.run_sync('foobar') 

1480 

1481 assert result.data == snapshot('{"foobar":"inner agent result"}') 

1482 assert messages == snapshot( 

1483 [ 

1484 ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]), 

1485 ModelResponse( 

1486 parts=[ToolCallPart(tool_name='foobar', args={'x': 'a'})], 

1487 model_name='test', 

1488 timestamp=IsNow(tz=timezone.utc), 

1489 ), 

1490 ModelRequest( 

1491 parts=[ 

1492 ToolReturnPart(tool_name='foobar', content='inner agent result', timestamp=IsNow(tz=timezone.utc)) 

1493 ] 

1494 ), 

1495 ModelResponse( 

1496 parts=[TextPart(content='{"foobar":"inner agent result"}')], 

1497 model_name='test', 

1498 timestamp=IsNow(tz=timezone.utc), 

1499 ), 

1500 ] 

1501 ) 

1502 

1503 

1504class Bar(BaseModel): 

1505 c: int 

1506 d: str 

1507 

1508 

1509def test_custom_result_type_sync() -> None: 

1510 agent = Agent('test', result_type=Foo) 

1511 

1512 assert agent.run_sync('Hello').data == snapshot(Foo(a=0, b='a')) 

1513 assert agent.run_sync('Hello', result_type=Bar).data == snapshot(Bar(c=0, d='a')) 

1514 assert agent.run_sync('Hello', result_type=str).data == snapshot('success (no tool calls)') 

1515 assert agent.run_sync('Hello', result_type=int).data == snapshot(0) 

1516 

1517 

1518async def test_custom_result_type_async() -> None: 

1519 agent = Agent('test') 

1520 

1521 result = await agent.run('Hello') 

1522 assert result.data == snapshot('success (no tool calls)') 

1523 

1524 result = await agent.run('Hello', result_type=Foo) 

1525 assert result.data == snapshot(Foo(a=0, b='a')) 

1526 result = await agent.run('Hello', result_type=int) 

1527 assert result.data == snapshot(0) 

1528 

1529 

1530def test_custom_result_type_invalid() -> None: 

1531 agent = Agent('test') 

1532 

1533 @agent.result_validator 

1534 def validate_result(ctx: RunContext[None], r: Any) -> Any: # pragma: no cover 

1535 return r 

1536 

1537 with pytest.raises(UserError, match='Cannot set a custom run `result_type` when the agent has result validators'): 

1538 agent.run_sync('Hello', result_type=int)