Coverage for tests/test_examples.py: 96.04%

179 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import json 

4import os 

5import re 

6import sys 

7from collections.abc import AsyncIterator, Iterable, Sequence 

8from dataclasses import dataclass 

9from inspect import FrameInfo 

10from io import StringIO 

11from pathlib import Path 

12from types import ModuleType 

13from typing import Any 

14 

15import httpx 

16import pytest 

17from _pytest.mark import ParameterSet 

18from devtools import debug 

19from pytest_examples import CodeExample, EvalExample, find_examples 

20from pytest_mock import MockerFixture 

21from rich.console import Console 

22 

23from pydantic_ai import ModelHTTPError 

24from pydantic_ai._utils import group_by_temporal 

25from pydantic_ai.exceptions import UnexpectedModelBehavior 

26from pydantic_ai.messages import ( 

27 ModelMessage, 

28 ModelResponse, 

29 RetryPromptPart, 

30 TextPart, 

31 ToolCallPart, 

32 ToolReturnPart, 

33 UserPromptPart, 

34) 

35from pydantic_ai.models import KnownModelName, Model, infer_model 

36from pydantic_ai.models.fallback import FallbackModel 

37from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel 

38from pydantic_ai.models.test import TestModel 

39 

40from .conftest import ClientWithHandler, TestEnv, try_import 

41 

42try: 

43 from pydantic_ai.providers.google_vertex import GoogleVertexProvider 

44except ImportError: 

45 GoogleVertexProvider = None 

46 

47 

48try: 

49 import logfire 

50except ImportError: 

51 logfire = None 

52 

53 

54with try_import() as imports_successful: 

55 from pydantic_evals.reporting import EvaluationReport 

56 

57 

58pytestmark = [ 

59 pytest.mark.skipif(not imports_successful(), reason='extras not installed'), 

60 pytest.mark.skipif(GoogleVertexProvider is None or logfire is None, reason='google-auth or logfire not installed'), 

61] 

62 

63 

64def find_filter_examples() -> Iterable[ParameterSet]: 

65 # Ensure this is run from the package root regardless of where/how the tests are run 

66 os.chdir(Path(__file__).parent.parent) 

67 

68 # TODO: need to add pydantic_evals to the following list, but some of those examples are broken 

69 # for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph', 'pydantic_evals'): 

70 for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph'): 

71 if ex.path.name != '_utils.py': 

72 prefix_settings = ex.prefix_settings() 

73 test_id = str(ex) 

74 if opt_title := prefix_settings.get('title'): 

75 test_id += f':{opt_title}' 

76 yield pytest.param(ex, id=test_id) 

77 

78 

79@pytest.mark.parametrize('example', find_filter_examples()) 

80def test_docs_examples( # noqa: C901 

81 example: CodeExample, 

82 eval_example: EvalExample, 

83 mocker: MockerFixture, 

84 client_with_handler: ClientWithHandler, 

85 allow_model_requests: None, 

86 env: TestEnv, 

87 tmp_path: Path, 

88): 

89 mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model) 

90 mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal) 

91 

92 mocker.patch('httpx.Client.get', side_effect=http_request) 

93 mocker.patch('httpx.Client.post', side_effect=http_request) 

94 mocker.patch('httpx.AsyncClient.get', side_effect=async_http_request) 

95 mocker.patch('httpx.AsyncClient.post', side_effect=async_http_request) 

96 mocker.patch('random.randint', return_value=4) 

97 mocker.patch('rich.prompt.Prompt.ask', side_effect=rich_prompt_ask) 

98 

99 class CustomEvaluationReport(EvaluationReport): 

100 def print(self, *args: Any, **kwargs: Any) -> None: 

101 if 'width' in kwargs: # pragma: no cover 

102 raise ValueError('width should not be passed to CustomEvaluationReport') 

103 table = self.console_table(*args, **kwargs) 

104 io_file = StringIO() 

105 Console(file=io_file, width=150).print(table) 

106 print(io_file.getvalue()) 

107 

108 mocker.patch('pydantic_evals.dataset.EvaluationReport', side_effect=CustomEvaluationReport) 

109 

110 if sys.version_info >= (3, 10): 

111 mocker.patch('pydantic_ai.mcp.MCPServerHTTP', return_value=MockMCPServer()) 

112 mocker.patch('mcp.server.fastmcp.FastMCP') 

113 

114 env.set('OPENAI_API_KEY', 'testing') 

115 env.set('GEMINI_API_KEY', 'testing') 

116 env.set('GROQ_API_KEY', 'testing') 

117 env.set('CO_API_KEY', 'testing') 

118 env.set('MISTRAL_API_KEY', 'testing') 

119 env.set('ANTHROPIC_API_KEY', 'testing') 

120 

121 sys.path.append('tests/example_modules') 

122 

123 prefix_settings = example.prefix_settings() 

124 opt_title = prefix_settings.get('title') 

125 opt_test = prefix_settings.get('test', '') 

126 opt_lint = prefix_settings.get('lint', '') 

127 noqa = prefix_settings.get('noqa', '') 

128 python_version = prefix_settings.get('py', None) 

129 

130 if python_version: 

131 python_version_info = tuple(int(v) for v in python_version.split('.')) 

132 if sys.version_info < python_version_info: 

133 pytest.skip(f'Python version {python_version} required') 

134 

135 cwd = Path.cwd() 

136 

137 if opt_test.startswith('skip') and opt_lint.startswith('skip'): 

138 pytest.skip('both running code and lint skipped') 

139 

140 if opt_title == 'sql_app_evals.py': 

141 os.chdir(tmp_path) 

142 examples = [{'request': f'sql prompt {i}', 'sql': f'SELECT {i}'} for i in range(15)] 

143 with (tmp_path / 'examples.json').open('w') as f: 

144 json.dump(examples, f) 

145 elif opt_title in { 

146 'ai_q_and_a_run.py', 

147 'count_down_from_persistence.py', 

148 'generate_dataset_example.py', 

149 'generate_dataset_example_json.py', 

150 'save_load_dataset_example.py', 

151 }: 

152 os.chdir(tmp_path) 

153 

154 ruff_ignore: list[str] = ['D', 'Q001'] 

155 # `from bank_database import DatabaseConn` wrongly sorted in imports 

156 # waiting for https://github.com/pydantic/pytest-examples/issues/43 

157 # and https://github.com/pydantic/pytest-examples/issues/46 

158 if 'import DatabaseConn' in example.source: 

159 ruff_ignore.append('I001') 

160 

161 if noqa: 

162 ruff_ignore.extend(noqa.upper().split()) 

163 

164 line_length = int(prefix_settings.get('line_length', '88')) 

165 

166 eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39', line_length=line_length) 

167 eval_example.print_callback = print_callback 

168 eval_example.include_print = custom_include_print 

169 

170 call_name = prefix_settings.get('call_name', 'main') 

171 

172 if not opt_lint.startswith('skip'): 

173 if eval_example.update_examples: # pragma: no cover 

174 eval_example.format(example) 

175 else: 

176 eval_example.lint(example) 

177 

178 if opt_test.startswith('skip'): 

179 print(opt_test[4:].lstrip(' -') or 'running code skipped') 

180 else: 

181 test_globals: dict[str, str] = {} 

182 if opt_title == 'mcp_client.py': 

183 test_globals['__name__'] = '__test__' 

184 if eval_example.update_examples: # pragma: no cover 

185 module_dict = eval_example.run_print_update(example, call=call_name, module_globals=test_globals) 

186 else: 

187 module_dict = eval_example.run_print_check(example, call=call_name, module_globals=test_globals) 

188 

189 os.chdir(cwd) 

190 if title := opt_title: 

191 if title.endswith('.py'): 

192 module_name = title[:-3] 

193 sys.modules[module_name] = module = ModuleType(module_name) 

194 module.__dict__.update(module_dict) 

195 

196 

197def print_callback(s: str) -> str: 

198 s = re.sub(r'datetime\.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL) 

199 s = re.sub(r'\d\.\d{4,}e-0\d', '0.0...', s) 

200 

201 # Replace durations below 100ms with 123µs 

202 s = re.sub(r'\b(:?\d+µs|\d{1,2}\.\d+ms)', r'123µs', s) 

203 # Replace durations above 100ms with 101.0ms 

204 s = re.sub(r'\b(:?\d{3,}\.\d+ms)', r'101.0ms', s) 

205 return re.sub(r'datetime.date\(', 'date(', s) 

206 

207 

208def custom_include_print(path: Path, frame: FrameInfo, args: Sequence[Any]) -> bool: 

209 return path.samefile(frame.filename) or frame.filename.endswith('test_examples.py') 

210 

211 

212def http_request(url: str, **kwargs: Any) -> httpx.Response: 

213 # sys.stdout.write(f'GET {args=} {kwargs=}\n') 

214 request = httpx.Request('GET', url, **kwargs) 

215 return httpx.Response(status_code=202, content='', request=request) 

216 

217 

218async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: 

219 return http_request(url, **kwargs) 

220 

221 

222def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: 

223 if prompt == 'Where would you like to fly from and to?': 

224 return 'SFO to ANC' 

225 elif prompt == 'What seat would you like?': 

226 return 'window seat with leg room' 

227 if prompt == 'Insert coins': 

228 return '1' 

229 elif prompt == 'Select product': 229 ↛ 231line 229 didn't jump to line 231 because the condition on line 229 was always true

230 return 'crisps' 

231 elif prompt == 'What is the capital of France?': 

232 return 'Vichy' 

233 elif prompt == 'what is 1 + 1?': 

234 return '2' 

235 else: # pragma: no cover 

236 raise ValueError(f'Unexpected prompt: {prompt}') 

237 

238 

239class MockMCPServer: 

240 is_running = True 

241 

242 async def __aenter__(self) -> MockMCPServer: 

243 return self 

244 

245 async def __aexit__(self, *args: Any) -> None: 

246 pass 

247 

248 @staticmethod 

249 async def list_tools() -> list[None]: 

250 return [] 

251 

252 

253text_responses: dict[str, str | ToolCallPart] = { 

254 'How many days between 2000-01-01 and 2025-03-18?': 'There are 9,208 days between January 1, 2000, and March 18, 2025.', 

255 'What is the weather like in West London and in Wiltshire?': ( 

256 'The weather in West London is raining, while in Wiltshire it is sunny.' 

257 ), 

258 'What will the weather be like in Paris on Tuesday?': ToolCallPart( 

259 tool_name='weather_forecast', args={'location': 'Paris', 'forecast_date': '2030-01-01'}, tool_call_id='0001' 

260 ), 

261 'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.', 

262 'Tell me a different joke.': 'No.', 

263 'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.', 

264 'What is the capital of France?': 'Paris', 

265 'What is the capital of Italy?': 'Rome', 

266 'What is the capital of the UK?': 'London', 

267 'Who was Albert Einstein?': 'Albert Einstein was a German-born theoretical physicist.', 

268 'What was his most famous equation?': "Albert Einstein's most famous equation is (E = mc^2).", 

269 'What is the date?': 'Hello Frank, the date today is 2032-01-02.', 

270 'Put my money on square eighteen': ToolCallPart( 

271 tool_name='roulette_wheel', args={'square': 18}, tool_call_id='pyd_ai_tool_call_id' 

272 ), 

273 'I bet five is the winner': ToolCallPart( 

274 tool_name='roulette_wheel', args={'square': 5}, tool_call_id='pyd_ai_tool_call_id' 

275 ), 

276 'My guess is 6': ToolCallPart(tool_name='roll_die', args={}, tool_call_id='pyd_ai_tool_call_id'), 

277 'My guess is 4': ToolCallPart(tool_name='roll_die', args={}, tool_call_id='pyd_ai_tool_call_id'), 

278 'Send a message to John Doe asking for coffee next week': ToolCallPart( 

279 tool_name='get_user_by_name', args={'name': 'John'} 

280 ), 

281 'Please get me the volume of a box with size 6.': ToolCallPart( 

282 tool_name='calc_volume', args={'size': 6}, tool_call_id='pyd_ai_tool_call_id' 

283 ), 

284 'Where does "hello world" come from?': ( 

285 'The first known use of "hello, world" was in a 1974 textbook about the C programming language.' 

286 ), 

287 'What is my balance?': ToolCallPart(tool_name='customer_balance', args={'include_pending': True}), 

288 'I just lost my card!': ToolCallPart( 

289 tool_name='final_result', 

290 args={ 

291 'support_advice': ( 

292 "I'm sorry to hear that, John. " 

293 'We are temporarily blocking your card to prevent unauthorized transactions.' 

294 ), 

295 'block_card': True, 

296 'risk': 8, 

297 }, 

298 ), 

299 'Where were the olympics held in 2012?': ToolCallPart( 

300 tool_name='final_result', 

301 args={'city': 'London', 'country': 'United Kingdom'}, 

302 ), 

303 'The box is 10x20x30': 'Please provide the units for the dimensions (e.g., cm, in, m).', 

304 'The box is 10x20x30 cm': ToolCallPart( 

305 tool_name='final_result', 

306 args={'width': 10, 'height': 20, 'depth': 30, 'units': 'cm'}, 

307 ), 

308 'red square, blue circle, green triangle': ToolCallPart( 

309 tool_name='final_result_list', 

310 args={'response': ['red', 'blue', 'green']}, 

311 ), 

312 'square size 10, circle size 20, triangle size 30': ToolCallPart( 

313 tool_name='final_result_list_2', 

314 args={'response': [10, 20, 30]}, 

315 ), 

316 'get me users who were last active yesterday.': ToolCallPart( 

317 tool_name='final_result_Success', 

318 args={'sql_query': 'SELECT * FROM users WHERE last_active::date = today() - interval 1 day'}, 

319 ), 

320 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.': ToolCallPart( 

321 tool_name='final_result', 

322 args={ 

323 'name': 'Ben', 

324 'dob': '1990-01-28', 

325 'bio': 'Likes the chain the dog and the pyramid', 

326 }, 

327 tool_call_id='pyd_ai_tool_call_id', 

328 ), 

329 'What is the capital of Italy? Answer with just the city.': 'Rome', 

330 'What is the capital of Italy? Answer with a paragraph.': ( 

331 'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.' 

332 'Rome is known for its rich history, stunning architecture, and delicious cuisine.' 

333 ), 

334 'Begin infinite retry loop!': ToolCallPart( 

335 tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id' 

336 ), 

337 'Please generate 5 jokes.': ToolCallPart( 

338 tool_name='final_result', 

339 args={'response': []}, 

340 tool_call_id='pyd_ai_tool_call_id', 

341 ), 

342 'SFO to ANC': ToolCallPart( 

343 tool_name='flight_search', 

344 args={'origin': 'SFO', 'destination': 'ANC'}, 

345 tool_call_id='pyd_ai_tool_call_id', 

346 ), 

347 'window seat with leg room': ToolCallPart( 

348 tool_name='final_result_SeatPreference', 

349 args={'row': 1, 'seat': 'A'}, 

350 tool_call_id='pyd_ai_tool_call_id', 

351 ), 

352 'Ask a simple question with a single correct answer.': 'What is the capital of France?', 

353 '<examples>\n <question>What is the capital of France?</question>\n <answer>Vichy</answer>\n</examples>': ToolCallPart( 

354 tool_name='final_result', 

355 args={'correct': False, 'comment': 'Vichy is no longer the capital of France.'}, 

356 tool_call_id='pyd_ai_tool_call_id', 

357 ), 

358 '<examples>\n <question>what is 1 + 1?</question>\n <answer>2</answer>\n</examples>': ToolCallPart( 

359 tool_name='final_result', 

360 args={'correct': True, 'comment': 'Well done, 1 + 1 = 2'}, 

361 tool_call_id='pyd_ai_tool_call_id', 

362 ), 

363 ( 

364 '<examples>\n' 

365 ' <dish_name>Spaghetti Bolognese</dish_name>\n' 

366 ' <dietary_restriction>vegetarian</dietary_restriction>\n' 

367 '</examples>' 

368 ): ToolCallPart( 

369 tool_name='final_result', 

370 args={ 

371 'ingredients': ['spaghetti', 'tomato sauce', 'vegetarian mince', 'onions', 'garlic'], 

372 'steps': ['Cook the spaghetti in boiling water', '...'], 

373 }, 

374 ), 

375 ( 

376 '<examples>\n' 

377 ' <dish_name>Chocolate Cake</dish_name>\n' 

378 ' <dietary_restriction>gluten-free</dietary_restriction>\n' 

379 '</examples>' 

380 ): ToolCallPart( 

381 tool_name='final_result', 

382 args={ 

383 'ingredients': ['gluten-free flour', 'cocoa powder', 'sugar', 'eggs'], 

384 'steps': ['Mix the ingredients', 'Bake at 350°F for 30 minutes'], 

385 }, 

386 ), 

387} 

388 

389tool_responses: dict[tuple[str, str], str] = { 

390 ( 

391 'weather_forecast', 

392 'The forecast in Paris on 2030-01-01 is 24°C and sunny.', 

393 ): 'It will be warm and sunny in Paris on Tuesday.', 

394} 

395 

396 

397async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover # noqa: C901 

398 m = messages[-1].parts[-1] 

399 if isinstance(m, UserPromptPart): 

400 assert isinstance(m.content, str) 

401 if m.content == 'Tell me a joke.' and any(t.name == 'joke_factory' for t in info.function_tools): 

402 return ModelResponse( 

403 parts=[ToolCallPart(tool_name='joke_factory', args={'count': 5}, tool_call_id='pyd_ai_tool_call_id')] 

404 ) 

405 elif m.content == 'Please generate 5 jokes.' and any(t.name == 'get_jokes' for t in info.function_tools): 

406 return ModelResponse( 

407 parts=[ToolCallPart(tool_name='get_jokes', args={'count': 5}, tool_call_id='pyd_ai_tool_call_id')] 

408 ) 

409 elif re.fullmatch(r'sql prompt \d+', m.content): 

410 return ModelResponse(parts=[TextPart('SELECT 1')]) 

411 elif m.content.startswith('Write a welcome email for the user:'): 

412 return ModelResponse( 

413 parts=[ 

414 ToolCallPart( 

415 tool_name='final_result', 

416 args={ 

417 'subject': 'Welcome to our tech blog!', 

418 'body': 'Hello John, Welcome to our tech blog! ...', 

419 }, 

420 tool_call_id='pyd_ai_tool_call_id', 

421 ) 

422 ] 

423 ) 

424 elif m.content.startswith('Write a list of 5 very rude things that I might say'): 

425 raise UnexpectedModelBehavior('Safety settings triggered', body='<safety settings details>') 

426 elif m.content.startswith('<examples>\n <user>'): 

427 return ModelResponse( 

428 parts=[ToolCallPart(tool_name='final_result_EmailOk', args={}, tool_call_id='pyd_ai_tool_call_id')] 

429 ) 

430 elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2: 

431 return ModelResponse(parts=[TextPart('what is 1 + 1?')]) 

432 elif '<Rubric>\n' in m.content: 

433 return ModelResponse( 

434 parts=[ToolCallPart(tool_name='final_result', args={'reason': '-', 'pass': True, 'score': 1.0})] 

435 ) 

436 elif 'Generate question-answer pairs about world capitals and landmarks.' in m.content: 

437 return ModelResponse( 

438 parts=[ 

439 TextPart( 

440 content=json.dumps( 

441 { 

442 'cases': [ 

443 { 

444 'name': 'Easy Capital Question', 

445 'inputs': {'question': 'What is the capital of France?'}, 

446 'metadata': {'difficulty': 'easy', 'category': 'Geography'}, 

447 'expected_output': {'answer': 'Paris', 'confidence': 0.95}, 

448 'evaluators': ['EqualsExpected'], 

449 }, 

450 { 

451 'name': 'Challenging Landmark Question', 

452 'inputs': { 

453 'question': 'Which world-famous landmark is located on the banks of the Seine River?', 

454 }, 

455 'metadata': {'difficulty': 'hard', 'category': 'Landmarks'}, 

456 'expected_output': {'answer': 'Eiffel Tower', 'confidence': 0.9}, 

457 'evaluators': ['EqualsExpected'], 

458 }, 

459 ], 

460 'evaluators': [], 

461 } 

462 ) 

463 ) 

464 ] 

465 ) 

466 elif response := text_responses.get(m.content): 

467 if isinstance(response, str): 

468 return ModelResponse(parts=[TextPart(response)]) 

469 else: 

470 return ModelResponse(parts=[response]) 

471 

472 elif isinstance(m, ToolReturnPart) and m.tool_name == 'roulette_wheel': 

473 win = m.content == 'winner' 

474 return ModelResponse( 

475 parts=[ToolCallPart(tool_name='final_result', args={'response': win}, tool_call_id='pyd_ai_tool_call_id')] 

476 ) 

477 elif isinstance(m, ToolReturnPart) and m.tool_name == 'roll_die': 

478 return ModelResponse( 

479 parts=[ToolCallPart(tool_name='get_player_name', args={}, tool_call_id='pyd_ai_tool_call_id')] 

480 ) 

481 elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_player_name': 

482 if 'Anne' in m.content: 

483 return ModelResponse(parts=[TextPart("Congratulations Anne, you guessed correctly! You're a winner!")]) 

484 elif 'Yashar' in m.content: 

485 return ModelResponse(parts=[TextPart('Tough luck, Yashar, you rolled a 4. Better luck next time.')]) 

486 if ( 

487 isinstance(m, RetryPromptPart) 

488 and isinstance(m.content, str) 

489 and m.content.startswith("No user found with name 'Joh") 

490 ): 

491 return ModelResponse( 

492 parts=[ 

493 ToolCallPart( 

494 tool_name='get_user_by_name', args={'name': 'John Doe'}, tool_call_id='pyd_ai_tool_call_id' 

495 ) 

496 ] 

497 ) 

498 elif isinstance(m, RetryPromptPart) and m.tool_name == 'infinite_retry_tool': 

499 return ModelResponse( 

500 parts=[ToolCallPart(tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id')] 

501 ) 

502 elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_user_by_name': 

503 args: dict[str, Any] = { 

504 'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!', 

505 'user_id': 123, 

506 } 

507 return ModelResponse( 

508 parts=[ToolCallPart(tool_name='final_result', args=args, tool_call_id='pyd_ai_tool_call_id')] 

509 ) 

510 elif isinstance(m, RetryPromptPart) and m.tool_name == 'calc_volume': 

511 return ModelResponse( 

512 parts=[ToolCallPart(tool_name='calc_volume', args={'size': 6}, tool_call_id='pyd_ai_tool_call_id')] 

513 ) 

514 elif isinstance(m, ToolReturnPart) and m.tool_name == 'customer_balance': 

515 args = { 

516 'support_advice': 'Hello John, your current account balance, including pending transactions, is $123.45.', 

517 'block_card': False, 

518 'risk': 1, 

519 } 

520 return ModelResponse( 

521 parts=[ToolCallPart(tool_name='final_result', args=args, tool_call_id='pyd_ai_tool_call_id')] 

522 ) 

523 elif isinstance(m, ToolReturnPart) and m.tool_name == 'joke_factory': 

524 return ModelResponse(parts=[TextPart('Did you hear about the toothpaste scandal? They called it Colgate.')]) 

525 elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_jokes': 

526 args = {'response': []} 

527 return ModelResponse( 

528 parts=[ToolCallPart(tool_name='final_result', args=args, tool_call_id='pyd_ai_tool_call_id')] 

529 ) 

530 elif isinstance(m, ToolReturnPart) and m.tool_name == 'flight_search': 

531 args = {'flight_number': m.content.flight_number} # type: ignore 

532 return ModelResponse( 

533 parts=[ToolCallPart(tool_name='final_result_FlightDetails', args=args, tool_call_id='pyd_ai_tool_call_id')] 

534 ) 

535 else: 

536 sys.stdout.write(str(debug.format(messages, info))) 

537 raise RuntimeError(f'Unexpected message: {m}') 

538 

539 

540async def stream_model_logic( # noqa C901 

541 messages: list[ModelMessage], info: AgentInfo 

542) -> AsyncIterator[str | DeltaToolCalls]: # pragma: no cover 

543 async def stream_text_response(r: str) -> AsyncIterator[str]: 

544 if isinstance(r, str): 

545 words = r.split(' ') 

546 chunk: list[str] = [] 

547 for word in words: 

548 chunk.append(word) 

549 if len(chunk) == 3: 

550 yield ' '.join(chunk) + ' ' 

551 chunk.clear() 

552 if chunk: 

553 yield ' '.join(chunk) 

554 

555 async def stream_tool_call_response(r: ToolCallPart) -> AsyncIterator[DeltaToolCalls]: 

556 json_text = r.args_as_json_str() 

557 

558 yield {1: DeltaToolCall(name=r.tool_name, tool_call_id=r.tool_call_id)} 

559 for chunk_index in range(0, len(json_text), 15): 

560 text_chunk = json_text[chunk_index : chunk_index + 15] 

561 yield {1: DeltaToolCall(json_args=text_chunk)} 

562 

563 async def stream_part_response(r: str | ToolCallPart) -> AsyncIterator[str | DeltaToolCalls]: 

564 if isinstance(r, str): 

565 async for chunk in stream_text_response(r): 

566 yield chunk 

567 else: 

568 async for chunk in stream_tool_call_response(r): 

569 yield chunk 

570 

571 last_part = messages[-1].parts[-1] 

572 if isinstance(last_part, UserPromptPart): 

573 assert isinstance(last_part.content, str) 

574 if response := text_responses.get(last_part.content): 

575 async for chunk in stream_part_response(response): 

576 yield chunk 

577 return 

578 elif isinstance(last_part, ToolReturnPart): 

579 assert isinstance(last_part.content, str) 

580 if response := tool_responses.get((last_part.tool_name, last_part.content)): 

581 async for chunk in stream_part_response(response): 

582 yield chunk 

583 return 

584 

585 sys.stdout.write(str(debug.format(messages, info))) 

586 raise RuntimeError(f'Unexpected message: {last_part}') 

587 

588 

589def mock_infer_model(model: Model | KnownModelName) -> Model: 

590 if model == 'test': 

591 return TestModel() 

592 

593 if isinstance(model, str): 

594 # Use the non-mocked model inference to ensure we get the same model name the user would 

595 model = infer_model(model) 

596 

597 if isinstance(model, FallbackModel): 

598 # When a fallback model is encountered, replace any OpenAIModel with a model that will raise a ModelHTTPError. 

599 # Otherwise, do the usual inference. 

600 def raise_http_error(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover 

601 raise ModelHTTPError(401, 'Invalid API Key') 

602 

603 mock_fallback_models: list[Model] = [] 

604 for m in model.models: 

605 try: 

606 from pydantic_ai.models.openai import OpenAIModel 

607 except ImportError: 

608 OpenAIModel = type(None) 

609 

610 if isinstance(m, OpenAIModel): 

611 # Raise an HTTP error for OpenAIModel 

612 mock_fallback_models.append(FunctionModel(raise_http_error, model_name=m.model_name)) 

613 else: 

614 mock_fallback_models.append(mock_infer_model(m)) 

615 return FallbackModel(*mock_fallback_models) 

616 if isinstance(model, (FunctionModel, TestModel)): 

617 return model 

618 else: 

619 model_name = model if isinstance(model, str) else model.model_name 

620 return FunctionModel(model_logic, stream_function=stream_model_logic, model_name=model_name) 

621 

622 

623def mock_group_by_temporal(aiter: Any, soft_max_interval: float | None) -> Any: 

624 """Mock group_by_temporal to avoid debouncing, since the iterators above have no delay.""" 

625 return group_by_temporal(aiter, None) 

626 

627 

628@dataclass 

629class MockCredentials: 

630 project_id = 'foobar'