Coverage for pydantic_ai_slim/pydantic_ai/models/openai.py: 96.23%
156 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3from collections.abc import AsyncIterable, AsyncIterator, Iterable
4from contextlib import asynccontextmanager
5from dataclasses import dataclass, field
6from datetime import datetime, timezone
7from itertools import chain
8from typing import Literal, Union, cast, overload
10from httpx import AsyncClient as AsyncHTTPClient
11from typing_extensions import assert_never
13from .. import UnexpectedModelBehavior, _utils, usage
14from .._utils import guard_tool_call_id as _guard_tool_call_id
15from ..messages import (
16 ModelMessage,
17 ModelRequest,
18 ModelResponse,
19 ModelResponsePart,
20 ModelResponseStreamEvent,
21 RetryPromptPart,
22 SystemPromptPart,
23 TextPart,
24 ToolCallPart,
25 ToolReturnPart,
26 UserPromptPart,
27)
28from ..settings import ModelSettings
29from ..tools import ToolDefinition
30from . import (
31 AgentModel,
32 Model,
33 StreamedResponse,
34 cached_async_http_client,
35 check_allow_model_requests,
36)
38try:
39 from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
40 from openai.types import ChatModel, chat
41 from openai.types.chat import ChatCompletionChunk
42except ImportError as _import_error:
43 raise ImportError(
44 'Please install `openai` to use the OpenAI model, '
45 "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
46 ) from _import_error
48OpenAIModelName = Union[ChatModel, str]
49"""
50Using this more broad type for the model name instead of the ChatModel definition
51allows this model to be used more easily with other model types (ie, Ollama)
52"""
54OpenAISystemPromptRole = Literal['system', 'developer', 'user']
57class OpenAIModelSettings(ModelSettings):
58 """Settings used for an OpenAI model request."""
60 # This class is a placeholder for any future openai-specific settings
63@dataclass(init=False)
64class OpenAIModel(Model):
65 """A model that uses the OpenAI API.
67 Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API.
69 Apart from `__init__`, all methods are private or match those of the base class.
70 """
72 model_name: OpenAIModelName
73 client: AsyncOpenAI = field(repr=False)
74 system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
76 def __init__(
77 self,
78 model_name: OpenAIModelName,
79 *,
80 base_url: str | None = None,
81 api_key: str | None = None,
82 openai_client: AsyncOpenAI | None = None,
83 http_client: AsyncHTTPClient | None = None,
84 system_prompt_role: OpenAISystemPromptRole | None = None,
85 ):
86 """Initialize an OpenAI model.
88 Args:
89 model_name: The name of the OpenAI model to use. List of model names available
90 [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
91 (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
92 base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
93 will be used if available. Otherwise, defaults to OpenAI's base url.
94 api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
95 will be used if available.
96 openai_client: An existing
97 [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
98 client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
99 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
100 system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
101 In the future, this may be inferred from the model name.
102 """
103 self.model_name: OpenAIModelName = model_name
104 if openai_client is not None:
105 assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
106 assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
107 assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
108 self.client = openai_client
109 elif http_client is not None: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
111 else:
112 self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
113 self.system_prompt_role = system_prompt_role
115 async def agent_model(
116 self,
117 *,
118 function_tools: list[ToolDefinition],
119 allow_text_result: bool,
120 result_tools: list[ToolDefinition],
121 ) -> AgentModel:
122 check_allow_model_requests()
123 tools = [self._map_tool_definition(r) for r in function_tools]
124 if result_tools:
125 tools += [self._map_tool_definition(r) for r in result_tools]
126 return OpenAIAgentModel(
127 self.client,
128 self.model_name,
129 allow_text_result,
130 tools,
131 self.system_prompt_role,
132 )
134 def name(self) -> str:
135 return f'openai:{self.model_name}'
137 @staticmethod
138 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
139 return {
140 'type': 'function',
141 'function': {
142 'name': f.name,
143 'description': f.description,
144 'parameters': f.parameters_json_schema,
145 },
146 }
149@dataclass
150class OpenAIAgentModel(AgentModel):
151 """Implementation of `AgentModel` for OpenAI models."""
153 client: AsyncOpenAI
154 model_name: OpenAIModelName
155 allow_text_result: bool
156 tools: list[chat.ChatCompletionToolParam]
157 system_prompt_role: OpenAISystemPromptRole | None
159 async def request(
160 self, messages: list[ModelMessage], model_settings: ModelSettings | None
161 ) -> tuple[ModelResponse, usage.Usage]:
162 response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
163 return self._process_response(response), _map_usage(response)
165 @asynccontextmanager
166 async def request_stream(
167 self, messages: list[ModelMessage], model_settings: ModelSettings | None
168 ) -> AsyncIterator[StreamedResponse]:
169 response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
170 async with response:
171 yield await self._process_streamed_response(response)
173 @overload
174 async def _completions_create(
175 self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
176 ) -> AsyncStream[ChatCompletionChunk]:
177 pass
179 @overload
180 async def _completions_create(
181 self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
182 ) -> chat.ChatCompletion:
183 pass
185 async def _completions_create(
186 self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
187 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
188 # standalone function to make it easier to override
189 if not self.tools:
190 tool_choice: Literal['none', 'required', 'auto'] | None = None
191 elif not self.allow_text_result:
192 tool_choice = 'required'
193 else:
194 tool_choice = 'auto'
196 openai_messages = list(chain(*(self._map_message(m) for m in messages)))
198 return await self.client.chat.completions.create(
199 model=self.model_name,
200 messages=openai_messages,
201 n=1,
202 parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
203 tools=self.tools or NOT_GIVEN,
204 tool_choice=tool_choice or NOT_GIVEN,
205 stream=stream,
206 stream_options={'include_usage': True} if stream else NOT_GIVEN,
207 max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
208 temperature=model_settings.get('temperature', NOT_GIVEN),
209 top_p=model_settings.get('top_p', NOT_GIVEN),
210 timeout=model_settings.get('timeout', NOT_GIVEN),
211 seed=model_settings.get('seed', NOT_GIVEN),
212 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
213 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
214 logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
215 )
217 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
218 """Process a non-streamed response, and prepare a message to return."""
219 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
220 choice = response.choices[0]
221 items: list[ModelResponsePart] = []
222 if choice.message.content is not None:
223 items.append(TextPart(choice.message.content))
224 if choice.message.tool_calls is not None:
225 for c in choice.message.tool_calls:
226 items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
227 return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
229 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
230 """Process a streamed response, and prepare a streaming response to return."""
231 peekable_response = _utils.PeekableAsyncStream(response)
232 first_chunk = await peekable_response.peek()
233 if isinstance(first_chunk, _utils.Unset): 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true
234 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
236 return OpenAIStreamedResponse(
237 _model_name=self.model_name,
238 _response=peekable_response,
239 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
240 )
242 def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
243 """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
244 if isinstance(message, ModelRequest):
245 yield from self._map_user_message(message)
246 elif isinstance(message, ModelResponse):
247 texts: list[str] = []
248 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
249 for item in message.parts:
250 if isinstance(item, TextPart):
251 texts.append(item.content)
252 elif isinstance(item, ToolCallPart):
253 tool_calls.append(_map_tool_call(item))
254 else:
255 assert_never(item)
256 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
257 if texts:
258 # Note: model responses from this model should only have one text item, so the following
259 # shouldn't merge multiple texts into one unless you switch models between runs:
260 message_param['content'] = '\n\n'.join(texts)
261 if tool_calls:
262 message_param['tool_calls'] = tool_calls
263 yield message_param
264 else:
265 assert_never(message)
267 def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
268 for part in message.parts:
269 if isinstance(part, SystemPromptPart):
270 if self.system_prompt_role == 'developer':
271 yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
272 elif self.system_prompt_role == 'user':
273 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
274 else:
275 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
276 elif isinstance(part, UserPromptPart):
277 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
278 elif isinstance(part, ToolReturnPart):
279 yield chat.ChatCompletionToolMessageParam(
280 role='tool',
281 tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
282 content=part.model_response_str(),
283 )
284 elif isinstance(part, RetryPromptPart):
285 if part.tool_name is None: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true
286 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
287 else:
288 yield chat.ChatCompletionToolMessageParam(
289 role='tool',
290 tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
291 content=part.model_response(),
292 )
293 else:
294 assert_never(part)
297@dataclass
298class OpenAIStreamedResponse(StreamedResponse):
299 """Implementation of `StreamedResponse` for OpenAI models."""
301 _response: AsyncIterable[ChatCompletionChunk]
302 _timestamp: datetime
304 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
305 async for chunk in self._response:
306 self._usage += _map_usage(chunk)
308 try:
309 choice = chunk.choices[0]
310 except IndexError:
311 continue
313 # Handle the text part of the response
314 content = choice.delta.content
315 if content is not None:
316 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
318 for dtc in choice.delta.tool_calls or []:
319 maybe_event = self._parts_manager.handle_tool_call_delta(
320 vendor_part_id=dtc.index,
321 tool_name=dtc.function and dtc.function.name,
322 args=dtc.function and dtc.function.arguments,
323 tool_call_id=dtc.id,
324 )
325 if maybe_event is not None:
326 yield maybe_event
328 def timestamp(self) -> datetime:
329 return self._timestamp
332def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
333 return chat.ChatCompletionMessageToolCallParam(
334 id=_guard_tool_call_id(t=t, model_source='OpenAI'),
335 type='function',
336 function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
337 )
340def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
341 response_usage = response.usage
342 if response_usage is None:
343 return usage.Usage()
344 else:
345 details: dict[str, int] = {}
346 if response_usage.completion_tokens_details is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true
347 details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
348 if response_usage.prompt_tokens_details is not None:
349 details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
350 return usage.Usage(
351 request_tokens=response_usage.prompt_tokens,
352 response_tokens=response_usage.completion_tokens,
353 total_tokens=response_usage.total_tokens,
354 details=details,
355 )