Coverage for pydantic_ai_slim/pydantic_ai/models/test.py: 99.21%
258 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import re
4import string
5from collections.abc import AsyncIterator, Iterable
6from contextlib import asynccontextmanager
7from dataclasses import InitVar, dataclass, field
8from datetime import date, datetime, timedelta
9from typing import Any, Literal
11import pydantic_core
13from .. import _utils
14from ..messages import (
15 ModelMessage,
16 ModelRequest,
17 ModelResponse,
18 ModelResponsePart,
19 ModelResponseStreamEvent,
20 RetryPromptPart,
21 TextPart,
22 ToolCallPart,
23 ToolReturnPart,
24)
25from ..result import Usage
26from ..settings import ModelSettings
27from ..tools import ToolDefinition
28from . import (
29 Model,
30 ModelRequestParameters,
31 StreamedResponse,
32)
33from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
36@dataclass
37class _TextResult:
38 """A private wrapper class to tag a result that came from the custom_result_text field."""
40 value: str | None
43@dataclass
44class _FunctionToolResult:
45 """A wrapper class to tag a result that came from the custom_result_args field."""
47 value: Any | None
50@dataclass
51class TestModel(Model):
52 """A model specifically for testing purposes.
54 This will (by default) call all tools in the agent, then return a tool response if possible,
55 otherwise a plain response.
57 How useful this model is will vary significantly.
59 Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those
60 of the base class.
61 """
63 # NOTE: Avoid test discovery by pytest.
64 __test__ = False
66 call_tools: list[str] | Literal['all'] = 'all'
67 """List of tools to call. If `'all'`, all tools will be called."""
68 custom_result_text: str | None = None
69 """If set, this text is returned as the final result."""
70 custom_result_args: Any | None = None
71 """If set, these args will be passed to the result tool."""
72 seed: int = 0
73 """Seed for generating random data."""
74 last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
75 """The last ModelRequestParameters passed to the model in a request.
77 The ModelRequestParameters contains information about the function and result tools available during request handling.
79 This is set when a request is made, so will reflect the function tools from the last step of the last run.
80 """
81 _model_name: str = field(default='test', repr=False)
82 _system: str = field(default='test', repr=False)
84 async def request(
85 self,
86 messages: list[ModelMessage],
87 model_settings: ModelSettings | None,
88 model_request_parameters: ModelRequestParameters,
89 ) -> tuple[ModelResponse, Usage]:
90 self.last_model_request_parameters = model_request_parameters
92 model_response = self._request(messages, model_settings, model_request_parameters)
93 usage = _estimate_usage([*messages, model_response])
94 return model_response, usage
96 @asynccontextmanager
97 async def request_stream(
98 self,
99 messages: list[ModelMessage],
100 model_settings: ModelSettings | None,
101 model_request_parameters: ModelRequestParameters,
102 ) -> AsyncIterator[StreamedResponse]:
103 self.last_model_request_parameters = model_request_parameters
105 model_response = self._request(messages, model_settings, model_request_parameters)
106 yield TestStreamedResponse(
107 _model_name=self._model_name, _structured_response=model_response, _messages=messages
108 )
110 @property
111 def model_name(self) -> str:
112 """The model name."""
113 return self._model_name
115 @property
116 def system(self) -> str:
117 """The system / model provider."""
118 return self._system
120 def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
121 return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
123 def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]:
124 if self.call_tools == 'all':
125 return [(r.name, r) for r in model_request_parameters.function_tools]
126 else:
127 function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools}
128 tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
129 return [(r.name, r) for r in tools_to_call]
131 def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
132 if self.custom_result_text is not None:
133 assert model_request_parameters.allow_text_result, (
134 'Plain response not allowed, but `custom_result_text` is set.'
135 )
136 assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
137 return _TextResult(self.custom_result_text)
138 elif self.custom_result_args is not None:
139 assert model_request_parameters.result_tools is not None, (
140 'No result tools provided, but `custom_result_args` is set.'
141 )
142 result_tool = model_request_parameters.result_tools[0]
144 if k := result_tool.outer_typed_dict_key:
145 return _FunctionToolResult({k: self.custom_result_args})
146 else:
147 return _FunctionToolResult(self.custom_result_args)
148 elif model_request_parameters.allow_text_result:
149 return _TextResult(None)
150 elif model_request_parameters.result_tools: 150 ↛ 153line 150 didn't jump to line 153 because the condition on line 150 was always true
151 return _FunctionToolResult(None)
152 else:
153 return _TextResult(None)
155 def _request(
156 self,
157 messages: list[ModelMessage],
158 model_settings: ModelSettings | None,
159 model_request_parameters: ModelRequestParameters,
160 ) -> ModelResponse:
161 tool_calls = self._get_tool_calls(model_request_parameters)
162 result = self._get_result(model_request_parameters)
163 result_tools = model_request_parameters.result_tools
165 # if there are tools, the first thing we want to do is call all of them
166 if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
167 return ModelResponse(
168 parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
169 model_name=self._model_name,
170 )
172 if messages: 172 ↛ 200line 172 didn't jump to line 200 because the condition on line 172 was always true
173 last_message = messages[-1]
174 assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
176 # check if there are any retry prompts, if so retry them
177 new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
178 if new_retry_names:
179 # Handle retries for both function tools and result tools
180 # Check function tools first
181 retry_parts: list[ModelResponsePart] = [
182 ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
183 ]
184 # Check result tools
185 if result_tools:
186 retry_parts.extend(
187 [
188 ToolCallPart(
189 tool.name,
190 result.value
191 if isinstance(result, _FunctionToolResult) and result.value is not None
192 else self.gen_tool_args(tool),
193 )
194 for tool in result_tools
195 if tool.name in new_retry_names
196 ]
197 )
198 return ModelResponse(parts=retry_parts, model_name=self._model_name)
200 if isinstance(result, _TextResult):
201 if (response_text := result.value) is None:
202 # build up details of tool responses
203 output: dict[str, Any] = {}
204 for message in messages:
205 if isinstance(message, ModelRequest):
206 for part in message.parts:
207 if isinstance(part, ToolReturnPart):
208 output[part.tool_name] = part.content
209 if output:
210 return ModelResponse(
211 parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name
212 )
213 else:
214 return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name)
215 else:
216 return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
217 else:
218 assert result_tools, 'No result tools provided'
219 custom_result_args = result.value
220 result_tool = result_tools[self.seed % len(result_tools)]
221 if custom_result_args is not None:
222 return ModelResponse(
223 parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name
224 )
225 else:
226 response_args = self.gen_tool_args(result_tool)
227 return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name)
230@dataclass
231class TestStreamedResponse(StreamedResponse):
232 """A structured response that streams test data."""
234 _model_name: str
235 _structured_response: ModelResponse
236 _messages: InitVar[Iterable[ModelMessage]]
237 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
239 def __post_init__(self, _messages: Iterable[ModelMessage]):
240 self._usage = _estimate_usage(_messages)
242 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
243 for i, part in enumerate(self._structured_response.parts):
244 if isinstance(part, TextPart):
245 text = part.content
246 *words, last_word = text.split(' ')
247 words = [f'{word} ' for word in words]
248 words.append(last_word)
249 if len(words) == 1 and len(text) > 2:
250 mid = len(text) // 2
251 words = [text[:mid], text[mid:]]
252 self._usage += _get_string_usage('')
253 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
254 for word in words:
255 self._usage += _get_string_usage(word)
256 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
257 else:
258 yield self._parts_manager.handle_tool_call_part(
259 vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
260 )
262 @property
263 def model_name(self) -> str:
264 """Get the model name of the response."""
265 return self._model_name
267 @property
268 def timestamp(self) -> datetime:
269 """Get the timestamp of the response."""
270 return self._timestamp
273_chars = string.ascii_letters + string.digits + string.punctuation
276class _JsonSchemaTestData:
277 """Generate data that matches a JSON schema.
279 This tries to generate the minimal viable data for the schema.
280 """
282 def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0):
283 self.schema = schema
284 self.defs = schema.get('$defs', {})
285 self.seed = seed
287 def generate(self) -> Any:
288 """Generate data for the JSON schema."""
289 return self._gen_any(self.schema)
291 def _gen_any(self, schema: dict[str, Any]) -> Any:
292 """Generate data for any JSON Schema."""
293 if const := schema.get('const'):
294 return const
295 elif enum := schema.get('enum'):
296 return enum[self.seed % len(enum)]
297 elif examples := schema.get('examples'):
298 return examples[self.seed % len(examples)]
299 elif ref := schema.get('$ref'):
300 key = re.sub(r'^#/\$defs/', '', ref)
301 js_def = self.defs[key]
302 return self._gen_any(js_def)
303 elif any_of := schema.get('anyOf'):
304 return self._gen_any(any_of[self.seed % len(any_of)])
306 type_ = schema.get('type')
307 if type_ is None:
308 # if there's no type or ref, we can't generate anything
309 return self._char()
310 elif type_ == 'object':
311 return self._object_gen(schema)
312 elif type_ == 'string':
313 return self._str_gen(schema)
314 elif type_ == 'integer':
315 return self._int_gen(schema)
316 elif type_ == 'number':
317 return float(self._int_gen(schema))
318 elif type_ == 'boolean':
319 return self._bool_gen()
320 elif type_ == 'array':
321 return self._array_gen(schema)
322 elif type_ == 'null':
323 return None
324 else:
325 raise NotImplementedError(f'Unknown type: {type_}, please submit a PR to extend JsonSchemaTestData!')
327 def _object_gen(self, schema: dict[str, Any]) -> dict[str, Any]:
328 """Generate data for a JSON Schema object."""
329 required = set(schema.get('required', []))
331 data: dict[str, Any] = {}
332 if properties := schema.get('properties'):
333 for key, value in properties.items():
334 if key in required:
335 data[key] = self._gen_any(value)
337 if addition_props := schema.get('additionalProperties'):
338 add_prop_key = 'additionalProperty'
339 while add_prop_key in data:
340 add_prop_key += '_'
341 if addition_props is True:
342 data[add_prop_key] = self._char()
343 else:
344 data[add_prop_key] = self._gen_any(addition_props)
346 return data
348 def _str_gen(self, schema: dict[str, Any]) -> str:
349 """Generate a string from a JSON Schema string."""
350 min_len = schema.get('minLength')
351 if min_len is not None:
352 return self._char() * min_len
354 if schema.get('maxLength') == 0:
355 return ''
357 if fmt := schema.get('format'):
358 if fmt == 'date':
359 return (date(2024, 1, 1) + timedelta(days=self.seed)).isoformat()
361 return self._char()
363 def _int_gen(self, schema: dict[str, Any]) -> int:
364 """Generate an integer from a JSON Schema integer."""
365 maximum = schema.get('maximum')
366 if maximum is None:
367 exc_max = schema.get('exclusiveMaximum')
368 if exc_max is not None:
369 maximum = exc_max - 1
371 minimum = schema.get('minimum')
372 if minimum is None:
373 exc_min = schema.get('exclusiveMinimum')
374 if exc_min is not None:
375 minimum = exc_min + 1
377 if minimum is not None and maximum is not None:
378 return minimum + self.seed % (maximum - minimum)
379 elif minimum is not None:
380 return minimum + self.seed
381 elif maximum is not None:
382 return maximum - self.seed
383 else:
384 return self.seed
386 def _bool_gen(self) -> bool:
387 """Generate a boolean from a JSON Schema boolean."""
388 return bool(self.seed % 2)
390 def _array_gen(self, schema: dict[str, Any]) -> list[Any]:
391 """Generate an array from a JSON Schema array."""
392 data: list[Any] = []
393 unique_items = schema.get('uniqueItems')
394 if prefix_items := schema.get('prefixItems'):
395 for item in prefix_items:
396 data.append(self._gen_any(item))
397 if unique_items:
398 self.seed += 1
400 items_schema = schema.get('items', {})
401 min_items = schema.get('minItems', 0)
402 if min_items > len(data):
403 for _ in range(min_items - len(data)):
404 data.append(self._gen_any(items_schema))
405 if unique_items:
406 self.seed += 1
407 elif items_schema:
408 # if there is an `items` schema, add an item unless it would break `maxItems` rule
409 max_items = schema.get('maxItems')
410 if max_items is None or max_items > len(data):
411 data.append(self._gen_any(items_schema))
412 if unique_items:
413 self.seed += 1
415 return data
417 def _char(self) -> str:
418 """Generate a character on the same principle as Excel columns, e.g. a-z, aa-az..."""
419 chars = len(_chars)
420 s = ''
421 rem = self.seed // chars
422 while rem > 0:
423 s += _chars[(rem - 1) % chars]
424 rem //= chars
425 s += _chars[self.seed % chars]
426 return s
429def _get_string_usage(text: str) -> Usage:
430 response_tokens = _estimate_string_tokens(text)
431 return Usage(response_tokens=response_tokens, total_tokens=response_tokens)