Coverage for tests/test_agent.py: 99.61%

489 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1import json 

2import re 

3import sys 

4from datetime import timezone 

5from typing import Any, Callable, Union 

6 

7import httpx 

8import pytest 

9from dirty_equals import IsJson 

10from inline_snapshot import snapshot 

11from pydantic import BaseModel, field_validator 

12from pydantic_core import to_json 

13 

14from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages 

15from pydantic_ai.messages import ( 

16 BinaryContent, 

17 ModelMessage, 

18 ModelRequest, 

19 ModelResponse, 

20 ModelResponsePart, 

21 RetryPromptPart, 

22 SystemPromptPart, 

23 TextPart, 

24 ToolCallPart, 

25 ToolReturnPart, 

26 UserPromptPart, 

27) 

28from pydantic_ai.models.function import AgentInfo, FunctionModel 

29from pydantic_ai.models.test import TestModel 

30from pydantic_ai.result import Usage 

31from pydantic_ai.tools import ToolDefinition 

32 

33from .conftest import IsNow, IsStr, TestEnv 

34 

35pytestmark = pytest.mark.anyio 

36 

37 

38def test_result_tuple(): 

39 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

40 assert info.result_tools is not None 

41 args_json = '{"response": ["foo", "bar"]}' 

42 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

43 

44 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str]) 

45 

46 result = agent.run_sync('Hello') 

47 assert result.data == ('foo', 'bar') 

48 

49 

50class Foo(BaseModel): 

51 a: int 

52 b: str 

53 

54 

55def test_result_pydantic_model(): 

56 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

57 assert info.result_tools is not None 

58 args_json = '{"a": 1, "b": "foo"}' 

59 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

60 

61 agent = Agent(FunctionModel(return_model), result_type=Foo) 

62 

63 result = agent.run_sync('Hello') 

64 assert isinstance(result.data, Foo) 

65 assert result.data.model_dump() == {'a': 1, 'b': 'foo'} 

66 

67 

68def test_result_pydantic_model_retry(): 

69 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

70 assert info.result_tools is not None 

71 if len(messages) == 1: 

72 args_json = '{"a": "wrong", "b": "foo"}' 

73 else: 

74 args_json = '{"a": 42, "b": "foo"}' 

75 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

76 

77 agent = Agent(FunctionModel(return_model), result_type=Foo) 

78 

79 assert agent.name is None 

80 

81 result = agent.run_sync('Hello') 

82 assert agent.name == 'agent' 

83 assert isinstance(result.data, Foo) 

84 assert result.data.model_dump() == {'a': 42, 'b': 'foo'} 

85 assert result.all_messages() == snapshot( 

86 [ 

87 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

88 ModelResponse( 

89 parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], 

90 model_name='function:return_model:', 

91 timestamp=IsNow(tz=timezone.utc), 

92 ), 

93 ModelRequest( 

94 parts=[ 

95 RetryPromptPart( 

96 tool_name='final_result', 

97 content=[ 

98 { 

99 'type': 'int_parsing', 

100 'loc': ('a',), 

101 'msg': 'Input should be a valid integer, unable to parse string as an integer', 

102 'input': 'wrong', 

103 } 

104 ], 

105 tool_call_id=IsStr(), 

106 timestamp=IsNow(tz=timezone.utc), 

107 ) 

108 ] 

109 ), 

110 ModelResponse( 

111 parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], 

112 model_name='function:return_model:', 

113 timestamp=IsNow(tz=timezone.utc), 

114 ), 

115 ModelRequest( 

116 parts=[ 

117 ToolReturnPart( 

118 tool_name='final_result', 

119 content='Final result processed.', 

120 tool_call_id=IsStr(), 

121 timestamp=IsNow(tz=timezone.utc), 

122 ) 

123 ] 

124 ), 

125 ] 

126 ) 

127 assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') 

128 

129 

130def test_result_pydantic_model_validation_error(): 

131 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

132 assert info.result_tools is not None 

133 if len(messages) == 1: 

134 args_json = '{"a": 1, "b": "foo"}' 

135 else: 

136 args_json = '{"a": 1, "b": "bar"}' 

137 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

138 

139 class Bar(BaseModel): 

140 a: int 

141 b: str 

142 

143 @field_validator('b') 

144 def check_b(cls, v: str) -> str: 

145 if v == 'foo': 

146 raise ValueError('must not be foo') 

147 return v 

148 

149 agent = Agent(FunctionModel(return_model), result_type=Bar) 

150 

151 result = agent.run_sync('Hello') 

152 assert isinstance(result.data, Bar) 

153 assert result.data.model_dump() == snapshot({'a': 1, 'b': 'bar'}) 

154 messages_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result.all_messages()] 

155 assert messages_part_kinds == snapshot( 

156 [ 

157 ('request', ['user-prompt']), 

158 ('response', ['tool-call']), 

159 ('request', ['retry-prompt']), 

160 ('response', ['tool-call']), 

161 ('request', ['tool-return']), 

162 ] 

163 ) 

164 

165 user_retry = result.all_messages()[2] 

166 assert isinstance(user_retry, ModelRequest) 

167 retry_prompt = user_retry.parts[0] 

168 assert isinstance(retry_prompt, RetryPromptPart) 

169 assert retry_prompt.model_response() == snapshot("""\ 

1701 validation errors: [ 

171 { 

172 "type": "value_error", 

173 "loc": [ 

174 "b" 

175 ], 

176 "msg": "Value error, must not be foo", 

177 "input": "foo" 

178 } 

179] 

180 

181Fix the errors and try again.""") 

182 

183 

184def test_result_validator(): 

185 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

186 assert info.result_tools is not None 

187 if len(messages) == 1: 

188 args_json = '{"a": 41, "b": "foo"}' 

189 else: 

190 args_json = '{"a": 42, "b": "foo"}' 

191 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

192 

193 agent = Agent(FunctionModel(return_model), result_type=Foo) 

194 

195 @agent.result_validator 

196 def validate_result(ctx: RunContext[None], r: Foo) -> Foo: 

197 assert ctx.tool_name == 'final_result' 

198 if r.a == 42: 

199 return r 

200 else: 

201 raise ModelRetry('"a" should be 42') 

202 

203 result = agent.run_sync('Hello') 

204 assert isinstance(result.data, Foo) 

205 assert result.data.model_dump() == {'a': 42, 'b': 'foo'} 

206 assert result.all_messages() == snapshot( 

207 [ 

208 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

209 ModelResponse( 

210 parts=[ToolCallPart(tool_name='final_result', args='{"a": 41, "b": "foo"}', tool_call_id=IsStr())], 

211 model_name='function:return_model:', 

212 timestamp=IsNow(tz=timezone.utc), 

213 ), 

214 ModelRequest( 

215 parts=[ 

216 RetryPromptPart( 

217 content='"a" should be 42', 

218 tool_name='final_result', 

219 tool_call_id=IsStr(), 

220 timestamp=IsNow(tz=timezone.utc), 

221 ) 

222 ] 

223 ), 

224 ModelResponse( 

225 parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], 

226 model_name='function:return_model:', 

227 timestamp=IsNow(tz=timezone.utc), 

228 ), 

229 ModelRequest( 

230 parts=[ 

231 ToolReturnPart( 

232 tool_name='final_result', 

233 content='Final result processed.', 

234 tool_call_id=IsStr(), 

235 timestamp=IsNow(tz=timezone.utc), 

236 ) 

237 ] 

238 ), 

239 ] 

240 ) 

241 

242 

243def test_plain_response_then_tuple(): 

244 call_index = 0 

245 

246 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

247 nonlocal call_index 

248 

249 assert info.result_tools is not None 

250 call_index += 1 

251 if call_index == 1: 

252 return ModelResponse(parts=[TextPart('hello')]) 

253 else: 

254 args_json = '{"response": ["foo", "bar"]}' 

255 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)]) 

256 

257 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str]) 

258 

259 result = agent.run_sync('Hello') 

260 assert result.data == ('foo', 'bar') 

261 assert call_index == 2 

262 assert result.all_messages() == snapshot( 

263 [ 

264 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

265 ModelResponse( 

266 parts=[TextPart(content='hello')], 

267 model_name='function:return_tuple:', 

268 timestamp=IsNow(tz=timezone.utc), 

269 ), 

270 ModelRequest( 

271 parts=[ 

272 RetryPromptPart( 

273 content='Plain text responses are not permitted, please call one of the functions instead.', 

274 timestamp=IsNow(tz=timezone.utc), 

275 tool_call_id=IsStr(), 

276 ) 

277 ] 

278 ), 

279 ModelResponse( 

280 parts=[ 

281 ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr()) 

282 ], 

283 model_name='function:return_tuple:', 

284 timestamp=IsNow(tz=timezone.utc), 

285 ), 

286 ModelRequest( 

287 parts=[ 

288 ToolReturnPart( 

289 tool_name='final_result', 

290 content='Final result processed.', 

291 tool_call_id=IsStr(), 

292 timestamp=IsNow(tz=timezone.utc), 

293 ) 

294 ] 

295 ), 

296 ] 

297 ) 

298 assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] 

299 assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( 

300 ModelRequest( 

301 parts=[ 

302 ToolReturnPart( 

303 tool_name='final_result', content='foobar', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) 

304 ) 

305 ] 

306 ) 

307 ) 

308 assert result.all_messages()[-1] == snapshot( 

309 ModelRequest( 

310 parts=[ 

311 ToolReturnPart( 

312 tool_name='final_result', 

313 content='Final result processed.', 

314 tool_call_id=IsStr(), 

315 timestamp=IsNow(tz=timezone.utc), 

316 ) 

317 ] 

318 ) 

319 ) 

320 

321 

322def test_result_tool_return_content_str_return(): 

323 agent = Agent('test') 

324 

325 result = agent.run_sync('Hello') 

326 assert result.data == 'success (no tool calls)' 

327 

328 msg = re.escape('Cannot set result tool return content when the return type is `str`.') 

329 with pytest.raises(ValueError, match=msg): 

330 result.all_messages(result_tool_return_content='foobar') 

331 

332 

333def test_result_tool_return_content_no_tool(): 

334 agent = Agent('test', result_type=int) 

335 

336 result = agent.run_sync('Hello') 

337 assert result.data == 0 

338 result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] 

339 with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): 

340 result.all_messages(result_tool_return_content='foobar') 

341 

342 

343def test_response_tuple(): 

344 m = TestModel() 

345 

346 agent = Agent(m, result_type=tuple[str, str]) 

347 assert agent._result_schema.allow_text_result is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] 

348 

349 result = agent.run_sync('Hello') 

350 assert result.data == snapshot(('a', 'a')) 

351 

352 assert m.last_model_request_parameters is not None 

353 assert m.last_model_request_parameters.function_tools == snapshot([]) 

354 assert m.last_model_request_parameters.allow_text_result is False 

355 

356 assert m.last_model_request_parameters.result_tools is not None 

357 assert len(m.last_model_request_parameters.result_tools) == 1 

358 assert m.last_model_request_parameters.result_tools == snapshot( 

359 [ 

360 ToolDefinition( 

361 name='final_result', 

362 description='The final response which ends this conversation', 

363 parameters_json_schema={ 

364 'properties': { 

365 'response': { 

366 'maxItems': 2, 

367 'minItems': 2, 

368 'prefixItems': [{'type': 'string'}, {'type': 'string'}], 

369 'title': 'Response', 

370 'type': 'array', 

371 } 

372 }, 

373 'required': ['response'], 

374 'type': 'object', 

375 }, 

376 outer_typed_dict_key='response', 

377 ) 

378 ] 

379 ) 

380 

381 

382@pytest.mark.parametrize( 

383 'input_union_callable', 

384 [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str], 

385 ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'], 

386) 

387def test_response_union_allow_str(input_union_callable: Callable[[], Any]): 

388 try: 

389 union = input_union_callable() 

390 except TypeError: 

391 pytest.skip('Python version does not support `|` syntax for unions') 

392 

393 m = TestModel() 

394 agent: Agent[None, Union[str, Foo]] = Agent(m, result_type=union) 

395 

396 got_tool_call_name = 'unset' 

397 

398 @agent.result_validator 

399 def validate_result(ctx: RunContext[None], r: Any) -> Any: 

400 nonlocal got_tool_call_name 

401 got_tool_call_name = ctx.tool_name 

402 return r 

403 

404 assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] 

405 

406 result = agent.run_sync('Hello') 

407 assert result.data == snapshot('success (no tool calls)') 

408 assert got_tool_call_name == snapshot(None) 

409 

410 assert m.last_model_request_parameters is not None 

411 assert m.last_model_request_parameters.function_tools == snapshot([]) 

412 assert m.last_model_request_parameters.allow_text_result is True 

413 

414 assert m.last_model_request_parameters.result_tools is not None 

415 assert len(m.last_model_request_parameters.result_tools) == 1 

416 

417 assert m.last_model_request_parameters.result_tools == snapshot( 

418 [ 

419 ToolDefinition( 

420 name='final_result', 

421 description='The final response which ends this conversation', 

422 parameters_json_schema={ 

423 'properties': { 

424 'a': {'title': 'A', 'type': 'integer'}, 

425 'b': {'title': 'B', 'type': 'string'}, 

426 }, 

427 'required': ['a', 'b'], 

428 'title': 'Foo', 

429 'type': 'object', 

430 }, 

431 ) 

432 ] 

433 ) 

434 

435 

436# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false 

437@pytest.mark.parametrize( 

438 'union_code', 

439 [ 

440 pytest.param('ResultType = Union[Foo, Bar]'), 

441 pytest.param('ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')), 

442 pytest.param( 

443 'ResultType: TypeAlias = Foo | Bar', 

444 marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='Python 3.10+'), 

445 ), 

446 pytest.param( 

447 'type ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+') 

448 ), 

449 ], 

450) 

451def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str): 

452 module_code = f''' 

453from pydantic import BaseModel 

454from typing import Union 

455from typing_extensions import TypeAlias 

456 

457class Foo(BaseModel): 

458 a: int 

459 b: str 

460 

461 

462class Bar(BaseModel): 

463 """This is a bar model.""" 

464 

465 b: str 

466 

467{union_code} 

468 ''' 

469 

470 mod = create_module(module_code) 

471 

472 m = TestModel() 

473 agent = Agent(m, result_type=mod.ResultType) 

474 got_tool_call_name = 'unset' 

475 

476 @agent.result_validator 

477 def validate_result(ctx: RunContext[None], r: Any) -> Any: 

478 nonlocal got_tool_call_name 

479 got_tool_call_name = ctx.tool_name 

480 return r 

481 

482 result = agent.run_sync('Hello') 

483 assert result.data == mod.Foo(a=0, b='a') 

484 assert got_tool_call_name == snapshot('final_result_Foo') 

485 

486 assert m.last_model_request_parameters is not None 

487 assert m.last_model_request_parameters.function_tools == snapshot([]) 

488 assert m.last_model_request_parameters.allow_text_result is False 

489 

490 assert m.last_model_request_parameters.result_tools is not None 

491 assert len(m.last_model_request_parameters.result_tools) == 2 

492 

493 assert m.last_model_request_parameters.result_tools == snapshot( 

494 [ 

495 ToolDefinition( 

496 name='final_result_Foo', 

497 description='Foo: The final response which ends this conversation', 

498 parameters_json_schema={ 

499 'properties': { 

500 'a': {'title': 'A', 'type': 'integer'}, 

501 'b': {'title': 'B', 'type': 'string'}, 

502 }, 

503 'required': ['a', 'b'], 

504 'title': 'Foo', 

505 'type': 'object', 

506 }, 

507 ), 

508 ToolDefinition( 

509 name='final_result_Bar', 

510 description='This is a bar model.', 

511 parameters_json_schema={ 

512 'properties': {'b': {'title': 'B', 'type': 'string'}}, 

513 'required': ['b'], 

514 'title': 'Bar', 

515 'type': 'object', 

516 }, 

517 ), 

518 ] 

519 ) 

520 

521 result = agent.run_sync('Hello', model=TestModel(seed=1)) 

522 assert result.data == mod.Bar(b='b') 

523 assert got_tool_call_name == snapshot('final_result_Bar') 

524 

525 

526def test_run_with_history_new(): 

527 m = TestModel() 

528 

529 agent = Agent(m, system_prompt='Foobar') 

530 

531 @agent.tool_plain 

532 async def ret_a(x: str) -> str: 

533 return f'{x}-apple' 

534 

535 result1 = agent.run_sync('Hello') 

536 assert result1.new_messages() == snapshot( 

537 [ 

538 ModelRequest( 

539 parts=[ 

540 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), 

541 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

542 ] 

543 ), 

544 ModelResponse( 

545 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], 

546 model_name='test', 

547 timestamp=IsNow(tz=timezone.utc), 

548 ), 

549 ModelRequest( 

550 parts=[ 

551 ToolReturnPart( 

552 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) 

553 ) 

554 ] 

555 ), 

556 ModelResponse( 

557 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

558 ), 

559 ] 

560 ) 

561 

562 # if we pass new_messages, system prompt is inserted before the message_history messages 

563 result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) 

564 assert result2.all_messages() == snapshot( 

565 [ 

566 ModelRequest( 

567 parts=[ 

568 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), 

569 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

570 ] 

571 ), 

572 ModelResponse( 

573 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], 

574 model_name='test', 

575 timestamp=IsNow(tz=timezone.utc), 

576 ), 

577 ModelRequest( 

578 parts=[ 

579 ToolReturnPart( 

580 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) 

581 ) 

582 ] 

583 ), 

584 ModelResponse( 

585 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

586 ), 

587 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), 

588 ModelResponse( 

589 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

590 ), 

591 ] 

592 ) 

593 assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] 

594 assert result2.data == snapshot('{"ret_a":"a-apple"}') 

595 assert result2._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] 

596 assert result2.usage() == snapshot( 

597 Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) 

598 ) 

599 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] 

600 assert new_msg_part_kinds == snapshot( 

601 [ 

602 ('request', ['system-prompt', 'user-prompt']), 

603 ('response', ['tool-call']), 

604 ('request', ['tool-return']), 

605 ('response', ['text']), 

606 ('request', ['user-prompt']), 

607 ('response', ['text']), 

608 ] 

609 ) 

610 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",') 

611 

612 # if we pass all_messages, system prompt is NOT inserted before the message_history messages, 

613 # so only one system prompt 

614 result3 = agent.run_sync('Hello again', message_history=result1.all_messages()) 

615 # same as result2 except for datetimes 

616 assert result3.all_messages() == snapshot( 

617 [ 

618 ModelRequest( 

619 parts=[ 

620 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), 

621 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

622 ] 

623 ), 

624 ModelResponse( 

625 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], 

626 model_name='test', 

627 timestamp=IsNow(tz=timezone.utc), 

628 ), 

629 ModelRequest( 

630 parts=[ 

631 ToolReturnPart( 

632 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) 

633 ) 

634 ] 

635 ), 

636 ModelResponse( 

637 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

638 ), 

639 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), 

640 ModelResponse( 

641 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

642 ), 

643 ] 

644 ) 

645 assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] 

646 assert result3.data == snapshot('{"ret_a":"a-apple"}') 

647 assert result3._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] 

648 assert result3.usage() == snapshot( 

649 Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) 

650 ) 

651 

652 

653def test_run_with_history_new_structured(): 

654 m = TestModel() 

655 

656 class Response(BaseModel): 

657 a: int 

658 

659 agent = Agent(m, system_prompt='Foobar', result_type=Response) 

660 

661 @agent.tool_plain 

662 async def ret_a(x: str) -> str: 

663 return f'{x}-apple' 

664 

665 result1 = agent.run_sync('Hello') 

666 assert result1.new_messages() == snapshot( 

667 [ 

668 ModelRequest( 

669 parts=[ 

670 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), 

671 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

672 ] 

673 ), 

674 ModelResponse( 

675 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], 

676 model_name='test', 

677 timestamp=IsNow(tz=timezone.utc), 

678 ), 

679 ModelRequest( 

680 parts=[ 

681 ToolReturnPart( 

682 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) 

683 ) 

684 ] 

685 ), 

686 ModelResponse( 

687 parts=[ 

688 ToolCallPart( 

689 tool_name='final_result', 

690 args={'a': 0}, 

691 tool_call_id=IsStr(), 

692 ) 

693 ], 

694 model_name='test', 

695 timestamp=IsNow(tz=timezone.utc), 

696 ), 

697 ModelRequest( 

698 parts=[ 

699 ToolReturnPart( 

700 tool_name='final_result', 

701 content='Final result processed.', 

702 tool_call_id=IsStr(), 

703 timestamp=IsNow(tz=timezone.utc), 

704 ) 

705 ] 

706 ), 

707 ] 

708 ) 

709 

710 result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) 

711 assert result2.all_messages() == snapshot( 

712 [ 

713 ModelRequest( 

714 parts=[ 

715 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), 

716 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

717 ], 

718 ), 

719 ModelResponse( 

720 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], 

721 model_name='test', 

722 timestamp=IsNow(tz=timezone.utc), 

723 ), 

724 ModelRequest( 

725 parts=[ 

726 ToolReturnPart( 

727 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) 

728 ) 

729 ], 

730 ), 

731 ModelResponse( 

732 parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], 

733 model_name='test', 

734 timestamp=IsNow(tz=timezone.utc), 

735 ), 

736 ModelRequest( 

737 parts=[ 

738 ToolReturnPart( 

739 tool_name='final_result', 

740 content='Final result processed.', 

741 tool_call_id=IsStr(), 

742 timestamp=IsNow(tz=timezone.utc), 

743 ), 

744 ], 

745 ), 

746 # second call, notice no repeated system prompt 

747 ModelRequest( 

748 parts=[ 

749 UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), 

750 ], 

751 ), 

752 ModelResponse( 

753 parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], 

754 model_name='test', 

755 timestamp=IsNow(tz=timezone.utc), 

756 ), 

757 ModelRequest( 

758 parts=[ 

759 ToolReturnPart( 

760 tool_name='final_result', 

761 content='Final result processed.', 

762 tool_call_id=IsStr(), 

763 timestamp=IsNow(tz=timezone.utc), 

764 ), 

765 ] 

766 ), 

767 ] 

768 ) 

769 assert result2.data == snapshot(Response(a=0)) 

770 assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] 

771 assert result2._result_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] 

772 assert result2.usage() == snapshot( 

773 Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None) 

774 ) 

775 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] 

776 assert new_msg_part_kinds == snapshot( 

777 [ 

778 ('request', ['system-prompt', 'user-prompt']), 

779 ('response', ['tool-call']), 

780 ('request', ['tool-return']), 

781 ('response', ['tool-call']), 

782 ('request', ['tool-return']), 

783 ('request', ['user-prompt']), 

784 ('response', ['tool-call']), 

785 ('request', ['tool-return']), 

786 ] 

787 ) 

788 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",') 

789 

790 

791def test_empty_tool_calls(): 

792 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: 

793 return ModelResponse(parts=[]) 

794 

795 agent = Agent(FunctionModel(empty)) 

796 

797 with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): 

798 agent.run_sync('Hello') 

799 

800 

801def test_unknown_tool(): 

802 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: 

803 return ModelResponse(parts=[ToolCallPart('foobar', '{}')]) 

804 

805 agent = Agent(FunctionModel(empty)) 

806 

807 with capture_run_messages() as messages: 

808 with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): 

809 agent.run_sync('Hello') 

810 assert messages == snapshot( 

811 [ 

812 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

813 ModelResponse( 

814 parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], 

815 model_name='function:empty:', 

816 timestamp=IsNow(tz=timezone.utc), 

817 ), 

818 ModelRequest( 

819 parts=[ 

820 RetryPromptPart( 

821 content="Unknown tool name: 'foobar'. No tools available.", 

822 tool_call_id=IsStr(), 

823 timestamp=IsNow(tz=timezone.utc), 

824 ) 

825 ] 

826 ), 

827 ModelResponse( 

828 parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], 

829 model_name='function:empty:', 

830 timestamp=IsNow(tz=timezone.utc), 

831 ), 

832 ] 

833 ) 

834 

835 

836def test_unknown_tool_fix(): 

837 def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: 

838 if len(m) > 1: 

839 return ModelResponse(parts=[TextPart('success')]) 

840 else: 

841 return ModelResponse(parts=[ToolCallPart('foobar', '{}')]) 

842 

843 agent = Agent(FunctionModel(empty)) 

844 

845 result = agent.run_sync('Hello') 

846 assert result.data == 'success' 

847 assert result.all_messages() == snapshot( 

848 [ 

849 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

850 ModelResponse( 

851 parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], 

852 model_name='function:empty:', 

853 timestamp=IsNow(tz=timezone.utc), 

854 ), 

855 ModelRequest( 

856 parts=[ 

857 RetryPromptPart( 

858 content="Unknown tool name: 'foobar'. No tools available.", 

859 tool_call_id=IsStr(), 

860 timestamp=IsNow(tz=timezone.utc), 

861 ) 

862 ] 

863 ), 

864 ModelResponse( 

865 parts=[TextPart(content='success')], 

866 model_name='function:empty:', 

867 timestamp=IsNow(tz=timezone.utc), 

868 ), 

869 ] 

870 ) 

871 

872 

873def test_model_requests_blocked(env: TestEnv): 

874 env.set('GEMINI_API_KEY', 'foobar') 

875 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True) 

876 

877 with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): 

878 agent.run_sync('Hello') 

879 

880 

881def test_override_model(env: TestEnv): 

882 env.set('GEMINI_API_KEY', 'foobar') 

883 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True) 

884 

885 with agent.override(model='test'): 

886 result = agent.run_sync('Hello') 

887 assert result.data == snapshot((0, 'a')) 

888 

889 

890def test_override_model_no_model(): 

891 agent = Agent() 

892 

893 with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'): 

894 with agent.override(model='test'): 

895 agent.run_sync('Hello') 

896 

897 

898def test_run_sync_multiple(): 

899 agent = Agent('test') 

900 

901 @agent.tool_plain 

902 async def make_request() -> str: 

903 async with httpx.AsyncClient() as client: 

904 # use this as I suspect it's about the fastest globally available endpoint 

905 try: 

906 response = await client.get('https://cloudflare.com/cdn-cgi/trace') 

907 except httpx.ConnectError: # pragma: no cover 

908 pytest.skip('offline') 

909 else: 

910 return str(response.status_code) 

911 

912 for _ in range(2): 

913 result = agent.run_sync('Hello') 

914 assert result.data == '{"make_request":"200"}' 

915 

916 

917async def test_agent_name(): 

918 my_agent = Agent('test') 

919 

920 assert my_agent.name is None 

921 

922 await my_agent.run('Hello', infer_name=False) 

923 assert my_agent.name is None 

924 

925 await my_agent.run('Hello') 

926 assert my_agent.name == 'my_agent' 

927 

928 

929async def test_agent_name_already_set(): 

930 my_agent = Agent('test', name='fig_tree') 

931 

932 assert my_agent.name == 'fig_tree' 

933 

934 await my_agent.run('Hello') 

935 assert my_agent.name == 'fig_tree' 

936 

937 

938async def test_agent_name_changes(): 

939 my_agent = Agent('test') 

940 

941 await my_agent.run('Hello') 

942 assert my_agent.name == 'my_agent' 

943 

944 new_agent = my_agent 

945 del my_agent 

946 

947 await new_agent.run('Hello') 

948 assert new_agent.name == 'my_agent' 

949 

950 

951def test_name_from_global(create_module: Callable[[str], Any]): 

952 module_code = """ 

953from pydantic_ai import Agent 

954 

955my_agent = Agent('test') 

956 

957def foo(): 

958 result = my_agent.run_sync('Hello') 

959 return result.data 

960""" 

961 

962 mod = create_module(module_code) 

963 

964 assert mod.my_agent.name is None 

965 assert mod.foo() == snapshot('success (no tool calls)') 

966 assert mod.my_agent.name == 'my_agent' 

967 

968 

969class TestMultipleToolCalls: 

970 """Tests for scenarios where multiple tool calls are made in a single response.""" 

971 

972 pytestmark = pytest.mark.usefixtures('set_event_loop') 

973 

974 class ResultType(BaseModel): 

975 """Result type used by all tests.""" 

976 

977 value: str 

978 

979 def test_early_strategy_stops_after_first_final_result(self): 

980 """Test that 'early' strategy stops processing regular tools after first final result.""" 

981 tool_called = [] 

982 

983 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

984 assert info.result_tools is not None 

985 return ModelResponse( 

986 parts=[ 

987 ToolCallPart('final_result', {'value': 'final'}), 

988 ToolCallPart('regular_tool', {'x': 1}), 

989 ToolCallPart('another_tool', {'y': 2}), 

990 ] 

991 ) 

992 

993 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early') 

994 

995 @agent.tool_plain 

996 def regular_tool(x: int) -> int: # pragma: no cover 

997 """A regular tool that should not be called.""" 

998 tool_called.append('regular_tool') 

999 return x 

1000 

1001 @agent.tool_plain 

1002 def another_tool(y: int) -> int: # pragma: no cover 

1003 """Another tool that should not be called.""" 

1004 tool_called.append('another_tool') 

1005 return y 

1006 

1007 result = agent.run_sync('test early strategy') 

1008 messages = result.all_messages() 

1009 

1010 # Verify no tools were called after final result 

1011 assert tool_called == [] 

1012 

1013 # Verify we got tool returns for all calls 

1014 assert messages[-1].parts == snapshot( 

1015 [ 

1016 ToolReturnPart( 

1017 tool_name='final_result', 

1018 content='Final result processed.', 

1019 tool_call_id=IsStr(), 

1020 timestamp=IsNow(tz=timezone.utc), 

1021 ), 

1022 ToolReturnPart( 

1023 tool_name='regular_tool', 

1024 content='Tool not executed - a final result was already processed.', 

1025 tool_call_id=IsStr(), 

1026 timestamp=IsNow(tz=timezone.utc), 

1027 ), 

1028 ToolReturnPart( 

1029 tool_name='another_tool', 

1030 content='Tool not executed - a final result was already processed.', 

1031 tool_call_id=IsStr(), 

1032 timestamp=IsNow(tz=timezone.utc), 

1033 ), 

1034 ] 

1035 ) 

1036 

1037 def test_early_strategy_uses_first_final_result(self): 

1038 """Test that 'early' strategy uses the first final result and ignores subsequent ones.""" 

1039 

1040 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1041 assert info.result_tools is not None 

1042 return ModelResponse( 

1043 parts=[ 

1044 ToolCallPart('final_result', {'value': 'first'}), 

1045 ToolCallPart('final_result', {'value': 'second'}), 

1046 ] 

1047 ) 

1048 

1049 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early') 

1050 result = agent.run_sync('test multiple final results') 

1051 

1052 # Verify the result came from the first final tool 

1053 assert result.data.value == 'first' 

1054 

1055 # Verify we got appropriate tool returns 

1056 assert result.new_messages()[-1].parts == snapshot( 

1057 [ 

1058 ToolReturnPart( 

1059 tool_name='final_result', 

1060 content='Final result processed.', 

1061 tool_call_id=IsStr(), 

1062 timestamp=IsNow(tz=timezone.utc), 

1063 ), 

1064 ToolReturnPart( 

1065 tool_name='final_result', 

1066 content='Result tool not used - a final result was already processed.', 

1067 tool_call_id=IsStr(), 

1068 timestamp=IsNow(tz=timezone.utc), 

1069 ), 

1070 ] 

1071 ) 

1072 

1073 def test_exhaustive_strategy_executes_all_tools(self): 

1074 """Test that 'exhaustive' strategy executes all tools while using first final result.""" 

1075 tool_called: list[str] = [] 

1076 

1077 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1078 assert info.result_tools is not None 

1079 return ModelResponse( 

1080 parts=[ 

1081 ToolCallPart('regular_tool', {'x': 42}), 

1082 ToolCallPart('final_result', {'value': 'first'}), 

1083 ToolCallPart('another_tool', {'y': 2}), 

1084 ToolCallPart('final_result', {'value': 'second'}), 

1085 ToolCallPart('unknown_tool', {'value': '???'}), 

1086 ] 

1087 ) 

1088 

1089 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='exhaustive') 

1090 

1091 @agent.tool_plain 

1092 def regular_tool(x: int) -> int: 

1093 """A regular tool that should be called.""" 

1094 tool_called.append('regular_tool') 

1095 return x 

1096 

1097 @agent.tool_plain 

1098 def another_tool(y: int) -> int: 

1099 """Another tool that should be called.""" 

1100 tool_called.append('another_tool') 

1101 return y 

1102 

1103 result = agent.run_sync('test exhaustive strategy') 

1104 

1105 # Verify the result came from the first final tool 

1106 assert result.data.value == 'first' 

1107 

1108 # Verify all regular tools were called 

1109 assert sorted(tool_called) == sorted(['regular_tool', 'another_tool']) 

1110 

1111 # Verify we got tool returns in the correct order 

1112 assert result.all_messages() == snapshot( 

1113 [ 

1114 ModelRequest( 

1115 parts=[UserPromptPart(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc))] 

1116 ), 

1117 ModelResponse( 

1118 parts=[ 

1119 ToolCallPart(tool_name='regular_tool', args={'x': 42}, tool_call_id=IsStr()), 

1120 ToolCallPart(tool_name='final_result', args={'value': 'first'}, tool_call_id=IsStr()), 

1121 ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()), 

1122 ToolCallPart(tool_name='final_result', args={'value': 'second'}, tool_call_id=IsStr()), 

1123 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), 

1124 ], 

1125 model_name='function:return_model:', 

1126 timestamp=IsNow(tz=timezone.utc), 

1127 ), 

1128 ModelRequest( 

1129 parts=[ 

1130 ToolReturnPart( 

1131 tool_name='final_result', 

1132 content='Final result processed.', 

1133 tool_call_id=IsStr(), 

1134 timestamp=IsNow(tz=timezone.utc), 

1135 ), 

1136 ToolReturnPart( 

1137 tool_name='final_result', 

1138 content='Result tool not used - a final result was already processed.', 

1139 tool_call_id=IsStr(), 

1140 timestamp=IsNow(tz=timezone.utc), 

1141 ), 

1142 RetryPromptPart( 

1143 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", 

1144 timestamp=IsNow(tz=timezone.utc), 

1145 tool_call_id=IsStr(), 

1146 ), 

1147 ToolReturnPart( 

1148 tool_name='regular_tool', 

1149 content=42, 

1150 tool_call_id=IsStr(), 

1151 timestamp=IsNow(tz=timezone.utc), 

1152 ), 

1153 ToolReturnPart( 

1154 tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) 

1155 ), 

1156 ] 

1157 ), 

1158 ] 

1159 ) 

1160 

1161 def test_early_strategy_with_final_result_in_middle(self): 

1162 """Test that 'early' strategy stops at first final result, regardless of position.""" 

1163 tool_called = [] 

1164 

1165 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1166 assert info.result_tools is not None 

1167 return ModelResponse( 

1168 parts=[ 

1169 ToolCallPart('regular_tool', {'x': 1}), 

1170 ToolCallPart('final_result', {'value': 'final'}), 

1171 ToolCallPart('another_tool', {'y': 2}), 

1172 ToolCallPart('unknown_tool', {'value': '???'}), 

1173 ] 

1174 ) 

1175 

1176 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early') 

1177 

1178 @agent.tool_plain 

1179 def regular_tool(x: int) -> int: # pragma: no cover 

1180 """A regular tool that should not be called.""" 

1181 tool_called.append('regular_tool') 

1182 return x 

1183 

1184 @agent.tool_plain 

1185 def another_tool(y: int) -> int: # pragma: no cover 

1186 """A tool that should not be called.""" 

1187 tool_called.append('another_tool') 

1188 return y 

1189 

1190 result = agent.run_sync('test early strategy with final result in middle') 

1191 

1192 # Verify no tools were called 

1193 assert tool_called == [] 

1194 

1195 # Verify we got appropriate tool returns 

1196 assert result.all_messages() == snapshot( 

1197 [ 

1198 ModelRequest( 

1199 parts=[ 

1200 UserPromptPart( 

1201 content='test early strategy with final result in middle', timestamp=IsNow(tz=timezone.utc) 

1202 ) 

1203 ] 

1204 ), 

1205 ModelResponse( 

1206 parts=[ 

1207 ToolCallPart(tool_name='regular_tool', args={'x': 1}, tool_call_id=IsStr()), 

1208 ToolCallPart(tool_name='final_result', args={'value': 'final'}, tool_call_id=IsStr()), 

1209 ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()), 

1210 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), 

1211 ], 

1212 model_name='function:return_model:', 

1213 timestamp=IsNow(tz=timezone.utc), 

1214 ), 

1215 ModelRequest( 

1216 parts=[ 

1217 ToolReturnPart( 

1218 tool_name='regular_tool', 

1219 content='Tool not executed - a final result was already processed.', 

1220 tool_call_id=IsStr(), 

1221 timestamp=IsNow(tz=timezone.utc), 

1222 ), 

1223 ToolReturnPart( 

1224 tool_name='final_result', 

1225 content='Final result processed.', 

1226 tool_call_id=IsStr(), 

1227 timestamp=IsNow(tz=timezone.utc), 

1228 ), 

1229 ToolReturnPart( 

1230 tool_name='another_tool', 

1231 content='Tool not executed - a final result was already processed.', 

1232 tool_call_id=IsStr(), 

1233 timestamp=IsNow(tz=timezone.utc), 

1234 ), 

1235 RetryPromptPart( 

1236 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", 

1237 timestamp=IsNow(tz=timezone.utc), 

1238 tool_call_id=IsStr(), 

1239 ), 

1240 ] 

1241 ), 

1242 ] 

1243 ) 

1244 

1245 def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(self): 

1246 """Test that 'early' strategy does not apply to tool calls without final tool.""" 

1247 tool_called = [] 

1248 agent = Agent(TestModel(), result_type=self.ResultType, end_strategy='early') 

1249 

1250 @agent.tool_plain 

1251 def regular_tool(x: int) -> int: 

1252 """A regular tool that should be called.""" 

1253 tool_called.append('regular_tool') 

1254 return x 

1255 

1256 result = agent.run_sync('test early strategy with regular tool calls') 

1257 assert tool_called == ['regular_tool'] 

1258 

1259 tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)] 

1260 assert tool_returns == snapshot([]) 

1261 

1262 def test_multiple_final_result_are_validated_correctly(self): 

1263 """Tests that if multiple final results are returned, but one fails validation, the other is used.""" 

1264 

1265 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1266 assert info.result_tools is not None 

1267 return ModelResponse( 

1268 parts=[ 

1269 ToolCallPart('final_result', {'bad_value': 'first'}, tool_call_id='first'), 

1270 ToolCallPart('final_result', {'value': 'second'}, tool_call_id='second'), 

1271 ] 

1272 ) 

1273 

1274 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early') 

1275 result = agent.run_sync('test multiple final results') 

1276 

1277 # Verify the result came from the second final tool 

1278 assert result.data.value == 'second' 

1279 

1280 # Verify we got appropriate tool returns 

1281 assert result.new_messages()[-1].parts == snapshot( 

1282 [ 

1283 ToolReturnPart( 

1284 tool_name='final_result', 

1285 tool_call_id='first', 

1286 content='Result tool not used - result failed validation.', 

1287 timestamp=IsNow(tz=timezone.utc), 

1288 ), 

1289 ToolReturnPart( 

1290 tool_name='final_result', 

1291 content='Final result processed.', 

1292 timestamp=IsNow(tz=timezone.utc), 

1293 tool_call_id='second', 

1294 ), 

1295 ] 

1296 ) 

1297 

1298 

1299async def test_model_settings_override() -> None: 

1300 def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1301 return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())]) 

1302 

1303 my_agent = Agent(FunctionModel(return_settings)) 

1304 assert (await my_agent.run('Hello')).data == IsJson(None) 

1305 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5}) 

1306 

1307 my_agent = Agent(FunctionModel(return_settings), model_settings={'temperature': 0.1}) 

1308 assert (await my_agent.run('Hello')).data == IsJson({'temperature': 0.1}) 

1309 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5}) 

1310 

1311 

1312async def test_empty_text_part(): 

1313 def return_empty_text(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1314 assert info.result_tools is not None 

1315 args_json = '{"response": ["foo", "bar"]}' 

1316 return ModelResponse( 

1317 parts=[ 

1318 TextPart(''), 

1319 ToolCallPart(info.result_tools[0].name, args_json), 

1320 ] 

1321 ) 

1322 

1323 agent = Agent(FunctionModel(return_empty_text), result_type=tuple[str, str]) 

1324 

1325 result = await agent.run('Hello') 

1326 assert result.data == ('foo', 'bar') 

1327 

1328 

1329def test_heterogeneous_responses_non_streaming() -> None: 

1330 """Indicates that tool calls are prioritized over text in heterogeneous responses.""" 

1331 

1332 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

1333 assert info.result_tools is not None 

1334 parts: list[ModelResponsePart] = [] 

1335 if len(messages) == 1: 

1336 parts = [TextPart(content='foo'), ToolCallPart('get_location', {'loc_name': 'London'})] 

1337 else: 

1338 parts = [TextPart(content='final response')] 

1339 return ModelResponse(parts=parts) 

1340 

1341 agent = Agent(FunctionModel(return_model)) 

1342 

1343 @agent.tool_plain 

1344 async def get_location(loc_name: str) -> str: 

1345 if loc_name == 'London': 1345 ↛ 1348line 1345 didn't jump to line 1348 because the condition on line 1345 was always true

1346 return json.dumps({'lat': 51, 'lng': 0}) 

1347 else: 

1348 raise ModelRetry('Wrong location, please try again') 

1349 

1350 result = agent.run_sync('Hello') 

1351 assert result.data == 'final response' 

1352 assert result.all_messages() == snapshot( 

1353 [ 

1354 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

1355 ModelResponse( 

1356 parts=[ 

1357 TextPart(content='foo'), 

1358 ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), 

1359 ], 

1360 model_name='function:return_model:', 

1361 timestamp=IsNow(tz=timezone.utc), 

1362 ), 

1363 ModelRequest( 

1364 parts=[ 

1365 ToolReturnPart( 

1366 tool_name='get_location', 

1367 content='{"lat": 51, "lng": 0}', 

1368 tool_call_id=IsStr(), 

1369 timestamp=IsNow(tz=timezone.utc), 

1370 ) 

1371 ] 

1372 ), 

1373 ModelResponse( 

1374 parts=[TextPart(content='final response')], 

1375 model_name='function:return_model:', 

1376 timestamp=IsNow(tz=timezone.utc), 

1377 ), 

1378 ] 

1379 ) 

1380 

1381 

1382def test_last_run_messages() -> None: 

1383 agent = Agent('test') 

1384 

1385 with pytest.raises(AttributeError, match='The `last_run_messages` attribute has been removed,'): 

1386 agent.last_run_messages # pyright: ignore[reportDeprecated] 

1387 

1388 

1389def test_nested_capture_run_messages() -> None: 

1390 agent = Agent('test') 

1391 

1392 with capture_run_messages() as messages1: 

1393 assert messages1 == [] 

1394 with capture_run_messages() as messages2: 

1395 assert messages2 == [] 

1396 assert messages1 is messages2 

1397 result = agent.run_sync('Hello') 

1398 assert result.data == 'success (no tool calls)' 

1399 

1400 assert messages1 == snapshot( 

1401 [ 

1402 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

1403 ModelResponse( 

1404 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

1405 ), 

1406 ] 

1407 ) 

1408 assert messages1 == messages2 

1409 

1410 

1411def test_double_capture_run_messages() -> None: 

1412 agent = Agent('test') 

1413 

1414 with capture_run_messages() as messages: 

1415 assert messages == [] 

1416 result = agent.run_sync('Hello') 

1417 assert result.data == 'success (no tool calls)' 

1418 result2 = agent.run_sync('Hello 2') 

1419 assert result2.data == 'success (no tool calls)' 

1420 assert messages == snapshot( 

1421 [ 

1422 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

1423 ModelResponse( 

1424 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc) 

1425 ), 

1426 ] 

1427 ) 

1428 

1429 

1430def test_dynamic_false_no_reevaluate(): 

1431 """When dynamic is false (default), the system prompt is not reevaluated 

1432 i.e: SystemPromptPart( 

1433 content="A", <--- Remains the same when `message_history` is passed. 

1434 part_kind='system-prompt') 

1435 """ 

1436 agent = Agent('test', system_prompt='Foobar') 

1437 

1438 dynamic_value = 'A' 

1439 

1440 @agent.system_prompt 

1441 async def func() -> str: 

1442 return dynamic_value 

1443 

1444 res = agent.run_sync('Hello') 

1445 

1446 assert res.all_messages() == snapshot( 

1447 [ 

1448 ModelRequest( 

1449 parts=[ 

1450 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), 

1451 SystemPromptPart( 

1452 content=dynamic_value, part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc) 

1453 ), 

1454 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1455 ], 

1456 kind='request', 

1457 ), 

1458 ModelResponse( 

1459 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1460 model_name='test', 

1461 timestamp=IsNow(tz=timezone.utc), 

1462 kind='response', 

1463 ), 

1464 ] 

1465 ) 

1466 

1467 dynamic_value = 'B' 

1468 

1469 res_two = agent.run_sync('World', message_history=res.all_messages()) 

1470 

1471 assert res_two.all_messages() == snapshot( 

1472 [ 

1473 ModelRequest( 

1474 parts=[ 

1475 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), 

1476 SystemPromptPart( 

1477 content='A', # Remains the same 

1478 part_kind='system-prompt', 

1479 timestamp=IsNow(tz=timezone.utc), 

1480 ), 

1481 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1482 ], 

1483 kind='request', 

1484 ), 

1485 ModelResponse( 

1486 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1487 model_name='test', 

1488 timestamp=IsNow(tz=timezone.utc), 

1489 kind='response', 

1490 ), 

1491 ModelRequest( 

1492 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], 

1493 kind='request', 

1494 ), 

1495 ModelResponse( 

1496 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1497 model_name='test', 

1498 timestamp=IsNow(tz=timezone.utc), 

1499 kind='response', 

1500 ), 

1501 ] 

1502 ) 

1503 

1504 

1505def test_dynamic_true_reevaluate_system_prompt(): 

1506 """When dynamic is true, the system prompt is reevaluated 

1507 i.e: SystemPromptPart( 

1508 content="B", <--- Updated value 

1509 part_kind='system-prompt') 

1510 """ 

1511 agent = Agent('test', system_prompt='Foobar') 

1512 

1513 dynamic_value = 'A' 

1514 

1515 @agent.system_prompt(dynamic=True) 

1516 async def func(): 

1517 return dynamic_value 

1518 

1519 res = agent.run_sync('Hello') 

1520 

1521 assert res.all_messages() == snapshot( 

1522 [ 

1523 ModelRequest( 

1524 parts=[ 

1525 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), 

1526 SystemPromptPart( 

1527 content=dynamic_value, 

1528 part_kind='system-prompt', 

1529 dynamic_ref=func.__qualname__, 

1530 timestamp=IsNow(tz=timezone.utc), 

1531 ), 

1532 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1533 ], 

1534 kind='request', 

1535 ), 

1536 ModelResponse( 

1537 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1538 model_name='test', 

1539 timestamp=IsNow(tz=timezone.utc), 

1540 kind='response', 

1541 ), 

1542 ] 

1543 ) 

1544 

1545 dynamic_value = 'B' 

1546 

1547 res_two = agent.run_sync('World', message_history=res.all_messages()) 

1548 

1549 assert res_two.all_messages() == snapshot( 

1550 [ 

1551 ModelRequest( 

1552 parts=[ 

1553 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), 

1554 SystemPromptPart( 

1555 content='B', 

1556 part_kind='system-prompt', 

1557 dynamic_ref=func.__qualname__, 

1558 timestamp=IsNow(tz=timezone.utc), 

1559 ), 

1560 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), 

1561 ], 

1562 kind='request', 

1563 ), 

1564 ModelResponse( 

1565 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1566 model_name='test', 

1567 timestamp=IsNow(tz=timezone.utc), 

1568 kind='response', 

1569 ), 

1570 ModelRequest( 

1571 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], 

1572 kind='request', 

1573 ), 

1574 ModelResponse( 

1575 parts=[TextPart(content='success (no tool calls)', part_kind='text')], 

1576 model_name='test', 

1577 timestamp=IsNow(tz=timezone.utc), 

1578 kind='response', 

1579 ), 

1580 ] 

1581 ) 

1582 

1583 

1584def test_capture_run_messages_tool_agent() -> None: 

1585 agent_outer = Agent('test') 

1586 agent_inner = Agent(TestModel(custom_result_text='inner agent result')) 

1587 

1588 @agent_outer.tool_plain 

1589 async def foobar(x: str) -> str: 

1590 result_ = await agent_inner.run(x) 

1591 return result_.data 

1592 

1593 with capture_run_messages() as messages: 

1594 result = agent_outer.run_sync('foobar') 

1595 

1596 assert result.data == snapshot('{"foobar":"inner agent result"}') 

1597 assert messages == snapshot( 

1598 [ 

1599 ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]), 

1600 ModelResponse( 

1601 parts=[ToolCallPart(tool_name='foobar', args={'x': 'a'}, tool_call_id=IsStr())], 

1602 model_name='test', 

1603 timestamp=IsNow(tz=timezone.utc), 

1604 ), 

1605 ModelRequest( 

1606 parts=[ 

1607 ToolReturnPart( 

1608 tool_name='foobar', 

1609 content='inner agent result', 

1610 tool_call_id=IsStr(), 

1611 timestamp=IsNow(tz=timezone.utc), 

1612 ) 

1613 ] 

1614 ), 

1615 ModelResponse( 

1616 parts=[TextPart(content='{"foobar":"inner agent result"}')], 

1617 model_name='test', 

1618 timestamp=IsNow(tz=timezone.utc), 

1619 ), 

1620 ] 

1621 ) 

1622 

1623 

1624class Bar(BaseModel): 

1625 c: int 

1626 d: str 

1627 

1628 

1629def test_custom_result_type_sync() -> None: 

1630 agent = Agent('test', result_type=Foo) 

1631 

1632 assert agent.run_sync('Hello').data == snapshot(Foo(a=0, b='a')) 

1633 assert agent.run_sync('Hello', result_type=Bar).data == snapshot(Bar(c=0, d='a')) 

1634 assert agent.run_sync('Hello', result_type=str).data == snapshot('success (no tool calls)') 

1635 assert agent.run_sync('Hello', result_type=int).data == snapshot(0) 

1636 

1637 

1638async def test_custom_result_type_async() -> None: 

1639 agent = Agent('test') 

1640 

1641 result = await agent.run('Hello') 

1642 assert result.data == snapshot('success (no tool calls)') 

1643 

1644 result = await agent.run('Hello', result_type=Foo) 

1645 assert result.data == snapshot(Foo(a=0, b='a')) 

1646 result = await agent.run('Hello', result_type=int) 

1647 assert result.data == snapshot(0) 

1648 

1649 

1650def test_custom_result_type_invalid() -> None: 

1651 agent = Agent('test') 

1652 

1653 @agent.result_validator 

1654 def validate_result(ctx: RunContext[None], r: Any) -> Any: # pragma: no cover 

1655 return r 

1656 

1657 with pytest.raises(UserError, match='Cannot set a custom run `result_type` when the agent has result validators'): 

1658 agent.run_sync('Hello', result_type=int) 

1659 

1660 

1661def test_binary_content_all_messages_json(): 

1662 agent = Agent('test') 

1663 

1664 result = agent.run_sync(['Hello', BinaryContent(data=b'Hello', media_type='text/plain')]) 

1665 assert json.loads(result.all_messages_json()) == snapshot( 

1666 [ 

1667 { 

1668 'parts': [ 

1669 { 

1670 'content': ['Hello', {'data': 'SGVsbG8=', 'media_type': 'text/plain', 'kind': 'binary'}], 

1671 'timestamp': IsStr(), 

1672 'part_kind': 'user-prompt', 

1673 } 

1674 ], 

1675 'kind': 'request', 

1676 }, 

1677 { 

1678 'parts': [{'content': 'success (no tool calls)', 'part_kind': 'text'}], 

1679 'model_name': 'test', 

1680 'timestamp': IsStr(), 

1681 'kind': 'response', 

1682 }, 

1683 ] 

1684 )