Coverage for pydantic_ai_slim/pydantic_ai/models/test.py: 99.21%
257 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import re
4import string
5from collections.abc import AsyncIterator, Iterable
6from contextlib import asynccontextmanager
7from dataclasses import InitVar, dataclass, field
8from datetime import date, datetime, timedelta
9from typing import Any, Literal
11import pydantic_core
13from .. import _utils
14from ..messages import (
15 ModelMessage,
16 ModelRequest,
17 ModelResponse,
18 ModelResponsePart,
19 ModelResponseStreamEvent,
20 RetryPromptPart,
21 TextPart,
22 ToolCallPart,
23 ToolReturnPart,
24)
25from ..result import Usage
26from ..settings import ModelSettings
27from ..tools import ToolDefinition
28from . import (
29 AgentModel,
30 Model,
31 StreamedResponse,
32)
33from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
36@dataclass
37class _TextResult:
38 """A private wrapper class to tag a result that came from the custom_result_text field."""
40 value: str | None
43@dataclass
44class _FunctionToolResult:
45 """A wrapper class to tag a result that came from the custom_result_args field."""
47 value: Any | None
50@dataclass
51class TestModel(Model):
52 """A model specifically for testing purposes.
54 This will (by default) call all tools in the agent, then return a tool response if possible,
55 otherwise a plain response.
57 How useful this model is will vary significantly.
59 Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those
60 of the base class.
61 """
63 # NOTE: Avoid test discovery by pytest.
64 __test__ = False
66 call_tools: list[str] | Literal['all'] = 'all'
67 """List of tools to call. If `'all'`, all tools will be called."""
68 custom_result_text: str | None = None
69 """If set, this text is returned as the final result."""
70 custom_result_args: Any | None = None
71 """If set, these args will be passed to the result tool."""
72 seed: int = 0
73 """Seed for generating random data."""
74 agent_model_function_tools: list[ToolDefinition] | None = field(default=None, init=False)
75 """Definition of function tools passed to the model.
77 This is set when the model is called, so will reflect the function tools from the last step of the last run.
78 """
79 agent_model_allow_text_result: bool | None = field(default=None, init=False)
80 """Whether plain text responses from the model are allowed.
82 This is set when the model is called, so will reflect the value from the last step of the last run.
83 """
84 agent_model_result_tools: list[ToolDefinition] | None = field(default=None, init=False)
85 """Definition of result tools passed to the model.
87 This is set when the model is called, so will reflect the result tools from the last step of the last run.
88 """
90 async def agent_model(
91 self,
92 *,
93 function_tools: list[ToolDefinition],
94 allow_text_result: bool,
95 result_tools: list[ToolDefinition],
96 ) -> AgentModel:
97 self.agent_model_function_tools = function_tools
98 self.agent_model_allow_text_result = allow_text_result
99 self.agent_model_result_tools = result_tools
101 if self.call_tools == 'all':
102 tool_calls = [(r.name, r) for r in function_tools]
103 else:
104 function_tools_lookup = {t.name: t for t in function_tools}
105 tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
106 tool_calls = [(r.name, r) for r in tools_to_call]
108 if self.custom_result_text is not None:
109 assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
110 assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
111 result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text)
112 elif self.custom_result_args is not None:
113 assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
114 result_tool = result_tools[0]
116 if k := result_tool.outer_typed_dict_key:
117 result = _FunctionToolResult({k: self.custom_result_args})
118 else:
119 result = _FunctionToolResult(self.custom_result_args)
120 elif allow_text_result:
121 result = _TextResult(None)
122 elif result_tools: 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true
123 result = _FunctionToolResult(None)
124 else:
125 result = _TextResult(None)
127 return TestAgentModel(tool_calls, result, result_tools, self.seed)
129 def name(self) -> str:
130 return 'test-model'
133@dataclass
134class TestAgentModel(AgentModel):
135 """Implementation of `AgentModel` for testing purposes."""
137 # NOTE: Avoid test discovery by pytest.
138 __test__ = False
140 tool_calls: list[tuple[str, ToolDefinition]]
141 # left means the text is plain text; right means it's a function call
142 result: _TextResult | _FunctionToolResult
143 result_tools: list[ToolDefinition]
144 seed: int
145 model_name: str = 'test'
147 async def request(
148 self, messages: list[ModelMessage], model_settings: ModelSettings | None
149 ) -> tuple[ModelResponse, Usage]:
150 model_response = self._request(messages, model_settings)
151 usage = _estimate_usage([*messages, model_response])
152 return model_response, usage
154 @asynccontextmanager
155 async def request_stream(
156 self, messages: list[ModelMessage], model_settings: ModelSettings | None
157 ) -> AsyncIterator[StreamedResponse]:
158 model_response = self._request(messages, model_settings)
159 yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
161 def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
162 return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
164 def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
165 # if there are tools, the first thing we want to do is call all of them
166 if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
167 return ModelResponse(
168 parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
169 model_name=self.model_name,
170 )
172 if messages: 172 ↛ 202line 172 didn't jump to line 202 because the condition on line 172 was always true
173 last_message = messages[-1]
174 assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
176 # check if there are any retry prompts, if so retry them
177 new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
178 if new_retry_names:
179 # Handle retries for both function tools and result tools
180 # Check function tools first
181 retry_parts: list[ModelResponsePart] = [
182 ToolCallPart(name, self.gen_tool_args(args))
183 for name, args in self.tool_calls
184 if name in new_retry_names
185 ]
186 # Check result tools
187 if self.result_tools:
188 retry_parts.extend(
189 [
190 ToolCallPart(
191 tool.name,
192 self.result.value
193 if isinstance(self.result, _FunctionToolResult) and self.result.value is not None
194 else self.gen_tool_args(tool),
195 )
196 for tool in self.result_tools
197 if tool.name in new_retry_names
198 ]
199 )
200 return ModelResponse(parts=retry_parts, model_name=self.model_name)
202 if isinstance(self.result, _TextResult):
203 if (response_text := self.result.value) is None:
204 # build up details of tool responses
205 output: dict[str, Any] = {}
206 for message in messages:
207 if isinstance(message, ModelRequest):
208 for part in message.parts:
209 if isinstance(part, ToolReturnPart):
210 output[part.tool_name] = part.content
211 if output:
212 return ModelResponse(
213 parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
214 )
215 else:
216 return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
217 else:
218 return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name)
219 else:
220 assert self.result_tools, 'No result tools provided'
221 custom_result_args = self.result.value
222 result_tool = self.result_tools[self.seed % len(self.result_tools)]
223 if custom_result_args is not None:
224 return ModelResponse(
225 parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name
226 )
227 else:
228 response_args = self.gen_tool_args(result_tool)
229 return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name)
232@dataclass
233class TestStreamedResponse(StreamedResponse):
234 """A structured response that streams test data."""
236 _structured_response: ModelResponse
237 _messages: InitVar[Iterable[ModelMessage]]
239 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
241 def __post_init__(self, _messages: Iterable[ModelMessage]):
242 self._usage = _estimate_usage(_messages)
244 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
245 for i, part in enumerate(self._structured_response.parts):
246 if isinstance(part, TextPart):
247 text = part.content
248 *words, last_word = text.split(' ')
249 words = [f'{word} ' for word in words]
250 words.append(last_word)
251 if len(words) == 1 and len(text) > 2:
252 mid = len(text) // 2
253 words = [text[:mid], text[mid:]]
254 self._usage += _get_string_usage('')
255 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
256 for word in words:
257 self._usage += _get_string_usage(word)
258 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
259 else:
260 yield self._parts_manager.handle_tool_call_part(
261 vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
262 )
264 def timestamp(self) -> datetime:
265 return self._timestamp
268_chars = string.ascii_letters + string.digits + string.punctuation
271class _JsonSchemaTestData:
272 """Generate data that matches a JSON schema.
274 This tries to generate the minimal viable data for the schema.
275 """
277 def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0):
278 self.schema = schema
279 self.defs = schema.get('$defs', {})
280 self.seed = seed
282 def generate(self) -> Any:
283 """Generate data for the JSON schema."""
284 return self._gen_any(self.schema)
286 def _gen_any(self, schema: dict[str, Any]) -> Any:
287 """Generate data for any JSON Schema."""
288 if const := schema.get('const'):
289 return const
290 elif enum := schema.get('enum'):
291 return enum[self.seed % len(enum)]
292 elif examples := schema.get('examples'):
293 return examples[self.seed % len(examples)]
294 elif ref := schema.get('$ref'):
295 key = re.sub(r'^#/\$defs/', '', ref)
296 js_def = self.defs[key]
297 return self._gen_any(js_def)
298 elif any_of := schema.get('anyOf'):
299 return self._gen_any(any_of[self.seed % len(any_of)])
301 type_ = schema.get('type')
302 if type_ is None:
303 # if there's no type or ref, we can't generate anything
304 return self._char()
305 elif type_ == 'object':
306 return self._object_gen(schema)
307 elif type_ == 'string':
308 return self._str_gen(schema)
309 elif type_ == 'integer':
310 return self._int_gen(schema)
311 elif type_ == 'number':
312 return float(self._int_gen(schema))
313 elif type_ == 'boolean':
314 return self._bool_gen()
315 elif type_ == 'array':
316 return self._array_gen(schema)
317 elif type_ == 'null':
318 return None
319 else:
320 raise NotImplementedError(f'Unknown type: {type_}, please submit a PR to extend JsonSchemaTestData!')
322 def _object_gen(self, schema: dict[str, Any]) -> dict[str, Any]:
323 """Generate data for a JSON Schema object."""
324 required = set(schema.get('required', []))
326 data: dict[str, Any] = {}
327 if properties := schema.get('properties'):
328 for key, value in properties.items():
329 if key in required:
330 data[key] = self._gen_any(value)
332 if addition_props := schema.get('additionalProperties'):
333 add_prop_key = 'additionalProperty'
334 while add_prop_key in data:
335 add_prop_key += '_'
336 if addition_props is True:
337 data[add_prop_key] = self._char()
338 else:
339 data[add_prop_key] = self._gen_any(addition_props)
341 return data
343 def _str_gen(self, schema: dict[str, Any]) -> str:
344 """Generate a string from a JSON Schema string."""
345 min_len = schema.get('minLength')
346 if min_len is not None:
347 return self._char() * min_len
349 if schema.get('maxLength') == 0:
350 return ''
352 if fmt := schema.get('format'):
353 if fmt == 'date':
354 return (date(2024, 1, 1) + timedelta(days=self.seed)).isoformat()
356 return self._char()
358 def _int_gen(self, schema: dict[str, Any]) -> int:
359 """Generate an integer from a JSON Schema integer."""
360 maximum = schema.get('maximum')
361 if maximum is None:
362 exc_max = schema.get('exclusiveMaximum')
363 if exc_max is not None:
364 maximum = exc_max - 1
366 minimum = schema.get('minimum')
367 if minimum is None:
368 exc_min = schema.get('exclusiveMinimum')
369 if exc_min is not None:
370 minimum = exc_min + 1
372 if minimum is not None and maximum is not None:
373 return minimum + self.seed % (maximum - minimum)
374 elif minimum is not None:
375 return minimum + self.seed
376 elif maximum is not None:
377 return maximum - self.seed
378 else:
379 return self.seed
381 def _bool_gen(self) -> bool:
382 """Generate a boolean from a JSON Schema boolean."""
383 return bool(self.seed % 2)
385 def _array_gen(self, schema: dict[str, Any]) -> list[Any]:
386 """Generate an array from a JSON Schema array."""
387 data: list[Any] = []
388 unique_items = schema.get('uniqueItems')
389 if prefix_items := schema.get('prefixItems'):
390 for item in prefix_items:
391 data.append(self._gen_any(item))
392 if unique_items:
393 self.seed += 1
395 items_schema = schema.get('items', {})
396 min_items = schema.get('minItems', 0)
397 if min_items > len(data):
398 for _ in range(min_items - len(data)):
399 data.append(self._gen_any(items_schema))
400 if unique_items:
401 self.seed += 1
402 elif items_schema:
403 # if there is an `items` schema, add an item unless it would break `maxItems` rule
404 max_items = schema.get('maxItems')
405 if max_items is None or max_items > len(data):
406 data.append(self._gen_any(items_schema))
407 if unique_items:
408 self.seed += 1
410 return data
412 def _char(self) -> str:
413 """Generate a character on the same principle as Excel columns, e.g. a-z, aa-az..."""
414 chars = len(_chars)
415 s = ''
416 rem = self.seed // chars
417 while rem > 0:
418 s += _chars[(rem - 1) % chars]
419 rem //= chars
420 s += _chars[self.seed % chars]
421 return s
424def _get_string_usage(text: str) -> Usage:
425 response_tokens = _estimate_string_tokens(text)
426 return Usage(response_tokens=response_tokens, total_tokens=response_tokens)