Coverage for pydantic_ai_slim/pydantic_ai/models/groq.py: 95.54%
148 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
1from __future__ import annotations as _annotations
3from collections.abc import AsyncIterable, AsyncIterator, Iterable
4from contextlib import asynccontextmanager
5from dataclasses import dataclass, field
6from datetime import datetime, timezone
7from itertools import chain
8from typing import Literal, cast, overload
10from httpx import AsyncClient as AsyncHTTPClient
11from typing_extensions import assert_never
13from .. import UnexpectedModelBehavior, _utils, usage
14from .._utils import guard_tool_call_id as _guard_tool_call_id
15from ..messages import (
16 ModelMessage,
17 ModelRequest,
18 ModelResponse,
19 ModelResponsePart,
20 ModelResponseStreamEvent,
21 RetryPromptPart,
22 SystemPromptPart,
23 TextPart,
24 ToolCallPart,
25 ToolReturnPart,
26 UserPromptPart,
27)
28from ..settings import ModelSettings
29from ..tools import ToolDefinition
30from . import (
31 AgentModel,
32 Model,
33 StreamedResponse,
34 cached_async_http_client,
35 check_allow_model_requests,
36)
38try:
39 from groq import NOT_GIVEN, AsyncGroq, AsyncStream
40 from groq.types import chat
41 from groq.types.chat import ChatCompletion, ChatCompletionChunk
42except ImportError as _import_error:
43 raise ImportError(
44 'Please install `groq` to use the Groq model, '
45 "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
46 ) from _import_error
48GroqModelName = Literal[
49 'llama-3.3-70b-versatile',
50 'llama-3.3-70b-specdec',
51 'llama-3.1-8b-instant',
52 'llama-3.2-1b-preview',
53 'llama-3.2-3b-preview',
54 'llama-3.2-11b-vision-preview',
55 'llama-3.2-90b-vision-preview',
56 'llama3-70b-8192',
57 'llama3-8b-8192',
58 'mixtral-8x7b-32768',
59 'gemma2-9b-it',
60]
61"""Named Groq models.
63See [the Groq docs](https://console.groq.com/docs/models) for a full list.
64"""
67class GroqModelSettings(ModelSettings):
68 """Settings used for a Groq model request."""
70 # This class is a placeholder for any future groq-specific settings
73@dataclass(init=False)
74class GroqModel(Model):
75 """A model that uses the Groq API.
77 Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API.
79 Apart from `__init__`, all methods are private or match those of the base class.
80 """
82 model_name: GroqModelName
83 client: AsyncGroq = field(repr=False)
85 def __init__(
86 self,
87 model_name: GroqModelName,
88 *,
89 api_key: str | None = None,
90 groq_client: AsyncGroq | None = None,
91 http_client: AsyncHTTPClient | None = None,
92 ):
93 """Initialize a Groq model.
95 Args:
96 model_name: The name of the Groq model to use. List of model names available
97 [here](https://console.groq.com/docs/models).
98 api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable
99 will be used if available.
100 groq_client: An existing
101 [`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage)
102 client to use, if provided, `api_key` and `http_client` must be `None`.
103 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
104 """
105 self.model_name = model_name
106 if groq_client is not None:
107 assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
108 assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
109 self.client = groq_client
110 elif http_client is not None: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 self.client = AsyncGroq(api_key=api_key, http_client=http_client)
112 else:
113 self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
115 async def agent_model(
116 self,
117 *,
118 function_tools: list[ToolDefinition],
119 allow_text_result: bool,
120 result_tools: list[ToolDefinition],
121 ) -> AgentModel:
122 check_allow_model_requests()
123 tools = [self._map_tool_definition(r) for r in function_tools]
124 if result_tools:
125 tools += [self._map_tool_definition(r) for r in result_tools]
126 return GroqAgentModel(
127 self.client,
128 self.model_name,
129 allow_text_result,
130 tools,
131 )
133 def name(self) -> str:
134 return f'groq:{self.model_name}'
136 @staticmethod
137 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
138 return {
139 'type': 'function',
140 'function': {
141 'name': f.name,
142 'description': f.description,
143 'parameters': f.parameters_json_schema,
144 },
145 }
148@dataclass
149class GroqAgentModel(AgentModel):
150 """Implementation of `AgentModel` for Groq models."""
152 client: AsyncGroq
153 model_name: str
154 allow_text_result: bool
155 tools: list[chat.ChatCompletionToolParam]
157 async def request(
158 self, messages: list[ModelMessage], model_settings: ModelSettings | None
159 ) -> tuple[ModelResponse, usage.Usage]:
160 response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
161 return self._process_response(response), _map_usage(response)
163 @asynccontextmanager
164 async def request_stream(
165 self, messages: list[ModelMessage], model_settings: ModelSettings | None
166 ) -> AsyncIterator[StreamedResponse]:
167 response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
168 async with response:
169 yield await self._process_streamed_response(response)
171 @overload
172 async def _completions_create(
173 self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
174 ) -> AsyncStream[ChatCompletionChunk]:
175 pass
177 @overload
178 async def _completions_create(
179 self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
180 ) -> chat.ChatCompletion:
181 pass
183 async def _completions_create(
184 self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
185 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
186 # standalone function to make it easier to override
187 if not self.tools:
188 tool_choice: Literal['none', 'required', 'auto'] | None = None
189 elif not self.allow_text_result:
190 tool_choice = 'required'
191 else:
192 tool_choice = 'auto'
194 groq_messages = list(chain(*(self._map_message(m) for m in messages)))
196 return await self.client.chat.completions.create(
197 model=str(self.model_name),
198 messages=groq_messages,
199 n=1,
200 parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
201 tools=self.tools or NOT_GIVEN,
202 tool_choice=tool_choice or NOT_GIVEN,
203 stream=stream,
204 max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
205 temperature=model_settings.get('temperature', NOT_GIVEN),
206 top_p=model_settings.get('top_p', NOT_GIVEN),
207 timeout=model_settings.get('timeout', NOT_GIVEN),
208 seed=model_settings.get('seed', NOT_GIVEN),
209 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
210 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
211 logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
212 )
214 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
215 """Process a non-streamed response, and prepare a message to return."""
216 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
217 choice = response.choices[0]
218 items: list[ModelResponsePart] = []
219 if choice.message.content is not None:
220 items.append(TextPart(content=choice.message.content))
221 if choice.message.tool_calls is not None:
222 for c in choice.message.tool_calls:
223 items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
224 return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
226 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
227 """Process a streamed response, and prepare a streaming response to return."""
228 peekable_response = _utils.PeekableAsyncStream(response)
229 first_chunk = await peekable_response.peek()
230 if isinstance(first_chunk, _utils.Unset): 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true
231 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
233 return GroqStreamedResponse(
234 _response=peekable_response,
235 _model_name=self.model_name,
236 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
237 )
239 @classmethod
240 def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
241 """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
242 if isinstance(message, ModelRequest):
243 yield from cls._map_user_message(message)
244 elif isinstance(message, ModelResponse):
245 texts: list[str] = []
246 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
247 for item in message.parts:
248 if isinstance(item, TextPart):
249 texts.append(item.content)
250 elif isinstance(item, ToolCallPart):
251 tool_calls.append(_map_tool_call(item))
252 else:
253 assert_never(item)
254 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
255 if texts:
256 # Note: model responses from this model should only have one text item, so the following
257 # shouldn't merge multiple texts into one unless you switch models between runs:
258 message_param['content'] = '\n\n'.join(texts)
259 if tool_calls:
260 message_param['tool_calls'] = tool_calls
261 yield message_param
262 else:
263 assert_never(message)
265 @classmethod
266 def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
267 for part in message.parts:
268 if isinstance(part, SystemPromptPart):
269 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
270 elif isinstance(part, UserPromptPart):
271 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
272 elif isinstance(part, ToolReturnPart):
273 yield chat.ChatCompletionToolMessageParam(
274 role='tool',
275 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
276 content=part.model_response_str(),
277 )
278 elif isinstance(part, RetryPromptPart): 278 ↛ 267line 278 didn't jump to line 267 because the condition on line 278 was always true
279 if part.tool_name is None: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true
280 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
281 else:
282 yield chat.ChatCompletionToolMessageParam(
283 role='tool',
284 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
285 content=part.model_response(),
286 )
289@dataclass
290class GroqStreamedResponse(StreamedResponse):
291 """Implementation of `StreamedResponse` for Groq models."""
293 _response: AsyncIterable[ChatCompletionChunk]
294 _timestamp: datetime
296 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
297 async for chunk in self._response:
298 self._usage += _map_usage(chunk)
300 try:
301 choice = chunk.choices[0]
302 except IndexError:
303 continue
305 # Handle the text part of the response
306 content = choice.delta.content
307 if content is not None:
308 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
310 # Handle the tool calls
311 for dtc in choice.delta.tool_calls or []:
312 maybe_event = self._parts_manager.handle_tool_call_delta(
313 vendor_part_id=dtc.index,
314 tool_name=dtc.function and dtc.function.name,
315 args=dtc.function and dtc.function.arguments,
316 tool_call_id=dtc.id,
317 )
318 if maybe_event is not None:
319 yield maybe_event
321 def timestamp(self) -> datetime:
322 return self._timestamp
325def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
326 return chat.ChatCompletionMessageToolCallParam(
327 id=_guard_tool_call_id(t=t, model_source='Groq'),
328 type='function',
329 function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
330 )
333def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
334 response_usage = None
335 if isinstance(completion, ChatCompletion):
336 response_usage = completion.usage
337 elif completion.x_groq is not None: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true
338 response_usage = completion.x_groq.usage
340 if response_usage is None:
341 return usage.Usage()
343 return usage.Usage(
344 request_tokens=response_usage.prompt_tokens,
345 response_tokens=response_usage.completion_tokens,
346 total_tokens=response_usage.total_tokens,
347 )