Coverage for pydantic_ai_slim/pydantic_ai/models/test.py: 99.21%

257 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3import re 

4import string 

5from collections.abc import AsyncIterator, Iterable 

6from contextlib import asynccontextmanager 

7from dataclasses import InitVar, dataclass, field 

8from datetime import date, datetime, timedelta 

9from typing import Any, Literal 

10 

11import pydantic_core 

12 

13from .. import _utils 

14from ..messages import ( 

15 ModelMessage, 

16 ModelRequest, 

17 ModelResponse, 

18 ModelResponsePart, 

19 ModelResponseStreamEvent, 

20 RetryPromptPart, 

21 TextPart, 

22 ToolCallPart, 

23 ToolReturnPart, 

24) 

25from ..result import Usage 

26from ..settings import ModelSettings 

27from ..tools import ToolDefinition 

28from . import ( 

29 AgentModel, 

30 Model, 

31 StreamedResponse, 

32) 

33from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage] 

34 

35 

36@dataclass 

37class _TextResult: 

38 """A private wrapper class to tag a result that came from the custom_result_text field.""" 

39 

40 value: str | None 

41 

42 

43@dataclass 

44class _FunctionToolResult: 

45 """A wrapper class to tag a result that came from the custom_result_args field.""" 

46 

47 value: Any | None 

48 

49 

50@dataclass 

51class TestModel(Model): 

52 """A model specifically for testing purposes. 

53 

54 This will (by default) call all tools in the agent, then return a tool response if possible, 

55 otherwise a plain response. 

56 

57 How useful this model is will vary significantly. 

58 

59 Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those 

60 of the base class. 

61 """ 

62 

63 # NOTE: Avoid test discovery by pytest. 

64 __test__ = False 

65 

66 call_tools: list[str] | Literal['all'] = 'all' 

67 """List of tools to call. If `'all'`, all tools will be called.""" 

68 custom_result_text: str | None = None 

69 """If set, this text is returned as the final result.""" 

70 custom_result_args: Any | None = None 

71 """If set, these args will be passed to the result tool.""" 

72 seed: int = 0 

73 """Seed for generating random data.""" 

74 agent_model_function_tools: list[ToolDefinition] | None = field(default=None, init=False) 

75 """Definition of function tools passed to the model. 

76 

77 This is set when the model is called, so will reflect the function tools from the last step of the last run. 

78 """ 

79 agent_model_allow_text_result: bool | None = field(default=None, init=False) 

80 """Whether plain text responses from the model are allowed. 

81 

82 This is set when the model is called, so will reflect the value from the last step of the last run. 

83 """ 

84 agent_model_result_tools: list[ToolDefinition] | None = field(default=None, init=False) 

85 """Definition of result tools passed to the model. 

86 

87 This is set when the model is called, so will reflect the result tools from the last step of the last run. 

88 """ 

89 

90 async def agent_model( 

91 self, 

92 *, 

93 function_tools: list[ToolDefinition], 

94 allow_text_result: bool, 

95 result_tools: list[ToolDefinition], 

96 ) -> AgentModel: 

97 self.agent_model_function_tools = function_tools 

98 self.agent_model_allow_text_result = allow_text_result 

99 self.agent_model_result_tools = result_tools 

100 

101 if self.call_tools == 'all': 

102 tool_calls = [(r.name, r) for r in function_tools] 

103 else: 

104 function_tools_lookup = {t.name: t for t in function_tools} 

105 tools_to_call = (function_tools_lookup[name] for name in self.call_tools) 

106 tool_calls = [(r.name, r) for r in tools_to_call] 

107 

108 if self.custom_result_text is not None: 

109 assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.' 

110 assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.' 

111 result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text) 

112 elif self.custom_result_args is not None: 

113 assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.' 

114 result_tool = result_tools[0] 

115 

116 if k := result_tool.outer_typed_dict_key: 

117 result = _FunctionToolResult({k: self.custom_result_args}) 

118 else: 

119 result = _FunctionToolResult(self.custom_result_args) 

120 elif allow_text_result: 

121 result = _TextResult(None) 

122 elif result_tools: 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true

123 result = _FunctionToolResult(None) 

124 else: 

125 result = _TextResult(None) 

126 

127 return TestAgentModel(tool_calls, result, result_tools, self.seed) 

128 

129 def name(self) -> str: 

130 return 'test-model' 

131 

132 

133@dataclass 

134class TestAgentModel(AgentModel): 

135 """Implementation of `AgentModel` for testing purposes.""" 

136 

137 # NOTE: Avoid test discovery by pytest. 

138 __test__ = False 

139 

140 tool_calls: list[tuple[str, ToolDefinition]] 

141 # left means the text is plain text; right means it's a function call 

142 result: _TextResult | _FunctionToolResult 

143 result_tools: list[ToolDefinition] 

144 seed: int 

145 model_name: str = 'test' 

146 

147 async def request( 

148 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

149 ) -> tuple[ModelResponse, Usage]: 

150 model_response = self._request(messages, model_settings) 

151 usage = _estimate_usage([*messages, model_response]) 

152 return model_response, usage 

153 

154 @asynccontextmanager 

155 async def request_stream( 

156 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

157 ) -> AsyncIterator[StreamedResponse]: 

158 model_response = self._request(messages, model_settings) 

159 yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages) 

160 

161 def gen_tool_args(self, tool_def: ToolDefinition) -> Any: 

162 return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() 

163 

164 def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse: 

165 # if there are tools, the first thing we want to do is call all of them 

166 if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages): 

167 return ModelResponse( 

168 parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls], 

169 model_name=self.model_name, 

170 ) 

171 

172 if messages: 172 ↛ 202line 172 didn't jump to line 202 because the condition on line 172 was always true

173 last_message = messages[-1] 

174 assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.' 

175 

176 # check if there are any retry prompts, if so retry them 

177 new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)} 

178 if new_retry_names: 

179 # Handle retries for both function tools and result tools 

180 # Check function tools first 

181 retry_parts: list[ModelResponsePart] = [ 

182 ToolCallPart(name, self.gen_tool_args(args)) 

183 for name, args in self.tool_calls 

184 if name in new_retry_names 

185 ] 

186 # Check result tools 

187 if self.result_tools: 

188 retry_parts.extend( 

189 [ 

190 ToolCallPart( 

191 tool.name, 

192 self.result.value 

193 if isinstance(self.result, _FunctionToolResult) and self.result.value is not None 

194 else self.gen_tool_args(tool), 

195 ) 

196 for tool in self.result_tools 

197 if tool.name in new_retry_names 

198 ] 

199 ) 

200 return ModelResponse(parts=retry_parts, model_name=self.model_name) 

201 

202 if isinstance(self.result, _TextResult): 

203 if (response_text := self.result.value) is None: 

204 # build up details of tool responses 

205 output: dict[str, Any] = {} 

206 for message in messages: 

207 if isinstance(message, ModelRequest): 

208 for part in message.parts: 

209 if isinstance(part, ToolReturnPart): 

210 output[part.tool_name] = part.content 

211 if output: 

212 return ModelResponse( 

213 parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name 

214 ) 

215 else: 

216 return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name) 

217 else: 

218 return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name) 

219 else: 

220 assert self.result_tools, 'No result tools provided' 

221 custom_result_args = self.result.value 

222 result_tool = self.result_tools[self.seed % len(self.result_tools)] 

223 if custom_result_args is not None: 

224 return ModelResponse( 

225 parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name 

226 ) 

227 else: 

228 response_args = self.gen_tool_args(result_tool) 

229 return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name) 

230 

231 

232@dataclass 

233class TestStreamedResponse(StreamedResponse): 

234 """A structured response that streams test data.""" 

235 

236 _structured_response: ModelResponse 

237 _messages: InitVar[Iterable[ModelMessage]] 

238 

239 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) 

240 

241 def __post_init__(self, _messages: Iterable[ModelMessage]): 

242 self._usage = _estimate_usage(_messages) 

243 

244 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

245 for i, part in enumerate(self._structured_response.parts): 

246 if isinstance(part, TextPart): 

247 text = part.content 

248 *words, last_word = text.split(' ') 

249 words = [f'{word} ' for word in words] 

250 words.append(last_word) 

251 if len(words) == 1 and len(text) > 2: 

252 mid = len(text) // 2 

253 words = [text[:mid], text[mid:]] 

254 self._usage += _get_string_usage('') 

255 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='') 

256 for word in words: 

257 self._usage += _get_string_usage(word) 

258 yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) 

259 else: 

260 yield self._parts_manager.handle_tool_call_part( 

261 vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id 

262 ) 

263 

264 def timestamp(self) -> datetime: 

265 return self._timestamp 

266 

267 

268_chars = string.ascii_letters + string.digits + string.punctuation 

269 

270 

271class _JsonSchemaTestData: 

272 """Generate data that matches a JSON schema. 

273 

274 This tries to generate the minimal viable data for the schema. 

275 """ 

276 

277 def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0): 

278 self.schema = schema 

279 self.defs = schema.get('$defs', {}) 

280 self.seed = seed 

281 

282 def generate(self) -> Any: 

283 """Generate data for the JSON schema.""" 

284 return self._gen_any(self.schema) 

285 

286 def _gen_any(self, schema: dict[str, Any]) -> Any: 

287 """Generate data for any JSON Schema.""" 

288 if const := schema.get('const'): 

289 return const 

290 elif enum := schema.get('enum'): 

291 return enum[self.seed % len(enum)] 

292 elif examples := schema.get('examples'): 

293 return examples[self.seed % len(examples)] 

294 elif ref := schema.get('$ref'): 

295 key = re.sub(r'^#/\$defs/', '', ref) 

296 js_def = self.defs[key] 

297 return self._gen_any(js_def) 

298 elif any_of := schema.get('anyOf'): 

299 return self._gen_any(any_of[self.seed % len(any_of)]) 

300 

301 type_ = schema.get('type') 

302 if type_ is None: 

303 # if there's no type or ref, we can't generate anything 

304 return self._char() 

305 elif type_ == 'object': 

306 return self._object_gen(schema) 

307 elif type_ == 'string': 

308 return self._str_gen(schema) 

309 elif type_ == 'integer': 

310 return self._int_gen(schema) 

311 elif type_ == 'number': 

312 return float(self._int_gen(schema)) 

313 elif type_ == 'boolean': 

314 return self._bool_gen() 

315 elif type_ == 'array': 

316 return self._array_gen(schema) 

317 elif type_ == 'null': 

318 return None 

319 else: 

320 raise NotImplementedError(f'Unknown type: {type_}, please submit a PR to extend JsonSchemaTestData!') 

321 

322 def _object_gen(self, schema: dict[str, Any]) -> dict[str, Any]: 

323 """Generate data for a JSON Schema object.""" 

324 required = set(schema.get('required', [])) 

325 

326 data: dict[str, Any] = {} 

327 if properties := schema.get('properties'): 

328 for key, value in properties.items(): 

329 if key in required: 

330 data[key] = self._gen_any(value) 

331 

332 if addition_props := schema.get('additionalProperties'): 

333 add_prop_key = 'additionalProperty' 

334 while add_prop_key in data: 

335 add_prop_key += '_' 

336 if addition_props is True: 

337 data[add_prop_key] = self._char() 

338 else: 

339 data[add_prop_key] = self._gen_any(addition_props) 

340 

341 return data 

342 

343 def _str_gen(self, schema: dict[str, Any]) -> str: 

344 """Generate a string from a JSON Schema string.""" 

345 min_len = schema.get('minLength') 

346 if min_len is not None: 

347 return self._char() * min_len 

348 

349 if schema.get('maxLength') == 0: 

350 return '' 

351 

352 if fmt := schema.get('format'): 

353 if fmt == 'date': 

354 return (date(2024, 1, 1) + timedelta(days=self.seed)).isoformat() 

355 

356 return self._char() 

357 

358 def _int_gen(self, schema: dict[str, Any]) -> int: 

359 """Generate an integer from a JSON Schema integer.""" 

360 maximum = schema.get('maximum') 

361 if maximum is None: 

362 exc_max = schema.get('exclusiveMaximum') 

363 if exc_max is not None: 

364 maximum = exc_max - 1 

365 

366 minimum = schema.get('minimum') 

367 if minimum is None: 

368 exc_min = schema.get('exclusiveMinimum') 

369 if exc_min is not None: 

370 minimum = exc_min + 1 

371 

372 if minimum is not None and maximum is not None: 

373 return minimum + self.seed % (maximum - minimum) 

374 elif minimum is not None: 

375 return minimum + self.seed 

376 elif maximum is not None: 

377 return maximum - self.seed 

378 else: 

379 return self.seed 

380 

381 def _bool_gen(self) -> bool: 

382 """Generate a boolean from a JSON Schema boolean.""" 

383 return bool(self.seed % 2) 

384 

385 def _array_gen(self, schema: dict[str, Any]) -> list[Any]: 

386 """Generate an array from a JSON Schema array.""" 

387 data: list[Any] = [] 

388 unique_items = schema.get('uniqueItems') 

389 if prefix_items := schema.get('prefixItems'): 

390 for item in prefix_items: 

391 data.append(self._gen_any(item)) 

392 if unique_items: 

393 self.seed += 1 

394 

395 items_schema = schema.get('items', {}) 

396 min_items = schema.get('minItems', 0) 

397 if min_items > len(data): 

398 for _ in range(min_items - len(data)): 

399 data.append(self._gen_any(items_schema)) 

400 if unique_items: 

401 self.seed += 1 

402 elif items_schema: 

403 # if there is an `items` schema, add an item unless it would break `maxItems` rule 

404 max_items = schema.get('maxItems') 

405 if max_items is None or max_items > len(data): 

406 data.append(self._gen_any(items_schema)) 

407 if unique_items: 

408 self.seed += 1 

409 

410 return data 

411 

412 def _char(self) -> str: 

413 """Generate a character on the same principle as Excel columns, e.g. a-z, aa-az...""" 

414 chars = len(_chars) 

415 s = '' 

416 rem = self.seed // chars 

417 while rem > 0: 

418 s += _chars[(rem - 1) % chars] 

419 rem //= chars 

420 s += _chars[self.seed % chars] 

421 return s 

422 

423 

424def _get_string_usage(text: str) -> Usage: 

425 response_tokens = _estimate_string_tokens(text) 

426 return Usage(response_tokens=response_tokens, total_tokens=response_tokens)