Coverage for tests/test_examples.py: 96.04%
179 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import json
4import os
5import re
6import sys
7from collections.abc import AsyncIterator, Iterable, Sequence
8from dataclasses import dataclass
9from inspect import FrameInfo
10from io import StringIO
11from pathlib import Path
12from types import ModuleType
13from typing import Any
15import httpx
16import pytest
17from _pytest.mark import ParameterSet
18from devtools import debug
19from pytest_examples import CodeExample, EvalExample, find_examples
20from pytest_mock import MockerFixture
21from rich.console import Console
23from pydantic_ai import ModelHTTPError
24from pydantic_ai._utils import group_by_temporal
25from pydantic_ai.exceptions import UnexpectedModelBehavior
26from pydantic_ai.messages import (
27 ModelMessage,
28 ModelResponse,
29 RetryPromptPart,
30 TextPart,
31 ToolCallPart,
32 ToolReturnPart,
33 UserPromptPart,
34)
35from pydantic_ai.models import KnownModelName, Model, infer_model
36from pydantic_ai.models.fallback import FallbackModel
37from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
38from pydantic_ai.models.test import TestModel
40from .conftest import ClientWithHandler, TestEnv, try_import
42try:
43 from pydantic_ai.providers.google_vertex import GoogleVertexProvider
44except ImportError:
45 GoogleVertexProvider = None
48try:
49 import logfire
50except ImportError:
51 logfire = None
54with try_import() as imports_successful:
55 from pydantic_evals.reporting import EvaluationReport
58pytestmark = [
59 pytest.mark.skipif(not imports_successful(), reason='extras not installed'),
60 pytest.mark.skipif(GoogleVertexProvider is None or logfire is None, reason='google-auth or logfire not installed'),
61]
64def find_filter_examples() -> Iterable[ParameterSet]:
65 # Ensure this is run from the package root regardless of where/how the tests are run
66 os.chdir(Path(__file__).parent.parent)
68 # TODO: need to add pydantic_evals to the following list, but some of those examples are broken
69 # for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph', 'pydantic_evals'):
70 for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph'):
71 if ex.path.name != '_utils.py':
72 prefix_settings = ex.prefix_settings()
73 test_id = str(ex)
74 if opt_title := prefix_settings.get('title'):
75 test_id += f':{opt_title}'
76 yield pytest.param(ex, id=test_id)
79@pytest.mark.parametrize('example', find_filter_examples())
80def test_docs_examples( # noqa: C901
81 example: CodeExample,
82 eval_example: EvalExample,
83 mocker: MockerFixture,
84 client_with_handler: ClientWithHandler,
85 allow_model_requests: None,
86 env: TestEnv,
87 tmp_path: Path,
88):
89 mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model)
90 mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal)
92 mocker.patch('httpx.Client.get', side_effect=http_request)
93 mocker.patch('httpx.Client.post', side_effect=http_request)
94 mocker.patch('httpx.AsyncClient.get', side_effect=async_http_request)
95 mocker.patch('httpx.AsyncClient.post', side_effect=async_http_request)
96 mocker.patch('random.randint', return_value=4)
97 mocker.patch('rich.prompt.Prompt.ask', side_effect=rich_prompt_ask)
99 class CustomEvaluationReport(EvaluationReport):
100 def print(self, *args: Any, **kwargs: Any) -> None:
101 if 'width' in kwargs: # pragma: no cover
102 raise ValueError('width should not be passed to CustomEvaluationReport')
103 table = self.console_table(*args, **kwargs)
104 io_file = StringIO()
105 Console(file=io_file, width=150).print(table)
106 print(io_file.getvalue())
108 mocker.patch('pydantic_evals.dataset.EvaluationReport', side_effect=CustomEvaluationReport)
110 if sys.version_info >= (3, 10):
111 mocker.patch('pydantic_ai.mcp.MCPServerHTTP', return_value=MockMCPServer())
112 mocker.patch('mcp.server.fastmcp.FastMCP')
114 env.set('OPENAI_API_KEY', 'testing')
115 env.set('GEMINI_API_KEY', 'testing')
116 env.set('GROQ_API_KEY', 'testing')
117 env.set('CO_API_KEY', 'testing')
118 env.set('MISTRAL_API_KEY', 'testing')
119 env.set('ANTHROPIC_API_KEY', 'testing')
121 sys.path.append('tests/example_modules')
123 prefix_settings = example.prefix_settings()
124 opt_title = prefix_settings.get('title')
125 opt_test = prefix_settings.get('test', '')
126 opt_lint = prefix_settings.get('lint', '')
127 noqa = prefix_settings.get('noqa', '')
128 python_version = prefix_settings.get('py', None)
130 if python_version:
131 python_version_info = tuple(int(v) for v in python_version.split('.'))
132 if sys.version_info < python_version_info:
133 pytest.skip(f'Python version {python_version} required')
135 cwd = Path.cwd()
137 if opt_test.startswith('skip') and opt_lint.startswith('skip'):
138 pytest.skip('both running code and lint skipped')
140 if opt_title == 'sql_app_evals.py':
141 os.chdir(tmp_path)
142 examples = [{'request': f'sql prompt {i}', 'sql': f'SELECT {i}'} for i in range(15)]
143 with (tmp_path / 'examples.json').open('w') as f:
144 json.dump(examples, f)
145 elif opt_title in {
146 'ai_q_and_a_run.py',
147 'count_down_from_persistence.py',
148 'generate_dataset_example.py',
149 'generate_dataset_example_json.py',
150 'save_load_dataset_example.py',
151 }:
152 os.chdir(tmp_path)
154 ruff_ignore: list[str] = ['D', 'Q001']
155 # `from bank_database import DatabaseConn` wrongly sorted in imports
156 # waiting for https://github.com/pydantic/pytest-examples/issues/43
157 # and https://github.com/pydantic/pytest-examples/issues/46
158 if 'import DatabaseConn' in example.source:
159 ruff_ignore.append('I001')
161 if noqa:
162 ruff_ignore.extend(noqa.upper().split())
164 line_length = int(prefix_settings.get('line_length', '88'))
166 eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39', line_length=line_length)
167 eval_example.print_callback = print_callback
168 eval_example.include_print = custom_include_print
170 call_name = prefix_settings.get('call_name', 'main')
172 if not opt_lint.startswith('skip'):
173 if eval_example.update_examples: # pragma: no cover
174 eval_example.format(example)
175 else:
176 eval_example.lint(example)
178 if opt_test.startswith('skip'):
179 print(opt_test[4:].lstrip(' -') or 'running code skipped')
180 else:
181 test_globals: dict[str, str] = {}
182 if opt_title == 'mcp_client.py':
183 test_globals['__name__'] = '__test__'
184 if eval_example.update_examples: # pragma: no cover
185 module_dict = eval_example.run_print_update(example, call=call_name, module_globals=test_globals)
186 else:
187 module_dict = eval_example.run_print_check(example, call=call_name, module_globals=test_globals)
189 os.chdir(cwd)
190 if title := opt_title:
191 if title.endswith('.py'):
192 module_name = title[:-3]
193 sys.modules[module_name] = module = ModuleType(module_name)
194 module.__dict__.update(module_dict)
197def print_callback(s: str) -> str:
198 s = re.sub(r'datetime\.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL)
199 s = re.sub(r'\d\.\d{4,}e-0\d', '0.0...', s)
201 # Replace durations below 100ms with 123µs
202 s = re.sub(r'\b(:?\d+µs|\d{1,2}\.\d+ms)', r'123µs', s)
203 # Replace durations above 100ms with 101.0ms
204 s = re.sub(r'\b(:?\d{3,}\.\d+ms)', r'101.0ms', s)
205 return re.sub(r'datetime.date\(', 'date(', s)
208def custom_include_print(path: Path, frame: FrameInfo, args: Sequence[Any]) -> bool:
209 return path.samefile(frame.filename) or frame.filename.endswith('test_examples.py')
212def http_request(url: str, **kwargs: Any) -> httpx.Response:
213 # sys.stdout.write(f'GET {args=} {kwargs=}\n')
214 request = httpx.Request('GET', url, **kwargs)
215 return httpx.Response(status_code=202, content='', request=request)
218async def async_http_request(url: str, **kwargs: Any) -> httpx.Response:
219 return http_request(url, **kwargs)
222def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str:
223 if prompt == 'Where would you like to fly from and to?':
224 return 'SFO to ANC'
225 elif prompt == 'What seat would you like?':
226 return 'window seat with leg room'
227 if prompt == 'Insert coins':
228 return '1'
229 elif prompt == 'Select product': 229 ↛ 231line 229 didn't jump to line 231 because the condition on line 229 was always true
230 return 'crisps'
231 elif prompt == 'What is the capital of France?':
232 return 'Vichy'
233 elif prompt == 'what is 1 + 1?':
234 return '2'
235 else: # pragma: no cover
236 raise ValueError(f'Unexpected prompt: {prompt}')
239class MockMCPServer:
240 is_running = True
242 async def __aenter__(self) -> MockMCPServer:
243 return self
245 async def __aexit__(self, *args: Any) -> None:
246 pass
248 @staticmethod
249 async def list_tools() -> list[None]:
250 return []
253text_responses: dict[str, str | ToolCallPart] = {
254 'How many days between 2000-01-01 and 2025-03-18?': 'There are 9,208 days between January 1, 2000, and March 18, 2025.',
255 'What is the weather like in West London and in Wiltshire?': (
256 'The weather in West London is raining, while in Wiltshire it is sunny.'
257 ),
258 'What will the weather be like in Paris on Tuesday?': ToolCallPart(
259 tool_name='weather_forecast', args={'location': 'Paris', 'forecast_date': '2030-01-01'}, tool_call_id='0001'
260 ),
261 'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.',
262 'Tell me a different joke.': 'No.',
263 'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.',
264 'What is the capital of France?': 'Paris',
265 'What is the capital of Italy?': 'Rome',
266 'What is the capital of the UK?': 'London',
267 'Who was Albert Einstein?': 'Albert Einstein was a German-born theoretical physicist.',
268 'What was his most famous equation?': "Albert Einstein's most famous equation is (E = mc^2).",
269 'What is the date?': 'Hello Frank, the date today is 2032-01-02.',
270 'Put my money on square eighteen': ToolCallPart(
271 tool_name='roulette_wheel', args={'square': 18}, tool_call_id='pyd_ai_tool_call_id'
272 ),
273 'I bet five is the winner': ToolCallPart(
274 tool_name='roulette_wheel', args={'square': 5}, tool_call_id='pyd_ai_tool_call_id'
275 ),
276 'My guess is 6': ToolCallPart(tool_name='roll_die', args={}, tool_call_id='pyd_ai_tool_call_id'),
277 'My guess is 4': ToolCallPart(tool_name='roll_die', args={}, tool_call_id='pyd_ai_tool_call_id'),
278 'Send a message to John Doe asking for coffee next week': ToolCallPart(
279 tool_name='get_user_by_name', args={'name': 'John'}
280 ),
281 'Please get me the volume of a box with size 6.': ToolCallPart(
282 tool_name='calc_volume', args={'size': 6}, tool_call_id='pyd_ai_tool_call_id'
283 ),
284 'Where does "hello world" come from?': (
285 'The first known use of "hello, world" was in a 1974 textbook about the C programming language.'
286 ),
287 'What is my balance?': ToolCallPart(tool_name='customer_balance', args={'include_pending': True}),
288 'I just lost my card!': ToolCallPart(
289 tool_name='final_result',
290 args={
291 'support_advice': (
292 "I'm sorry to hear that, John. "
293 'We are temporarily blocking your card to prevent unauthorized transactions.'
294 ),
295 'block_card': True,
296 'risk': 8,
297 },
298 ),
299 'Where were the olympics held in 2012?': ToolCallPart(
300 tool_name='final_result',
301 args={'city': 'London', 'country': 'United Kingdom'},
302 ),
303 'The box is 10x20x30': 'Please provide the units for the dimensions (e.g., cm, in, m).',
304 'The box is 10x20x30 cm': ToolCallPart(
305 tool_name='final_result',
306 args={'width': 10, 'height': 20, 'depth': 30, 'units': 'cm'},
307 ),
308 'red square, blue circle, green triangle': ToolCallPart(
309 tool_name='final_result_list',
310 args={'response': ['red', 'blue', 'green']},
311 ),
312 'square size 10, circle size 20, triangle size 30': ToolCallPart(
313 tool_name='final_result_list_2',
314 args={'response': [10, 20, 30]},
315 ),
316 'get me users who were last active yesterday.': ToolCallPart(
317 tool_name='final_result_Success',
318 args={'sql_query': 'SELECT * FROM users WHERE last_active::date = today() - interval 1 day'},
319 ),
320 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.': ToolCallPart(
321 tool_name='final_result',
322 args={
323 'name': 'Ben',
324 'dob': '1990-01-28',
325 'bio': 'Likes the chain the dog and the pyramid',
326 },
327 tool_call_id='pyd_ai_tool_call_id',
328 ),
329 'What is the capital of Italy? Answer with just the city.': 'Rome',
330 'What is the capital of Italy? Answer with a paragraph.': (
331 'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.'
332 'Rome is known for its rich history, stunning architecture, and delicious cuisine.'
333 ),
334 'Begin infinite retry loop!': ToolCallPart(
335 tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id'
336 ),
337 'Please generate 5 jokes.': ToolCallPart(
338 tool_name='final_result',
339 args={'response': []},
340 tool_call_id='pyd_ai_tool_call_id',
341 ),
342 'SFO to ANC': ToolCallPart(
343 tool_name='flight_search',
344 args={'origin': 'SFO', 'destination': 'ANC'},
345 tool_call_id='pyd_ai_tool_call_id',
346 ),
347 'window seat with leg room': ToolCallPart(
348 tool_name='final_result_SeatPreference',
349 args={'row': 1, 'seat': 'A'},
350 tool_call_id='pyd_ai_tool_call_id',
351 ),
352 'Ask a simple question with a single correct answer.': 'What is the capital of France?',
353 '<examples>\n <question>What is the capital of France?</question>\n <answer>Vichy</answer>\n</examples>': ToolCallPart(
354 tool_name='final_result',
355 args={'correct': False, 'comment': 'Vichy is no longer the capital of France.'},
356 tool_call_id='pyd_ai_tool_call_id',
357 ),
358 '<examples>\n <question>what is 1 + 1?</question>\n <answer>2</answer>\n</examples>': ToolCallPart(
359 tool_name='final_result',
360 args={'correct': True, 'comment': 'Well done, 1 + 1 = 2'},
361 tool_call_id='pyd_ai_tool_call_id',
362 ),
363 (
364 '<examples>\n'
365 ' <dish_name>Spaghetti Bolognese</dish_name>\n'
366 ' <dietary_restriction>vegetarian</dietary_restriction>\n'
367 '</examples>'
368 ): ToolCallPart(
369 tool_name='final_result',
370 args={
371 'ingredients': ['spaghetti', 'tomato sauce', 'vegetarian mince', 'onions', 'garlic'],
372 'steps': ['Cook the spaghetti in boiling water', '...'],
373 },
374 ),
375 (
376 '<examples>\n'
377 ' <dish_name>Chocolate Cake</dish_name>\n'
378 ' <dietary_restriction>gluten-free</dietary_restriction>\n'
379 '</examples>'
380 ): ToolCallPart(
381 tool_name='final_result',
382 args={
383 'ingredients': ['gluten-free flour', 'cocoa powder', 'sugar', 'eggs'],
384 'steps': ['Mix the ingredients', 'Bake at 350°F for 30 minutes'],
385 },
386 ),
387}
389tool_responses: dict[tuple[str, str], str] = {
390 (
391 'weather_forecast',
392 'The forecast in Paris on 2030-01-01 is 24°C and sunny.',
393 ): 'It will be warm and sunny in Paris on Tuesday.',
394}
397async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover # noqa: C901
398 m = messages[-1].parts[-1]
399 if isinstance(m, UserPromptPart):
400 assert isinstance(m.content, str)
401 if m.content == 'Tell me a joke.' and any(t.name == 'joke_factory' for t in info.function_tools):
402 return ModelResponse(
403 parts=[ToolCallPart(tool_name='joke_factory', args={'count': 5}, tool_call_id='pyd_ai_tool_call_id')]
404 )
405 elif m.content == 'Please generate 5 jokes.' and any(t.name == 'get_jokes' for t in info.function_tools):
406 return ModelResponse(
407 parts=[ToolCallPart(tool_name='get_jokes', args={'count': 5}, tool_call_id='pyd_ai_tool_call_id')]
408 )
409 elif re.fullmatch(r'sql prompt \d+', m.content):
410 return ModelResponse(parts=[TextPart('SELECT 1')])
411 elif m.content.startswith('Write a welcome email for the user:'):
412 return ModelResponse(
413 parts=[
414 ToolCallPart(
415 tool_name='final_result',
416 args={
417 'subject': 'Welcome to our tech blog!',
418 'body': 'Hello John, Welcome to our tech blog! ...',
419 },
420 tool_call_id='pyd_ai_tool_call_id',
421 )
422 ]
423 )
424 elif m.content.startswith('Write a list of 5 very rude things that I might say'):
425 raise UnexpectedModelBehavior('Safety settings triggered', body='<safety settings details>')
426 elif m.content.startswith('<examples>\n <user>'):
427 return ModelResponse(
428 parts=[ToolCallPart(tool_name='final_result_EmailOk', args={}, tool_call_id='pyd_ai_tool_call_id')]
429 )
430 elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2:
431 return ModelResponse(parts=[TextPart('what is 1 + 1?')])
432 elif '<Rubric>\n' in m.content:
433 return ModelResponse(
434 parts=[ToolCallPart(tool_name='final_result', args={'reason': '-', 'pass': True, 'score': 1.0})]
435 )
436 elif 'Generate question-answer pairs about world capitals and landmarks.' in m.content:
437 return ModelResponse(
438 parts=[
439 TextPart(
440 content=json.dumps(
441 {
442 'cases': [
443 {
444 'name': 'Easy Capital Question',
445 'inputs': {'question': 'What is the capital of France?'},
446 'metadata': {'difficulty': 'easy', 'category': 'Geography'},
447 'expected_output': {'answer': 'Paris', 'confidence': 0.95},
448 'evaluators': ['EqualsExpected'],
449 },
450 {
451 'name': 'Challenging Landmark Question',
452 'inputs': {
453 'question': 'Which world-famous landmark is located on the banks of the Seine River?',
454 },
455 'metadata': {'difficulty': 'hard', 'category': 'Landmarks'},
456 'expected_output': {'answer': 'Eiffel Tower', 'confidence': 0.9},
457 'evaluators': ['EqualsExpected'],
458 },
459 ],
460 'evaluators': [],
461 }
462 )
463 )
464 ]
465 )
466 elif response := text_responses.get(m.content):
467 if isinstance(response, str):
468 return ModelResponse(parts=[TextPart(response)])
469 else:
470 return ModelResponse(parts=[response])
472 elif isinstance(m, ToolReturnPart) and m.tool_name == 'roulette_wheel':
473 win = m.content == 'winner'
474 return ModelResponse(
475 parts=[ToolCallPart(tool_name='final_result', args={'response': win}, tool_call_id='pyd_ai_tool_call_id')]
476 )
477 elif isinstance(m, ToolReturnPart) and m.tool_name == 'roll_die':
478 return ModelResponse(
479 parts=[ToolCallPart(tool_name='get_player_name', args={}, tool_call_id='pyd_ai_tool_call_id')]
480 )
481 elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_player_name':
482 if 'Anne' in m.content:
483 return ModelResponse(parts=[TextPart("Congratulations Anne, you guessed correctly! You're a winner!")])
484 elif 'Yashar' in m.content:
485 return ModelResponse(parts=[TextPart('Tough luck, Yashar, you rolled a 4. Better luck next time.')])
486 if (
487 isinstance(m, RetryPromptPart)
488 and isinstance(m.content, str)
489 and m.content.startswith("No user found with name 'Joh")
490 ):
491 return ModelResponse(
492 parts=[
493 ToolCallPart(
494 tool_name='get_user_by_name', args={'name': 'John Doe'}, tool_call_id='pyd_ai_tool_call_id'
495 )
496 ]
497 )
498 elif isinstance(m, RetryPromptPart) and m.tool_name == 'infinite_retry_tool':
499 return ModelResponse(
500 parts=[ToolCallPart(tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id')]
501 )
502 elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_user_by_name':
503 args: dict[str, Any] = {
504 'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!',
505 'user_id': 123,
506 }
507 return ModelResponse(
508 parts=[ToolCallPart(tool_name='final_result', args=args, tool_call_id='pyd_ai_tool_call_id')]
509 )
510 elif isinstance(m, RetryPromptPart) and m.tool_name == 'calc_volume':
511 return ModelResponse(
512 parts=[ToolCallPart(tool_name='calc_volume', args={'size': 6}, tool_call_id='pyd_ai_tool_call_id')]
513 )
514 elif isinstance(m, ToolReturnPart) and m.tool_name == 'customer_balance':
515 args = {
516 'support_advice': 'Hello John, your current account balance, including pending transactions, is $123.45.',
517 'block_card': False,
518 'risk': 1,
519 }
520 return ModelResponse(
521 parts=[ToolCallPart(tool_name='final_result', args=args, tool_call_id='pyd_ai_tool_call_id')]
522 )
523 elif isinstance(m, ToolReturnPart) and m.tool_name == 'joke_factory':
524 return ModelResponse(parts=[TextPart('Did you hear about the toothpaste scandal? They called it Colgate.')])
525 elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_jokes':
526 args = {'response': []}
527 return ModelResponse(
528 parts=[ToolCallPart(tool_name='final_result', args=args, tool_call_id='pyd_ai_tool_call_id')]
529 )
530 elif isinstance(m, ToolReturnPart) and m.tool_name == 'flight_search':
531 args = {'flight_number': m.content.flight_number} # type: ignore
532 return ModelResponse(
533 parts=[ToolCallPart(tool_name='final_result_FlightDetails', args=args, tool_call_id='pyd_ai_tool_call_id')]
534 )
535 else:
536 sys.stdout.write(str(debug.format(messages, info)))
537 raise RuntimeError(f'Unexpected message: {m}')
540async def stream_model_logic( # noqa C901
541 messages: list[ModelMessage], info: AgentInfo
542) -> AsyncIterator[str | DeltaToolCalls]: # pragma: no cover
543 async def stream_text_response(r: str) -> AsyncIterator[str]:
544 if isinstance(r, str):
545 words = r.split(' ')
546 chunk: list[str] = []
547 for word in words:
548 chunk.append(word)
549 if len(chunk) == 3:
550 yield ' '.join(chunk) + ' '
551 chunk.clear()
552 if chunk:
553 yield ' '.join(chunk)
555 async def stream_tool_call_response(r: ToolCallPart) -> AsyncIterator[DeltaToolCalls]:
556 json_text = r.args_as_json_str()
558 yield {1: DeltaToolCall(name=r.tool_name, tool_call_id=r.tool_call_id)}
559 for chunk_index in range(0, len(json_text), 15):
560 text_chunk = json_text[chunk_index : chunk_index + 15]
561 yield {1: DeltaToolCall(json_args=text_chunk)}
563 async def stream_part_response(r: str | ToolCallPart) -> AsyncIterator[str | DeltaToolCalls]:
564 if isinstance(r, str):
565 async for chunk in stream_text_response(r):
566 yield chunk
567 else:
568 async for chunk in stream_tool_call_response(r):
569 yield chunk
571 last_part = messages[-1].parts[-1]
572 if isinstance(last_part, UserPromptPart):
573 assert isinstance(last_part.content, str)
574 if response := text_responses.get(last_part.content):
575 async for chunk in stream_part_response(response):
576 yield chunk
577 return
578 elif isinstance(last_part, ToolReturnPart):
579 assert isinstance(last_part.content, str)
580 if response := tool_responses.get((last_part.tool_name, last_part.content)):
581 async for chunk in stream_part_response(response):
582 yield chunk
583 return
585 sys.stdout.write(str(debug.format(messages, info)))
586 raise RuntimeError(f'Unexpected message: {last_part}')
589def mock_infer_model(model: Model | KnownModelName) -> Model:
590 if model == 'test':
591 return TestModel()
593 if isinstance(model, str):
594 # Use the non-mocked model inference to ensure we get the same model name the user would
595 model = infer_model(model)
597 if isinstance(model, FallbackModel):
598 # When a fallback model is encountered, replace any OpenAIModel with a model that will raise a ModelHTTPError.
599 # Otherwise, do the usual inference.
600 def raise_http_error(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover
601 raise ModelHTTPError(401, 'Invalid API Key')
603 mock_fallback_models: list[Model] = []
604 for m in model.models:
605 try:
606 from pydantic_ai.models.openai import OpenAIModel
607 except ImportError:
608 OpenAIModel = type(None)
610 if isinstance(m, OpenAIModel):
611 # Raise an HTTP error for OpenAIModel
612 mock_fallback_models.append(FunctionModel(raise_http_error, model_name=m.model_name))
613 else:
614 mock_fallback_models.append(mock_infer_model(m))
615 return FallbackModel(*mock_fallback_models)
616 if isinstance(model, (FunctionModel, TestModel)):
617 return model
618 else:
619 model_name = model if isinstance(model, str) else model.model_name
620 return FunctionModel(model_logic, stream_function=stream_model_logic, model_name=model_name)
623def mock_group_by_temporal(aiter: Any, soft_max_interval: float | None) -> Any:
624 """Mock group_by_temporal to avoid debouncing, since the iterators above have no delay."""
625 return group_by_temporal(aiter, None)
628@dataclass
629class MockCredentials:
630 project_id = 'foobar'