Coverage for tests/test_agent.py: 99.18%
465 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1import json
2import re
3import sys
4from datetime import timezone
5from typing import Any, Callable, Union
7import httpx
8import pytest
9from dirty_equals import IsJson
10from inline_snapshot import snapshot
11from pydantic import BaseModel, field_validator
12from pydantic_core import to_json
14from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages
15from pydantic_ai.messages import (
16 ModelMessage,
17 ModelRequest,
18 ModelResponse,
19 ModelResponsePart,
20 RetryPromptPart,
21 SystemPromptPart,
22 TextPart,
23 ToolCallPart,
24 ToolReturnPart,
25 UserPromptPart,
26)
27from pydantic_ai.models import cached_async_http_client
28from pydantic_ai.models.function import AgentInfo, FunctionModel
29from pydantic_ai.models.test import TestModel
30from pydantic_ai.result import RunResult, Usage
31from pydantic_ai.tools import ToolDefinition
33from .conftest import IsNow, TestEnv
35pytestmark = pytest.mark.anyio
38def test_result_tuple():
39 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
40 assert info.result_tools is not None
41 args_json = '{"response": ["foo", "bar"]}'
42 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
44 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str])
46 result = agent.run_sync('Hello')
47 assert result.data == ('foo', 'bar')
50class Foo(BaseModel):
51 a: int
52 b: str
55def test_result_pydantic_model():
56 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
57 assert info.result_tools is not None
58 args_json = '{"a": 1, "b": "foo"}'
59 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
61 agent = Agent(FunctionModel(return_model), result_type=Foo)
63 result = agent.run_sync('Hello')
64 assert isinstance(result.data, Foo)
65 assert result.data.model_dump() == {'a': 1, 'b': 'foo'}
68def test_result_pydantic_model_retry():
69 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
70 assert info.result_tools is not None
71 if len(messages) == 1:
72 args_json = '{"a": "wrong", "b": "foo"}'
73 else:
74 args_json = '{"a": 42, "b": "foo"}'
75 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
77 agent = Agent(FunctionModel(return_model), result_type=Foo)
79 assert agent.name is None
81 result = agent.run_sync('Hello')
82 assert agent.name == 'agent'
83 assert isinstance(result.data, Foo)
84 assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
85 assert result.all_messages() == snapshot(
86 [
87 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
88 ModelResponse(
89 parts=[ToolCallPart('final_result', '{"a": "wrong", "b": "foo"}')],
90 model_name='function:return_model',
91 timestamp=IsNow(tz=timezone.utc),
92 ),
93 ModelRequest(
94 parts=[
95 RetryPromptPart(
96 tool_name='final_result',
97 content=[
98 {
99 'type': 'int_parsing',
100 'loc': ('a',),
101 'msg': 'Input should be a valid integer, unable to parse string as an integer',
102 'input': 'wrong',
103 }
104 ],
105 timestamp=IsNow(tz=timezone.utc),
106 )
107 ]
108 ),
109 ModelResponse(
110 parts=[ToolCallPart('final_result', '{"a": 42, "b": "foo"}')],
111 model_name='function:return_model',
112 timestamp=IsNow(tz=timezone.utc),
113 ),
114 ModelRequest(
115 parts=[
116 ToolReturnPart(
117 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
118 )
119 ]
120 ),
121 ]
122 )
123 assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",')
126def test_result_pydantic_model_validation_error():
127 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
128 assert info.result_tools is not None
129 if len(messages) == 1:
130 args_json = '{"a": 1, "b": "foo"}'
131 else:
132 args_json = '{"a": 1, "b": "bar"}'
133 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
135 class Bar(BaseModel):
136 a: int
137 b: str
139 @field_validator('b')
140 def check_b(cls, v: str) -> str:
141 if v == 'foo':
142 raise ValueError('must not be foo')
143 return v
145 agent = Agent(FunctionModel(return_model), result_type=Bar)
147 result = agent.run_sync('Hello')
148 assert isinstance(result.data, Bar)
149 assert result.data.model_dump() == snapshot({'a': 1, 'b': 'bar'})
150 messages_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result.all_messages()]
151 assert messages_part_kinds == snapshot(
152 [
153 ('request', ['user-prompt']),
154 ('response', ['tool-call']),
155 ('request', ['retry-prompt']),
156 ('response', ['tool-call']),
157 ('request', ['tool-return']),
158 ]
159 )
161 user_retry = result.all_messages()[2]
162 assert isinstance(user_retry, ModelRequest)
163 retry_prompt = user_retry.parts[0]
164 assert isinstance(retry_prompt, RetryPromptPart)
165 assert retry_prompt.model_response() == snapshot("""\
1661 validation errors: [
167 {
168 "type": "value_error",
169 "loc": [
170 "b"
171 ],
172 "msg": "Value error, must not be foo",
173 "input": "foo"
174 }
175]
177Fix the errors and try again.""")
180def test_result_validator():
181 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
182 assert info.result_tools is not None
183 if len(messages) == 1:
184 args_json = '{"a": 41, "b": "foo"}'
185 else:
186 args_json = '{"a": 42, "b": "foo"}'
187 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
189 agent = Agent(FunctionModel(return_model), result_type=Foo)
191 @agent.result_validator
192 def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
193 assert ctx.tool_name == 'final_result'
194 if r.a == 42:
195 return r
196 else:
197 raise ModelRetry('"a" should be 42')
199 result = agent.run_sync('Hello')
200 assert isinstance(result.data, Foo)
201 assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
202 assert result.all_messages() == snapshot(
203 [
204 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
205 ModelResponse(
206 parts=[ToolCallPart('final_result', '{"a": 41, "b": "foo"}')],
207 model_name='function:return_model',
208 timestamp=IsNow(tz=timezone.utc),
209 ),
210 ModelRequest(
211 parts=[
212 RetryPromptPart(
213 content='"a" should be 42', tool_name='final_result', timestamp=IsNow(tz=timezone.utc)
214 )
215 ]
216 ),
217 ModelResponse(
218 parts=[ToolCallPart('final_result', '{"a": 42, "b": "foo"}')],
219 model_name='function:return_model',
220 timestamp=IsNow(tz=timezone.utc),
221 ),
222 ModelRequest(
223 parts=[
224 ToolReturnPart(
225 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
226 )
227 ]
228 ),
229 ]
230 )
233def test_plain_response_then_tuple():
234 call_index = 0
236 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
237 nonlocal call_index
239 assert info.result_tools is not None
240 call_index += 1
241 if call_index == 1:
242 return ModelResponse(parts=[TextPart('hello')])
243 else:
244 args_json = '{"response": ["foo", "bar"]}'
245 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
247 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str])
249 result = agent.run_sync('Hello')
250 assert result.data == ('foo', 'bar')
251 assert call_index == 2
252 assert result.all_messages() == snapshot(
253 [
254 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
255 ModelResponse(
256 parts=[TextPart(content='hello')],
257 model_name='function:return_tuple',
258 timestamp=IsNow(tz=timezone.utc),
259 ),
260 ModelRequest(
261 parts=[
262 RetryPromptPart(
263 content='Plain text responses are not permitted, please call one of the functions instead.',
264 timestamp=IsNow(tz=timezone.utc),
265 )
266 ]
267 ),
268 ModelResponse(
269 parts=[ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}')],
270 model_name='function:return_tuple',
271 timestamp=IsNow(tz=timezone.utc),
272 ),
273 ModelRequest(
274 parts=[
275 ToolReturnPart(
276 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
277 )
278 ]
279 ),
280 ]
281 )
282 assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage]
283 assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot(
284 ModelRequest(
285 parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))]
286 )
287 )
288 assert result.all_messages()[-1] == snapshot(
289 ModelRequest(
290 parts=[
291 ToolReturnPart(
292 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
293 )
294 ]
295 )
296 )
299def test_result_tool_return_content_str_return():
300 agent = Agent('test')
302 result = agent.run_sync('Hello')
303 assert result.data == 'success (no tool calls)'
305 msg = re.escape('Cannot set result tool return content when the return type is `str`.')
306 with pytest.raises(ValueError, match=msg):
307 result.all_messages(result_tool_return_content='foobar')
310def test_result_tool_return_content_no_tool():
311 agent = Agent('test', result_type=int)
313 result = agent.run_sync('Hello')
314 assert result.data == 0
315 result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage]
316 with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")):
317 result.all_messages(result_tool_return_content='foobar')
320def test_response_tuple():
321 m = TestModel()
323 agent = Agent(m, result_type=tuple[str, str])
324 assert agent._result_schema.allow_text_result is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]
326 result = agent.run_sync('Hello')
327 assert result.data == snapshot(('a', 'a'))
329 assert m.agent_model_function_tools == snapshot([])
330 assert m.agent_model_allow_text_result is False
332 assert m.agent_model_result_tools is not None
333 assert len(m.agent_model_result_tools) == 1
335 assert m.agent_model_result_tools == snapshot(
336 [
337 ToolDefinition(
338 name='final_result',
339 description='The final response which ends this conversation',
340 parameters_json_schema={
341 'properties': {
342 'response': {
343 'maxItems': 2,
344 'minItems': 2,
345 'prefixItems': [{'type': 'string'}, {'type': 'string'}],
346 'title': 'Response',
347 'type': 'array',
348 }
349 },
350 'required': ['response'],
351 'type': 'object',
352 },
353 outer_typed_dict_key='response',
354 )
355 ]
356 )
359@pytest.mark.parametrize(
360 'input_union_callable',
361 [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str],
362 ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'],
363)
364def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
365 try:
366 union = input_union_callable()
367 except TypeError:
368 pytest.skip('Python version does not support `|` syntax for unions')
370 m = TestModel()
371 agent: Agent[None, Union[str, Foo]] = Agent(m, result_type=union)
373 got_tool_call_name = 'unset'
375 @agent.result_validator
376 def validate_result(ctx: RunContext[None], r: Any) -> Any:
377 nonlocal got_tool_call_name
378 got_tool_call_name = ctx.tool_name
379 return r
381 assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]
383 result = agent.run_sync('Hello')
384 assert result.data == snapshot('success (no tool calls)')
385 assert got_tool_call_name == snapshot(None)
387 assert m.agent_model_function_tools == snapshot([])
388 assert m.agent_model_allow_text_result is True
390 assert m.agent_model_result_tools is not None
391 assert len(m.agent_model_result_tools) == 1
393 assert m.agent_model_result_tools == snapshot(
394 [
395 ToolDefinition(
396 name='final_result',
397 description='The final response which ends this conversation',
398 parameters_json_schema={
399 'properties': {
400 'a': {'title': 'A', 'type': 'integer'},
401 'b': {'title': 'B', 'type': 'string'},
402 },
403 'required': ['a', 'b'],
404 'title': 'Foo',
405 'type': 'object',
406 },
407 )
408 ]
409 )
412# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
413@pytest.mark.parametrize(
414 'union_code',
415 [
416 pytest.param('ResultType = Union[Foo, Bar]'),
417 pytest.param('ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')),
418 pytest.param(
419 'ResultType: TypeAlias = Foo | Bar',
420 marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='Python 3.10+'),
421 ),
422 pytest.param(
423 'type ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+')
424 ),
425 ],
426)
427def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str):
428 module_code = f'''
429from pydantic import BaseModel
430from typing import Union
431from typing_extensions import TypeAlias
433class Foo(BaseModel):
434 a: int
435 b: str
438class Bar(BaseModel):
439 """This is a bar model."""
441 b: str
443{union_code}
444 '''
446 mod = create_module(module_code)
448 m = TestModel()
449 agent = Agent(m, result_type=mod.ResultType)
450 got_tool_call_name = 'unset'
452 @agent.result_validator
453 def validate_result(ctx: RunContext[None], r: Any) -> Any:
454 nonlocal got_tool_call_name
455 got_tool_call_name = ctx.tool_name
456 return r
458 result = agent.run_sync('Hello')
459 assert result.data == mod.Foo(a=0, b='a')
460 assert got_tool_call_name == snapshot('final_result_Foo')
462 assert m.agent_model_function_tools == snapshot([])
463 assert m.agent_model_allow_text_result is False
465 assert m.agent_model_result_tools is not None
466 assert len(m.agent_model_result_tools) == 2
468 assert m.agent_model_result_tools == snapshot(
469 [
470 ToolDefinition(
471 name='final_result_Foo',
472 description='Foo: The final response which ends this conversation',
473 parameters_json_schema={
474 'properties': {
475 'a': {'title': 'A', 'type': 'integer'},
476 'b': {'title': 'B', 'type': 'string'},
477 },
478 'required': ['a', 'b'],
479 'title': 'Foo',
480 'type': 'object',
481 },
482 ),
483 ToolDefinition(
484 name='final_result_Bar',
485 description='This is a bar model.',
486 parameters_json_schema={
487 'properties': {'b': {'title': 'B', 'type': 'string'}},
488 'required': ['b'],
489 'title': 'Bar',
490 'type': 'object',
491 },
492 ),
493 ]
494 )
496 result = agent.run_sync('Hello', model=TestModel(seed=1))
497 assert result.data == mod.Bar(b='b')
498 assert got_tool_call_name == snapshot('final_result_Bar')
501def test_run_with_history_new():
502 m = TestModel()
504 agent = Agent(m, system_prompt='Foobar')
506 @agent.tool_plain
507 async def ret_a(x: str) -> str:
508 return f'{x}-apple'
510 result1 = agent.run_sync('Hello')
511 assert result1.new_messages() == snapshot(
512 [
513 ModelRequest(
514 parts=[
515 SystemPromptPart(content='Foobar'),
516 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
517 ]
518 ),
519 ModelResponse(
520 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})],
521 model_name='test',
522 timestamp=IsNow(tz=timezone.utc),
523 ),
524 ModelRequest(
525 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
526 ),
527 ModelResponse(
528 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
529 ),
530 ]
531 )
533 # if we pass new_messages, system prompt is inserted before the message_history messages
534 result2 = agent.run_sync('Hello again', message_history=result1.new_messages())
535 assert result2 == snapshot(
536 RunResult(
537 _all_messages=[
538 ModelRequest(
539 parts=[
540 SystemPromptPart(content='Foobar'),
541 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
542 ]
543 ),
544 ModelResponse(
545 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})],
546 model_name='test',
547 timestamp=IsNow(tz=timezone.utc),
548 ),
549 ModelRequest(
550 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
551 ),
552 ModelResponse(
553 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
554 ),
555 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]),
556 ModelResponse(
557 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
558 ),
559 ],
560 _new_message_index=4,
561 data='{"ret_a":"a-apple"}',
562 _result_tool_name=None,
563 _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None),
564 )
565 )
566 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()]
567 assert new_msg_part_kinds == snapshot(
568 [
569 ('request', ['system-prompt', 'user-prompt']),
570 ('response', ['tool-call']),
571 ('request', ['tool-return']),
572 ('response', ['text']),
573 ('request', ['user-prompt']),
574 ('response', ['text']),
575 ]
576 )
577 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",')
579 # if we pass all_messages, system prompt is NOT inserted before the message_history messages,
580 # so only one system prompt
581 result3 = agent.run_sync('Hello again', message_history=result1.all_messages())
582 # same as result2 except for datetimes
583 assert result3 == snapshot(
584 RunResult(
585 _all_messages=[
586 ModelRequest(
587 parts=[
588 SystemPromptPart(content='Foobar'),
589 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
590 ]
591 ),
592 ModelResponse(
593 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})],
594 model_name='test',
595 timestamp=IsNow(tz=timezone.utc),
596 ),
597 ModelRequest(
598 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
599 ),
600 ModelResponse(
601 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
602 ),
603 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]),
604 ModelResponse(
605 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
606 ),
607 ],
608 _new_message_index=4,
609 data='{"ret_a":"a-apple"}',
610 _result_tool_name=None,
611 _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None),
612 )
613 )
616def test_run_with_history_new_structured():
617 m = TestModel()
619 class Response(BaseModel):
620 a: int
622 agent = Agent(m, system_prompt='Foobar', result_type=Response)
624 @agent.tool_plain
625 async def ret_a(x: str) -> str:
626 return f'{x}-apple'
628 result1 = agent.run_sync('Hello')
629 assert result1.new_messages() == snapshot(
630 [
631 ModelRequest(
632 parts=[
633 SystemPromptPart(content='Foobar'),
634 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
635 ]
636 ),
637 ModelResponse(
638 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})],
639 model_name='test',
640 timestamp=IsNow(tz=timezone.utc),
641 ),
642 ModelRequest(
643 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
644 ),
645 ModelResponse(
646 parts=[
647 ToolCallPart(
648 tool_name='final_result',
649 args={'a': 0},
650 tool_call_id=None,
651 )
652 ],
653 model_name='test',
654 timestamp=IsNow(tz=timezone.utc),
655 ),
656 ModelRequest(
657 parts=[
658 ToolReturnPart(
659 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
660 )
661 ]
662 ),
663 ]
664 )
666 result2 = agent.run_sync('Hello again', message_history=result1.new_messages())
667 assert result2 == snapshot(
668 RunResult(
669 data=Response(a=0),
670 _all_messages=[
671 ModelRequest(
672 parts=[
673 SystemPromptPart(content='Foobar'),
674 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
675 ],
676 ),
677 ModelResponse(
678 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})],
679 model_name='test',
680 timestamp=IsNow(tz=timezone.utc),
681 ),
682 ModelRequest(
683 parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))],
684 ),
685 ModelResponse(
686 parts=[ToolCallPart(tool_name='final_result', args={'a': 0})],
687 model_name='test',
688 timestamp=IsNow(tz=timezone.utc),
689 ),
690 ModelRequest(
691 parts=[
692 ToolReturnPart(
693 tool_name='final_result',
694 content='Final result processed.',
695 timestamp=IsNow(tz=timezone.utc),
696 ),
697 ],
698 ),
699 # second call, notice no repeated system prompt
700 ModelRequest(
701 parts=[
702 UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)),
703 ],
704 ),
705 ModelResponse(
706 parts=[ToolCallPart(tool_name='final_result', args={'a': 0})],
707 model_name='test',
708 timestamp=IsNow(tz=timezone.utc),
709 ),
710 ModelRequest(
711 parts=[
712 ToolReturnPart(
713 tool_name='final_result',
714 content='Final result processed.',
715 timestamp=IsNow(tz=timezone.utc),
716 ),
717 ]
718 ),
719 ],
720 _new_message_index=5,
721 _result_tool_name='final_result',
722 _usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None),
723 )
724 )
725 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()]
726 assert new_msg_part_kinds == snapshot(
727 [
728 ('request', ['system-prompt', 'user-prompt']),
729 ('response', ['tool-call']),
730 ('request', ['tool-return']),
731 ('response', ['tool-call']),
732 ('request', ['tool-return']),
733 ('request', ['user-prompt']),
734 ('response', ['tool-call']),
735 ('request', ['tool-return']),
736 ]
737 )
738 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",')
741def test_empty_tool_calls():
742 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
743 return ModelResponse(parts=[])
745 agent = Agent(FunctionModel(empty))
747 with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
748 agent.run_sync('Hello')
751def test_unknown_tool():
752 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
753 return ModelResponse(parts=[ToolCallPart('foobar', '{}')])
755 agent = Agent(FunctionModel(empty))
757 with capture_run_messages() as messages:
758 with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'):
759 agent.run_sync('Hello')
760 assert messages == snapshot(
761 [
762 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
763 ModelResponse(
764 parts=[ToolCallPart(tool_name='foobar', args='{}')],
765 model_name='function:empty',
766 timestamp=IsNow(tz=timezone.utc),
767 ),
768 ModelRequest(
769 parts=[
770 RetryPromptPart(
771 content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)
772 )
773 ]
774 ),
775 ModelResponse(
776 parts=[ToolCallPart(tool_name='foobar', args='{}')],
777 model_name='function:empty',
778 timestamp=IsNow(tz=timezone.utc),
779 ),
780 ]
781 )
784def test_unknown_tool_fix():
785 def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
786 if len(m) > 1:
787 return ModelResponse(parts=[TextPart('success')])
788 else:
789 return ModelResponse(parts=[ToolCallPart('foobar', '{}')])
791 agent = Agent(FunctionModel(empty))
793 result = agent.run_sync('Hello')
794 assert result.data == 'success'
795 assert result.all_messages() == snapshot(
796 [
797 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
798 ModelResponse(
799 parts=[ToolCallPart(tool_name='foobar', args='{}')],
800 model_name='function:empty',
801 timestamp=IsNow(tz=timezone.utc),
802 ),
803 ModelRequest(
804 parts=[
805 RetryPromptPart(
806 content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)
807 )
808 ]
809 ),
810 ModelResponse(
811 parts=[TextPart(content='success')],
812 model_name='function:empty',
813 timestamp=IsNow(tz=timezone.utc),
814 ),
815 ]
816 )
819def test_model_requests_blocked(env: TestEnv):
820 env.set('GEMINI_API_KEY', 'foobar')
821 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)
823 with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
824 agent.run_sync('Hello')
827def test_override_model(env: TestEnv):
828 env.set('GEMINI_API_KEY', 'foobar')
829 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)
831 with agent.override(model='test'):
832 result = agent.run_sync('Hello')
833 assert result.data == snapshot((0, 'a'))
836def test_override_model_no_model():
837 agent = Agent()
839 with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'):
840 with agent.override(model='test'):
841 agent.run_sync('Hello')
844def test_run_sync_multiple():
845 agent = Agent('test')
847 @agent.tool_plain
848 async def make_request() -> str:
849 # raised a `RuntimeError: Event loop is closed` on repeat runs when we used `asyncio.run()`
850 client = cached_async_http_client()
851 # use this as I suspect it's about the fastest globally available endpoint
852 try:
853 response = await client.get('https://cloudflare.com/cdn-cgi/trace')
854 except httpx.ConnectError:
855 pytest.skip('offline')
856 else:
857 return str(response.status_code)
859 for _ in range(2):
860 result = agent.run_sync('Hello')
861 assert result.data == '{"make_request":"200"}'
864async def test_agent_name():
865 my_agent = Agent('test')
867 assert my_agent.name is None
869 await my_agent.run('Hello', infer_name=False)
870 assert my_agent.name is None
872 await my_agent.run('Hello')
873 assert my_agent.name == 'my_agent'
876async def test_agent_name_already_set():
877 my_agent = Agent('test', name='fig_tree')
879 assert my_agent.name == 'fig_tree'
881 await my_agent.run('Hello')
882 assert my_agent.name == 'fig_tree'
885async def test_agent_name_changes():
886 my_agent = Agent('test')
888 await my_agent.run('Hello')
889 assert my_agent.name == 'my_agent'
891 new_agent = my_agent
892 del my_agent
894 await new_agent.run('Hello')
895 assert new_agent.name == 'my_agent'
898def test_name_from_global(create_module: Callable[[str], Any]):
899 module_code = """
900from pydantic_ai import Agent
902my_agent = Agent('test')
904def foo():
905 result = my_agent.run_sync('Hello')
906 return result.data
907"""
909 mod = create_module(module_code)
911 assert mod.my_agent.name is None
912 assert mod.foo() == snapshot('success (no tool calls)')
913 assert mod.my_agent.name == 'my_agent'
916class TestMultipleToolCalls:
917 """Tests for scenarios where multiple tool calls are made in a single response."""
919 pytestmark = pytest.mark.usefixtures('set_event_loop')
921 class ResultType(BaseModel):
922 """Result type used by all tests."""
924 value: str
926 def test_early_strategy_stops_after_first_final_result(self):
927 """Test that 'early' strategy stops processing regular tools after first final result."""
928 tool_called = []
930 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
931 assert info.result_tools is not None
932 return ModelResponse(
933 parts=[
934 ToolCallPart('final_result', {'value': 'final'}),
935 ToolCallPart('regular_tool', {'x': 1}),
936 ToolCallPart('another_tool', {'y': 2}),
937 ]
938 )
940 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
942 @agent.tool_plain
943 def regular_tool(x: int) -> int: # pragma: no cover
944 """A regular tool that should not be called."""
945 tool_called.append('regular_tool')
946 return x
948 @agent.tool_plain
949 def another_tool(y: int) -> int: # pragma: no cover
950 """Another tool that should not be called."""
951 tool_called.append('another_tool')
952 return y
954 result = agent.run_sync('test early strategy')
955 messages = result.all_messages()
957 # Verify no tools were called after final result
958 assert tool_called == []
960 # Verify we got tool returns for all calls
961 assert messages[-1].parts == snapshot(
962 [
963 ToolReturnPart(
964 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
965 ),
966 ToolReturnPart(
967 tool_name='regular_tool',
968 content='Tool not executed - a final result was already processed.',
969 timestamp=IsNow(tz=timezone.utc),
970 ),
971 ToolReturnPart(
972 tool_name='another_tool',
973 content='Tool not executed - a final result was already processed.',
974 timestamp=IsNow(tz=timezone.utc),
975 ),
976 ]
977 )
979 def test_early_strategy_uses_first_final_result(self):
980 """Test that 'early' strategy uses the first final result and ignores subsequent ones."""
982 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
983 assert info.result_tools is not None
984 return ModelResponse(
985 parts=[
986 ToolCallPart('final_result', {'value': 'first'}),
987 ToolCallPart('final_result', {'value': 'second'}),
988 ]
989 )
991 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
992 result = agent.run_sync('test multiple final results')
994 # Verify the result came from the first final tool
995 assert result.data.value == 'first'
997 # Verify we got appropriate tool returns
998 assert result.new_messages()[-1].parts == snapshot(
999 [
1000 ToolReturnPart(
1001 tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
1002 ),
1003 ToolReturnPart(
1004 tool_name='final_result',
1005 content='Result tool not used - a final result was already processed.',
1006 timestamp=IsNow(tz=timezone.utc),
1007 ),
1008 ]
1009 )
1011 def test_exhaustive_strategy_executes_all_tools(self):
1012 """Test that 'exhaustive' strategy executes all tools while using first final result."""
1013 tool_called: list[str] = []
1015 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1016 assert info.result_tools is not None
1017 return ModelResponse(
1018 parts=[
1019 ToolCallPart('regular_tool', {'x': 42}),
1020 ToolCallPart('final_result', {'value': 'first'}),
1021 ToolCallPart('another_tool', {'y': 2}),
1022 ToolCallPart('final_result', {'value': 'second'}),
1023 ToolCallPart('unknown_tool', {'value': '???'}),
1024 ]
1025 )
1027 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='exhaustive')
1029 @agent.tool_plain
1030 def regular_tool(x: int) -> int:
1031 """A regular tool that should be called."""
1032 tool_called.append('regular_tool')
1033 return x
1035 @agent.tool_plain
1036 def another_tool(y: int) -> int:
1037 """Another tool that should be called."""
1038 tool_called.append('another_tool')
1039 return y
1041 result = agent.run_sync('test exhaustive strategy')
1043 # Verify the result came from the first final tool
1044 assert result.data.value == 'first'
1046 # Verify all regular tools were called
1047 assert sorted(tool_called) == sorted(['regular_tool', 'another_tool'])
1049 # Verify we got tool returns in the correct order
1050 assert result.all_messages() == snapshot(
1051 [
1052 ModelRequest(
1053 parts=[UserPromptPart(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc))]
1054 ),
1055 ModelResponse(
1056 parts=[
1057 ToolCallPart(tool_name='regular_tool', args={'x': 42}),
1058 ToolCallPart(tool_name='final_result', args={'value': 'first'}),
1059 ToolCallPart(tool_name='another_tool', args={'y': 2}),
1060 ToolCallPart(tool_name='final_result', args={'value': 'second'}),
1061 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}),
1062 ],
1063 model_name='function:return_model',
1064 timestamp=IsNow(tz=timezone.utc),
1065 ),
1066 ModelRequest(
1067 parts=[
1068 ToolReturnPart(
1069 tool_name='final_result',
1070 content='Final result processed.',
1071 timestamp=IsNow(tz=timezone.utc),
1072 ),
1073 ToolReturnPart(
1074 tool_name='final_result',
1075 content='Result tool not used - a final result was already processed.',
1076 timestamp=IsNow(tz=timezone.utc),
1077 ),
1078 RetryPromptPart(
1079 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
1080 timestamp=IsNow(tz=timezone.utc),
1081 ),
1082 ToolReturnPart(tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc)),
1083 ToolReturnPart(tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc)),
1084 ]
1085 ),
1086 ]
1087 )
1089 def test_early_strategy_with_final_result_in_middle(self):
1090 """Test that 'early' strategy stops at first final result, regardless of position."""
1091 tool_called = []
1093 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1094 assert info.result_tools is not None
1095 return ModelResponse(
1096 parts=[
1097 ToolCallPart('regular_tool', {'x': 1}),
1098 ToolCallPart('final_result', {'value': 'final'}),
1099 ToolCallPart('another_tool', {'y': 2}),
1100 ToolCallPart('unknown_tool', {'value': '???'}),
1101 ]
1102 )
1104 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
1106 @agent.tool_plain
1107 def regular_tool(x: int) -> int: # pragma: no cover
1108 """A regular tool that should not be called."""
1109 tool_called.append('regular_tool')
1110 return x
1112 @agent.tool_plain
1113 def another_tool(y: int) -> int: # pragma: no cover
1114 """A tool that should not be called."""
1115 tool_called.append('another_tool')
1116 return y
1118 result = agent.run_sync('test early strategy with final result in middle')
1120 # Verify no tools were called
1121 assert tool_called == []
1123 # Verify we got appropriate tool returns
1124 assert result.all_messages() == snapshot(
1125 [
1126 ModelRequest(
1127 parts=[
1128 UserPromptPart(
1129 content='test early strategy with final result in middle', timestamp=IsNow(tz=timezone.utc)
1130 )
1131 ]
1132 ),
1133 ModelResponse(
1134 parts=[
1135 ToolCallPart(tool_name='regular_tool', args={'x': 1}),
1136 ToolCallPart(tool_name='final_result', args={'value': 'final'}),
1137 ToolCallPart(tool_name='another_tool', args={'y': 2}),
1138 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}),
1139 ],
1140 model_name='function:return_model',
1141 timestamp=IsNow(tz=timezone.utc),
1142 ),
1143 ModelRequest(
1144 parts=[
1145 ToolReturnPart(
1146 tool_name='regular_tool',
1147 content='Tool not executed - a final result was already processed.',
1148 timestamp=IsNow(tz=timezone.utc),
1149 ),
1150 ToolReturnPart(
1151 tool_name='final_result',
1152 content='Final result processed.',
1153 timestamp=IsNow(tz=timezone.utc),
1154 ),
1155 ToolReturnPart(
1156 tool_name='another_tool',
1157 content='Tool not executed - a final result was already processed.',
1158 timestamp=IsNow(tz=timezone.utc),
1159 ),
1160 RetryPromptPart(
1161 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
1162 timestamp=IsNow(tz=timezone.utc),
1163 ),
1164 ]
1165 ),
1166 ]
1167 )
1169 def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(self):
1170 """Test that 'early' strategy does not apply to tool calls without final tool."""
1171 tool_called = []
1172 agent = Agent(TestModel(), result_type=self.ResultType, end_strategy='early')
1174 @agent.tool_plain
1175 def regular_tool(x: int) -> int:
1176 """A regular tool that should be called."""
1177 tool_called.append('regular_tool')
1178 return x
1180 result = agent.run_sync('test early strategy with regular tool calls')
1181 assert tool_called == ['regular_tool']
1183 tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)]
1184 assert tool_returns == snapshot([])
1187async def test_model_settings_override() -> None:
1188 def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1189 return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())])
1191 my_agent = Agent(FunctionModel(return_settings))
1192 assert (await my_agent.run('Hello')).data == IsJson(None)
1193 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5})
1195 my_agent = Agent(FunctionModel(return_settings), model_settings={'temperature': 0.1})
1196 assert (await my_agent.run('Hello')).data == IsJson({'temperature': 0.1})
1197 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5})
1200async def test_empty_text_part():
1201 def return_empty_text(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1202 assert info.result_tools is not None
1203 args_json = '{"response": ["foo", "bar"]}'
1204 return ModelResponse(parts=[TextPart(''), ToolCallPart(info.result_tools[0].name, args_json)])
1206 agent = Agent(FunctionModel(return_empty_text), result_type=tuple[str, str])
1208 result = await agent.run('Hello')
1209 assert result.data == ('foo', 'bar')
1212def test_heterogeneous_responses_non_streaming() -> None:
1213 """Indicates that tool calls are prioritized over text in heterogeneous responses."""
1215 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1216 assert info.result_tools is not None
1217 parts: list[ModelResponsePart] = []
1218 if len(messages) == 1:
1219 parts = [
1220 TextPart(content='foo'),
1221 ToolCallPart('get_location', {'loc_name': 'London'}),
1222 ]
1223 else:
1224 parts = [TextPart(content='final response')]
1225 return ModelResponse(parts=parts)
1227 agent = Agent(FunctionModel(return_model))
1229 @agent.tool_plain
1230 async def get_location(loc_name: str) -> str:
1231 if loc_name == 'London': 1231 ↛ 1234line 1231 didn't jump to line 1234 because the condition on line 1231 was always true
1232 return json.dumps({'lat': 51, 'lng': 0})
1233 else:
1234 raise ModelRetry('Wrong location, please try again')
1236 result = agent.run_sync('Hello')
1237 assert result.data == 'final response'
1238 assert result.all_messages() == snapshot(
1239 [
1240 ModelRequest(
1241 parts=[
1242 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
1243 ]
1244 ),
1245 ModelResponse(
1246 parts=[
1247 TextPart(content='foo'),
1248 ToolCallPart(
1249 tool_name='get_location',
1250 args={'loc_name': 'London'},
1251 ),
1252 ],
1253 model_name='function:return_model',
1254 timestamp=IsNow(tz=timezone.utc),
1255 ),
1256 ModelRequest(
1257 parts=[
1258 ToolReturnPart(
1259 tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)
1260 )
1261 ]
1262 ),
1263 ModelResponse(
1264 parts=[TextPart(content='final response')],
1265 model_name='function:return_model',
1266 timestamp=IsNow(tz=timezone.utc),
1267 ),
1268 ]
1269 )
1272def test_last_run_messages() -> None:
1273 agent = Agent('test')
1275 with pytest.raises(AttributeError, match='The `last_run_messages` attribute has been removed,'):
1276 agent.last_run_messages # pyright: ignore[reportDeprecated]
1279def test_nested_capture_run_messages() -> None:
1280 agent = Agent('test')
1282 with capture_run_messages() as messages1:
1283 assert messages1 == []
1284 with capture_run_messages() as messages2:
1285 assert messages2 == []
1286 assert messages1 is messages2
1287 result = agent.run_sync('Hello')
1288 assert result.data == 'success (no tool calls)'
1290 assert messages1 == snapshot(
1291 [
1292 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
1293 ModelResponse(
1294 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc)
1295 ),
1296 ]
1297 )
1298 assert messages1 == messages2
1301def test_double_capture_run_messages() -> None:
1302 agent = Agent('test')
1304 with capture_run_messages() as messages:
1305 assert messages == []
1306 result = agent.run_sync('Hello')
1307 assert result.data == 'success (no tool calls)'
1308 result2 = agent.run_sync('Hello 2')
1309 assert result2.data == 'success (no tool calls)'
1310 assert messages == snapshot(
1311 [
1312 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
1313 ModelResponse(
1314 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc)
1315 ),
1316 ]
1317 )
1320def test_dynamic_false_no_reevaluate():
1321 """When dynamic is false (default), the system prompt is not reevaluated
1322 i.e: SystemPromptPart(
1323 content="A", <--- Remains the same when `message_history` is passed.
1324 part_kind='system-prompt')
1325 """
1326 agent = Agent('test', system_prompt='Foobar')
1328 dynamic_value = 'A'
1330 @agent.system_prompt
1331 async def func() -> str:
1332 return dynamic_value
1334 res = agent.run_sync('Hello')
1336 assert res.all_messages() == snapshot(
1337 [
1338 ModelRequest(
1339 parts=[
1340 SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1341 SystemPromptPart(content=dynamic_value, part_kind='system-prompt'),
1342 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1343 ],
1344 kind='request',
1345 ),
1346 ModelResponse(
1347 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1348 model_name='test',
1349 timestamp=IsNow(tz=timezone.utc),
1350 kind='response',
1351 ),
1352 ]
1353 )
1355 dynamic_value = 'B'
1357 res_two = agent.run_sync('World', message_history=res.all_messages())
1359 assert res_two.all_messages() == snapshot(
1360 [
1361 ModelRequest(
1362 parts=[
1363 SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1364 SystemPromptPart(
1365 content='A', # Remains the same
1366 part_kind='system-prompt',
1367 ),
1368 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1369 ],
1370 kind='request',
1371 ),
1372 ModelResponse(
1373 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1374 model_name='test',
1375 timestamp=IsNow(tz=timezone.utc),
1376 kind='response',
1377 ),
1378 ModelRequest(
1379 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
1380 kind='request',
1381 ),
1382 ModelResponse(
1383 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1384 model_name='test',
1385 timestamp=IsNow(tz=timezone.utc),
1386 kind='response',
1387 ),
1388 ]
1389 )
1392def test_dynamic_true_reevaluate_system_prompt():
1393 """When dynamic is true, the system prompt is reevaluated
1394 i.e: SystemPromptPart(
1395 content="B", <--- Updated value
1396 part_kind='system-prompt')
1397 """
1398 agent = Agent('test', system_prompt='Foobar')
1400 dynamic_value = 'A'
1402 @agent.system_prompt(dynamic=True)
1403 async def func():
1404 return dynamic_value
1406 res = agent.run_sync('Hello')
1408 assert res.all_messages() == snapshot(
1409 [
1410 ModelRequest(
1411 parts=[
1412 SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1413 SystemPromptPart(
1414 content=dynamic_value,
1415 part_kind='system-prompt',
1416 dynamic_ref=func.__qualname__,
1417 ),
1418 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1419 ],
1420 kind='request',
1421 ),
1422 ModelResponse(
1423 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1424 model_name='test',
1425 timestamp=IsNow(tz=timezone.utc),
1426 kind='response',
1427 ),
1428 ]
1429 )
1431 dynamic_value = 'B'
1433 res_two = agent.run_sync('World', message_history=res.all_messages())
1435 assert res_two.all_messages() == snapshot(
1436 [
1437 ModelRequest(
1438 parts=[
1439 SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1440 SystemPromptPart(
1441 content='B',
1442 part_kind='system-prompt',
1443 dynamic_ref=func.__qualname__,
1444 ),
1445 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1446 ],
1447 kind='request',
1448 ),
1449 ModelResponse(
1450 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1451 model_name='test',
1452 timestamp=IsNow(tz=timezone.utc),
1453 kind='response',
1454 ),
1455 ModelRequest(
1456 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
1457 kind='request',
1458 ),
1459 ModelResponse(
1460 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1461 model_name='test',
1462 timestamp=IsNow(tz=timezone.utc),
1463 kind='response',
1464 ),
1465 ]
1466 )
1469def test_capture_run_messages_tool_agent() -> None:
1470 agent_outer = Agent('test')
1471 agent_inner = Agent(TestModel(custom_result_text='inner agent result'))
1473 @agent_outer.tool_plain
1474 async def foobar(x: str) -> str:
1475 result_ = await agent_inner.run(x)
1476 return result_.data
1478 with capture_run_messages() as messages:
1479 result = agent_outer.run_sync('foobar')
1481 assert result.data == snapshot('{"foobar":"inner agent result"}')
1482 assert messages == snapshot(
1483 [
1484 ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]),
1485 ModelResponse(
1486 parts=[ToolCallPart(tool_name='foobar', args={'x': 'a'})],
1487 model_name='test',
1488 timestamp=IsNow(tz=timezone.utc),
1489 ),
1490 ModelRequest(
1491 parts=[
1492 ToolReturnPart(tool_name='foobar', content='inner agent result', timestamp=IsNow(tz=timezone.utc))
1493 ]
1494 ),
1495 ModelResponse(
1496 parts=[TextPart(content='{"foobar":"inner agent result"}')],
1497 model_name='test',
1498 timestamp=IsNow(tz=timezone.utc),
1499 ),
1500 ]
1501 )
1504class Bar(BaseModel):
1505 c: int
1506 d: str
1509def test_custom_result_type_sync() -> None:
1510 agent = Agent('test', result_type=Foo)
1512 assert agent.run_sync('Hello').data == snapshot(Foo(a=0, b='a'))
1513 assert agent.run_sync('Hello', result_type=Bar).data == snapshot(Bar(c=0, d='a'))
1514 assert agent.run_sync('Hello', result_type=str).data == snapshot('success (no tool calls)')
1515 assert agent.run_sync('Hello', result_type=int).data == snapshot(0)
1518async def test_custom_result_type_async() -> None:
1519 agent = Agent('test')
1521 result = await agent.run('Hello')
1522 assert result.data == snapshot('success (no tool calls)')
1524 result = await agent.run('Hello', result_type=Foo)
1525 assert result.data == snapshot(Foo(a=0, b='a'))
1526 result = await agent.run('Hello', result_type=int)
1527 assert result.data == snapshot(0)
1530def test_custom_result_type_invalid() -> None:
1531 agent = Agent('test')
1533 @agent.result_validator
1534 def validate_result(ctx: RunContext[None], r: Any) -> Any: # pragma: no cover
1535 return r
1537 with pytest.raises(UserError, match='Cannot set a custom run `result_type` when the agent has result validators'):
1538 agent.run_sync('Hello', result_type=int)