Coverage for pydantic_ai_slim/pydantic_ai/result.py: 94.34%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3from abc import ABC, abstractmethod 

4from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable 

5from copy import deepcopy 

6from dataclasses import dataclass, field 

7from datetime import datetime 

8from typing import Generic, Union, cast 

9 

10import logfire_api 

11from typing_extensions import TypeVar 

12 

13from . import _result, _utils, exceptions, messages as _messages, models 

14from .tools import AgentDepsT, RunContext 

15from .usage import Usage, UsageLimits 

16 

17__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult' 

18 

19 

20T = TypeVar('T') 

21"""An invariant TypeVar.""" 

22ResultDataT_inv = TypeVar('ResultDataT_inv', default=str) 

23""" 

24An invariant type variable for the result data of a model. 

25 

26We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used 

27in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types 

28possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and 

29changing it would have negative consequences for the ergonomics of the library. 

30 

31At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would 

32resolve these potential variance issues. 

33""" 

34ResultDataT = TypeVar('ResultDataT', default=str, covariant=True) 

35"""Covariant type variable for the result data type of a run.""" 

36 

37ResultValidatorFunc = Union[ 

38 Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv], 

39 Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]], 

40 Callable[[ResultDataT_inv], ResultDataT_inv], 

41 Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]], 

42] 

43""" 

44A function that always takes and returns the same type of data (which is the result type of an agent run), and: 

45 

46* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument 

47* may or may not be async 

48 

49Usage `ResultValidatorFunc[AgentDepsT, T]`. 

50""" 

51 

52_logfire = logfire_api.Logfire(otel_scope='pydantic-ai') 

53 

54 

55@dataclass 

56class _BaseRunResult(ABC, Generic[ResultDataT]): 

57 """Base type for results. 

58 

59 You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`. 

60 """ 

61 

62 _all_messages: list[_messages.ModelMessage] 

63 _new_message_index: int 

64 

65 def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: 

66 """Return the history of _messages. 

67 

68 Args: 

69 result_tool_return_content: The return content of the tool call to set in the last message. 

70 This provides a convenient way to modify the content of the result tool call if you want to continue 

71 the conversation and want to set the response to the result tool call. If `None`, the last message will 

72 not be modified. 

73 

74 Returns: 

75 List of messages. 

76 """ 

77 # this is a method to be consistent with the other methods 

78 if result_tool_return_content is not None: 

79 raise NotImplementedError('Setting result tool return content is not supported for this result type.') 

80 return self._all_messages 

81 

82 def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: 

83 """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. 

84 

85 Args: 

86 result_tool_return_content: The return content of the tool call to set in the last message. 

87 This provides a convenient way to modify the content of the result tool call if you want to continue 

88 the conversation and want to set the response to the result tool call. If `None`, the last message will 

89 not be modified. 

90 

91 Returns: 

92 JSON bytes representing the messages. 

93 """ 

94 return _messages.ModelMessagesTypeAdapter.dump_json( 

95 self.all_messages(result_tool_return_content=result_tool_return_content) 

96 ) 

97 

98 def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: 

99 """Return new messages associated with this run. 

100 

101 Messages from older runs are excluded. 

102 

103 Args: 

104 result_tool_return_content: The return content of the tool call to set in the last message. 

105 This provides a convenient way to modify the content of the result tool call if you want to continue 

106 the conversation and want to set the response to the result tool call. If `None`, the last message will 

107 not be modified. 

108 

109 Returns: 

110 List of new messages. 

111 """ 

112 return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] 

113 

114 def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: 

115 """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes. 

116 

117 Args: 

118 result_tool_return_content: The return content of the tool call to set in the last message. 

119 This provides a convenient way to modify the content of the result tool call if you want to continue 

120 the conversation and want to set the response to the result tool call. If `None`, the last message will 

121 not be modified. 

122 

123 Returns: 

124 JSON bytes representing the new messages. 

125 """ 

126 return _messages.ModelMessagesTypeAdapter.dump_json( 

127 self.new_messages(result_tool_return_content=result_tool_return_content) 

128 ) 

129 

130 @abstractmethod 

131 def usage(self) -> Usage: 

132 raise NotImplementedError() 

133 

134 

135@dataclass 

136class RunResult(_BaseRunResult[ResultDataT]): 

137 """Result of a non-streamed run.""" 

138 

139 data: ResultDataT 

140 """Data from the final response in the run.""" 

141 _result_tool_name: str | None 

142 _usage: Usage 

143 

144 def usage(self) -> Usage: 

145 """Return the usage of the whole run.""" 

146 return self._usage 

147 

148 def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: 

149 """Return the history of _messages. 

150 

151 Args: 

152 result_tool_return_content: The return content of the tool call to set in the last message. 

153 This provides a convenient way to modify the content of the result tool call if you want to continue 

154 the conversation and want to set the response to the result tool call. If `None`, the last message will 

155 not be modified. 

156 

157 Returns: 

158 List of messages. 

159 """ 

160 if result_tool_return_content is not None: 

161 return self._set_result_tool_return(result_tool_return_content) 

162 else: 

163 return self._all_messages 

164 

165 def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: 

166 """Set return content for the result tool. 

167 

168 Useful if you want to continue the conversation and want to set the response to the result tool call. 

169 """ 

170 if not self._result_tool_name: 

171 raise ValueError('Cannot set result tool return content when the return type is `str`.') 

172 messages = deepcopy(self._all_messages) 

173 last_message = messages[-1] 

174 for part in last_message.parts: 

175 if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name: 

176 part.content = return_content 

177 return messages 

178 raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.') 

179 

180 

181@dataclass 

182class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]): 

183 """Result of a streamed run that returns structured data via a tool call.""" 

184 

185 _usage_limits: UsageLimits | None 

186 _stream_response: models.StreamedResponse 

187 _result_schema: _result.ResultSchema[ResultDataT] | None 

188 _run_ctx: RunContext[AgentDepsT] 

189 _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] 

190 _result_tool_name: str | None 

191 _on_complete: Callable[[], Awaitable[None]] 

192 is_complete: bool = field(default=False, init=False) 

193 """Whether the stream has all been received. 

194 

195 This is set to `True` when one of 

196 [`stream`][pydantic_ai.result.StreamedRunResult.stream], 

197 [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text], 

198 [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or 

199 [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes. 

200 """ 

201 

202 async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: 

203 """Stream the response as an async iterable. 

204 

205 The pydantic validator for structured data will be called in 

206 [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation) 

207 on each iteration. 

208 

209 Args: 

210 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. 

211 Debouncing is particularly important for long structured responses to reduce the overhead of 

212 performing validation as each token is received. 

213 

214 Returns: 

215 An async iterable of the response data. 

216 """ 

217 async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): 

218 result = await self.validate_structured_result(structured_message, allow_partial=not is_last) 

219 yield result 

220 

221 async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: 

222 """Stream the text result as an async iterable. 

223 

224 !!! note 

225 Result validators will NOT be called on the text result if `delta=True`. 

226 

227 Args: 

228 delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text 

229 up to the current point. 

230 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. 

231 Debouncing is particularly important for long structured responses to reduce the overhead of 

232 performing validation as each token is received. 

233 """ 

234 if self._result_schema and not self._result_schema.allow_text_result: 

235 raise exceptions.UserError('stream_text() can only be used with text responses') 

236 

237 usage_checking_stream = _get_usage_checking_stream_response( 

238 self._stream_response, self._usage_limits, self.usage 

239 ) 

240 

241 # Define a "merged" version of the iterator that will yield items that have already been retrieved 

242 # and items that we receive while streaming. We define a dedicated async iterator for this so we can 

243 # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. 

244 async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: 

245 # if the response currently has any parts with content, yield those before streaming 

246 msg = self._stream_response.get() 

247 for i, part in enumerate(msg.parts): 

248 if isinstance(part, _messages.TextPart) and part.content: 

249 yield part.content, i 

250 

251 async for event in usage_checking_stream: 

252 if ( 252 ↛ 257line 252 didn't jump to line 257

253 isinstance(event, _messages.PartStartEvent) 

254 and isinstance(event.part, _messages.TextPart) 

255 and event.part.content 

256 ): 

257 yield event.part.content, event.index 

258 elif ( 258 ↛ 251line 258 didn't jump to line 251

259 isinstance(event, _messages.PartDeltaEvent) 

260 and isinstance(event.delta, _messages.TextPartDelta) 

261 and event.delta.content_delta 

262 ): 

263 yield event.delta.content_delta, event.index 

264 

265 async def _stream_text_deltas() -> AsyncIterator[str]: 

266 async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: 

267 async for items in group_iter: 

268 yield ''.join([content for content, _ in items]) 

269 

270 with _logfire.span('response stream text') as lf_span: 

271 if delta: 

272 async for text in _stream_text_deltas(): 

273 yield text 

274 else: 

275 # a quick benchmark shows it's faster to build up a string with concat when we're 

276 # yielding at each step 

277 deltas: list[str] = [] 

278 combined_validated_text = '' 

279 async for text in _stream_text_deltas(): 

280 deltas.append(text) 

281 combined_text = ''.join(deltas) 

282 combined_validated_text = await self._validate_text_result(combined_text) 

283 yield combined_validated_text 

284 

285 lf_span.set_attribute('combined_text', combined_validated_text) 

286 await self._marked_completed( 

287 _messages.ModelResponse( 

288 parts=[_messages.TextPart(combined_validated_text)], 

289 model_name=self._stream_response.model_name(), 

290 ) 

291 ) 

292 

293 async def stream_structured( 

294 self, *, debounce_by: float | None = 0.1 

295 ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]: 

296 """Stream the response as an async iterable of Structured LLM Messages. 

297 

298 Args: 

299 debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. 

300 Debouncing is particularly important for long structured responses to reduce the overhead of 

301 performing validation as each token is received. 

302 

303 Returns: 

304 An async iterable of the structured response message and whether that is the last message. 

305 """ 

306 usage_checking_stream = _get_usage_checking_stream_response( 

307 self._stream_response, self._usage_limits, self.usage 

308 ) 

309 

310 with _logfire.span('response stream structured') as lf_span: 

311 # if the message currently has any parts with content, yield before streaming 

312 msg = self._stream_response.get() 

313 for part in msg.parts: 

314 if part.has_content(): 

315 yield msg, False 

316 break 

317 

318 async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: 

319 async for _events in group_iter: 

320 msg = self._stream_response.get() 

321 yield msg, False 

322 msg = self._stream_response.get() 

323 yield msg, True 

324 # TODO: Should this now be `final_response` instead of `structured_response`? 

325 lf_span.set_attribute('structured_response', msg) 

326 await self._marked_completed(msg) 

327 

328 async def get_data(self) -> ResultDataT: 

329 """Stream the whole response, validate and return it.""" 

330 usage_checking_stream = _get_usage_checking_stream_response( 

331 self._stream_response, self._usage_limits, self.usage 

332 ) 

333 

334 async for _ in usage_checking_stream: 

335 pass 

336 message = self._stream_response.get() 

337 await self._marked_completed(message) 

338 return await self.validate_structured_result(message) 

339 

340 def usage(self) -> Usage: 

341 """Return the usage of the whole run. 

342 

343 !!! note 

344 This won't return the full usage until the stream is finished. 

345 """ 

346 return self._run_ctx.usage + self._stream_response.usage() 

347 

348 def timestamp(self) -> datetime: 

349 """Get the timestamp of the response.""" 

350 return self._stream_response.timestamp() 

351 

352 async def validate_structured_result( 

353 self, message: _messages.ModelResponse, *, allow_partial: bool = False 

354 ) -> ResultDataT: 

355 """Validate a structured result message.""" 

356 if self._result_schema is not None and self._result_tool_name is not None: 

357 match = self._result_schema.find_named_tool(message.parts, self._result_tool_name) 

358 if match is None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true

359 raise exceptions.UnexpectedModelBehavior( 

360 f'Invalid response, unable to find tool: {self._result_schema.tool_names()}' 

361 ) 

362 

363 call, result_tool = match 

364 result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) 

365 

366 for validator in self._result_validators: 366 ↛ 367line 366 didn't jump to line 367 because the loop on line 366 never started

367 result_data = await validator.validate(result_data, call, self._run_ctx) 

368 return result_data 

369 else: 

370 text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) 

371 for validator in self._result_validators: 371 ↛ 372line 371 didn't jump to line 372 because the loop on line 371 never started

372 text = await validator.validate( 

373 text, 

374 None, 

375 self._run_ctx, 

376 ) 

377 # Since there is no result tool, we can assume that str is compatible with ResultDataT 

378 return cast(ResultDataT, text) 

379 

380 async def _validate_text_result(self, text: str) -> str: 

381 for validator in self._result_validators: 381 ↛ 382line 381 didn't jump to line 382 because the loop on line 381 never started

382 text = await validator.validate( 

383 text, 

384 None, 

385 self._run_ctx, 

386 ) 

387 return text 

388 

389 async def _marked_completed(self, message: _messages.ModelResponse) -> None: 

390 self.is_complete = True 

391 self._all_messages.append(message) 

392 await self._on_complete() 

393 

394 

395def _get_usage_checking_stream_response( 

396 stream_response: AsyncIterable[_messages.ModelResponseStreamEvent], 

397 limits: UsageLimits | None, 

398 get_usage: Callable[[], Usage], 

399) -> AsyncIterable[_messages.ModelResponseStreamEvent]: 

400 if limits is not None and limits.has_token_limits(): 

401 

402 async def _usage_checking_iterator(): 

403 async for item in stream_response: 403 ↛ exitline 403 didn't return from function '_usage_checking_iterator' because the loop on line 403 didn't complete

404 limits.check_tokens(get_usage()) 

405 yield item 

406 

407 return _usage_checking_iterator() 

408 else: 

409 return stream_response