Coverage for tests/test_tools.py: 99.42%
324 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1import json
2from dataclasses import dataclass
3from typing import Annotated, Any, Callable, Literal, Union
5import pydantic_core
6import pytest
7from _pytest.logging import LogCaptureFixture
8from inline_snapshot import snapshot
9from pydantic import BaseModel, Field, WithJsonSchema
10from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
11from pydantic_core import PydanticSerializationError, core_schema
13from pydantic_ai import Agent, RunContext, Tool, UserError
14from pydantic_ai.messages import (
15 ModelMessage,
16 ModelRequest,
17 ModelResponse,
18 TextPart,
19 ToolCallPart,
20 ToolReturnPart,
21)
22from pydantic_ai.models.function import AgentInfo, FunctionModel
23from pydantic_ai.models.test import TestModel
24from pydantic_ai.tools import ToolDefinition
27def test_tool_no_ctx():
28 agent = Agent(TestModel())
30 with pytest.raises(UserError) as exc_info:
32 @agent.tool # pyright: ignore[reportArgumentType]
33 def invalid_tool(x: int) -> str: # pragma: no cover
34 return 'Hello'
36 assert str(exc_info.value) == snapshot(
37 'Error generating schema for test_tool_no_ctx.<locals>.invalid_tool:\n'
38 ' First parameter of tools that take context must be annotated with RunContext[...]'
39 )
42def test_tool_plain_with_ctx():
43 agent = Agent(TestModel())
45 with pytest.raises(UserError) as exc_info:
47 @agent.tool_plain
48 async def invalid_tool(ctx: RunContext[None]) -> str: # pragma: no cover
49 return 'Hello'
51 assert str(exc_info.value) == snapshot(
52 'Error generating schema for test_tool_plain_with_ctx.<locals>.invalid_tool:\n'
53 ' RunContext annotations can only be used with tools that take context'
54 )
57def test_tool_ctx_second():
58 agent = Agent(TestModel())
60 with pytest.raises(UserError) as exc_info:
62 @agent.tool # pyright: ignore[reportArgumentType]
63 def invalid_tool(x: int, ctx: RunContext[None]) -> str: # pragma: no cover
64 return 'Hello'
66 assert str(exc_info.value) == snapshot(
67 'Error generating schema for test_tool_ctx_second.<locals>.invalid_tool:\n'
68 ' First parameter of tools that take context must be annotated with RunContext[...]\n'
69 ' RunContext annotations can only be used as the first argument'
70 )
73async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover
74 """Do foobar stuff, a lot.
76 Args:
77 foo: The foo thing.
78 bar: The bar thing.
79 """
80 return f'{foo} {bar}'
83async def get_json_schema(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
84 if len(info.function_tools) == 1:
85 r = info.function_tools[0]
86 return ModelResponse(parts=[TextPart(pydantic_core.to_json(r).decode())])
87 else:
88 return ModelResponse(parts=[TextPart(pydantic_core.to_json(info.function_tools).decode())])
91@pytest.mark.parametrize('docstring_format', ['google', 'auto'])
92def test_docstring_google(docstring_format: Literal['google', 'auto']):
93 agent = Agent(FunctionModel(get_json_schema))
94 agent.tool_plain(docstring_format=docstring_format)(google_style_docstring)
96 result = agent.run_sync('Hello')
97 json_schema = json.loads(result.data)
98 assert json_schema == snapshot(
99 {
100 'name': 'google_style_docstring',
101 'description': 'Do foobar stuff, a lot.',
102 'parameters_json_schema': {
103 'properties': {
104 'foo': {'description': 'The foo thing.', 'type': 'integer'},
105 'bar': {'description': 'The bar thing.', 'type': 'string'},
106 },
107 'required': ['foo', 'bar'],
108 'type': 'object',
109 'additionalProperties': False,
110 },
111 'outer_typed_dict_key': None,
112 }
113 )
114 keys = list(json_schema.keys())
115 # name should be the first key
116 assert keys[0] == 'name'
117 # description should be the second key
118 assert keys[1] == 'description'
121def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover
122 """Sphinx style docstring.
124 :param foo: The foo thing.
125 """
126 return str(foo)
129@pytest.mark.parametrize('docstring_format', ['sphinx', 'auto'])
130def test_docstring_sphinx(docstring_format: Literal['sphinx', 'auto']):
131 agent = Agent(FunctionModel(get_json_schema))
132 agent.tool_plain(docstring_format=docstring_format)(sphinx_style_docstring)
134 result = agent.run_sync('Hello')
135 json_schema = json.loads(result.data)
136 assert json_schema == snapshot(
137 {
138 'name': 'sphinx_style_docstring',
139 'description': 'Sphinx style docstring.',
140 'parameters_json_schema': {
141 'properties': {'foo': {'description': 'The foo thing.', 'type': 'integer'}},
142 'required': ['foo'],
143 'type': 'object',
144 'additionalProperties': False,
145 },
146 'outer_typed_dict_key': None,
147 }
148 )
151def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover
152 """Numpy style docstring.
154 Parameters
155 ----------
156 foo : int
157 The foo thing.
158 bar : str
159 The bar thing.
160 """
161 return f'{foo} {bar}'
164@pytest.mark.parametrize('docstring_format', ['numpy', 'auto'])
165def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']):
166 agent = Agent(FunctionModel(get_json_schema))
167 agent.tool_plain(docstring_format=docstring_format)(numpy_style_docstring)
169 result = agent.run_sync('Hello')
170 json_schema = json.loads(result.data)
171 assert json_schema == snapshot(
172 {
173 'name': 'numpy_style_docstring',
174 'description': 'Numpy style docstring.',
175 'parameters_json_schema': {
176 'properties': {
177 'foo': {'description': 'The foo thing.', 'type': 'integer'},
178 'bar': {'description': 'The bar thing.', 'type': 'string'},
179 },
180 'required': ['foo', 'bar'],
181 'type': 'object',
182 'additionalProperties': False,
183 },
184 'outer_typed_dict_key': None,
185 }
186 )
189def test_google_style_with_returns():
190 agent = Agent(FunctionModel(get_json_schema))
192 def my_tool(x: int) -> str: # pragma: no cover
193 """A function that does something.
195 Args:
196 x: The input value.
198 Returns:
199 str: The result as a string.
200 """
201 return str(x)
203 agent.tool_plain(my_tool)
204 result = agent.run_sync('Hello')
205 json_schema = json.loads(result.data)
206 assert json_schema == snapshot(
207 {
208 'name': 'my_tool',
209 'description': """\
210<summary>A function that does something.</summary>
211<returns>
212<type>str</type>
213<description>The result as a string.</description>
214</returns>\
215""",
216 'parameters_json_schema': {
217 'additionalProperties': False,
218 'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
219 'required': ['x'],
220 'type': 'object',
221 },
222 'outer_typed_dict_key': None,
223 }
224 )
227def test_sphinx_style_with_returns():
228 agent = Agent(FunctionModel(get_json_schema))
230 def my_tool(x: int) -> str: # pragma: no cover
231 """A sphinx function with returns.
233 :param x: The input value.
234 :rtype: str
235 :return: The result as a string with type.
236 """
237 return str(x)
239 agent.tool_plain(docstring_format='sphinx')(my_tool)
240 result = agent.run_sync('Hello')
241 json_schema = json.loads(result.data)
242 assert json_schema == snapshot(
243 {
244 'name': 'my_tool',
245 'description': """\
246<summary>A sphinx function with returns.</summary>
247<returns>
248<type>str</type>
249<description>The result as a string with type.</description>
250</returns>\
251""",
252 'parameters_json_schema': {
253 'additionalProperties': False,
254 'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
255 'required': ['x'],
256 'type': 'object',
257 },
258 'outer_typed_dict_key': None,
259 }
260 )
263def test_numpy_style_with_returns():
264 agent = Agent(FunctionModel(get_json_schema))
266 def my_tool(x: int) -> str: # pragma: no cover
267 """A numpy function with returns.
269 Parameters
270 ----------
271 x : int
272 The input value.
274 Returns
275 -------
276 str
277 The result as a string with type.
278 """
279 return str(x)
281 agent.tool_plain(docstring_format='numpy')(my_tool)
282 result = agent.run_sync('Hello')
283 json_schema = json.loads(result.data)
284 assert json_schema == snapshot(
285 {
286 'name': 'my_tool',
287 'description': """\
288<summary>A numpy function with returns.</summary>
289<returns>
290<type>str</type>
291<description>The result as a string with type.</description>
292</returns>\
293""",
294 'parameters_json_schema': {
295 'additionalProperties': False,
296 'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
297 'required': ['x'],
298 'type': 'object',
299 },
300 'outer_typed_dict_key': None,
301 }
302 )
305def only_returns_type() -> str: # pragma: no cover
306 """
308 Returns:
309 str: The result as a string.
310 """
311 return 'foo'
314def test_only_returns_type():
315 agent = Agent(FunctionModel(get_json_schema))
316 agent.tool_plain(only_returns_type)
318 result = agent.run_sync('Hello')
319 json_schema = json.loads(result.data)
320 assert json_schema == snapshot(
321 {
322 'name': 'only_returns_type',
323 'description': """\
324<returns>
325<type>str</type>
326<description>The result as a string.</description>
327</returns>\
328""",
329 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'},
330 'outer_typed_dict_key': None,
331 }
332 )
335def unknown_docstring(**kwargs: int) -> str: # pragma: no cover
336 """Unknown style docstring."""
337 return str(kwargs)
340def test_docstring_unknown():
341 agent = Agent(FunctionModel(get_json_schema))
342 agent.tool_plain(unknown_docstring)
344 result = agent.run_sync('Hello')
345 json_schema = json.loads(result.data)
346 assert json_schema == snapshot(
347 {
348 'name': 'unknown_docstring',
349 'description': 'Unknown style docstring.',
350 'parameters_json_schema': {'properties': {}, 'type': 'object', 'additionalProperties': True},
351 'outer_typed_dict_key': None,
352 }
353 )
356# fmt: off
357async def google_style_docstring_no_body(
358 foo: int, bar: Annotated[str, Field(description='from fields')]
359) -> str: # pragma: no cover
360 """
361 Args:
362 foo: The foo thing.
363 bar: The bar thing.
364 """
366 return f'{foo} {bar}'
367# fmt: on
370@pytest.mark.parametrize('docstring_format', ['google', 'auto'])
371def test_docstring_google_no_body(docstring_format: Literal['google', 'auto']):
372 agent = Agent(FunctionModel(get_json_schema))
373 agent.tool_plain(docstring_format=docstring_format)(google_style_docstring_no_body)
375 result = agent.run_sync('')
376 json_schema = json.loads(result.data)
377 assert json_schema == snapshot(
378 {
379 'name': 'google_style_docstring_no_body',
380 'description': '',
381 'parameters_json_schema': {
382 'properties': {
383 'foo': {'description': 'The foo thing.', 'type': 'integer'},
384 'bar': {'description': 'from fields', 'type': 'string'},
385 },
386 'required': ['foo', 'bar'],
387 'type': 'object',
388 'additionalProperties': False,
389 },
390 'outer_typed_dict_key': None,
391 }
392 )
395class Foo(BaseModel):
396 x: int
397 y: str
400def test_takes_just_model():
401 agent = Agent()
403 @agent.tool_plain
404 def takes_just_model(model: Foo) -> str:
405 return f'{model.x} {model.y}'
407 result = agent.run_sync('', model=FunctionModel(get_json_schema))
408 json_schema = json.loads(result.data)
409 assert json_schema == snapshot(
410 {
411 'name': 'takes_just_model',
412 'description': None,
413 'parameters_json_schema': {
414 'properties': {
415 'x': {'type': 'integer'},
416 'y': {'type': 'string'},
417 },
418 'required': ['x', 'y'],
419 'title': 'Foo',
420 'type': 'object',
421 },
422 'outer_typed_dict_key': None,
423 }
424 )
426 result = agent.run_sync('', model=TestModel())
427 assert result.data == snapshot('{"takes_just_model":"0 a"}')
430def test_takes_model_and_int():
431 agent = Agent()
433 @agent.tool_plain
434 def takes_just_model(model: Foo, z: int) -> str:
435 return f'{model.x} {model.y} {z}'
437 result = agent.run_sync('', model=FunctionModel(get_json_schema))
438 json_schema = json.loads(result.data)
439 assert json_schema == snapshot(
440 {
441 'name': 'takes_just_model',
442 'description': '',
443 'parameters_json_schema': {
444 '$defs': {
445 'Foo': {
446 'properties': {
447 'x': {'type': 'integer'},
448 'y': {'type': 'string'},
449 },
450 'required': ['x', 'y'],
451 'title': 'Foo',
452 'type': 'object',
453 }
454 },
455 'properties': {
456 'model': {'$ref': '#/$defs/Foo'},
457 'z': {'type': 'integer'},
458 },
459 'required': ['model', 'z'],
460 'type': 'object',
461 'additionalProperties': False,
462 },
463 'outer_typed_dict_key': None,
464 }
465 )
467 result = agent.run_sync('', model=TestModel())
468 assert result.data == snapshot('{"takes_just_model":"0 a 0"}')
471# pyright: reportPrivateUsage=false
472def test_init_tool_plain():
473 call_args: list[int] = []
475 def plain_tool(x: int) -> int:
476 call_args.append(x)
477 return x + 1
479 agent = Agent('test', tools=[Tool(plain_tool)], retries=7)
480 result = agent.run_sync('foobar')
481 assert result.data == snapshot('{"plain_tool":1}')
482 assert call_args == snapshot([0])
483 assert agent._function_tools['plain_tool'].takes_ctx is False
484 assert agent._function_tools['plain_tool'].max_retries == 7
486 agent_infer = Agent('test', tools=[plain_tool], retries=7)
487 result = agent_infer.run_sync('foobar')
488 assert result.data == snapshot('{"plain_tool":1}')
489 assert call_args == snapshot([0, 0])
490 assert agent_infer._function_tools['plain_tool'].takes_ctx is False
491 assert agent_infer._function_tools['plain_tool'].max_retries == 7
494def ctx_tool(ctx: RunContext[int], x: int) -> int:
495 return x + ctx.deps
498# pyright: reportPrivateUsage=false
499def test_init_tool_ctx():
500 agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7)
501 result = agent.run_sync('foobar', deps=5)
502 assert result.data == snapshot('{"ctx_tool":5}')
503 assert agent._function_tools['ctx_tool'].takes_ctx is True
504 assert agent._function_tools['ctx_tool'].max_retries == 3
506 agent_infer = Agent('test', tools=[ctx_tool], deps_type=int)
507 result = agent_infer.run_sync('foobar', deps=6)
508 assert result.data == snapshot('{"ctx_tool":6}')
509 assert agent_infer._function_tools['ctx_tool'].takes_ctx is True
512def test_repeat_tool():
513 with pytest.raises(UserError, match="Tool name conflicts with existing tool: 'ctx_tool'"):
514 Agent('test', tools=[Tool(ctx_tool), ctx_tool], deps_type=int)
517def test_tool_return_conflict():
518 # this is okay
519 Agent('test', tools=[ctx_tool], deps_type=int)
520 # this is also okay
521 Agent('test', tools=[ctx_tool], deps_type=int, result_type=int)
522 # this raises an error
523 with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"):
524 Agent('test', tools=[ctx_tool], deps_type=int, result_type=int, result_tool_name='ctx_tool')
527def test_init_ctx_tool_invalid():
528 def plain_tool(x: int) -> int: # pragma: no cover
529 return x + 1
531 m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]'
532 with pytest.raises(UserError, match=m):
533 Tool(plain_tool, takes_ctx=True)
536def test_init_plain_tool_invalid():
537 with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'):
538 Tool(ctx_tool, takes_ctx=False)
541def test_return_pydantic_model():
542 agent = Agent('test')
544 @agent.tool_plain
545 def return_pydantic_model(x: int) -> Foo:
546 return Foo(x=x, y='a')
548 result = agent.run_sync('')
549 assert result.data == snapshot('{"return_pydantic_model":{"x":0,"y":"a"}}')
552def test_return_bytes():
553 agent = Agent('test')
555 @agent.tool_plain
556 def return_pydantic_model() -> bytes:
557 return '🐈 Hello'.encode()
559 result = agent.run_sync('')
560 assert result.data == snapshot('{"return_pydantic_model":"🐈 Hello"}')
563def test_return_bytes_invalid():
564 agent = Agent('test')
566 @agent.tool_plain
567 def return_pydantic_model() -> bytes:
568 return b'\00 \x81'
570 with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'):
571 agent.run_sync('')
574def test_return_unknown():
575 agent = Agent('test')
577 class Foobar:
578 pass
580 @agent.tool_plain
581 def return_pydantic_model() -> Foobar:
582 return Foobar()
584 with pytest.raises(PydanticSerializationError, match='Unable to serialize unknown type:'):
585 agent.run_sync('')
588def test_dynamic_cls_tool():
589 @dataclass
590 class MyTool(Tool[int]):
591 spam: int
593 def __init__(self, spam: int = 0, **kwargs: Any):
594 self.spam = spam
595 kwargs.update(function=self.tool_function, takes_ctx=False)
596 super().__init__(**kwargs)
598 def tool_function(self, x: int, y: str) -> str:
599 return f'{self.spam} {x} {y}'
601 async def prepare_tool_def(self, ctx: RunContext[int]) -> Union[ToolDefinition, None]:
602 if ctx.deps != 42:
603 return await super().prepare_tool_def(ctx)
605 agent = Agent('test', tools=[MyTool(spam=777)], deps_type=int)
606 r = agent.run_sync('', deps=1)
607 assert r.data == snapshot('{"tool_function":"777 0 a"}')
609 r = agent.run_sync('', deps=42)
610 assert r.data == snapshot('success (no tool calls)')
613def test_dynamic_plain_tool_decorator():
614 agent = Agent('test', deps_type=int)
616 async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]:
617 if ctx.deps != 42:
618 return tool_def
620 @agent.tool_plain(prepare=prepare_tool_def)
621 def foobar(x: int, y: str) -> str:
622 return f'{x} {y}'
624 r = agent.run_sync('', deps=1)
625 assert r.data == snapshot('{"foobar":"0 a"}')
627 r = agent.run_sync('', deps=42)
628 assert r.data == snapshot('success (no tool calls)')
631def test_dynamic_tool_decorator():
632 agent = Agent('test', deps_type=int)
634 async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]:
635 if ctx.deps != 42:
636 return tool_def
638 @agent.tool(prepare=prepare_tool_def)
639 def foobar(ctx: RunContext[int], x: int, y: str) -> str:
640 return f'{ctx.deps} {x} {y}'
642 r = agent.run_sync('', deps=1)
643 assert r.data == snapshot('{"foobar":"1 0 a"}')
645 r = agent.run_sync('', deps=42)
646 assert r.data == snapshot('success (no tool calls)')
649def test_plain_tool_name():
650 agent = Agent(FunctionModel(get_json_schema))
652 def my_tool(arg: str) -> str: ... 652 ↛ exitline 652 didn't return from function 'my_tool' because
654 agent.tool_plain(name='foo_tool')(my_tool)
655 result = agent.run_sync('Hello')
656 json_schema = json.loads(result.data)
657 assert json_schema['name'] == 'foo_tool'
660def test_tool_name():
661 agent = Agent(FunctionModel(get_json_schema))
663 def my_tool(ctx: RunContext, arg: str) -> str: ... 663 ↛ exitline 663 didn't return from function 'my_tool' because
665 agent.tool(name='foo_tool')(my_tool)
666 result = agent.run_sync('Hello')
667 json_schema = json.loads(result.data)
668 assert json_schema['name'] == 'foo_tool'
671def test_dynamic_tool_use_messages():
672 async def repeat_call_foobar(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
673 if info.function_tools:
674 tool = info.function_tools[0]
675 return ModelResponse(parts=[ToolCallPart(tool.name, {'x': 42, 'y': 'a'})])
676 else:
677 return ModelResponse(parts=[TextPart('done')])
679 agent = Agent(FunctionModel(repeat_call_foobar), deps_type=int)
681 async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]:
682 if len(ctx.messages) < 5:
683 return tool_def
685 @agent.tool(prepare=prepare_tool_def)
686 def foobar(ctx: RunContext[int], x: int, y: str) -> str:
687 return f'{ctx.deps} {x} {y}'
689 r = agent.run_sync('', deps=1)
690 assert r.data == snapshot('done')
691 message_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in r.all_messages()]
692 assert message_part_kinds == snapshot(
693 [
694 ('request', ['user-prompt']),
695 ('response', ['tool-call']),
696 ('request', ['tool-return']),
697 ('response', ['tool-call']),
698 ('request', ['tool-return']),
699 ('response', ['text']),
700 ]
701 )
704def test_future_run_context(create_module: Callable[[str], Any]):
705 mod = create_module("""
706from __future__ import annotations
708from pydantic_ai import Agent, RunContext
710def ctx_tool(ctx: RunContext[int], x: int) -> int:
711 return x + ctx.deps
713agent = Agent('test', tools=[ctx_tool], deps_type=int)
714 """)
715 result = mod.agent.run_sync('foobar', deps=5)
716 assert result.data == snapshot('{"ctx_tool":5}')
719async def tool_without_return_annotation_in_docstring() -> str: # pragma: no cover
720 """A tool that documents what it returns but doesn't have a return annotation in the docstring."""
722 return ''
725def test_suppress_griffe_logging(caplog: LogCaptureFixture):
726 # This would cause griffe to emit a warning log if we didn't suppress the griffe logging.
728 agent = Agent(FunctionModel(get_json_schema))
729 agent.tool_plain(tool_without_return_annotation_in_docstring)
731 result = agent.run_sync('')
732 json_schema = json.loads(result.data)
733 assert json_schema == snapshot(
734 {
735 'description': "A tool that documents what it returns but doesn't have a return annotation in the docstring.",
736 'name': 'tool_without_return_annotation_in_docstring',
737 'outer_typed_dict_key': None,
738 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'},
739 }
740 )
742 # Without suppressing griffe logging, we get:
743 # assert caplog.messages == snapshot(['<module>:4: No type or annotation for returned value 1'])
744 assert caplog.messages == snapshot([])
747async def missing_parameter_descriptions_docstring(foo: int, bar: str) -> str: # pragma: no cover
748 """Describes function ops, but missing parameter descriptions."""
749 return f'{foo} {bar}'
752def test_enforce_parameter_descriptions() -> None:
753 agent = Agent(FunctionModel(get_json_schema))
755 with pytest.raises(UserError) as exc_info:
756 agent.tool_plain(require_parameter_descriptions=True)(missing_parameter_descriptions_docstring)
758 error_reason = exc_info.value.args[0]
759 error_parts = [
760 'Error generating schema for missing_parameter_descriptions_docstring',
761 'Missing parameter descriptions for ',
762 'foo',
763 'bar',
764 ]
765 assert all(err_part in error_reason for err_part in error_parts)
768def test_json_schema_required_parameters(set_event_loop: None):
769 agent = Agent(FunctionModel(get_json_schema))
771 @agent.tool
772 def my_tool(ctx: RunContext[None], a: int, b: int = 1) -> int:
773 raise NotImplementedError
775 @agent.tool_plain
776 def my_tool_plain(*, a: int = 1, b: int) -> int:
777 raise NotImplementedError
779 result = agent.run_sync('Hello')
780 json_schema = json.loads(result.data)
781 assert json_schema == snapshot(
782 [
783 {
784 'description': '',
785 'name': 'my_tool',
786 'outer_typed_dict_key': None,
787 'parameters_json_schema': {
788 'additionalProperties': False,
789 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}},
790 'required': ['a'],
791 'type': 'object',
792 },
793 },
794 {
795 'description': '',
796 'name': 'my_tool_plain',
797 'outer_typed_dict_key': None,
798 'parameters_json_schema': {
799 'additionalProperties': False,
800 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}},
801 'required': ['b'],
802 'type': 'object',
803 },
804 },
805 ]
806 )
809def test_call_tool_without_unrequired_parameters(set_event_loop: None):
810 async def call_tools_first(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
811 if len(messages) == 1:
812 return ModelResponse(
813 parts=[
814 ToolCallPart(tool_name='my_tool', args={'a': 13}),
815 ToolCallPart(tool_name='my_tool', args={'a': 13, 'b': 4}),
816 ToolCallPart(tool_name='my_tool_plain', args={'b': 17}),
817 ToolCallPart(tool_name='my_tool_plain', args={'a': 4, 'b': 17}),
818 ]
819 )
820 else:
821 return ModelResponse(parts=[TextPart('finished')])
823 agent = Agent(FunctionModel(call_tools_first))
825 @agent.tool
826 def my_tool(ctx: RunContext[None], a: int, b: int = 2) -> int:
827 return a + b
829 @agent.tool_plain
830 def my_tool_plain(*, a: int = 3, b: int) -> int:
831 return a * b
833 result = agent.run_sync('Hello')
834 all_messages = result.all_messages()
835 first_response = all_messages[1]
836 second_request = all_messages[2]
837 assert isinstance(first_response, ModelResponse)
838 assert isinstance(second_request, ModelRequest)
839 tool_call_args = [p.args for p in first_response.parts if isinstance(p, ToolCallPart)]
840 tool_returns = [p.content for p in second_request.parts if isinstance(p, ToolReturnPart)]
841 assert tool_call_args == snapshot(
842 [
843 {'a': 13},
844 {'a': 13, 'b': 4},
845 {'b': 17},
846 {'a': 4, 'b': 17},
847 ]
848 )
849 assert tool_returns == snapshot([15, 17, 51, 68])
852def test_schema_generator():
853 class MyGenerateJsonSchema(GenerateJsonSchema):
854 def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
855 # Add useless property titles just to show we can
856 s = super().typed_dict_schema(schema)
857 for p in s.get('properties', {}):
858 s['properties'][p]['title'] = f'{s["properties"][p].get("title")} title'
859 return s
861 agent = Agent(FunctionModel(get_json_schema))
863 def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = None, **kwargs: Any):
864 return x # pragma: no cover
866 agent.tool_plain(name='my_tool_1')(my_tool)
867 agent.tool_plain(name='my_tool_2', schema_generator=MyGenerateJsonSchema)(my_tool)
869 result = agent.run_sync('Hello')
870 json_schema = json.loads(result.data)
871 assert json_schema == snapshot(
872 [
873 {
874 'description': '',
875 'name': 'my_tool_1',
876 'outer_typed_dict_key': None,
877 'parameters_json_schema': {
878 'additionalProperties': True,
879 'properties': {'x': {'type': 'string'}},
880 'type': 'object',
881 },
882 },
883 {
884 'description': '',
885 'name': 'my_tool_2',
886 'outer_typed_dict_key': None,
887 'parameters_json_schema': {
888 'properties': {'x': {'type': 'string', 'title': 'X title'}},
889 'type': 'object',
890 },
891 },
892 ]
893 )