Coverage for pydantic_ai_slim/pydantic_ai/models/test.py: 99.21%

258 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import re 

4import string 

5from collections.abc import AsyncIterator, Iterable 

6from contextlib import asynccontextmanager 

7from dataclasses import InitVar, dataclass, field 

8from datetime import date, datetime, timedelta 

9from typing import Any, Literal 

10 

11import pydantic_core 

12 

13from .. import _utils 

14from ..messages import ( 

15 ModelMessage, 

16 ModelRequest, 

17 ModelResponse, 

18 ModelResponsePart, 

19 ModelResponseStreamEvent, 

20 RetryPromptPart, 

21 TextPart, 

22 ToolCallPart, 

23 ToolReturnPart, 

24) 

25from ..result import Usage 

26from ..settings import ModelSettings 

27from ..tools import ToolDefinition 

28from . import ( 

29 Model, 

30 ModelRequestParameters, 

31 StreamedResponse, 

32) 

33from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage] 

34 

35 

36@dataclass 

37class _TextResult: 

38 """A private wrapper class to tag a result that came from the custom_result_text field.""" 

39 

40 value: str | None 

41 

42 

43@dataclass 

44class _FunctionToolResult: 

45 """A wrapper class to tag a result that came from the custom_result_args field.""" 

46 

47 value: Any | None 

48 

49 

50@dataclass 

51class TestModel(Model): 

52 """A model specifically for testing purposes. 

53 

54 This will (by default) call all tools in the agent, then return a tool response if possible, 

55 otherwise a plain response. 

56 

57 How useful this model is will vary significantly. 

58 

59 Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those 

60 of the base class. 

61 """ 

62 

63 # NOTE: Avoid test discovery by pytest. 

64 __test__ = False 

65 

66 call_tools: list[str] | Literal['all'] = 'all' 

67 """List of tools to call. If `'all'`, all tools will be called.""" 

68 custom_result_text: str | None = None 

69 """If set, this text is returned as the final result.""" 

70 custom_result_args: Any | None = None 

71 """If set, these args will be passed to the result tool.""" 

72 seed: int = 0 

73 """Seed for generating random data.""" 

74 last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False) 

75 """The last ModelRequestParameters passed to the model in a request. 

76 

77 The ModelRequestParameters contains information about the function and result tools available during request handling. 

78 

79 This is set when a request is made, so will reflect the function tools from the last step of the last run. 

80 """ 

81 _model_name: str = field(default='test', repr=False) 

82 _system: str = field(default='test', repr=False) 

83 

84 async def request( 

85 self, 

86 messages: list[ModelMessage], 

87 model_settings: ModelSettings | None, 

88 model_request_parameters: ModelRequestParameters, 

89 ) -> tuple[ModelResponse, Usage]: 

90 self.last_model_request_parameters = model_request_parameters 

91 

92 model_response = self._request(messages, model_settings, model_request_parameters) 

93 usage = _estimate_usage([*messages, model_response]) 

94 return model_response, usage 

95 

96 @asynccontextmanager 

97 async def request_stream( 

98 self, 

99 messages: list[ModelMessage], 

100 model_settings: ModelSettings | None, 

101 model_request_parameters: ModelRequestParameters, 

102 ) -> AsyncIterator[StreamedResponse]: 

103 self.last_model_request_parameters = model_request_parameters 

104 

105 model_response = self._request(messages, model_settings, model_request_parameters) 

106 yield TestStreamedResponse( 

107 _model_name=self._model_name, _structured_response=model_response, _messages=messages 

108 ) 

109 

110 @property 

111 def model_name(self) -> str: 

112 """The model name.""" 

113 return self._model_name 

114 

115 @property 

116 def system(self) -> str: 

117 """The system / model provider.""" 

118 return self._system 

119 

120 def gen_tool_args(self, tool_def: ToolDefinition) -> Any: 

121 return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() 

122 

123 def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]: 

124 if self.call_tools == 'all': 

125 return [(r.name, r) for r in model_request_parameters.function_tools] 

126 else: 

127 function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools} 

128 tools_to_call = (function_tools_lookup[name] for name in self.call_tools) 

129 return [(r.name, r) for r in tools_to_call] 

130 

131 def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult: 

132 if self.custom_result_text is not None: 

133 assert model_request_parameters.allow_text_result, ( 

134 'Plain response not allowed, but `custom_result_text` is set.' 

135 ) 

136 assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.' 

137 return _TextResult(self.custom_result_text) 

138 elif self.custom_result_args is not None: 

139 assert model_request_parameters.result_tools is not None, ( 

140 'No result tools provided, but `custom_result_args` is set.' 

141 ) 

142 result_tool = model_request_parameters.result_tools[0] 

143 

144 if k := result_tool.outer_typed_dict_key: 

145 return _FunctionToolResult({k: self.custom_result_args}) 

146 else: 

147 return _FunctionToolResult(self.custom_result_args) 

148 elif model_request_parameters.allow_text_result: 

149 return _TextResult(None) 

150 elif model_request_parameters.result_tools: 150 ↛ 153line 150 didn't jump to line 153 because the condition on line 150 was always true

151 return _FunctionToolResult(None) 

152 else: 

153 return _TextResult(None) 

154 

155 def _request( 

156 self, 

157 messages: list[ModelMessage], 

158 model_settings: ModelSettings | None, 

159 model_request_parameters: ModelRequestParameters, 

160 ) -> ModelResponse: 

161 tool_calls = self._get_tool_calls(model_request_parameters) 

162 result = self._get_result(model_request_parameters) 

163 result_tools = model_request_parameters.result_tools 

164 

165 # if there are tools, the first thing we want to do is call all of them 

166 if tool_calls and not any(isinstance(m, ModelResponse) for m in messages): 

167 return ModelResponse( 

168 parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls], 

169 model_name=self._model_name, 

170 ) 

171 

172 if messages: 172 ↛ 200line 172 didn't jump to line 200 because the condition on line 172 was always true

173 last_message = messages[-1] 

174 assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.' 

175 

176 # check if there are any retry prompts, if so retry them 

177 new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)} 

178 if new_retry_names: 

179 # Handle retries for both function tools and result tools 

180 # Check function tools first 

181 retry_parts: list[ModelResponsePart] = [ 

182 ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names 

183 ] 

184 # Check result tools 

185 if result_tools: 

186 retry_parts.extend( 

187 [ 

188 ToolCallPart( 

189 tool.name, 

190 result.value 

191 if isinstance(result, _FunctionToolResult) and result.value is not None 

192 else self.gen_tool_args(tool), 

193 ) 

194 for tool in result_tools 

195 if tool.name in new_retry_names 

196 ] 

197 ) 

198 return ModelResponse(parts=retry_parts, model_name=self._model_name) 

199 

200 if isinstance(result, _TextResult): 

201 if (response_text := result.value) is None: 

202 # build up details of tool responses 

203 output: dict[str, Any] = {} 

204 for message in messages: 

205 if isinstance(message, ModelRequest): 

206 for part in message.parts: 

207 if isinstance(part, ToolReturnPart): 

208 output[part.tool_name] = part.content 

209 if output: 

210 return ModelResponse( 

211 parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name 

212 ) 

213 else: 

214 return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name) 

215 else: 

216 return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name) 

217 else: 

218 assert result_tools, 'No result tools provided' 

219 custom_result_args = result.value 

220 result_tool = result_tools[self.seed % len(result_tools)] 

221 if custom_result_args is not None: 

222 return ModelResponse( 

223 parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name 

224 ) 

225 else: 

226 response_args = self.gen_tool_args(result_tool) 

227 return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name) 

228 

229 

230@dataclass 

231class TestStreamedResponse(StreamedResponse): 

232 """A structured response that streams test data.""" 

233 

234 _model_name: str 

235 _structured_response: ModelResponse 

236 _messages: InitVar[Iterable[ModelMessage]] 

237 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) 

238 

239 def __post_init__(self, _messages: Iterable[ModelMessage]): 

240 self._usage = _estimate_usage(_messages) 

241 

242 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

243 for i, part in enumerate(self._structured_response.parts): 

244 if isinstance(part, TextPart): 

245 text = part.content 

246 *words, last_word = text.split(' ') 

247 words = [f'{word} ' for word in words] 

248 words.append(last_word) 

249 if len(words) == 1 and len(text) > 2: 

250 mid = len(text) // 2 

251 words = [text[:mid], text[mid:]] 

252 self._usage += _get_string_usage('') 

253 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='') 

254 for word in words: 

255 self._usage += _get_string_usage(word) 

256 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) 

257 else: 

258 yield self._parts_manager.handle_tool_call_part( 

259 vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id 

260 ) 

261 

262 @property 

263 def model_name(self) -> str: 

264 """Get the model name of the response.""" 

265 return self._model_name 

266 

267 @property 

268 def timestamp(self) -> datetime: 

269 """Get the timestamp of the response.""" 

270 return self._timestamp 

271 

272 

273_chars = string.ascii_letters + string.digits + string.punctuation 

274 

275 

276class _JsonSchemaTestData: 

277 """Generate data that matches a JSON schema. 

278 

279 This tries to generate the minimal viable data for the schema. 

280 """ 

281 

282 def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0): 

283 self.schema = schema 

284 self.defs = schema.get('$defs', {}) 

285 self.seed = seed 

286 

287 def generate(self) -> Any: 

288 """Generate data for the JSON schema.""" 

289 return self._gen_any(self.schema) 

290 

291 def _gen_any(self, schema: dict[str, Any]) -> Any: 

292 """Generate data for any JSON Schema.""" 

293 if const := schema.get('const'): 

294 return const 

295 elif enum := schema.get('enum'): 

296 return enum[self.seed % len(enum)] 

297 elif examples := schema.get('examples'): 

298 return examples[self.seed % len(examples)] 

299 elif ref := schema.get('$ref'): 

300 key = re.sub(r'^#/\$defs/', '', ref) 

301 js_def = self.defs[key] 

302 return self._gen_any(js_def) 

303 elif any_of := schema.get('anyOf'): 

304 return self._gen_any(any_of[self.seed % len(any_of)]) 

305 

306 type_ = schema.get('type') 

307 if type_ is None: 

308 # if there's no type or ref, we can't generate anything 

309 return self._char() 

310 elif type_ == 'object': 

311 return self._object_gen(schema) 

312 elif type_ == 'string': 

313 return self._str_gen(schema) 

314 elif type_ == 'integer': 

315 return self._int_gen(schema) 

316 elif type_ == 'number': 

317 return float(self._int_gen(schema)) 

318 elif type_ == 'boolean': 

319 return self._bool_gen() 

320 elif type_ == 'array': 

321 return self._array_gen(schema) 

322 elif type_ == 'null': 

323 return None 

324 else: 

325 raise NotImplementedError(f'Unknown type: {type_}, please submit a PR to extend JsonSchemaTestData!') 

326 

327 def _object_gen(self, schema: dict[str, Any]) -> dict[str, Any]: 

328 """Generate data for a JSON Schema object.""" 

329 required = set(schema.get('required', [])) 

330 

331 data: dict[str, Any] = {} 

332 if properties := schema.get('properties'): 

333 for key, value in properties.items(): 

334 if key in required: 

335 data[key] = self._gen_any(value) 

336 

337 if addition_props := schema.get('additionalProperties'): 

338 add_prop_key = 'additionalProperty' 

339 while add_prop_key in data: 

340 add_prop_key += '_' 

341 if addition_props is True: 

342 data[add_prop_key] = self._char() 

343 else: 

344 data[add_prop_key] = self._gen_any(addition_props) 

345 

346 return data 

347 

348 def _str_gen(self, schema: dict[str, Any]) -> str: 

349 """Generate a string from a JSON Schema string.""" 

350 min_len = schema.get('minLength') 

351 if min_len is not None: 

352 return self._char() * min_len 

353 

354 if schema.get('maxLength') == 0: 

355 return '' 

356 

357 if fmt := schema.get('format'): 

358 if fmt == 'date': 

359 return (date(2024, 1, 1) + timedelta(days=self.seed)).isoformat() 

360 

361 return self._char() 

362 

363 def _int_gen(self, schema: dict[str, Any]) -> int: 

364 """Generate an integer from a JSON Schema integer.""" 

365 maximum = schema.get('maximum') 

366 if maximum is None: 

367 exc_max = schema.get('exclusiveMaximum') 

368 if exc_max is not None: 

369 maximum = exc_max - 1 

370 

371 minimum = schema.get('minimum') 

372 if minimum is None: 

373 exc_min = schema.get('exclusiveMinimum') 

374 if exc_min is not None: 

375 minimum = exc_min + 1 

376 

377 if minimum is not None and maximum is not None: 

378 return minimum + self.seed % (maximum - minimum) 

379 elif minimum is not None: 

380 return minimum + self.seed 

381 elif maximum is not None: 

382 return maximum - self.seed 

383 else: 

384 return self.seed 

385 

386 def _bool_gen(self) -> bool: 

387 """Generate a boolean from a JSON Schema boolean.""" 

388 return bool(self.seed % 2) 

389 

390 def _array_gen(self, schema: dict[str, Any]) -> list[Any]: 

391 """Generate an array from a JSON Schema array.""" 

392 data: list[Any] = [] 

393 unique_items = schema.get('uniqueItems') 

394 if prefix_items := schema.get('prefixItems'): 

395 for item in prefix_items: 

396 data.append(self._gen_any(item)) 

397 if unique_items: 

398 self.seed += 1 

399 

400 items_schema = schema.get('items', {}) 

401 min_items = schema.get('minItems', 0) 

402 if min_items > len(data): 

403 for _ in range(min_items - len(data)): 

404 data.append(self._gen_any(items_schema)) 

405 if unique_items: 

406 self.seed += 1 

407 elif items_schema: 

408 # if there is an `items` schema, add an item unless it would break `maxItems` rule 

409 max_items = schema.get('maxItems') 

410 if max_items is None or max_items > len(data): 

411 data.append(self._gen_any(items_schema)) 

412 if unique_items: 

413 self.seed += 1 

414 

415 return data 

416 

417 def _char(self) -> str: 

418 """Generate a character on the same principle as Excel columns, e.g. a-z, aa-az...""" 

419 chars = len(_chars) 

420 s = '' 

421 rem = self.seed // chars 

422 while rem > 0: 

423 s += _chars[(rem - 1) % chars] 

424 rem //= chars 

425 s += _chars[self.seed % chars] 

426 return s 

427 

428 

429def _get_string_usage(text: str) -> Usage: 

430 response_tokens = _estimate_string_tokens(text) 

431 return Usage(response_tokens=response_tokens, total_tokens=response_tokens)