Coverage for pydantic_ai_slim/pydantic_ai/models/openai.py: 98.12%
194 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import base64
4from collections.abc import AsyncIterable, AsyncIterator
5from contextlib import asynccontextmanager
6from dataclasses import dataclass, field
7from datetime import datetime, timezone
8from typing import Literal, Union, cast, overload
10from typing_extensions import assert_never
12from pydantic_ai.providers import Provider, infer_provider
14from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15from .._utils import guard_tool_call_id as _guard_tool_call_id
16from ..messages import (
17 AudioUrl,
18 BinaryContent,
19 DocumentUrl,
20 ImageUrl,
21 ModelMessage,
22 ModelRequest,
23 ModelResponse,
24 ModelResponsePart,
25 ModelResponseStreamEvent,
26 RetryPromptPart,
27 SystemPromptPart,
28 TextPart,
29 ToolCallPart,
30 ToolReturnPart,
31 UserPromptPart,
32)
33from ..settings import ModelSettings
34from ..tools import ToolDefinition
35from . import (
36 Model,
37 ModelRequestParameters,
38 StreamedResponse,
39 cached_async_http_client,
40 check_allow_model_requests,
41)
43try:
44 from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
45 from openai.types import ChatModel, chat
46 from openai.types.chat import (
47 ChatCompletionChunk,
48 ChatCompletionContentPartImageParam,
49 ChatCompletionContentPartInputAudioParam,
50 ChatCompletionContentPartParam,
51 ChatCompletionContentPartTextParam,
52 )
53 from openai.types.chat.chat_completion_content_part_image_param import ImageURL
54 from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
55except ImportError as _import_error:
56 raise ImportError(
57 'Please install `openai` to use the OpenAI model, '
58 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
59 ) from _import_error
61OpenAIModelName = Union[str, ChatModel]
62"""
63Possible OpenAI model names.
65Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but
66allow any name in the type hints.
67See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list.
69Using this more broad type for the model name instead of the ChatModel definition
70allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
71"""
73OpenAISystemPromptRole = Literal['system', 'developer', 'user']
76class OpenAIModelSettings(ModelSettings, total=False):
77 """Settings used for an OpenAI model request.
79 ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
80 """
82 openai_reasoning_effort: chat.ChatCompletionReasoningEffort
83 """
84 Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
85 Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
86 result in faster responses and fewer tokens used on reasoning in a response.
87 """
89 openai_user: str
90 """A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
92 See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
93 """
96@dataclass(init=False)
97class OpenAIModel(Model):
98 """A model that uses the OpenAI API.
100 Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API.
102 Apart from `__init__`, all methods are private or match those of the base class.
103 """
105 client: AsyncOpenAI = field(repr=False)
106 system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
108 _model_name: OpenAIModelName = field(repr=False)
109 _system: str = field(default='openai', repr=False)
111 def __init__(
112 self,
113 model_name: OpenAIModelName,
114 *,
115 provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
116 system_prompt_role: OpenAISystemPromptRole | None = None,
117 ):
118 """Initialize an OpenAI model.
120 Args:
121 model_name: The name of the OpenAI model to use. List of model names available
122 [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
123 (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
124 provider: The provider to use. Defaults to `'openai'`.
125 system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
126 In the future, this may be inferred from the model name.
127 """
128 self._model_name = model_name
129 if isinstance(provider, str):
130 provider = infer_provider(provider)
131 self.client = provider.client
132 self.system_prompt_role = system_prompt_role
134 @property
135 def base_url(self) -> str:
136 return str(self.client.base_url)
138 async def request(
139 self,
140 messages: list[ModelMessage],
141 model_settings: ModelSettings | None,
142 model_request_parameters: ModelRequestParameters,
143 ) -> tuple[ModelResponse, usage.Usage]:
144 check_allow_model_requests()
145 response = await self._completions_create(
146 messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
147 )
148 return self._process_response(response), _map_usage(response)
150 @asynccontextmanager
151 async def request_stream(
152 self,
153 messages: list[ModelMessage],
154 model_settings: ModelSettings | None,
155 model_request_parameters: ModelRequestParameters,
156 ) -> AsyncIterator[StreamedResponse]:
157 check_allow_model_requests()
158 response = await self._completions_create(
159 messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
160 )
161 async with response:
162 yield await self._process_streamed_response(response)
164 @property
165 def model_name(self) -> OpenAIModelName:
166 """The model name."""
167 return self._model_name
169 @property
170 def system(self) -> str:
171 """The system / model provider."""
172 return self._system
174 @overload
175 async def _completions_create(
176 self,
177 messages: list[ModelMessage],
178 stream: Literal[True],
179 model_settings: OpenAIModelSettings,
180 model_request_parameters: ModelRequestParameters,
181 ) -> AsyncStream[ChatCompletionChunk]:
182 pass
184 @overload
185 async def _completions_create(
186 self,
187 messages: list[ModelMessage],
188 stream: Literal[False],
189 model_settings: OpenAIModelSettings,
190 model_request_parameters: ModelRequestParameters,
191 ) -> chat.ChatCompletion:
192 pass
194 async def _completions_create(
195 self,
196 messages: list[ModelMessage],
197 stream: bool,
198 model_settings: OpenAIModelSettings,
199 model_request_parameters: ModelRequestParameters,
200 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
201 tools = self._get_tools(model_request_parameters)
203 # standalone function to make it easier to override
204 if not tools:
205 tool_choice: Literal['none', 'required', 'auto'] | None = None
206 elif not model_request_parameters.allow_text_result:
207 tool_choice = 'required'
208 else:
209 tool_choice = 'auto'
211 openai_messages: list[chat.ChatCompletionMessageParam] = []
212 for m in messages:
213 async for msg in self._map_message(m):
214 openai_messages.append(msg)
216 try:
217 return await self.client.chat.completions.create(
218 model=self._model_name,
219 messages=openai_messages,
220 n=1,
221 parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
222 tools=tools or NOT_GIVEN,
223 tool_choice=tool_choice or NOT_GIVEN,
224 stream=stream,
225 stream_options={'include_usage': True} if stream else NOT_GIVEN,
226 max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
227 temperature=model_settings.get('temperature', NOT_GIVEN),
228 top_p=model_settings.get('top_p', NOT_GIVEN),
229 timeout=model_settings.get('timeout', NOT_GIVEN),
230 seed=model_settings.get('seed', NOT_GIVEN),
231 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
232 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
233 logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
234 reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
235 user=model_settings.get('openai_user', NOT_GIVEN),
236 )
237 except APIStatusError as e:
238 if (status_code := e.status_code) >= 400: 238 ↛ 240line 238 didn't jump to line 240 because the condition on line 238 was always true
239 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
240 raise
242 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
243 """Process a non-streamed response, and prepare a message to return."""
244 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
245 choice = response.choices[0]
246 items: list[ModelResponsePart] = []
247 if choice.message.content is not None:
248 items.append(TextPart(choice.message.content))
249 if choice.message.tool_calls is not None:
250 for c in choice.message.tool_calls:
251 items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
252 return ModelResponse(items, model_name=response.model, timestamp=timestamp)
254 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
255 """Process a streamed response, and prepare a streaming response to return."""
256 peekable_response = _utils.PeekableAsyncStream(response)
257 first_chunk = await peekable_response.peek()
258 if isinstance(first_chunk, _utils.Unset): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true
259 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
261 return OpenAIStreamedResponse(
262 _model_name=self._model_name,
263 _response=peekable_response,
264 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
265 )
267 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
268 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
269 if model_request_parameters.result_tools:
270 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
271 return tools
273 async def _map_message(self, message: ModelMessage) -> AsyncIterable[chat.ChatCompletionMessageParam]:
274 """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
275 if isinstance(message, ModelRequest):
276 async for item in self._map_user_message(message):
277 yield item
278 elif isinstance(message, ModelResponse):
279 texts: list[str] = []
280 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
281 for item in message.parts:
282 if isinstance(item, TextPart):
283 texts.append(item.content)
284 elif isinstance(item, ToolCallPart):
285 tool_calls.append(self._map_tool_call(item))
286 else:
287 assert_never(item)
288 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
289 if texts:
290 # Note: model responses from this model should only have one text item, so the following
291 # shouldn't merge multiple texts into one unless you switch models between runs:
292 message_param['content'] = '\n\n'.join(texts)
293 if tool_calls:
294 message_param['tool_calls'] = tool_calls
295 yield message_param
296 else:
297 assert_never(message)
299 @staticmethod
300 def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
301 return chat.ChatCompletionMessageToolCallParam(
302 id=_guard_tool_call_id(t=t),
303 type='function',
304 function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
305 )
307 @staticmethod
308 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
309 return {
310 'type': 'function',
311 'function': {
312 'name': f.name,
313 'description': f.description,
314 'parameters': f.parameters_json_schema,
315 },
316 }
318 async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
319 for part in message.parts:
320 if isinstance(part, SystemPromptPart):
321 if self.system_prompt_role == 'developer':
322 yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
323 elif self.system_prompt_role == 'user':
324 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
325 else:
326 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
327 elif isinstance(part, UserPromptPart):
328 yield await self._map_user_prompt(part)
329 elif isinstance(part, ToolReturnPart):
330 yield chat.ChatCompletionToolMessageParam(
331 role='tool',
332 tool_call_id=_guard_tool_call_id(t=part),
333 content=part.model_response_str(),
334 )
335 elif isinstance(part, RetryPromptPart):
336 if part.tool_name is None: 336 ↛ 337line 336 didn't jump to line 337 because the condition on line 336 was never true
337 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
338 else:
339 yield chat.ChatCompletionToolMessageParam(
340 role='tool',
341 tool_call_id=_guard_tool_call_id(t=part),
342 content=part.model_response(),
343 )
344 else:
345 assert_never(part)
347 @staticmethod
348 async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
349 content: str | list[ChatCompletionContentPartParam]
350 if isinstance(part.content, str):
351 content = part.content
352 else:
353 content = []
354 for item in part.content:
355 if isinstance(item, str):
356 content.append(ChatCompletionContentPartTextParam(text=item, type='text'))
357 elif isinstance(item, ImageUrl):
358 image_url = ImageURL(url=item.url)
359 content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
360 elif isinstance(item, BinaryContent):
361 base64_encoded = base64.b64encode(item.data).decode('utf-8')
362 if item.is_image:
363 image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
364 content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
365 elif item.is_audio:
366 assert item.format in ('wav', 'mp3')
367 audio = InputAudio(data=base64_encoded, format=item.format)
368 content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
369 else: # pragma: no cover
370 raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
371 elif isinstance(item, AudioUrl): # pragma: no cover
372 client = cached_async_http_client()
373 response = await client.get(item.url)
374 response.raise_for_status()
375 base64_encoded = base64.b64encode(response.content).decode('utf-8')
376 audio = InputAudio(data=base64_encoded, format=response.headers.get('content-type'))
377 content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
378 elif isinstance(item, DocumentUrl): # pragma: no cover
379 raise NotImplementedError('DocumentUrl is not supported for OpenAI')
380 # The following implementation should have worked, but it seems we have the following error:
381 # pydantic_ai.exceptions.ModelHTTPError: status_code: 400, model_name: gpt-4o, body:
382 # {
383 # 'message': "Unknown parameter: 'messages[1].content[1].file.data'.",
384 # 'type': 'invalid_request_error',
385 # 'param': 'messages[1].content[1].file.data',
386 # 'code': 'unknown_parameter'
387 # }
388 #
389 # client = cached_async_http_client()
390 # response = await client.get(item.url)
391 # response.raise_for_status()
392 # base64_encoded = base64.b64encode(response.content).decode('utf-8')
393 # media_type = response.headers.get('content-type').split(';')[0]
394 # file_data = f'data:{media_type};base64,{base64_encoded}'
395 # file = File(file={'file_data': file_data, 'file_name': item.url, 'file_id': item.url}, type='file')
396 # content.append(file)
397 else:
398 assert_never(item)
399 return chat.ChatCompletionUserMessageParam(role='user', content=content)
402@dataclass
403class OpenAIStreamedResponse(StreamedResponse):
404 """Implementation of `StreamedResponse` for OpenAI models."""
406 _model_name: OpenAIModelName
407 _response: AsyncIterable[ChatCompletionChunk]
408 _timestamp: datetime
410 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
411 async for chunk in self._response:
412 self._usage += _map_usage(chunk)
414 try:
415 choice = chunk.choices[0]
416 except IndexError:
417 continue
419 # Handle the text part of the response
420 content = choice.delta.content
421 if content is not None:
422 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
424 for dtc in choice.delta.tool_calls or []:
425 maybe_event = self._parts_manager.handle_tool_call_delta(
426 vendor_part_id=dtc.index,
427 tool_name=dtc.function and dtc.function.name,
428 args=dtc.function and dtc.function.arguments,
429 tool_call_id=dtc.id,
430 )
431 if maybe_event is not None:
432 yield maybe_event
434 @property
435 def model_name(self) -> OpenAIModelName:
436 """Get the model name of the response."""
437 return self._model_name
439 @property
440 def timestamp(self) -> datetime:
441 """Get the timestamp of the response."""
442 return self._timestamp
445def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
446 response_usage = response.usage
447 if response_usage is None:
448 return usage.Usage()
449 else:
450 details: dict[str, int] = {}
451 if response_usage.completion_tokens_details is not None:
452 details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
453 if response_usage.prompt_tokens_details is not None:
454 details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
455 return usage.Usage(
456 request_tokens=response_usage.prompt_tokens,
457 response_tokens=response_usage.completion_tokens,
458 total_tokens=response_usage.total_tokens,
459 details=details,
460 )