Coverage for pydantic_ai_slim/pydantic_ai/models/openai.py: 96.23%

156 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import AsyncIterable, AsyncIterator, Iterable 

4from contextlib import asynccontextmanager 

5from dataclasses import dataclass, field 

6from datetime import datetime, timezone 

7from itertools import chain 

8from typing import Literal, Union, cast, overload 

9 

10from httpx import AsyncClient as AsyncHTTPClient 

11from typing_extensions import assert_never 

12 

13from .. import UnexpectedModelBehavior, _utils, usage 

14from .._utils import guard_tool_call_id as _guard_tool_call_id 

15from ..messages import ( 

16 ModelMessage, 

17 ModelRequest, 

18 ModelResponse, 

19 ModelResponsePart, 

20 ModelResponseStreamEvent, 

21 RetryPromptPart, 

22 SystemPromptPart, 

23 TextPart, 

24 ToolCallPart, 

25 ToolReturnPart, 

26 UserPromptPart, 

27) 

28from ..settings import ModelSettings 

29from ..tools import ToolDefinition 

30from . import ( 

31 AgentModel, 

32 Model, 

33 StreamedResponse, 

34 cached_async_http_client, 

35 check_allow_model_requests, 

36) 

37 

38try: 

39 from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream 

40 from openai.types import ChatModel, chat 

41 from openai.types.chat import ChatCompletionChunk 

42except ImportError as _import_error: 

43 raise ImportError( 

44 'Please install `openai` to use the OpenAI model, ' 

45 "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`" 

46 ) from _import_error 

47 

48OpenAIModelName = Union[ChatModel, str] 

49""" 

50Using this more broad type for the model name instead of the ChatModel definition 

51allows this model to be used more easily with other model types (ie, Ollama) 

52""" 

53 

54OpenAISystemPromptRole = Literal['system', 'developer', 'user'] 

55 

56 

57class OpenAIModelSettings(ModelSettings): 

58 """Settings used for an OpenAI model request.""" 

59 

60 # This class is a placeholder for any future openai-specific settings 

61 

62 

63@dataclass(init=False) 

64class OpenAIModel(Model): 

65 """A model that uses the OpenAI API. 

66 

67 Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API. 

68 

69 Apart from `__init__`, all methods are private or match those of the base class. 

70 """ 

71 

72 model_name: OpenAIModelName 

73 client: AsyncOpenAI = field(repr=False) 

74 system_prompt_role: OpenAISystemPromptRole | None = field(default=None) 

75 

76 def __init__( 

77 self, 

78 model_name: OpenAIModelName, 

79 *, 

80 base_url: str | None = None, 

81 api_key: str | None = None, 

82 openai_client: AsyncOpenAI | None = None, 

83 http_client: AsyncHTTPClient | None = None, 

84 system_prompt_role: OpenAISystemPromptRole | None = None, 

85 ): 

86 """Initialize an OpenAI model. 

87 

88 Args: 

89 model_name: The name of the OpenAI model to use. List of model names available 

90 [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7) 

91 (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API). 

92 base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable 

93 will be used if available. Otherwise, defaults to OpenAI's base url. 

94 api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable 

95 will be used if available. 

96 openai_client: An existing 

97 [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage) 

98 client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`. 

99 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

100 system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`. 

101 In the future, this may be inferred from the model name. 

102 """ 

103 self.model_name: OpenAIModelName = model_name 

104 if openai_client is not None: 

105 assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' 

106 assert base_url is None, 'Cannot provide both `openai_client` and `base_url`' 

107 assert api_key is None, 'Cannot provide both `openai_client` and `api_key`' 

108 self.client = openai_client 

109 elif http_client is not None: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) 

111 else: 

112 self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client()) 

113 self.system_prompt_role = system_prompt_role 

114 

115 async def agent_model( 

116 self, 

117 *, 

118 function_tools: list[ToolDefinition], 

119 allow_text_result: bool, 

120 result_tools: list[ToolDefinition], 

121 ) -> AgentModel: 

122 check_allow_model_requests() 

123 tools = [self._map_tool_definition(r) for r in function_tools] 

124 if result_tools: 

125 tools += [self._map_tool_definition(r) for r in result_tools] 

126 return OpenAIAgentModel( 

127 self.client, 

128 self.model_name, 

129 allow_text_result, 

130 tools, 

131 self.system_prompt_role, 

132 ) 

133 

134 def name(self) -> str: 

135 return f'openai:{self.model_name}' 

136 

137 @staticmethod 

138 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: 

139 return { 

140 'type': 'function', 

141 'function': { 

142 'name': f.name, 

143 'description': f.description, 

144 'parameters': f.parameters_json_schema, 

145 }, 

146 } 

147 

148 

149@dataclass 

150class OpenAIAgentModel(AgentModel): 

151 """Implementation of `AgentModel` for OpenAI models.""" 

152 

153 client: AsyncOpenAI 

154 model_name: OpenAIModelName 

155 allow_text_result: bool 

156 tools: list[chat.ChatCompletionToolParam] 

157 system_prompt_role: OpenAISystemPromptRole | None 

158 

159 async def request( 

160 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

161 ) -> tuple[ModelResponse, usage.Usage]: 

162 response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {})) 

163 return self._process_response(response), _map_usage(response) 

164 

165 @asynccontextmanager 

166 async def request_stream( 

167 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

168 ) -> AsyncIterator[StreamedResponse]: 

169 response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {})) 

170 async with response: 

171 yield await self._process_streamed_response(response) 

172 

173 @overload 

174 async def _completions_create( 

175 self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings 

176 ) -> AsyncStream[ChatCompletionChunk]: 

177 pass 

178 

179 @overload 

180 async def _completions_create( 

181 self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings 

182 ) -> chat.ChatCompletion: 

183 pass 

184 

185 async def _completions_create( 

186 self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings 

187 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: 

188 # standalone function to make it easier to override 

189 if not self.tools: 

190 tool_choice: Literal['none', 'required', 'auto'] | None = None 

191 elif not self.allow_text_result: 

192 tool_choice = 'required' 

193 else: 

194 tool_choice = 'auto' 

195 

196 openai_messages = list(chain(*(self._map_message(m) for m in messages))) 

197 

198 return await self.client.chat.completions.create( 

199 model=self.model_name, 

200 messages=openai_messages, 

201 n=1, 

202 parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN), 

203 tools=self.tools or NOT_GIVEN, 

204 tool_choice=tool_choice or NOT_GIVEN, 

205 stream=stream, 

206 stream_options={'include_usage': True} if stream else NOT_GIVEN, 

207 max_tokens=model_settings.get('max_tokens', NOT_GIVEN), 

208 temperature=model_settings.get('temperature', NOT_GIVEN), 

209 top_p=model_settings.get('top_p', NOT_GIVEN), 

210 timeout=model_settings.get('timeout', NOT_GIVEN), 

211 seed=model_settings.get('seed', NOT_GIVEN), 

212 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), 

213 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), 

214 logit_bias=model_settings.get('logit_bias', NOT_GIVEN), 

215 ) 

216 

217 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: 

218 """Process a non-streamed response, and prepare a message to return.""" 

219 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) 

220 choice = response.choices[0] 

221 items: list[ModelResponsePart] = [] 

222 if choice.message.content is not None: 

223 items.append(TextPart(choice.message.content)) 

224 if choice.message.tool_calls is not None: 

225 for c in choice.message.tool_calls: 

226 items.append(ToolCallPart(c.function.name, c.function.arguments, c.id)) 

227 return ModelResponse(items, model_name=self.model_name, timestamp=timestamp) 

228 

229 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: 

230 """Process a streamed response, and prepare a streaming response to return.""" 

231 peekable_response = _utils.PeekableAsyncStream(response) 

232 first_chunk = await peekable_response.peek() 

233 if isinstance(first_chunk, _utils.Unset): 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true

234 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

235 

236 return OpenAIStreamedResponse( 

237 _model_name=self.model_name, 

238 _response=peekable_response, 

239 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), 

240 ) 

241 

242 def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: 

243 """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" 

244 if isinstance(message, ModelRequest): 

245 yield from self._map_user_message(message) 

246 elif isinstance(message, ModelResponse): 

247 texts: list[str] = [] 

248 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] 

249 for item in message.parts: 

250 if isinstance(item, TextPart): 

251 texts.append(item.content) 

252 elif isinstance(item, ToolCallPart): 

253 tool_calls.append(_map_tool_call(item)) 

254 else: 

255 assert_never(item) 

256 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') 

257 if texts: 

258 # Note: model responses from this model should only have one text item, so the following 

259 # shouldn't merge multiple texts into one unless you switch models between runs: 

260 message_param['content'] = '\n\n'.join(texts) 

261 if tool_calls: 

262 message_param['tool_calls'] = tool_calls 

263 yield message_param 

264 else: 

265 assert_never(message) 

266 

267 def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]: 

268 for part in message.parts: 

269 if isinstance(part, SystemPromptPart): 

270 if self.system_prompt_role == 'developer': 

271 yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content) 

272 elif self.system_prompt_role == 'user': 

273 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) 

274 else: 

275 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) 

276 elif isinstance(part, UserPromptPart): 

277 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) 

278 elif isinstance(part, ToolReturnPart): 

279 yield chat.ChatCompletionToolMessageParam( 

280 role='tool', 

281 tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'), 

282 content=part.model_response_str(), 

283 ) 

284 elif isinstance(part, RetryPromptPart): 

285 if part.tool_name is None: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true

286 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) 

287 else: 

288 yield chat.ChatCompletionToolMessageParam( 

289 role='tool', 

290 tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'), 

291 content=part.model_response(), 

292 ) 

293 else: 

294 assert_never(part) 

295 

296 

297@dataclass 

298class OpenAIStreamedResponse(StreamedResponse): 

299 """Implementation of `StreamedResponse` for OpenAI models.""" 

300 

301 _response: AsyncIterable[ChatCompletionChunk] 

302 _timestamp: datetime 

303 

304 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

305 async for chunk in self._response: 

306 self._usage += _map_usage(chunk) 

307 

308 try: 

309 choice = chunk.choices[0] 

310 except IndexError: 

311 continue 

312 

313 # Handle the text part of the response 

314 content = choice.delta.content 

315 if content is not None: 

316 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) 

317 

318 for dtc in choice.delta.tool_calls or []: 

319 maybe_event = self._parts_manager.handle_tool_call_delta( 

320 vendor_part_id=dtc.index, 

321 tool_name=dtc.function and dtc.function.name, 

322 args=dtc.function and dtc.function.arguments, 

323 tool_call_id=dtc.id, 

324 ) 

325 if maybe_event is not None: 

326 yield maybe_event 

327 

328 def timestamp(self) -> datetime: 

329 return self._timestamp 

330 

331 

332def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: 

333 return chat.ChatCompletionMessageToolCallParam( 

334 id=_guard_tool_call_id(t=t, model_source='OpenAI'), 

335 type='function', 

336 function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, 

337 ) 

338 

339 

340def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage: 

341 response_usage = response.usage 

342 if response_usage is None: 

343 return usage.Usage() 

344 else: 

345 details: dict[str, int] = {} 

346 if response_usage.completion_tokens_details is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true

347 details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) 

348 if response_usage.prompt_tokens_details is not None: 

349 details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True)) 

350 return usage.Usage( 

351 request_tokens=response_usage.prompt_tokens, 

352 response_tokens=response_usage.completion_tokens, 

353 total_tokens=response_usage.total_tokens, 

354 details=details, 

355 )