Coverage for pydantic_ai_slim/pydantic_ai/models/openai.py: 98.12%

194 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import base64 

4from collections.abc import AsyncIterable, AsyncIterator 

5from contextlib import asynccontextmanager 

6from dataclasses import dataclass, field 

7from datetime import datetime, timezone 

8from typing import Literal, Union, cast, overload 

9 

10from typing_extensions import assert_never 

11 

12from pydantic_ai.providers import Provider, infer_provider 

13 

14from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage 

15from .._utils import guard_tool_call_id as _guard_tool_call_id 

16from ..messages import ( 

17 AudioUrl, 

18 BinaryContent, 

19 DocumentUrl, 

20 ImageUrl, 

21 ModelMessage, 

22 ModelRequest, 

23 ModelResponse, 

24 ModelResponsePart, 

25 ModelResponseStreamEvent, 

26 RetryPromptPart, 

27 SystemPromptPart, 

28 TextPart, 

29 ToolCallPart, 

30 ToolReturnPart, 

31 UserPromptPart, 

32) 

33from ..settings import ModelSettings 

34from ..tools import ToolDefinition 

35from . import ( 

36 Model, 

37 ModelRequestParameters, 

38 StreamedResponse, 

39 cached_async_http_client, 

40 check_allow_model_requests, 

41) 

42 

43try: 

44 from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream 

45 from openai.types import ChatModel, chat 

46 from openai.types.chat import ( 

47 ChatCompletionChunk, 

48 ChatCompletionContentPartImageParam, 

49 ChatCompletionContentPartInputAudioParam, 

50 ChatCompletionContentPartParam, 

51 ChatCompletionContentPartTextParam, 

52 ) 

53 from openai.types.chat.chat_completion_content_part_image_param import ImageURL 

54 from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio 

55except ImportError as _import_error: 

56 raise ImportError( 

57 'Please install `openai` to use the OpenAI model, ' 

58 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' 

59 ) from _import_error 

60 

61OpenAIModelName = Union[str, ChatModel] 

62""" 

63Possible OpenAI model names. 

64 

65Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but 

66allow any name in the type hints. 

67See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list. 

68 

69Using this more broad type for the model name instead of the ChatModel definition 

70allows this model to be used more easily with other model types (ie, Ollama, Deepseek). 

71""" 

72 

73OpenAISystemPromptRole = Literal['system', 'developer', 'user'] 

74 

75 

76class OpenAIModelSettings(ModelSettings, total=False): 

77 """Settings used for an OpenAI model request. 

78 

79 ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. 

80 """ 

81 

82 openai_reasoning_effort: chat.ChatCompletionReasoningEffort 

83 """ 

84 Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning). 

85 Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can 

86 result in faster responses and fewer tokens used on reasoning in a response. 

87 """ 

88 

89 openai_user: str 

90 """A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse. 

91 

92 See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details. 

93 """ 

94 

95 

96@dataclass(init=False) 

97class OpenAIModel(Model): 

98 """A model that uses the OpenAI API. 

99 

100 Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API. 

101 

102 Apart from `__init__`, all methods are private or match those of the base class. 

103 """ 

104 

105 client: AsyncOpenAI = field(repr=False) 

106 system_prompt_role: OpenAISystemPromptRole | None = field(default=None) 

107 

108 _model_name: OpenAIModelName = field(repr=False) 

109 _system: str = field(default='openai', repr=False) 

110 

111 def __init__( 

112 self, 

113 model_name: OpenAIModelName, 

114 *, 

115 provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai', 

116 system_prompt_role: OpenAISystemPromptRole | None = None, 

117 ): 

118 """Initialize an OpenAI model. 

119 

120 Args: 

121 model_name: The name of the OpenAI model to use. List of model names available 

122 [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7) 

123 (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API). 

124 provider: The provider to use. Defaults to `'openai'`. 

125 system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`. 

126 In the future, this may be inferred from the model name. 

127 """ 

128 self._model_name = model_name 

129 if isinstance(provider, str): 

130 provider = infer_provider(provider) 

131 self.client = provider.client 

132 self.system_prompt_role = system_prompt_role 

133 

134 @property 

135 def base_url(self) -> str: 

136 return str(self.client.base_url) 

137 

138 async def request( 

139 self, 

140 messages: list[ModelMessage], 

141 model_settings: ModelSettings | None, 

142 model_request_parameters: ModelRequestParameters, 

143 ) -> tuple[ModelResponse, usage.Usage]: 

144 check_allow_model_requests() 

145 response = await self._completions_create( 

146 messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters 

147 ) 

148 return self._process_response(response), _map_usage(response) 

149 

150 @asynccontextmanager 

151 async def request_stream( 

152 self, 

153 messages: list[ModelMessage], 

154 model_settings: ModelSettings | None, 

155 model_request_parameters: ModelRequestParameters, 

156 ) -> AsyncIterator[StreamedResponse]: 

157 check_allow_model_requests() 

158 response = await self._completions_create( 

159 messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters 

160 ) 

161 async with response: 

162 yield await self._process_streamed_response(response) 

163 

164 @property 

165 def model_name(self) -> OpenAIModelName: 

166 """The model name.""" 

167 return self._model_name 

168 

169 @property 

170 def system(self) -> str: 

171 """The system / model provider.""" 

172 return self._system 

173 

174 @overload 

175 async def _completions_create( 

176 self, 

177 messages: list[ModelMessage], 

178 stream: Literal[True], 

179 model_settings: OpenAIModelSettings, 

180 model_request_parameters: ModelRequestParameters, 

181 ) -> AsyncStream[ChatCompletionChunk]: 

182 pass 

183 

184 @overload 

185 async def _completions_create( 

186 self, 

187 messages: list[ModelMessage], 

188 stream: Literal[False], 

189 model_settings: OpenAIModelSettings, 

190 model_request_parameters: ModelRequestParameters, 

191 ) -> chat.ChatCompletion: 

192 pass 

193 

194 async def _completions_create( 

195 self, 

196 messages: list[ModelMessage], 

197 stream: bool, 

198 model_settings: OpenAIModelSettings, 

199 model_request_parameters: ModelRequestParameters, 

200 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: 

201 tools = self._get_tools(model_request_parameters) 

202 

203 # standalone function to make it easier to override 

204 if not tools: 

205 tool_choice: Literal['none', 'required', 'auto'] | None = None 

206 elif not model_request_parameters.allow_text_result: 

207 tool_choice = 'required' 

208 else: 

209 tool_choice = 'auto' 

210 

211 openai_messages: list[chat.ChatCompletionMessageParam] = [] 

212 for m in messages: 

213 async for msg in self._map_message(m): 

214 openai_messages.append(msg) 

215 

216 try: 

217 return await self.client.chat.completions.create( 

218 model=self._model_name, 

219 messages=openai_messages, 

220 n=1, 

221 parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), 

222 tools=tools or NOT_GIVEN, 

223 tool_choice=tool_choice or NOT_GIVEN, 

224 stream=stream, 

225 stream_options={'include_usage': True} if stream else NOT_GIVEN, 

226 max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN), 

227 temperature=model_settings.get('temperature', NOT_GIVEN), 

228 top_p=model_settings.get('top_p', NOT_GIVEN), 

229 timeout=model_settings.get('timeout', NOT_GIVEN), 

230 seed=model_settings.get('seed', NOT_GIVEN), 

231 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), 

232 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), 

233 logit_bias=model_settings.get('logit_bias', NOT_GIVEN), 

234 reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), 

235 user=model_settings.get('openai_user', NOT_GIVEN), 

236 ) 

237 except APIStatusError as e: 

238 if (status_code := e.status_code) >= 400: 238 ↛ 240line 238 didn't jump to line 240 because the condition on line 238 was always true

239 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e 

240 raise 

241 

242 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: 

243 """Process a non-streamed response, and prepare a message to return.""" 

244 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) 

245 choice = response.choices[0] 

246 items: list[ModelResponsePart] = [] 

247 if choice.message.content is not None: 

248 items.append(TextPart(choice.message.content)) 

249 if choice.message.tool_calls is not None: 

250 for c in choice.message.tool_calls: 

251 items.append(ToolCallPart(c.function.name, c.function.arguments, c.id)) 

252 return ModelResponse(items, model_name=response.model, timestamp=timestamp) 

253 

254 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: 

255 """Process a streamed response, and prepare a streaming response to return.""" 

256 peekable_response = _utils.PeekableAsyncStream(response) 

257 first_chunk = await peekable_response.peek() 

258 if isinstance(first_chunk, _utils.Unset): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true

259 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

260 

261 return OpenAIStreamedResponse( 

262 _model_name=self._model_name, 

263 _response=peekable_response, 

264 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), 

265 ) 

266 

267 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: 

268 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] 

269 if model_request_parameters.result_tools: 

270 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools] 

271 return tools 

272 

273 async def _map_message(self, message: ModelMessage) -> AsyncIterable[chat.ChatCompletionMessageParam]: 

274 """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" 

275 if isinstance(message, ModelRequest): 

276 async for item in self._map_user_message(message): 

277 yield item 

278 elif isinstance(message, ModelResponse): 

279 texts: list[str] = [] 

280 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] 

281 for item in message.parts: 

282 if isinstance(item, TextPart): 

283 texts.append(item.content) 

284 elif isinstance(item, ToolCallPart): 

285 tool_calls.append(self._map_tool_call(item)) 

286 else: 

287 assert_never(item) 

288 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') 

289 if texts: 

290 # Note: model responses from this model should only have one text item, so the following 

291 # shouldn't merge multiple texts into one unless you switch models between runs: 

292 message_param['content'] = '\n\n'.join(texts) 

293 if tool_calls: 

294 message_param['tool_calls'] = tool_calls 

295 yield message_param 

296 else: 

297 assert_never(message) 

298 

299 @staticmethod 

300 def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: 

301 return chat.ChatCompletionMessageToolCallParam( 

302 id=_guard_tool_call_id(t=t), 

303 type='function', 

304 function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, 

305 ) 

306 

307 @staticmethod 

308 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: 

309 return { 

310 'type': 'function', 

311 'function': { 

312 'name': f.name, 

313 'description': f.description, 

314 'parameters': f.parameters_json_schema, 

315 }, 

316 } 

317 

318 async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]: 

319 for part in message.parts: 

320 if isinstance(part, SystemPromptPart): 

321 if self.system_prompt_role == 'developer': 

322 yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content) 

323 elif self.system_prompt_role == 'user': 

324 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) 

325 else: 

326 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) 

327 elif isinstance(part, UserPromptPart): 

328 yield await self._map_user_prompt(part) 

329 elif isinstance(part, ToolReturnPart): 

330 yield chat.ChatCompletionToolMessageParam( 

331 role='tool', 

332 tool_call_id=_guard_tool_call_id(t=part), 

333 content=part.model_response_str(), 

334 ) 

335 elif isinstance(part, RetryPromptPart): 

336 if part.tool_name is None: 336 ↛ 337line 336 didn't jump to line 337 because the condition on line 336 was never true

337 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) 

338 else: 

339 yield chat.ChatCompletionToolMessageParam( 

340 role='tool', 

341 tool_call_id=_guard_tool_call_id(t=part), 

342 content=part.model_response(), 

343 ) 

344 else: 

345 assert_never(part) 

346 

347 @staticmethod 

348 async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam: 

349 content: str | list[ChatCompletionContentPartParam] 

350 if isinstance(part.content, str): 

351 content = part.content 

352 else: 

353 content = [] 

354 for item in part.content: 

355 if isinstance(item, str): 

356 content.append(ChatCompletionContentPartTextParam(text=item, type='text')) 

357 elif isinstance(item, ImageUrl): 

358 image_url = ImageURL(url=item.url) 

359 content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) 

360 elif isinstance(item, BinaryContent): 

361 base64_encoded = base64.b64encode(item.data).decode('utf-8') 

362 if item.is_image: 

363 image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') 

364 content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) 

365 elif item.is_audio: 

366 assert item.format in ('wav', 'mp3') 

367 audio = InputAudio(data=base64_encoded, format=item.format) 

368 content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) 

369 else: # pragma: no cover 

370 raise RuntimeError(f'Unsupported binary content type: {item.media_type}') 

371 elif isinstance(item, AudioUrl): # pragma: no cover 

372 client = cached_async_http_client() 

373 response = await client.get(item.url) 

374 response.raise_for_status() 

375 base64_encoded = base64.b64encode(response.content).decode('utf-8') 

376 audio = InputAudio(data=base64_encoded, format=response.headers.get('content-type')) 

377 content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) 

378 elif isinstance(item, DocumentUrl): # pragma: no cover 

379 raise NotImplementedError('DocumentUrl is not supported for OpenAI') 

380 # The following implementation should have worked, but it seems we have the following error: 

381 # pydantic_ai.exceptions.ModelHTTPError: status_code: 400, model_name: gpt-4o, body: 

382 # { 

383 # 'message': "Unknown parameter: 'messages[1].content[1].file.data'.", 

384 # 'type': 'invalid_request_error', 

385 # 'param': 'messages[1].content[1].file.data', 

386 # 'code': 'unknown_parameter' 

387 # } 

388 # 

389 # client = cached_async_http_client() 

390 # response = await client.get(item.url) 

391 # response.raise_for_status() 

392 # base64_encoded = base64.b64encode(response.content).decode('utf-8') 

393 # media_type = response.headers.get('content-type').split(';')[0] 

394 # file_data = f'data:{media_type};base64,{base64_encoded}' 

395 # file = File(file={'file_data': file_data, 'file_name': item.url, 'file_id': item.url}, type='file') 

396 # content.append(file) 

397 else: 

398 assert_never(item) 

399 return chat.ChatCompletionUserMessageParam(role='user', content=content) 

400 

401 

402@dataclass 

403class OpenAIStreamedResponse(StreamedResponse): 

404 """Implementation of `StreamedResponse` for OpenAI models.""" 

405 

406 _model_name: OpenAIModelName 

407 _response: AsyncIterable[ChatCompletionChunk] 

408 _timestamp: datetime 

409 

410 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

411 async for chunk in self._response: 

412 self._usage += _map_usage(chunk) 

413 

414 try: 

415 choice = chunk.choices[0] 

416 except IndexError: 

417 continue 

418 

419 # Handle the text part of the response 

420 content = choice.delta.content 

421 if content is not None: 

422 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) 

423 

424 for dtc in choice.delta.tool_calls or []: 

425 maybe_event = self._parts_manager.handle_tool_call_delta( 

426 vendor_part_id=dtc.index, 

427 tool_name=dtc.function and dtc.function.name, 

428 args=dtc.function and dtc.function.arguments, 

429 tool_call_id=dtc.id, 

430 ) 

431 if maybe_event is not None: 

432 yield maybe_event 

433 

434 @property 

435 def model_name(self) -> OpenAIModelName: 

436 """Get the model name of the response.""" 

437 return self._model_name 

438 

439 @property 

440 def timestamp(self) -> datetime: 

441 """Get the timestamp of the response.""" 

442 return self._timestamp 

443 

444 

445def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage: 

446 response_usage = response.usage 

447 if response_usage is None: 

448 return usage.Usage() 

449 else: 

450 details: dict[str, int] = {} 

451 if response_usage.completion_tokens_details is not None: 

452 details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) 

453 if response_usage.prompt_tokens_details is not None: 

454 details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True)) 

455 return usage.Usage( 

456 request_tokens=response_usage.prompt_tokens, 

457 response_tokens=response_usage.completion_tokens, 

458 total_tokens=response_usage.total_tokens, 

459 details=details, 

460 )