Coverage for pydantic_ai_slim/pydantic_ai/models/groq.py: 96.68%
177 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import base64
4from collections.abc import AsyncIterable, AsyncIterator, Iterable
5from contextlib import asynccontextmanager
6from dataclasses import dataclass, field
7from datetime import datetime, timezone
8from itertools import chain
9from typing import Literal, Union, cast, overload
11from typing_extensions import assert_never
13from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
14from .._utils import guard_tool_call_id as _guard_tool_call_id
15from ..messages import (
16 BinaryContent,
17 DocumentUrl,
18 ImageUrl,
19 ModelMessage,
20 ModelRequest,
21 ModelResponse,
22 ModelResponsePart,
23 ModelResponseStreamEvent,
24 RetryPromptPart,
25 SystemPromptPart,
26 TextPart,
27 ToolCallPart,
28 ToolReturnPart,
29 UserPromptPart,
30)
31from ..providers import Provider, infer_provider
32from ..settings import ModelSettings
33from ..tools import ToolDefinition
34from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
36try:
37 from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
38 from groq.types import chat
39 from groq.types.chat.chat_completion_content_part_image_param import ImageURL
40except ImportError as _import_error:
41 raise ImportError(
42 'Please install `groq` to use the Groq model, '
43 'you can use the `groq` optional group — `pip install "pydantic-ai-slim[groq]"`'
44 ) from _import_error
47LatestGroqModelNames = Literal[
48 'llama-3.3-70b-versatile',
49 'llama-3.3-70b-specdec',
50 'llama-3.1-8b-instant',
51 'llama-3.2-1b-preview',
52 'llama-3.2-3b-preview',
53 'llama-3.2-11b-vision-preview',
54 'llama-3.2-90b-vision-preview',
55 'llama3-70b-8192',
56 'llama3-8b-8192',
57 'mixtral-8x7b-32768',
58 'gemma2-9b-it',
59]
60"""Latest Groq models."""
62GroqModelName = Union[str, LatestGroqModelNames]
63"""
64Possible Groq model names.
66Since Groq supports a variety of date-stamped models, we explicitly list the latest models but
67allow any name in the type hints.
68See [the Groq docs](https://console.groq.com/docs/models) for a full list.
69"""
72class GroqModelSettings(ModelSettings):
73 """Settings used for a Groq model request.
75 ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
76 """
78 # This class is a placeholder for any future groq-specific settings
81@dataclass(init=False)
82class GroqModel(Model):
83 """A model that uses the Groq API.
85 Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API.
87 Apart from `__init__`, all methods are private or match those of the base class.
88 """
90 client: AsyncGroq = field(repr=False)
92 _model_name: GroqModelName = field(repr=False)
93 _system: str = field(default='groq', repr=False)
95 def __init__(self, model_name: GroqModelName, *, provider: Literal['groq'] | Provider[AsyncGroq] = 'groq'):
96 """Initialize a Groq model.
98 Args:
99 model_name: The name of the Groq model to use. List of model names available
100 [here](https://console.groq.com/docs/models).
101 provider: The provider to use for authentication and API access. Can be either the string
102 'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
103 created using the other parameters.
104 """
105 self._model_name = model_name
107 if isinstance(provider, str):
108 provider = infer_provider(provider)
109 self.client = provider.client
111 @property
112 def base_url(self) -> str:
113 return str(self.client.base_url)
115 async def request(
116 self,
117 messages: list[ModelMessage],
118 model_settings: ModelSettings | None,
119 model_request_parameters: ModelRequestParameters,
120 ) -> tuple[ModelResponse, usage.Usage]:
121 check_allow_model_requests()
122 response = await self._completions_create(
123 messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
124 )
125 return self._process_response(response), _map_usage(response)
127 @asynccontextmanager
128 async def request_stream(
129 self,
130 messages: list[ModelMessage],
131 model_settings: ModelSettings | None,
132 model_request_parameters: ModelRequestParameters,
133 ) -> AsyncIterator[StreamedResponse]:
134 check_allow_model_requests()
135 response = await self._completions_create(
136 messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
137 )
138 async with response:
139 yield await self._process_streamed_response(response)
141 @property
142 def model_name(self) -> GroqModelName:
143 """The model name."""
144 return self._model_name
146 @property
147 def system(self) -> str:
148 """The system / model provider."""
149 return self._system
151 @overload
152 async def _completions_create(
153 self,
154 messages: list[ModelMessage],
155 stream: Literal[True],
156 model_settings: GroqModelSettings,
157 model_request_parameters: ModelRequestParameters,
158 ) -> AsyncStream[chat.ChatCompletionChunk]:
159 pass
161 @overload
162 async def _completions_create(
163 self,
164 messages: list[ModelMessage],
165 stream: Literal[False],
166 model_settings: GroqModelSettings,
167 model_request_parameters: ModelRequestParameters,
168 ) -> chat.ChatCompletion:
169 pass
171 async def _completions_create(
172 self,
173 messages: list[ModelMessage],
174 stream: bool,
175 model_settings: GroqModelSettings,
176 model_request_parameters: ModelRequestParameters,
177 ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
178 tools = self._get_tools(model_request_parameters)
179 # standalone function to make it easier to override
180 if not tools:
181 tool_choice: Literal['none', 'required', 'auto'] | None = None
182 elif not model_request_parameters.allow_text_result:
183 tool_choice = 'required'
184 else:
185 tool_choice = 'auto'
187 groq_messages = list(chain(*(self._map_message(m) for m in messages)))
189 try:
190 return await self.client.chat.completions.create(
191 model=str(self._model_name),
192 messages=groq_messages,
193 n=1,
194 parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
195 tools=tools or NOT_GIVEN,
196 tool_choice=tool_choice or NOT_GIVEN,
197 stream=stream,
198 max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
199 temperature=model_settings.get('temperature', NOT_GIVEN),
200 top_p=model_settings.get('top_p', NOT_GIVEN),
201 timeout=model_settings.get('timeout', NOT_GIVEN),
202 seed=model_settings.get('seed', NOT_GIVEN),
203 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
204 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
205 logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
206 )
207 except APIStatusError as e:
208 if (status_code := e.status_code) >= 400: 208 ↛ 210line 208 didn't jump to line 210 because the condition on line 208 was always true
209 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
210 raise
212 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
213 """Process a non-streamed response, and prepare a message to return."""
214 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
215 choice = response.choices[0]
216 items: list[ModelResponsePart] = []
217 if choice.message.content is not None:
218 items.append(TextPart(content=choice.message.content))
219 if choice.message.tool_calls is not None:
220 for c in choice.message.tool_calls:
221 items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
222 return ModelResponse(items, model_name=response.model, timestamp=timestamp)
224 async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
225 """Process a streamed response, and prepare a streaming response to return."""
226 peekable_response = _utils.PeekableAsyncStream(response)
227 first_chunk = await peekable_response.peek()
228 if isinstance(first_chunk, _utils.Unset): 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true
229 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
231 return GroqStreamedResponse(
232 _response=peekable_response,
233 _model_name=self._model_name,
234 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
235 )
237 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
238 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
239 if model_request_parameters.result_tools:
240 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
241 return tools
243 def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
244 """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
245 if isinstance(message, ModelRequest):
246 yield from self._map_user_message(message)
247 elif isinstance(message, ModelResponse):
248 texts: list[str] = []
249 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
250 for item in message.parts:
251 if isinstance(item, TextPart):
252 texts.append(item.content)
253 elif isinstance(item, ToolCallPart):
254 tool_calls.append(self._map_tool_call(item))
255 else:
256 assert_never(item)
257 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
258 if texts:
259 # Note: model responses from this model should only have one text item, so the following
260 # shouldn't merge multiple texts into one unless you switch models between runs:
261 message_param['content'] = '\n\n'.join(texts)
262 if tool_calls:
263 message_param['tool_calls'] = tool_calls
264 yield message_param
265 else:
266 assert_never(message)
268 @staticmethod
269 def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
270 return chat.ChatCompletionMessageToolCallParam(
271 id=_guard_tool_call_id(t=t),
272 type='function',
273 function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
274 )
276 @staticmethod
277 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
278 return {
279 'type': 'function',
280 'function': {
281 'name': f.name,
282 'description': f.description,
283 'parameters': f.parameters_json_schema,
284 },
285 }
287 @classmethod
288 def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
289 for part in message.parts:
290 if isinstance(part, SystemPromptPart):
291 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
292 elif isinstance(part, UserPromptPart):
293 yield cls._map_user_prompt(part)
294 elif isinstance(part, ToolReturnPart):
295 yield chat.ChatCompletionToolMessageParam(
296 role='tool',
297 tool_call_id=_guard_tool_call_id(t=part),
298 content=part.model_response_str(),
299 )
300 elif isinstance(part, RetryPromptPart): 300 ↛ 289line 300 didn't jump to line 289 because the condition on line 300 was always true
301 if part.tool_name is None: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true
302 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
303 else:
304 yield chat.ChatCompletionToolMessageParam(
305 role='tool',
306 tool_call_id=_guard_tool_call_id(t=part),
307 content=part.model_response(),
308 )
310 @staticmethod
311 def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
312 content: str | list[chat.ChatCompletionContentPartParam]
313 if isinstance(part.content, str):
314 content = part.content
315 else:
316 content = []
317 for item in part.content:
318 if isinstance(item, str):
319 content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
320 elif isinstance(item, ImageUrl):
321 image_url = ImageURL(url=item.url)
322 content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
323 elif isinstance(item, BinaryContent):
324 base64_encoded = base64.b64encode(item.data).decode('utf-8')
325 if item.is_image:
326 image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
327 content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
328 else:
329 raise RuntimeError('Only images are supported for binary content in Groq.')
330 elif isinstance(item, DocumentUrl): # pragma: no cover
331 raise RuntimeError('DocumentUrl is not supported in Groq.')
332 else: # pragma: no cover
333 raise RuntimeError(f'Unsupported content type: {type(item)}')
335 return chat.ChatCompletionUserMessageParam(role='user', content=content)
338@dataclass
339class GroqStreamedResponse(StreamedResponse):
340 """Implementation of `StreamedResponse` for Groq models."""
342 _model_name: GroqModelName
343 _response: AsyncIterable[chat.ChatCompletionChunk]
344 _timestamp: datetime
346 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
347 async for chunk in self._response:
348 self._usage += _map_usage(chunk)
350 try:
351 choice = chunk.choices[0]
352 except IndexError:
353 continue
355 # Handle the text part of the response
356 content = choice.delta.content
357 if content is not None:
358 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
360 # Handle the tool calls
361 for dtc in choice.delta.tool_calls or []:
362 maybe_event = self._parts_manager.handle_tool_call_delta(
363 vendor_part_id=dtc.index,
364 tool_name=dtc.function and dtc.function.name,
365 args=dtc.function and dtc.function.arguments,
366 tool_call_id=dtc.id,
367 )
368 if maybe_event is not None:
369 yield maybe_event
371 @property
372 def model_name(self) -> GroqModelName:
373 """Get the model name of the response."""
374 return self._model_name
376 @property
377 def timestamp(self) -> datetime:
378 """Get the timestamp of the response."""
379 return self._timestamp
382def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
383 response_usage = None
384 if isinstance(completion, chat.ChatCompletion):
385 response_usage = completion.usage
386 elif completion.x_groq is not None: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true
387 response_usage = completion.x_groq.usage
389 if response_usage is None:
390 return usage.Usage()
392 return usage.Usage(
393 request_tokens=response_usage.prompt_tokens,
394 response_tokens=response_usage.completion_tokens,
395 total_tokens=response_usage.total_tokens,
396 )