Coverage for tests/test_agent.py: 99.61%
489 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1import json
2import re
3import sys
4from datetime import timezone
5from typing import Any, Callable, Union
7import httpx
8import pytest
9from dirty_equals import IsJson
10from inline_snapshot import snapshot
11from pydantic import BaseModel, field_validator
12from pydantic_core import to_json
14from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages
15from pydantic_ai.messages import (
16 BinaryContent,
17 ModelMessage,
18 ModelRequest,
19 ModelResponse,
20 ModelResponsePart,
21 RetryPromptPart,
22 SystemPromptPart,
23 TextPart,
24 ToolCallPart,
25 ToolReturnPart,
26 UserPromptPart,
27)
28from pydantic_ai.models.function import AgentInfo, FunctionModel
29from pydantic_ai.models.test import TestModel
30from pydantic_ai.result import Usage
31from pydantic_ai.tools import ToolDefinition
33from .conftest import IsNow, IsStr, TestEnv
35pytestmark = pytest.mark.anyio
38def test_result_tuple():
39 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
40 assert info.result_tools is not None
41 args_json = '{"response": ["foo", "bar"]}'
42 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
44 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str])
46 result = agent.run_sync('Hello')
47 assert result.data == ('foo', 'bar')
50class Foo(BaseModel):
51 a: int
52 b: str
55def test_result_pydantic_model():
56 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
57 assert info.result_tools is not None
58 args_json = '{"a": 1, "b": "foo"}'
59 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
61 agent = Agent(FunctionModel(return_model), result_type=Foo)
63 result = agent.run_sync('Hello')
64 assert isinstance(result.data, Foo)
65 assert result.data.model_dump() == {'a': 1, 'b': 'foo'}
68def test_result_pydantic_model_retry():
69 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
70 assert info.result_tools is not None
71 if len(messages) == 1:
72 args_json = '{"a": "wrong", "b": "foo"}'
73 else:
74 args_json = '{"a": 42, "b": "foo"}'
75 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
77 agent = Agent(FunctionModel(return_model), result_type=Foo)
79 assert agent.name is None
81 result = agent.run_sync('Hello')
82 assert agent.name == 'agent'
83 assert isinstance(result.data, Foo)
84 assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
85 assert result.all_messages() == snapshot(
86 [
87 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
88 ModelResponse(
89 parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())],
90 model_name='function:return_model:',
91 timestamp=IsNow(tz=timezone.utc),
92 ),
93 ModelRequest(
94 parts=[
95 RetryPromptPart(
96 tool_name='final_result',
97 content=[
98 {
99 'type': 'int_parsing',
100 'loc': ('a',),
101 'msg': 'Input should be a valid integer, unable to parse string as an integer',
102 'input': 'wrong',
103 }
104 ],
105 tool_call_id=IsStr(),
106 timestamp=IsNow(tz=timezone.utc),
107 )
108 ]
109 ),
110 ModelResponse(
111 parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())],
112 model_name='function:return_model:',
113 timestamp=IsNow(tz=timezone.utc),
114 ),
115 ModelRequest(
116 parts=[
117 ToolReturnPart(
118 tool_name='final_result',
119 content='Final result processed.',
120 tool_call_id=IsStr(),
121 timestamp=IsNow(tz=timezone.utc),
122 )
123 ]
124 ),
125 ]
126 )
127 assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",')
130def test_result_pydantic_model_validation_error():
131 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
132 assert info.result_tools is not None
133 if len(messages) == 1:
134 args_json = '{"a": 1, "b": "foo"}'
135 else:
136 args_json = '{"a": 1, "b": "bar"}'
137 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
139 class Bar(BaseModel):
140 a: int
141 b: str
143 @field_validator('b')
144 def check_b(cls, v: str) -> str:
145 if v == 'foo':
146 raise ValueError('must not be foo')
147 return v
149 agent = Agent(FunctionModel(return_model), result_type=Bar)
151 result = agent.run_sync('Hello')
152 assert isinstance(result.data, Bar)
153 assert result.data.model_dump() == snapshot({'a': 1, 'b': 'bar'})
154 messages_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result.all_messages()]
155 assert messages_part_kinds == snapshot(
156 [
157 ('request', ['user-prompt']),
158 ('response', ['tool-call']),
159 ('request', ['retry-prompt']),
160 ('response', ['tool-call']),
161 ('request', ['tool-return']),
162 ]
163 )
165 user_retry = result.all_messages()[2]
166 assert isinstance(user_retry, ModelRequest)
167 retry_prompt = user_retry.parts[0]
168 assert isinstance(retry_prompt, RetryPromptPart)
169 assert retry_prompt.model_response() == snapshot("""\
1701 validation errors: [
171 {
172 "type": "value_error",
173 "loc": [
174 "b"
175 ],
176 "msg": "Value error, must not be foo",
177 "input": "foo"
178 }
179]
181Fix the errors and try again.""")
184def test_result_validator():
185 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
186 assert info.result_tools is not None
187 if len(messages) == 1:
188 args_json = '{"a": 41, "b": "foo"}'
189 else:
190 args_json = '{"a": 42, "b": "foo"}'
191 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
193 agent = Agent(FunctionModel(return_model), result_type=Foo)
195 @agent.result_validator
196 def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
197 assert ctx.tool_name == 'final_result'
198 if r.a == 42:
199 return r
200 else:
201 raise ModelRetry('"a" should be 42')
203 result = agent.run_sync('Hello')
204 assert isinstance(result.data, Foo)
205 assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
206 assert result.all_messages() == snapshot(
207 [
208 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
209 ModelResponse(
210 parts=[ToolCallPart(tool_name='final_result', args='{"a": 41, "b": "foo"}', tool_call_id=IsStr())],
211 model_name='function:return_model:',
212 timestamp=IsNow(tz=timezone.utc),
213 ),
214 ModelRequest(
215 parts=[
216 RetryPromptPart(
217 content='"a" should be 42',
218 tool_name='final_result',
219 tool_call_id=IsStr(),
220 timestamp=IsNow(tz=timezone.utc),
221 )
222 ]
223 ),
224 ModelResponse(
225 parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())],
226 model_name='function:return_model:',
227 timestamp=IsNow(tz=timezone.utc),
228 ),
229 ModelRequest(
230 parts=[
231 ToolReturnPart(
232 tool_name='final_result',
233 content='Final result processed.',
234 tool_call_id=IsStr(),
235 timestamp=IsNow(tz=timezone.utc),
236 )
237 ]
238 ),
239 ]
240 )
243def test_plain_response_then_tuple():
244 call_index = 0
246 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
247 nonlocal call_index
249 assert info.result_tools is not None
250 call_index += 1
251 if call_index == 1:
252 return ModelResponse(parts=[TextPart('hello')])
253 else:
254 args_json = '{"response": ["foo", "bar"]}'
255 return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, args_json)])
257 agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str])
259 result = agent.run_sync('Hello')
260 assert result.data == ('foo', 'bar')
261 assert call_index == 2
262 assert result.all_messages() == snapshot(
263 [
264 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
265 ModelResponse(
266 parts=[TextPart(content='hello')],
267 model_name='function:return_tuple:',
268 timestamp=IsNow(tz=timezone.utc),
269 ),
270 ModelRequest(
271 parts=[
272 RetryPromptPart(
273 content='Plain text responses are not permitted, please call one of the functions instead.',
274 timestamp=IsNow(tz=timezone.utc),
275 tool_call_id=IsStr(),
276 )
277 ]
278 ),
279 ModelResponse(
280 parts=[
281 ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr())
282 ],
283 model_name='function:return_tuple:',
284 timestamp=IsNow(tz=timezone.utc),
285 ),
286 ModelRequest(
287 parts=[
288 ToolReturnPart(
289 tool_name='final_result',
290 content='Final result processed.',
291 tool_call_id=IsStr(),
292 timestamp=IsNow(tz=timezone.utc),
293 )
294 ]
295 ),
296 ]
297 )
298 assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage]
299 assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot(
300 ModelRequest(
301 parts=[
302 ToolReturnPart(
303 tool_name='final_result', content='foobar', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
304 )
305 ]
306 )
307 )
308 assert result.all_messages()[-1] == snapshot(
309 ModelRequest(
310 parts=[
311 ToolReturnPart(
312 tool_name='final_result',
313 content='Final result processed.',
314 tool_call_id=IsStr(),
315 timestamp=IsNow(tz=timezone.utc),
316 )
317 ]
318 )
319 )
322def test_result_tool_return_content_str_return():
323 agent = Agent('test')
325 result = agent.run_sync('Hello')
326 assert result.data == 'success (no tool calls)'
328 msg = re.escape('Cannot set result tool return content when the return type is `str`.')
329 with pytest.raises(ValueError, match=msg):
330 result.all_messages(result_tool_return_content='foobar')
333def test_result_tool_return_content_no_tool():
334 agent = Agent('test', result_type=int)
336 result = agent.run_sync('Hello')
337 assert result.data == 0
338 result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage]
339 with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")):
340 result.all_messages(result_tool_return_content='foobar')
343def test_response_tuple():
344 m = TestModel()
346 agent = Agent(m, result_type=tuple[str, str])
347 assert agent._result_schema.allow_text_result is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]
349 result = agent.run_sync('Hello')
350 assert result.data == snapshot(('a', 'a'))
352 assert m.last_model_request_parameters is not None
353 assert m.last_model_request_parameters.function_tools == snapshot([])
354 assert m.last_model_request_parameters.allow_text_result is False
356 assert m.last_model_request_parameters.result_tools is not None
357 assert len(m.last_model_request_parameters.result_tools) == 1
358 assert m.last_model_request_parameters.result_tools == snapshot(
359 [
360 ToolDefinition(
361 name='final_result',
362 description='The final response which ends this conversation',
363 parameters_json_schema={
364 'properties': {
365 'response': {
366 'maxItems': 2,
367 'minItems': 2,
368 'prefixItems': [{'type': 'string'}, {'type': 'string'}],
369 'title': 'Response',
370 'type': 'array',
371 }
372 },
373 'required': ['response'],
374 'type': 'object',
375 },
376 outer_typed_dict_key='response',
377 )
378 ]
379 )
382@pytest.mark.parametrize(
383 'input_union_callable',
384 [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str],
385 ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'],
386)
387def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
388 try:
389 union = input_union_callable()
390 except TypeError:
391 pytest.skip('Python version does not support `|` syntax for unions')
393 m = TestModel()
394 agent: Agent[None, Union[str, Foo]] = Agent(m, result_type=union)
396 got_tool_call_name = 'unset'
398 @agent.result_validator
399 def validate_result(ctx: RunContext[None], r: Any) -> Any:
400 nonlocal got_tool_call_name
401 got_tool_call_name = ctx.tool_name
402 return r
404 assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]
406 result = agent.run_sync('Hello')
407 assert result.data == snapshot('success (no tool calls)')
408 assert got_tool_call_name == snapshot(None)
410 assert m.last_model_request_parameters is not None
411 assert m.last_model_request_parameters.function_tools == snapshot([])
412 assert m.last_model_request_parameters.allow_text_result is True
414 assert m.last_model_request_parameters.result_tools is not None
415 assert len(m.last_model_request_parameters.result_tools) == 1
417 assert m.last_model_request_parameters.result_tools == snapshot(
418 [
419 ToolDefinition(
420 name='final_result',
421 description='The final response which ends this conversation',
422 parameters_json_schema={
423 'properties': {
424 'a': {'title': 'A', 'type': 'integer'},
425 'b': {'title': 'B', 'type': 'string'},
426 },
427 'required': ['a', 'b'],
428 'title': 'Foo',
429 'type': 'object',
430 },
431 )
432 ]
433 )
436# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
437@pytest.mark.parametrize(
438 'union_code',
439 [
440 pytest.param('ResultType = Union[Foo, Bar]'),
441 pytest.param('ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')),
442 pytest.param(
443 'ResultType: TypeAlias = Foo | Bar',
444 marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='Python 3.10+'),
445 ),
446 pytest.param(
447 'type ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+')
448 ),
449 ],
450)
451def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str):
452 module_code = f'''
453from pydantic import BaseModel
454from typing import Union
455from typing_extensions import TypeAlias
457class Foo(BaseModel):
458 a: int
459 b: str
462class Bar(BaseModel):
463 """This is a bar model."""
465 b: str
467{union_code}
468 '''
470 mod = create_module(module_code)
472 m = TestModel()
473 agent = Agent(m, result_type=mod.ResultType)
474 got_tool_call_name = 'unset'
476 @agent.result_validator
477 def validate_result(ctx: RunContext[None], r: Any) -> Any:
478 nonlocal got_tool_call_name
479 got_tool_call_name = ctx.tool_name
480 return r
482 result = agent.run_sync('Hello')
483 assert result.data == mod.Foo(a=0, b='a')
484 assert got_tool_call_name == snapshot('final_result_Foo')
486 assert m.last_model_request_parameters is not None
487 assert m.last_model_request_parameters.function_tools == snapshot([])
488 assert m.last_model_request_parameters.allow_text_result is False
490 assert m.last_model_request_parameters.result_tools is not None
491 assert len(m.last_model_request_parameters.result_tools) == 2
493 assert m.last_model_request_parameters.result_tools == snapshot(
494 [
495 ToolDefinition(
496 name='final_result_Foo',
497 description='Foo: The final response which ends this conversation',
498 parameters_json_schema={
499 'properties': {
500 'a': {'title': 'A', 'type': 'integer'},
501 'b': {'title': 'B', 'type': 'string'},
502 },
503 'required': ['a', 'b'],
504 'title': 'Foo',
505 'type': 'object',
506 },
507 ),
508 ToolDefinition(
509 name='final_result_Bar',
510 description='This is a bar model.',
511 parameters_json_schema={
512 'properties': {'b': {'title': 'B', 'type': 'string'}},
513 'required': ['b'],
514 'title': 'Bar',
515 'type': 'object',
516 },
517 ),
518 ]
519 )
521 result = agent.run_sync('Hello', model=TestModel(seed=1))
522 assert result.data == mod.Bar(b='b')
523 assert got_tool_call_name == snapshot('final_result_Bar')
526def test_run_with_history_new():
527 m = TestModel()
529 agent = Agent(m, system_prompt='Foobar')
531 @agent.tool_plain
532 async def ret_a(x: str) -> str:
533 return f'{x}-apple'
535 result1 = agent.run_sync('Hello')
536 assert result1.new_messages() == snapshot(
537 [
538 ModelRequest(
539 parts=[
540 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)),
541 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
542 ]
543 ),
544 ModelResponse(
545 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())],
546 model_name='test',
547 timestamp=IsNow(tz=timezone.utc),
548 ),
549 ModelRequest(
550 parts=[
551 ToolReturnPart(
552 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
553 )
554 ]
555 ),
556 ModelResponse(
557 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
558 ),
559 ]
560 )
562 # if we pass new_messages, system prompt is inserted before the message_history messages
563 result2 = agent.run_sync('Hello again', message_history=result1.new_messages())
564 assert result2.all_messages() == snapshot(
565 [
566 ModelRequest(
567 parts=[
568 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)),
569 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
570 ]
571 ),
572 ModelResponse(
573 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())],
574 model_name='test',
575 timestamp=IsNow(tz=timezone.utc),
576 ),
577 ModelRequest(
578 parts=[
579 ToolReturnPart(
580 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
581 )
582 ]
583 ),
584 ModelResponse(
585 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
586 ),
587 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]),
588 ModelResponse(
589 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
590 ),
591 ]
592 )
593 assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage]
594 assert result2.data == snapshot('{"ret_a":"a-apple"}')
595 assert result2._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage]
596 assert result2.usage() == snapshot(
597 Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None)
598 )
599 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()]
600 assert new_msg_part_kinds == snapshot(
601 [
602 ('request', ['system-prompt', 'user-prompt']),
603 ('response', ['tool-call']),
604 ('request', ['tool-return']),
605 ('response', ['text']),
606 ('request', ['user-prompt']),
607 ('response', ['text']),
608 ]
609 )
610 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",')
612 # if we pass all_messages, system prompt is NOT inserted before the message_history messages,
613 # so only one system prompt
614 result3 = agent.run_sync('Hello again', message_history=result1.all_messages())
615 # same as result2 except for datetimes
616 assert result3.all_messages() == snapshot(
617 [
618 ModelRequest(
619 parts=[
620 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)),
621 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
622 ]
623 ),
624 ModelResponse(
625 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())],
626 model_name='test',
627 timestamp=IsNow(tz=timezone.utc),
628 ),
629 ModelRequest(
630 parts=[
631 ToolReturnPart(
632 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
633 )
634 ]
635 ),
636 ModelResponse(
637 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
638 ),
639 ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]),
640 ModelResponse(
641 parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc)
642 ),
643 ]
644 )
645 assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage]
646 assert result3.data == snapshot('{"ret_a":"a-apple"}')
647 assert result3._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage]
648 assert result3.usage() == snapshot(
649 Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None)
650 )
653def test_run_with_history_new_structured():
654 m = TestModel()
656 class Response(BaseModel):
657 a: int
659 agent = Agent(m, system_prompt='Foobar', result_type=Response)
661 @agent.tool_plain
662 async def ret_a(x: str) -> str:
663 return f'{x}-apple'
665 result1 = agent.run_sync('Hello')
666 assert result1.new_messages() == snapshot(
667 [
668 ModelRequest(
669 parts=[
670 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)),
671 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
672 ]
673 ),
674 ModelResponse(
675 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())],
676 model_name='test',
677 timestamp=IsNow(tz=timezone.utc),
678 ),
679 ModelRequest(
680 parts=[
681 ToolReturnPart(
682 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
683 )
684 ]
685 ),
686 ModelResponse(
687 parts=[
688 ToolCallPart(
689 tool_name='final_result',
690 args={'a': 0},
691 tool_call_id=IsStr(),
692 )
693 ],
694 model_name='test',
695 timestamp=IsNow(tz=timezone.utc),
696 ),
697 ModelRequest(
698 parts=[
699 ToolReturnPart(
700 tool_name='final_result',
701 content='Final result processed.',
702 tool_call_id=IsStr(),
703 timestamp=IsNow(tz=timezone.utc),
704 )
705 ]
706 ),
707 ]
708 )
710 result2 = agent.run_sync('Hello again', message_history=result1.new_messages())
711 assert result2.all_messages() == snapshot(
712 [
713 ModelRequest(
714 parts=[
715 SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)),
716 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
717 ],
718 ),
719 ModelResponse(
720 parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())],
721 model_name='test',
722 timestamp=IsNow(tz=timezone.utc),
723 ),
724 ModelRequest(
725 parts=[
726 ToolReturnPart(
727 tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
728 )
729 ],
730 ),
731 ModelResponse(
732 parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())],
733 model_name='test',
734 timestamp=IsNow(tz=timezone.utc),
735 ),
736 ModelRequest(
737 parts=[
738 ToolReturnPart(
739 tool_name='final_result',
740 content='Final result processed.',
741 tool_call_id=IsStr(),
742 timestamp=IsNow(tz=timezone.utc),
743 ),
744 ],
745 ),
746 # second call, notice no repeated system prompt
747 ModelRequest(
748 parts=[
749 UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)),
750 ],
751 ),
752 ModelResponse(
753 parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())],
754 model_name='test',
755 timestamp=IsNow(tz=timezone.utc),
756 ),
757 ModelRequest(
758 parts=[
759 ToolReturnPart(
760 tool_name='final_result',
761 content='Final result processed.',
762 tool_call_id=IsStr(),
763 timestamp=IsNow(tz=timezone.utc),
764 ),
765 ]
766 ),
767 ]
768 )
769 assert result2.data == snapshot(Response(a=0))
770 assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage]
771 assert result2._result_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage]
772 assert result2.usage() == snapshot(
773 Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None)
774 )
775 new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()]
776 assert new_msg_part_kinds == snapshot(
777 [
778 ('request', ['system-prompt', 'user-prompt']),
779 ('response', ['tool-call']),
780 ('request', ['tool-return']),
781 ('response', ['tool-call']),
782 ('request', ['tool-return']),
783 ('request', ['user-prompt']),
784 ('response', ['tool-call']),
785 ('request', ['tool-return']),
786 ]
787 )
788 assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",')
791def test_empty_tool_calls():
792 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
793 return ModelResponse(parts=[])
795 agent = Agent(FunctionModel(empty))
797 with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
798 agent.run_sync('Hello')
801def test_unknown_tool():
802 def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
803 return ModelResponse(parts=[ToolCallPart('foobar', '{}')])
805 agent = Agent(FunctionModel(empty))
807 with capture_run_messages() as messages:
808 with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'):
809 agent.run_sync('Hello')
810 assert messages == snapshot(
811 [
812 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
813 ModelResponse(
814 parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())],
815 model_name='function:empty:',
816 timestamp=IsNow(tz=timezone.utc),
817 ),
818 ModelRequest(
819 parts=[
820 RetryPromptPart(
821 content="Unknown tool name: 'foobar'. No tools available.",
822 tool_call_id=IsStr(),
823 timestamp=IsNow(tz=timezone.utc),
824 )
825 ]
826 ),
827 ModelResponse(
828 parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())],
829 model_name='function:empty:',
830 timestamp=IsNow(tz=timezone.utc),
831 ),
832 ]
833 )
836def test_unknown_tool_fix():
837 def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
838 if len(m) > 1:
839 return ModelResponse(parts=[TextPart('success')])
840 else:
841 return ModelResponse(parts=[ToolCallPart('foobar', '{}')])
843 agent = Agent(FunctionModel(empty))
845 result = agent.run_sync('Hello')
846 assert result.data == 'success'
847 assert result.all_messages() == snapshot(
848 [
849 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
850 ModelResponse(
851 parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())],
852 model_name='function:empty:',
853 timestamp=IsNow(tz=timezone.utc),
854 ),
855 ModelRequest(
856 parts=[
857 RetryPromptPart(
858 content="Unknown tool name: 'foobar'. No tools available.",
859 tool_call_id=IsStr(),
860 timestamp=IsNow(tz=timezone.utc),
861 )
862 ]
863 ),
864 ModelResponse(
865 parts=[TextPart(content='success')],
866 model_name='function:empty:',
867 timestamp=IsNow(tz=timezone.utc),
868 ),
869 ]
870 )
873def test_model_requests_blocked(env: TestEnv):
874 env.set('GEMINI_API_KEY', 'foobar')
875 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)
877 with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
878 agent.run_sync('Hello')
881def test_override_model(env: TestEnv):
882 env.set('GEMINI_API_KEY', 'foobar')
883 agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)
885 with agent.override(model='test'):
886 result = agent.run_sync('Hello')
887 assert result.data == snapshot((0, 'a'))
890def test_override_model_no_model():
891 agent = Agent()
893 with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'):
894 with agent.override(model='test'):
895 agent.run_sync('Hello')
898def test_run_sync_multiple():
899 agent = Agent('test')
901 @agent.tool_plain
902 async def make_request() -> str:
903 async with httpx.AsyncClient() as client:
904 # use this as I suspect it's about the fastest globally available endpoint
905 try:
906 response = await client.get('https://cloudflare.com/cdn-cgi/trace')
907 except httpx.ConnectError: # pragma: no cover
908 pytest.skip('offline')
909 else:
910 return str(response.status_code)
912 for _ in range(2):
913 result = agent.run_sync('Hello')
914 assert result.data == '{"make_request":"200"}'
917async def test_agent_name():
918 my_agent = Agent('test')
920 assert my_agent.name is None
922 await my_agent.run('Hello', infer_name=False)
923 assert my_agent.name is None
925 await my_agent.run('Hello')
926 assert my_agent.name == 'my_agent'
929async def test_agent_name_already_set():
930 my_agent = Agent('test', name='fig_tree')
932 assert my_agent.name == 'fig_tree'
934 await my_agent.run('Hello')
935 assert my_agent.name == 'fig_tree'
938async def test_agent_name_changes():
939 my_agent = Agent('test')
941 await my_agent.run('Hello')
942 assert my_agent.name == 'my_agent'
944 new_agent = my_agent
945 del my_agent
947 await new_agent.run('Hello')
948 assert new_agent.name == 'my_agent'
951def test_name_from_global(create_module: Callable[[str], Any]):
952 module_code = """
953from pydantic_ai import Agent
955my_agent = Agent('test')
957def foo():
958 result = my_agent.run_sync('Hello')
959 return result.data
960"""
962 mod = create_module(module_code)
964 assert mod.my_agent.name is None
965 assert mod.foo() == snapshot('success (no tool calls)')
966 assert mod.my_agent.name == 'my_agent'
969class TestMultipleToolCalls:
970 """Tests for scenarios where multiple tool calls are made in a single response."""
972 pytestmark = pytest.mark.usefixtures('set_event_loop')
974 class ResultType(BaseModel):
975 """Result type used by all tests."""
977 value: str
979 def test_early_strategy_stops_after_first_final_result(self):
980 """Test that 'early' strategy stops processing regular tools after first final result."""
981 tool_called = []
983 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
984 assert info.result_tools is not None
985 return ModelResponse(
986 parts=[
987 ToolCallPart('final_result', {'value': 'final'}),
988 ToolCallPart('regular_tool', {'x': 1}),
989 ToolCallPart('another_tool', {'y': 2}),
990 ]
991 )
993 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
995 @agent.tool_plain
996 def regular_tool(x: int) -> int: # pragma: no cover
997 """A regular tool that should not be called."""
998 tool_called.append('regular_tool')
999 return x
1001 @agent.tool_plain
1002 def another_tool(y: int) -> int: # pragma: no cover
1003 """Another tool that should not be called."""
1004 tool_called.append('another_tool')
1005 return y
1007 result = agent.run_sync('test early strategy')
1008 messages = result.all_messages()
1010 # Verify no tools were called after final result
1011 assert tool_called == []
1013 # Verify we got tool returns for all calls
1014 assert messages[-1].parts == snapshot(
1015 [
1016 ToolReturnPart(
1017 tool_name='final_result',
1018 content='Final result processed.',
1019 tool_call_id=IsStr(),
1020 timestamp=IsNow(tz=timezone.utc),
1021 ),
1022 ToolReturnPart(
1023 tool_name='regular_tool',
1024 content='Tool not executed - a final result was already processed.',
1025 tool_call_id=IsStr(),
1026 timestamp=IsNow(tz=timezone.utc),
1027 ),
1028 ToolReturnPart(
1029 tool_name='another_tool',
1030 content='Tool not executed - a final result was already processed.',
1031 tool_call_id=IsStr(),
1032 timestamp=IsNow(tz=timezone.utc),
1033 ),
1034 ]
1035 )
1037 def test_early_strategy_uses_first_final_result(self):
1038 """Test that 'early' strategy uses the first final result and ignores subsequent ones."""
1040 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1041 assert info.result_tools is not None
1042 return ModelResponse(
1043 parts=[
1044 ToolCallPart('final_result', {'value': 'first'}),
1045 ToolCallPart('final_result', {'value': 'second'}),
1046 ]
1047 )
1049 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
1050 result = agent.run_sync('test multiple final results')
1052 # Verify the result came from the first final tool
1053 assert result.data.value == 'first'
1055 # Verify we got appropriate tool returns
1056 assert result.new_messages()[-1].parts == snapshot(
1057 [
1058 ToolReturnPart(
1059 tool_name='final_result',
1060 content='Final result processed.',
1061 tool_call_id=IsStr(),
1062 timestamp=IsNow(tz=timezone.utc),
1063 ),
1064 ToolReturnPart(
1065 tool_name='final_result',
1066 content='Result tool not used - a final result was already processed.',
1067 tool_call_id=IsStr(),
1068 timestamp=IsNow(tz=timezone.utc),
1069 ),
1070 ]
1071 )
1073 def test_exhaustive_strategy_executes_all_tools(self):
1074 """Test that 'exhaustive' strategy executes all tools while using first final result."""
1075 tool_called: list[str] = []
1077 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1078 assert info.result_tools is not None
1079 return ModelResponse(
1080 parts=[
1081 ToolCallPart('regular_tool', {'x': 42}),
1082 ToolCallPart('final_result', {'value': 'first'}),
1083 ToolCallPart('another_tool', {'y': 2}),
1084 ToolCallPart('final_result', {'value': 'second'}),
1085 ToolCallPart('unknown_tool', {'value': '???'}),
1086 ]
1087 )
1089 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='exhaustive')
1091 @agent.tool_plain
1092 def regular_tool(x: int) -> int:
1093 """A regular tool that should be called."""
1094 tool_called.append('regular_tool')
1095 return x
1097 @agent.tool_plain
1098 def another_tool(y: int) -> int:
1099 """Another tool that should be called."""
1100 tool_called.append('another_tool')
1101 return y
1103 result = agent.run_sync('test exhaustive strategy')
1105 # Verify the result came from the first final tool
1106 assert result.data.value == 'first'
1108 # Verify all regular tools were called
1109 assert sorted(tool_called) == sorted(['regular_tool', 'another_tool'])
1111 # Verify we got tool returns in the correct order
1112 assert result.all_messages() == snapshot(
1113 [
1114 ModelRequest(
1115 parts=[UserPromptPart(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc))]
1116 ),
1117 ModelResponse(
1118 parts=[
1119 ToolCallPart(tool_name='regular_tool', args={'x': 42}, tool_call_id=IsStr()),
1120 ToolCallPart(tool_name='final_result', args={'value': 'first'}, tool_call_id=IsStr()),
1121 ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()),
1122 ToolCallPart(tool_name='final_result', args={'value': 'second'}, tool_call_id=IsStr()),
1123 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()),
1124 ],
1125 model_name='function:return_model:',
1126 timestamp=IsNow(tz=timezone.utc),
1127 ),
1128 ModelRequest(
1129 parts=[
1130 ToolReturnPart(
1131 tool_name='final_result',
1132 content='Final result processed.',
1133 tool_call_id=IsStr(),
1134 timestamp=IsNow(tz=timezone.utc),
1135 ),
1136 ToolReturnPart(
1137 tool_name='final_result',
1138 content='Result tool not used - a final result was already processed.',
1139 tool_call_id=IsStr(),
1140 timestamp=IsNow(tz=timezone.utc),
1141 ),
1142 RetryPromptPart(
1143 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
1144 timestamp=IsNow(tz=timezone.utc),
1145 tool_call_id=IsStr(),
1146 ),
1147 ToolReturnPart(
1148 tool_name='regular_tool',
1149 content=42,
1150 tool_call_id=IsStr(),
1151 timestamp=IsNow(tz=timezone.utc),
1152 ),
1153 ToolReturnPart(
1154 tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
1155 ),
1156 ]
1157 ),
1158 ]
1159 )
1161 def test_early_strategy_with_final_result_in_middle(self):
1162 """Test that 'early' strategy stops at first final result, regardless of position."""
1163 tool_called = []
1165 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1166 assert info.result_tools is not None
1167 return ModelResponse(
1168 parts=[
1169 ToolCallPart('regular_tool', {'x': 1}),
1170 ToolCallPart('final_result', {'value': 'final'}),
1171 ToolCallPart('another_tool', {'y': 2}),
1172 ToolCallPart('unknown_tool', {'value': '???'}),
1173 ]
1174 )
1176 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
1178 @agent.tool_plain
1179 def regular_tool(x: int) -> int: # pragma: no cover
1180 """A regular tool that should not be called."""
1181 tool_called.append('regular_tool')
1182 return x
1184 @agent.tool_plain
1185 def another_tool(y: int) -> int: # pragma: no cover
1186 """A tool that should not be called."""
1187 tool_called.append('another_tool')
1188 return y
1190 result = agent.run_sync('test early strategy with final result in middle')
1192 # Verify no tools were called
1193 assert tool_called == []
1195 # Verify we got appropriate tool returns
1196 assert result.all_messages() == snapshot(
1197 [
1198 ModelRequest(
1199 parts=[
1200 UserPromptPart(
1201 content='test early strategy with final result in middle', timestamp=IsNow(tz=timezone.utc)
1202 )
1203 ]
1204 ),
1205 ModelResponse(
1206 parts=[
1207 ToolCallPart(tool_name='regular_tool', args={'x': 1}, tool_call_id=IsStr()),
1208 ToolCallPart(tool_name='final_result', args={'value': 'final'}, tool_call_id=IsStr()),
1209 ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()),
1210 ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()),
1211 ],
1212 model_name='function:return_model:',
1213 timestamp=IsNow(tz=timezone.utc),
1214 ),
1215 ModelRequest(
1216 parts=[
1217 ToolReturnPart(
1218 tool_name='regular_tool',
1219 content='Tool not executed - a final result was already processed.',
1220 tool_call_id=IsStr(),
1221 timestamp=IsNow(tz=timezone.utc),
1222 ),
1223 ToolReturnPart(
1224 tool_name='final_result',
1225 content='Final result processed.',
1226 tool_call_id=IsStr(),
1227 timestamp=IsNow(tz=timezone.utc),
1228 ),
1229 ToolReturnPart(
1230 tool_name='another_tool',
1231 content='Tool not executed - a final result was already processed.',
1232 tool_call_id=IsStr(),
1233 timestamp=IsNow(tz=timezone.utc),
1234 ),
1235 RetryPromptPart(
1236 content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
1237 timestamp=IsNow(tz=timezone.utc),
1238 tool_call_id=IsStr(),
1239 ),
1240 ]
1241 ),
1242 ]
1243 )
1245 def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(self):
1246 """Test that 'early' strategy does not apply to tool calls without final tool."""
1247 tool_called = []
1248 agent = Agent(TestModel(), result_type=self.ResultType, end_strategy='early')
1250 @agent.tool_plain
1251 def regular_tool(x: int) -> int:
1252 """A regular tool that should be called."""
1253 tool_called.append('regular_tool')
1254 return x
1256 result = agent.run_sync('test early strategy with regular tool calls')
1257 assert tool_called == ['regular_tool']
1259 tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)]
1260 assert tool_returns == snapshot([])
1262 def test_multiple_final_result_are_validated_correctly(self):
1263 """Tests that if multiple final results are returned, but one fails validation, the other is used."""
1265 def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1266 assert info.result_tools is not None
1267 return ModelResponse(
1268 parts=[
1269 ToolCallPart('final_result', {'bad_value': 'first'}, tool_call_id='first'),
1270 ToolCallPart('final_result', {'value': 'second'}, tool_call_id='second'),
1271 ]
1272 )
1274 agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
1275 result = agent.run_sync('test multiple final results')
1277 # Verify the result came from the second final tool
1278 assert result.data.value == 'second'
1280 # Verify we got appropriate tool returns
1281 assert result.new_messages()[-1].parts == snapshot(
1282 [
1283 ToolReturnPart(
1284 tool_name='final_result',
1285 tool_call_id='first',
1286 content='Result tool not used - result failed validation.',
1287 timestamp=IsNow(tz=timezone.utc),
1288 ),
1289 ToolReturnPart(
1290 tool_name='final_result',
1291 content='Final result processed.',
1292 timestamp=IsNow(tz=timezone.utc),
1293 tool_call_id='second',
1294 ),
1295 ]
1296 )
1299async def test_model_settings_override() -> None:
1300 def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1301 return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())])
1303 my_agent = Agent(FunctionModel(return_settings))
1304 assert (await my_agent.run('Hello')).data == IsJson(None)
1305 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5})
1307 my_agent = Agent(FunctionModel(return_settings), model_settings={'temperature': 0.1})
1308 assert (await my_agent.run('Hello')).data == IsJson({'temperature': 0.1})
1309 assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).data == IsJson({'temperature': 0.5})
1312async def test_empty_text_part():
1313 def return_empty_text(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1314 assert info.result_tools is not None
1315 args_json = '{"response": ["foo", "bar"]}'
1316 return ModelResponse(
1317 parts=[
1318 TextPart(''),
1319 ToolCallPart(info.result_tools[0].name, args_json),
1320 ]
1321 )
1323 agent = Agent(FunctionModel(return_empty_text), result_type=tuple[str, str])
1325 result = await agent.run('Hello')
1326 assert result.data == ('foo', 'bar')
1329def test_heterogeneous_responses_non_streaming() -> None:
1330 """Indicates that tool calls are prioritized over text in heterogeneous responses."""
1332 def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1333 assert info.result_tools is not None
1334 parts: list[ModelResponsePart] = []
1335 if len(messages) == 1:
1336 parts = [TextPart(content='foo'), ToolCallPart('get_location', {'loc_name': 'London'})]
1337 else:
1338 parts = [TextPart(content='final response')]
1339 return ModelResponse(parts=parts)
1341 agent = Agent(FunctionModel(return_model))
1343 @agent.tool_plain
1344 async def get_location(loc_name: str) -> str:
1345 if loc_name == 'London': 1345 ↛ 1348line 1345 didn't jump to line 1348 because the condition on line 1345 was always true
1346 return json.dumps({'lat': 51, 'lng': 0})
1347 else:
1348 raise ModelRetry('Wrong location, please try again')
1350 result = agent.run_sync('Hello')
1351 assert result.data == 'final response'
1352 assert result.all_messages() == snapshot(
1353 [
1354 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
1355 ModelResponse(
1356 parts=[
1357 TextPart(content='foo'),
1358 ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()),
1359 ],
1360 model_name='function:return_model:',
1361 timestamp=IsNow(tz=timezone.utc),
1362 ),
1363 ModelRequest(
1364 parts=[
1365 ToolReturnPart(
1366 tool_name='get_location',
1367 content='{"lat": 51, "lng": 0}',
1368 tool_call_id=IsStr(),
1369 timestamp=IsNow(tz=timezone.utc),
1370 )
1371 ]
1372 ),
1373 ModelResponse(
1374 parts=[TextPart(content='final response')],
1375 model_name='function:return_model:',
1376 timestamp=IsNow(tz=timezone.utc),
1377 ),
1378 ]
1379 )
1382def test_last_run_messages() -> None:
1383 agent = Agent('test')
1385 with pytest.raises(AttributeError, match='The `last_run_messages` attribute has been removed,'):
1386 agent.last_run_messages # pyright: ignore[reportDeprecated]
1389def test_nested_capture_run_messages() -> None:
1390 agent = Agent('test')
1392 with capture_run_messages() as messages1:
1393 assert messages1 == []
1394 with capture_run_messages() as messages2:
1395 assert messages2 == []
1396 assert messages1 is messages2
1397 result = agent.run_sync('Hello')
1398 assert result.data == 'success (no tool calls)'
1400 assert messages1 == snapshot(
1401 [
1402 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
1403 ModelResponse(
1404 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc)
1405 ),
1406 ]
1407 )
1408 assert messages1 == messages2
1411def test_double_capture_run_messages() -> None:
1412 agent = Agent('test')
1414 with capture_run_messages() as messages:
1415 assert messages == []
1416 result = agent.run_sync('Hello')
1417 assert result.data == 'success (no tool calls)'
1418 result2 = agent.run_sync('Hello 2')
1419 assert result2.data == 'success (no tool calls)'
1420 assert messages == snapshot(
1421 [
1422 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
1423 ModelResponse(
1424 parts=[TextPart(content='success (no tool calls)')], model_name='test', timestamp=IsNow(tz=timezone.utc)
1425 ),
1426 ]
1427 )
1430def test_dynamic_false_no_reevaluate():
1431 """When dynamic is false (default), the system prompt is not reevaluated
1432 i.e: SystemPromptPart(
1433 content="A", <--- Remains the same when `message_history` is passed.
1434 part_kind='system-prompt')
1435 """
1436 agent = Agent('test', system_prompt='Foobar')
1438 dynamic_value = 'A'
1440 @agent.system_prompt
1441 async def func() -> str:
1442 return dynamic_value
1444 res = agent.run_sync('Hello')
1446 assert res.all_messages() == snapshot(
1447 [
1448 ModelRequest(
1449 parts=[
1450 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)),
1451 SystemPromptPart(
1452 content=dynamic_value, part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)
1453 ),
1454 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1455 ],
1456 kind='request',
1457 ),
1458 ModelResponse(
1459 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1460 model_name='test',
1461 timestamp=IsNow(tz=timezone.utc),
1462 kind='response',
1463 ),
1464 ]
1465 )
1467 dynamic_value = 'B'
1469 res_two = agent.run_sync('World', message_history=res.all_messages())
1471 assert res_two.all_messages() == snapshot(
1472 [
1473 ModelRequest(
1474 parts=[
1475 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)),
1476 SystemPromptPart(
1477 content='A', # Remains the same
1478 part_kind='system-prompt',
1479 timestamp=IsNow(tz=timezone.utc),
1480 ),
1481 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1482 ],
1483 kind='request',
1484 ),
1485 ModelResponse(
1486 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1487 model_name='test',
1488 timestamp=IsNow(tz=timezone.utc),
1489 kind='response',
1490 ),
1491 ModelRequest(
1492 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
1493 kind='request',
1494 ),
1495 ModelResponse(
1496 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1497 model_name='test',
1498 timestamp=IsNow(tz=timezone.utc),
1499 kind='response',
1500 ),
1501 ]
1502 )
1505def test_dynamic_true_reevaluate_system_prompt():
1506 """When dynamic is true, the system prompt is reevaluated
1507 i.e: SystemPromptPart(
1508 content="B", <--- Updated value
1509 part_kind='system-prompt')
1510 """
1511 agent = Agent('test', system_prompt='Foobar')
1513 dynamic_value = 'A'
1515 @agent.system_prompt(dynamic=True)
1516 async def func():
1517 return dynamic_value
1519 res = agent.run_sync('Hello')
1521 assert res.all_messages() == snapshot(
1522 [
1523 ModelRequest(
1524 parts=[
1525 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)),
1526 SystemPromptPart(
1527 content=dynamic_value,
1528 part_kind='system-prompt',
1529 dynamic_ref=func.__qualname__,
1530 timestamp=IsNow(tz=timezone.utc),
1531 ),
1532 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1533 ],
1534 kind='request',
1535 ),
1536 ModelResponse(
1537 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1538 model_name='test',
1539 timestamp=IsNow(tz=timezone.utc),
1540 kind='response',
1541 ),
1542 ]
1543 )
1545 dynamic_value = 'B'
1547 res_two = agent.run_sync('World', message_history=res.all_messages())
1549 assert res_two.all_messages() == snapshot(
1550 [
1551 ModelRequest(
1552 parts=[
1553 SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)),
1554 SystemPromptPart(
1555 content='B',
1556 part_kind='system-prompt',
1557 dynamic_ref=func.__qualname__,
1558 timestamp=IsNow(tz=timezone.utc),
1559 ),
1560 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1561 ],
1562 kind='request',
1563 ),
1564 ModelResponse(
1565 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1566 model_name='test',
1567 timestamp=IsNow(tz=timezone.utc),
1568 kind='response',
1569 ),
1570 ModelRequest(
1571 parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
1572 kind='request',
1573 ),
1574 ModelResponse(
1575 parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1576 model_name='test',
1577 timestamp=IsNow(tz=timezone.utc),
1578 kind='response',
1579 ),
1580 ]
1581 )
1584def test_capture_run_messages_tool_agent() -> None:
1585 agent_outer = Agent('test')
1586 agent_inner = Agent(TestModel(custom_result_text='inner agent result'))
1588 @agent_outer.tool_plain
1589 async def foobar(x: str) -> str:
1590 result_ = await agent_inner.run(x)
1591 return result_.data
1593 with capture_run_messages() as messages:
1594 result = agent_outer.run_sync('foobar')
1596 assert result.data == snapshot('{"foobar":"inner agent result"}')
1597 assert messages == snapshot(
1598 [
1599 ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]),
1600 ModelResponse(
1601 parts=[ToolCallPart(tool_name='foobar', args={'x': 'a'}, tool_call_id=IsStr())],
1602 model_name='test',
1603 timestamp=IsNow(tz=timezone.utc),
1604 ),
1605 ModelRequest(
1606 parts=[
1607 ToolReturnPart(
1608 tool_name='foobar',
1609 content='inner agent result',
1610 tool_call_id=IsStr(),
1611 timestamp=IsNow(tz=timezone.utc),
1612 )
1613 ]
1614 ),
1615 ModelResponse(
1616 parts=[TextPart(content='{"foobar":"inner agent result"}')],
1617 model_name='test',
1618 timestamp=IsNow(tz=timezone.utc),
1619 ),
1620 ]
1621 )
1624class Bar(BaseModel):
1625 c: int
1626 d: str
1629def test_custom_result_type_sync() -> None:
1630 agent = Agent('test', result_type=Foo)
1632 assert agent.run_sync('Hello').data == snapshot(Foo(a=0, b='a'))
1633 assert agent.run_sync('Hello', result_type=Bar).data == snapshot(Bar(c=0, d='a'))
1634 assert agent.run_sync('Hello', result_type=str).data == snapshot('success (no tool calls)')
1635 assert agent.run_sync('Hello', result_type=int).data == snapshot(0)
1638async def test_custom_result_type_async() -> None:
1639 agent = Agent('test')
1641 result = await agent.run('Hello')
1642 assert result.data == snapshot('success (no tool calls)')
1644 result = await agent.run('Hello', result_type=Foo)
1645 assert result.data == snapshot(Foo(a=0, b='a'))
1646 result = await agent.run('Hello', result_type=int)
1647 assert result.data == snapshot(0)
1650def test_custom_result_type_invalid() -> None:
1651 agent = Agent('test')
1653 @agent.result_validator
1654 def validate_result(ctx: RunContext[None], r: Any) -> Any: # pragma: no cover
1655 return r
1657 with pytest.raises(UserError, match='Cannot set a custom run `result_type` when the agent has result validators'):
1658 agent.run_sync('Hello', result_type=int)
1661def test_binary_content_all_messages_json():
1662 agent = Agent('test')
1664 result = agent.run_sync(['Hello', BinaryContent(data=b'Hello', media_type='text/plain')])
1665 assert json.loads(result.all_messages_json()) == snapshot(
1666 [
1667 {
1668 'parts': [
1669 {
1670 'content': ['Hello', {'data': 'SGVsbG8=', 'media_type': 'text/plain', 'kind': 'binary'}],
1671 'timestamp': IsStr(),
1672 'part_kind': 'user-prompt',
1673 }
1674 ],
1675 'kind': 'request',
1676 },
1677 {
1678 'parts': [{'content': 'success (no tool calls)', 'part_kind': 'text'}],
1679 'model_name': 'test',
1680 'timestamp': IsStr(),
1681 'kind': 'response',
1682 },
1683 ]
1684 )