Coverage for tests/test_tools.py: 99.42%

324 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1import json 

2from dataclasses import dataclass 

3from typing import Annotated, Any, Callable, Literal, Union 

4 

5import pydantic_core 

6import pytest 

7from _pytest.logging import LogCaptureFixture 

8from inline_snapshot import snapshot 

9from pydantic import BaseModel, Field, WithJsonSchema 

10from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue 

11from pydantic_core import PydanticSerializationError, core_schema 

12 

13from pydantic_ai import Agent, RunContext, Tool, UserError 

14from pydantic_ai.messages import ( 

15 ModelMessage, 

16 ModelRequest, 

17 ModelResponse, 

18 TextPart, 

19 ToolCallPart, 

20 ToolReturnPart, 

21) 

22from pydantic_ai.models.function import AgentInfo, FunctionModel 

23from pydantic_ai.models.test import TestModel 

24from pydantic_ai.tools import ToolDefinition 

25 

26 

27def test_tool_no_ctx(): 

28 agent = Agent(TestModel()) 

29 

30 with pytest.raises(UserError) as exc_info: 

31 

32 @agent.tool # pyright: ignore[reportArgumentType] 

33 def invalid_tool(x: int) -> str: # pragma: no cover 

34 return 'Hello' 

35 

36 assert str(exc_info.value) == snapshot( 

37 'Error generating schema for test_tool_no_ctx.<locals>.invalid_tool:\n' 

38 ' First parameter of tools that take context must be annotated with RunContext[...]' 

39 ) 

40 

41 

42def test_tool_plain_with_ctx(): 

43 agent = Agent(TestModel()) 

44 

45 with pytest.raises(UserError) as exc_info: 

46 

47 @agent.tool_plain 

48 async def invalid_tool(ctx: RunContext[None]) -> str: # pragma: no cover 

49 return 'Hello' 

50 

51 assert str(exc_info.value) == snapshot( 

52 'Error generating schema for test_tool_plain_with_ctx.<locals>.invalid_tool:\n' 

53 ' RunContext annotations can only be used with tools that take context' 

54 ) 

55 

56 

57def test_tool_ctx_second(): 

58 agent = Agent(TestModel()) 

59 

60 with pytest.raises(UserError) as exc_info: 

61 

62 @agent.tool # pyright: ignore[reportArgumentType] 

63 def invalid_tool(x: int, ctx: RunContext[None]) -> str: # pragma: no cover 

64 return 'Hello' 

65 

66 assert str(exc_info.value) == snapshot( 

67 'Error generating schema for test_tool_ctx_second.<locals>.invalid_tool:\n' 

68 ' First parameter of tools that take context must be annotated with RunContext[...]\n' 

69 ' RunContext annotations can only be used as the first argument' 

70 ) 

71 

72 

73async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover 

74 """Do foobar stuff, a lot. 

75 

76 Args: 

77 foo: The foo thing. 

78 bar: The bar thing. 

79 """ 

80 return f'{foo} {bar}' 

81 

82 

83async def get_json_schema(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

84 if len(info.function_tools) == 1: 

85 r = info.function_tools[0] 

86 return ModelResponse(parts=[TextPart(pydantic_core.to_json(r).decode())]) 

87 else: 

88 return ModelResponse(parts=[TextPart(pydantic_core.to_json(info.function_tools).decode())]) 

89 

90 

91@pytest.mark.parametrize('docstring_format', ['google', 'auto']) 

92def test_docstring_google(docstring_format: Literal['google', 'auto']): 

93 agent = Agent(FunctionModel(get_json_schema)) 

94 agent.tool_plain(docstring_format=docstring_format)(google_style_docstring) 

95 

96 result = agent.run_sync('Hello') 

97 json_schema = json.loads(result.data) 

98 assert json_schema == snapshot( 

99 { 

100 'name': 'google_style_docstring', 

101 'description': 'Do foobar stuff, a lot.', 

102 'parameters_json_schema': { 

103 'properties': { 

104 'foo': {'description': 'The foo thing.', 'type': 'integer'}, 

105 'bar': {'description': 'The bar thing.', 'type': 'string'}, 

106 }, 

107 'required': ['foo', 'bar'], 

108 'type': 'object', 

109 'additionalProperties': False, 

110 }, 

111 'outer_typed_dict_key': None, 

112 } 

113 ) 

114 keys = list(json_schema.keys()) 

115 # name should be the first key 

116 assert keys[0] == 'name' 

117 # description should be the second key 

118 assert keys[1] == 'description' 

119 

120 

121def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover 

122 """Sphinx style docstring. 

123 

124 :param foo: The foo thing. 

125 """ 

126 return str(foo) 

127 

128 

129@pytest.mark.parametrize('docstring_format', ['sphinx', 'auto']) 

130def test_docstring_sphinx(docstring_format: Literal['sphinx', 'auto']): 

131 agent = Agent(FunctionModel(get_json_schema)) 

132 agent.tool_plain(docstring_format=docstring_format)(sphinx_style_docstring) 

133 

134 result = agent.run_sync('Hello') 

135 json_schema = json.loads(result.data) 

136 assert json_schema == snapshot( 

137 { 

138 'name': 'sphinx_style_docstring', 

139 'description': 'Sphinx style docstring.', 

140 'parameters_json_schema': { 

141 'properties': {'foo': {'description': 'The foo thing.', 'type': 'integer'}}, 

142 'required': ['foo'], 

143 'type': 'object', 

144 'additionalProperties': False, 

145 }, 

146 'outer_typed_dict_key': None, 

147 } 

148 ) 

149 

150 

151def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover 

152 """Numpy style docstring. 

153 

154 Parameters 

155 ---------- 

156 foo : int 

157 The foo thing. 

158 bar : str 

159 The bar thing. 

160 """ 

161 return f'{foo} {bar}' 

162 

163 

164@pytest.mark.parametrize('docstring_format', ['numpy', 'auto']) 

165def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']): 

166 agent = Agent(FunctionModel(get_json_schema)) 

167 agent.tool_plain(docstring_format=docstring_format)(numpy_style_docstring) 

168 

169 result = agent.run_sync('Hello') 

170 json_schema = json.loads(result.data) 

171 assert json_schema == snapshot( 

172 { 

173 'name': 'numpy_style_docstring', 

174 'description': 'Numpy style docstring.', 

175 'parameters_json_schema': { 

176 'properties': { 

177 'foo': {'description': 'The foo thing.', 'type': 'integer'}, 

178 'bar': {'description': 'The bar thing.', 'type': 'string'}, 

179 }, 

180 'required': ['foo', 'bar'], 

181 'type': 'object', 

182 'additionalProperties': False, 

183 }, 

184 'outer_typed_dict_key': None, 

185 } 

186 ) 

187 

188 

189def test_google_style_with_returns(): 

190 agent = Agent(FunctionModel(get_json_schema)) 

191 

192 def my_tool(x: int) -> str: # pragma: no cover 

193 """A function that does something. 

194 

195 Args: 

196 x: The input value. 

197 

198 Returns: 

199 str: The result as a string. 

200 """ 

201 return str(x) 

202 

203 agent.tool_plain(my_tool) 

204 result = agent.run_sync('Hello') 

205 json_schema = json.loads(result.data) 

206 assert json_schema == snapshot( 

207 { 

208 'name': 'my_tool', 

209 'description': """\ 

210<summary>A function that does something.</summary> 

211<returns> 

212<type>str</type> 

213<description>The result as a string.</description> 

214</returns>\ 

215""", 

216 'parameters_json_schema': { 

217 'additionalProperties': False, 

218 'properties': {'x': {'description': 'The input value.', 'type': 'integer'}}, 

219 'required': ['x'], 

220 'type': 'object', 

221 }, 

222 'outer_typed_dict_key': None, 

223 } 

224 ) 

225 

226 

227def test_sphinx_style_with_returns(): 

228 agent = Agent(FunctionModel(get_json_schema)) 

229 

230 def my_tool(x: int) -> str: # pragma: no cover 

231 """A sphinx function with returns. 

232 

233 :param x: The input value. 

234 :rtype: str 

235 :return: The result as a string with type. 

236 """ 

237 return str(x) 

238 

239 agent.tool_plain(docstring_format='sphinx')(my_tool) 

240 result = agent.run_sync('Hello') 

241 json_schema = json.loads(result.data) 

242 assert json_schema == snapshot( 

243 { 

244 'name': 'my_tool', 

245 'description': """\ 

246<summary>A sphinx function with returns.</summary> 

247<returns> 

248<type>str</type> 

249<description>The result as a string with type.</description> 

250</returns>\ 

251""", 

252 'parameters_json_schema': { 

253 'additionalProperties': False, 

254 'properties': {'x': {'description': 'The input value.', 'type': 'integer'}}, 

255 'required': ['x'], 

256 'type': 'object', 

257 }, 

258 'outer_typed_dict_key': None, 

259 } 

260 ) 

261 

262 

263def test_numpy_style_with_returns(): 

264 agent = Agent(FunctionModel(get_json_schema)) 

265 

266 def my_tool(x: int) -> str: # pragma: no cover 

267 """A numpy function with returns. 

268 

269 Parameters 

270 ---------- 

271 x : int 

272 The input value. 

273 

274 Returns 

275 ------- 

276 str 

277 The result as a string with type. 

278 """ 

279 return str(x) 

280 

281 agent.tool_plain(docstring_format='numpy')(my_tool) 

282 result = agent.run_sync('Hello') 

283 json_schema = json.loads(result.data) 

284 assert json_schema == snapshot( 

285 { 

286 'name': 'my_tool', 

287 'description': """\ 

288<summary>A numpy function with returns.</summary> 

289<returns> 

290<type>str</type> 

291<description>The result as a string with type.</description> 

292</returns>\ 

293""", 

294 'parameters_json_schema': { 

295 'additionalProperties': False, 

296 'properties': {'x': {'description': 'The input value.', 'type': 'integer'}}, 

297 'required': ['x'], 

298 'type': 'object', 

299 }, 

300 'outer_typed_dict_key': None, 

301 } 

302 ) 

303 

304 

305def only_returns_type() -> str: # pragma: no cover 

306 """ 

307 

308 Returns: 

309 str: The result as a string. 

310 """ 

311 return 'foo' 

312 

313 

314def test_only_returns_type(): 

315 agent = Agent(FunctionModel(get_json_schema)) 

316 agent.tool_plain(only_returns_type) 

317 

318 result = agent.run_sync('Hello') 

319 json_schema = json.loads(result.data) 

320 assert json_schema == snapshot( 

321 { 

322 'name': 'only_returns_type', 

323 'description': """\ 

324<returns> 

325<type>str</type> 

326<description>The result as a string.</description> 

327</returns>\ 

328""", 

329 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 

330 'outer_typed_dict_key': None, 

331 } 

332 ) 

333 

334 

335def unknown_docstring(**kwargs: int) -> str: # pragma: no cover 

336 """Unknown style docstring.""" 

337 return str(kwargs) 

338 

339 

340def test_docstring_unknown(): 

341 agent = Agent(FunctionModel(get_json_schema)) 

342 agent.tool_plain(unknown_docstring) 

343 

344 result = agent.run_sync('Hello') 

345 json_schema = json.loads(result.data) 

346 assert json_schema == snapshot( 

347 { 

348 'name': 'unknown_docstring', 

349 'description': 'Unknown style docstring.', 

350 'parameters_json_schema': {'properties': {}, 'type': 'object', 'additionalProperties': True}, 

351 'outer_typed_dict_key': None, 

352 } 

353 ) 

354 

355 

356# fmt: off 

357async def google_style_docstring_no_body( 

358 foo: int, bar: Annotated[str, Field(description='from fields')] 

359) -> str: # pragma: no cover 

360 """ 

361 Args: 

362 foo: The foo thing. 

363 bar: The bar thing. 

364 """ 

365 

366 return f'{foo} {bar}' 

367# fmt: on 

368 

369 

370@pytest.mark.parametrize('docstring_format', ['google', 'auto']) 

371def test_docstring_google_no_body(docstring_format: Literal['google', 'auto']): 

372 agent = Agent(FunctionModel(get_json_schema)) 

373 agent.tool_plain(docstring_format=docstring_format)(google_style_docstring_no_body) 

374 

375 result = agent.run_sync('') 

376 json_schema = json.loads(result.data) 

377 assert json_schema == snapshot( 

378 { 

379 'name': 'google_style_docstring_no_body', 

380 'description': '', 

381 'parameters_json_schema': { 

382 'properties': { 

383 'foo': {'description': 'The foo thing.', 'type': 'integer'}, 

384 'bar': {'description': 'from fields', 'type': 'string'}, 

385 }, 

386 'required': ['foo', 'bar'], 

387 'type': 'object', 

388 'additionalProperties': False, 

389 }, 

390 'outer_typed_dict_key': None, 

391 } 

392 ) 

393 

394 

395class Foo(BaseModel): 

396 x: int 

397 y: str 

398 

399 

400def test_takes_just_model(): 

401 agent = Agent() 

402 

403 @agent.tool_plain 

404 def takes_just_model(model: Foo) -> str: 

405 return f'{model.x} {model.y}' 

406 

407 result = agent.run_sync('', model=FunctionModel(get_json_schema)) 

408 json_schema = json.loads(result.data) 

409 assert json_schema == snapshot( 

410 { 

411 'name': 'takes_just_model', 

412 'description': None, 

413 'parameters_json_schema': { 

414 'properties': { 

415 'x': {'type': 'integer'}, 

416 'y': {'type': 'string'}, 

417 }, 

418 'required': ['x', 'y'], 

419 'title': 'Foo', 

420 'type': 'object', 

421 }, 

422 'outer_typed_dict_key': None, 

423 } 

424 ) 

425 

426 result = agent.run_sync('', model=TestModel()) 

427 assert result.data == snapshot('{"takes_just_model":"0 a"}') 

428 

429 

430def test_takes_model_and_int(): 

431 agent = Agent() 

432 

433 @agent.tool_plain 

434 def takes_just_model(model: Foo, z: int) -> str: 

435 return f'{model.x} {model.y} {z}' 

436 

437 result = agent.run_sync('', model=FunctionModel(get_json_schema)) 

438 json_schema = json.loads(result.data) 

439 assert json_schema == snapshot( 

440 { 

441 'name': 'takes_just_model', 

442 'description': '', 

443 'parameters_json_schema': { 

444 '$defs': { 

445 'Foo': { 

446 'properties': { 

447 'x': {'type': 'integer'}, 

448 'y': {'type': 'string'}, 

449 }, 

450 'required': ['x', 'y'], 

451 'title': 'Foo', 

452 'type': 'object', 

453 } 

454 }, 

455 'properties': { 

456 'model': {'$ref': '#/$defs/Foo'}, 

457 'z': {'type': 'integer'}, 

458 }, 

459 'required': ['model', 'z'], 

460 'type': 'object', 

461 'additionalProperties': False, 

462 }, 

463 'outer_typed_dict_key': None, 

464 } 

465 ) 

466 

467 result = agent.run_sync('', model=TestModel()) 

468 assert result.data == snapshot('{"takes_just_model":"0 a 0"}') 

469 

470 

471# pyright: reportPrivateUsage=false 

472def test_init_tool_plain(): 

473 call_args: list[int] = [] 

474 

475 def plain_tool(x: int) -> int: 

476 call_args.append(x) 

477 return x + 1 

478 

479 agent = Agent('test', tools=[Tool(plain_tool)], retries=7) 

480 result = agent.run_sync('foobar') 

481 assert result.data == snapshot('{"plain_tool":1}') 

482 assert call_args == snapshot([0]) 

483 assert agent._function_tools['plain_tool'].takes_ctx is False 

484 assert agent._function_tools['plain_tool'].max_retries == 7 

485 

486 agent_infer = Agent('test', tools=[plain_tool], retries=7) 

487 result = agent_infer.run_sync('foobar') 

488 assert result.data == snapshot('{"plain_tool":1}') 

489 assert call_args == snapshot([0, 0]) 

490 assert agent_infer._function_tools['plain_tool'].takes_ctx is False 

491 assert agent_infer._function_tools['plain_tool'].max_retries == 7 

492 

493 

494def ctx_tool(ctx: RunContext[int], x: int) -> int: 

495 return x + ctx.deps 

496 

497 

498# pyright: reportPrivateUsage=false 

499def test_init_tool_ctx(): 

500 agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7) 

501 result = agent.run_sync('foobar', deps=5) 

502 assert result.data == snapshot('{"ctx_tool":5}') 

503 assert agent._function_tools['ctx_tool'].takes_ctx is True 

504 assert agent._function_tools['ctx_tool'].max_retries == 3 

505 

506 agent_infer = Agent('test', tools=[ctx_tool], deps_type=int) 

507 result = agent_infer.run_sync('foobar', deps=6) 

508 assert result.data == snapshot('{"ctx_tool":6}') 

509 assert agent_infer._function_tools['ctx_tool'].takes_ctx is True 

510 

511 

512def test_repeat_tool(): 

513 with pytest.raises(UserError, match="Tool name conflicts with existing tool: 'ctx_tool'"): 

514 Agent('test', tools=[Tool(ctx_tool), ctx_tool], deps_type=int) 

515 

516 

517def test_tool_return_conflict(): 

518 # this is okay 

519 Agent('test', tools=[ctx_tool], deps_type=int) 

520 # this is also okay 

521 Agent('test', tools=[ctx_tool], deps_type=int, result_type=int) 

522 # this raises an error 

523 with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"): 

524 Agent('test', tools=[ctx_tool], deps_type=int, result_type=int, result_tool_name='ctx_tool') 

525 

526 

527def test_init_ctx_tool_invalid(): 

528 def plain_tool(x: int) -> int: # pragma: no cover 

529 return x + 1 

530 

531 m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]' 

532 with pytest.raises(UserError, match=m): 

533 Tool(plain_tool, takes_ctx=True) 

534 

535 

536def test_init_plain_tool_invalid(): 

537 with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'): 

538 Tool(ctx_tool, takes_ctx=False) 

539 

540 

541def test_return_pydantic_model(): 

542 agent = Agent('test') 

543 

544 @agent.tool_plain 

545 def return_pydantic_model(x: int) -> Foo: 

546 return Foo(x=x, y='a') 

547 

548 result = agent.run_sync('') 

549 assert result.data == snapshot('{"return_pydantic_model":{"x":0,"y":"a"}}') 

550 

551 

552def test_return_bytes(): 

553 agent = Agent('test') 

554 

555 @agent.tool_plain 

556 def return_pydantic_model() -> bytes: 

557 return '🐈 Hello'.encode() 

558 

559 result = agent.run_sync('') 

560 assert result.data == snapshot('{"return_pydantic_model":"🐈 Hello"}') 

561 

562 

563def test_return_bytes_invalid(): 

564 agent = Agent('test') 

565 

566 @agent.tool_plain 

567 def return_pydantic_model() -> bytes: 

568 return b'\00 \x81' 

569 

570 with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'): 

571 agent.run_sync('') 

572 

573 

574def test_return_unknown(): 

575 agent = Agent('test') 

576 

577 class Foobar: 

578 pass 

579 

580 @agent.tool_plain 

581 def return_pydantic_model() -> Foobar: 

582 return Foobar() 

583 

584 with pytest.raises(PydanticSerializationError, match='Unable to serialize unknown type:'): 

585 agent.run_sync('') 

586 

587 

588def test_dynamic_cls_tool(): 

589 @dataclass 

590 class MyTool(Tool[int]): 

591 spam: int 

592 

593 def __init__(self, spam: int = 0, **kwargs: Any): 

594 self.spam = spam 

595 kwargs.update(function=self.tool_function, takes_ctx=False) 

596 super().__init__(**kwargs) 

597 

598 def tool_function(self, x: int, y: str) -> str: 

599 return f'{self.spam} {x} {y}' 

600 

601 async def prepare_tool_def(self, ctx: RunContext[int]) -> Union[ToolDefinition, None]: 

602 if ctx.deps != 42: 

603 return await super().prepare_tool_def(ctx) 

604 

605 agent = Agent('test', tools=[MyTool(spam=777)], deps_type=int) 

606 r = agent.run_sync('', deps=1) 

607 assert r.data == snapshot('{"tool_function":"777 0 a"}') 

608 

609 r = agent.run_sync('', deps=42) 

610 assert r.data == snapshot('success (no tool calls)') 

611 

612 

613def test_dynamic_plain_tool_decorator(): 

614 agent = Agent('test', deps_type=int) 

615 

616 async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]: 

617 if ctx.deps != 42: 

618 return tool_def 

619 

620 @agent.tool_plain(prepare=prepare_tool_def) 

621 def foobar(x: int, y: str) -> str: 

622 return f'{x} {y}' 

623 

624 r = agent.run_sync('', deps=1) 

625 assert r.data == snapshot('{"foobar":"0 a"}') 

626 

627 r = agent.run_sync('', deps=42) 

628 assert r.data == snapshot('success (no tool calls)') 

629 

630 

631def test_dynamic_tool_decorator(): 

632 agent = Agent('test', deps_type=int) 

633 

634 async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]: 

635 if ctx.deps != 42: 

636 return tool_def 

637 

638 @agent.tool(prepare=prepare_tool_def) 

639 def foobar(ctx: RunContext[int], x: int, y: str) -> str: 

640 return f'{ctx.deps} {x} {y}' 

641 

642 r = agent.run_sync('', deps=1) 

643 assert r.data == snapshot('{"foobar":"1 0 a"}') 

644 

645 r = agent.run_sync('', deps=42) 

646 assert r.data == snapshot('success (no tool calls)') 

647 

648 

649def test_plain_tool_name(): 

650 agent = Agent(FunctionModel(get_json_schema)) 

651 

652 def my_tool(arg: str) -> str: ... 652 ↛ exitline 652 didn't return from function 'my_tool' because

653 

654 agent.tool_plain(name='foo_tool')(my_tool) 

655 result = agent.run_sync('Hello') 

656 json_schema = json.loads(result.data) 

657 assert json_schema['name'] == 'foo_tool' 

658 

659 

660def test_tool_name(): 

661 agent = Agent(FunctionModel(get_json_schema)) 

662 

663 def my_tool(ctx: RunContext, arg: str) -> str: ... 663 ↛ exitline 663 didn't return from function 'my_tool' because

664 

665 agent.tool(name='foo_tool')(my_tool) 

666 result = agent.run_sync('Hello') 

667 json_schema = json.loads(result.data) 

668 assert json_schema['name'] == 'foo_tool' 

669 

670 

671def test_dynamic_tool_use_messages(): 

672 async def repeat_call_foobar(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

673 if info.function_tools: 

674 tool = info.function_tools[0] 

675 return ModelResponse(parts=[ToolCallPart(tool.name, {'x': 42, 'y': 'a'})]) 

676 else: 

677 return ModelResponse(parts=[TextPart('done')]) 

678 

679 agent = Agent(FunctionModel(repeat_call_foobar), deps_type=int) 

680 

681 async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]: 

682 if len(ctx.messages) < 5: 

683 return tool_def 

684 

685 @agent.tool(prepare=prepare_tool_def) 

686 def foobar(ctx: RunContext[int], x: int, y: str) -> str: 

687 return f'{ctx.deps} {x} {y}' 

688 

689 r = agent.run_sync('', deps=1) 

690 assert r.data == snapshot('done') 

691 message_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in r.all_messages()] 

692 assert message_part_kinds == snapshot( 

693 [ 

694 ('request', ['user-prompt']), 

695 ('response', ['tool-call']), 

696 ('request', ['tool-return']), 

697 ('response', ['tool-call']), 

698 ('request', ['tool-return']), 

699 ('response', ['text']), 

700 ] 

701 ) 

702 

703 

704def test_future_run_context(create_module: Callable[[str], Any]): 

705 mod = create_module(""" 

706from __future__ import annotations 

707 

708from pydantic_ai import Agent, RunContext 

709 

710def ctx_tool(ctx: RunContext[int], x: int) -> int: 

711 return x + ctx.deps 

712 

713agent = Agent('test', tools=[ctx_tool], deps_type=int) 

714 """) 

715 result = mod.agent.run_sync('foobar', deps=5) 

716 assert result.data == snapshot('{"ctx_tool":5}') 

717 

718 

719async def tool_without_return_annotation_in_docstring() -> str: # pragma: no cover 

720 """A tool that documents what it returns but doesn't have a return annotation in the docstring.""" 

721 

722 return '' 

723 

724 

725def test_suppress_griffe_logging(caplog: LogCaptureFixture): 

726 # This would cause griffe to emit a warning log if we didn't suppress the griffe logging. 

727 

728 agent = Agent(FunctionModel(get_json_schema)) 

729 agent.tool_plain(tool_without_return_annotation_in_docstring) 

730 

731 result = agent.run_sync('') 

732 json_schema = json.loads(result.data) 

733 assert json_schema == snapshot( 

734 { 

735 'description': "A tool that documents what it returns but doesn't have a return annotation in the docstring.", 

736 'name': 'tool_without_return_annotation_in_docstring', 

737 'outer_typed_dict_key': None, 

738 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 

739 } 

740 ) 

741 

742 # Without suppressing griffe logging, we get: 

743 # assert caplog.messages == snapshot(['<module>:4: No type or annotation for returned value 1']) 

744 assert caplog.messages == snapshot([]) 

745 

746 

747async def missing_parameter_descriptions_docstring(foo: int, bar: str) -> str: # pragma: no cover 

748 """Describes function ops, but missing parameter descriptions.""" 

749 return f'{foo} {bar}' 

750 

751 

752def test_enforce_parameter_descriptions() -> None: 

753 agent = Agent(FunctionModel(get_json_schema)) 

754 

755 with pytest.raises(UserError) as exc_info: 

756 agent.tool_plain(require_parameter_descriptions=True)(missing_parameter_descriptions_docstring) 

757 

758 error_reason = exc_info.value.args[0] 

759 error_parts = [ 

760 'Error generating schema for missing_parameter_descriptions_docstring', 

761 'Missing parameter descriptions for ', 

762 'foo', 

763 'bar', 

764 ] 

765 assert all(err_part in error_reason for err_part in error_parts) 

766 

767 

768def test_json_schema_required_parameters(set_event_loop: None): 

769 agent = Agent(FunctionModel(get_json_schema)) 

770 

771 @agent.tool 

772 def my_tool(ctx: RunContext[None], a: int, b: int = 1) -> int: 

773 raise NotImplementedError 

774 

775 @agent.tool_plain 

776 def my_tool_plain(*, a: int = 1, b: int) -> int: 

777 raise NotImplementedError 

778 

779 result = agent.run_sync('Hello') 

780 json_schema = json.loads(result.data) 

781 assert json_schema == snapshot( 

782 [ 

783 { 

784 'description': '', 

785 'name': 'my_tool', 

786 'outer_typed_dict_key': None, 

787 'parameters_json_schema': { 

788 'additionalProperties': False, 

789 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, 

790 'required': ['a'], 

791 'type': 'object', 

792 }, 

793 }, 

794 { 

795 'description': '', 

796 'name': 'my_tool_plain', 

797 'outer_typed_dict_key': None, 

798 'parameters_json_schema': { 

799 'additionalProperties': False, 

800 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, 

801 'required': ['b'], 

802 'type': 'object', 

803 }, 

804 }, 

805 ] 

806 ) 

807 

808 

809def test_call_tool_without_unrequired_parameters(set_event_loop: None): 

810 async def call_tools_first(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 

811 if len(messages) == 1: 

812 return ModelResponse( 

813 parts=[ 

814 ToolCallPart(tool_name='my_tool', args={'a': 13}), 

815 ToolCallPart(tool_name='my_tool', args={'a': 13, 'b': 4}), 

816 ToolCallPart(tool_name='my_tool_plain', args={'b': 17}), 

817 ToolCallPart(tool_name='my_tool_plain', args={'a': 4, 'b': 17}), 

818 ] 

819 ) 

820 else: 

821 return ModelResponse(parts=[TextPart('finished')]) 

822 

823 agent = Agent(FunctionModel(call_tools_first)) 

824 

825 @agent.tool 

826 def my_tool(ctx: RunContext[None], a: int, b: int = 2) -> int: 

827 return a + b 

828 

829 @agent.tool_plain 

830 def my_tool_plain(*, a: int = 3, b: int) -> int: 

831 return a * b 

832 

833 result = agent.run_sync('Hello') 

834 all_messages = result.all_messages() 

835 first_response = all_messages[1] 

836 second_request = all_messages[2] 

837 assert isinstance(first_response, ModelResponse) 

838 assert isinstance(second_request, ModelRequest) 

839 tool_call_args = [p.args for p in first_response.parts if isinstance(p, ToolCallPart)] 

840 tool_returns = [p.content for p in second_request.parts if isinstance(p, ToolReturnPart)] 

841 assert tool_call_args == snapshot( 

842 [ 

843 {'a': 13}, 

844 {'a': 13, 'b': 4}, 

845 {'b': 17}, 

846 {'a': 4, 'b': 17}, 

847 ] 

848 ) 

849 assert tool_returns == snapshot([15, 17, 51, 68]) 

850 

851 

852def test_schema_generator(): 

853 class MyGenerateJsonSchema(GenerateJsonSchema): 

854 def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue: 

855 # Add useless property titles just to show we can 

856 s = super().typed_dict_schema(schema) 

857 for p in s.get('properties', {}): 

858 s['properties'][p]['title'] = f'{s["properties"][p].get("title")} title' 

859 return s 

860 

861 agent = Agent(FunctionModel(get_json_schema)) 

862 

863 def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = None, **kwargs: Any): 

864 return x # pragma: no cover 

865 

866 agent.tool_plain(name='my_tool_1')(my_tool) 

867 agent.tool_plain(name='my_tool_2', schema_generator=MyGenerateJsonSchema)(my_tool) 

868 

869 result = agent.run_sync('Hello') 

870 json_schema = json.loads(result.data) 

871 assert json_schema == snapshot( 

872 [ 

873 { 

874 'description': '', 

875 'name': 'my_tool_1', 

876 'outer_typed_dict_key': None, 

877 'parameters_json_schema': { 

878 'additionalProperties': True, 

879 'properties': {'x': {'type': 'string'}}, 

880 'type': 'object', 

881 }, 

882 }, 

883 { 

884 'description': '', 

885 'name': 'my_tool_2', 

886 'outer_typed_dict_key': None, 

887 'parameters_json_schema': { 

888 'properties': {'x': {'type': 'string', 'title': 'X title'}}, 

889 'type': 'object', 

890 }, 

891 }, 

892 ] 

893 )