Coverage for pydantic_ai_slim/pydantic_ai/models/groq.py: 95.54%
148 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3from collections.abc import AsyncIterable, AsyncIterator, Iterable
4from contextlib import asynccontextmanager
5from dataclasses import dataclass, field
6from datetime import datetime, timezone
7from itertools import chain
8from typing import Literal, cast, overload
10from httpx import AsyncClient as AsyncHTTPClient
11from typing_extensions import assert_never
13from .. import UnexpectedModelBehavior, _utils, usage
14from .._utils import guard_tool_call_id as _guard_tool_call_id
15from ..messages import (
16 ModelMessage,
17 ModelRequest,
18 ModelResponse,
19 ModelResponsePart,
20 ModelResponseStreamEvent,
21 RetryPromptPart,
22 SystemPromptPart,
23 TextPart,
24 ToolCallPart,
25 ToolReturnPart,
26 UserPromptPart,
27)
28from ..settings import ModelSettings
29from ..tools import ToolDefinition
30from . import (
31 AgentModel,
32 Model,
33 StreamedResponse,
34 cached_async_http_client,
35 check_allow_model_requests,
36)
38try:
39 from groq import NOT_GIVEN, AsyncGroq, AsyncStream
40 from groq.types import chat
41 from groq.types.chat import ChatCompletion, ChatCompletionChunk
42except ImportError as _import_error:
43 raise ImportError(
44 'Please install `groq` to use the Groq model, '
45 "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
46 ) from _import_error
48GroqModelName = Literal[
49 'llama-3.3-70b-versatile',
50 'llama-3.1-70b-versatile',
51 'llama3-groq-70b-8192-tool-use-preview',
52 'llama3-groq-8b-8192-tool-use-preview',
53 'llama-3.1-70b-specdec',
54 'llama-3.1-8b-instant',
55 'llama-3.2-1b-preview',
56 'llama-3.2-3b-preview',
57 'llama-3.2-11b-vision-preview',
58 'llama-3.2-90b-vision-preview',
59 'llama3-70b-8192',
60 'llama3-8b-8192',
61 'mixtral-8x7b-32768',
62 'gemma2-9b-it',
63 'gemma-7b-it',
64]
65"""Named Groq models.
67See [the Groq docs](https://console.groq.com/docs/models) for a full list.
68"""
71class GroqModelSettings(ModelSettings):
72 """Settings used for a Groq model request."""
74 # This class is a placeholder for any future groq-specific settings
77@dataclass(init=False)
78class GroqModel(Model):
79 """A model that uses the Groq API.
81 Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API.
83 Apart from `__init__`, all methods are private or match those of the base class.
84 """
86 model_name: GroqModelName
87 client: AsyncGroq = field(repr=False)
89 def __init__(
90 self,
91 model_name: GroqModelName,
92 *,
93 api_key: str | None = None,
94 groq_client: AsyncGroq | None = None,
95 http_client: AsyncHTTPClient | None = None,
96 ):
97 """Initialize a Groq model.
99 Args:
100 model_name: The name of the Groq model to use. List of model names available
101 [here](https://console.groq.com/docs/models).
102 api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable
103 will be used if available.
104 groq_client: An existing
105 [`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage)
106 client to use, if provided, `api_key` and `http_client` must be `None`.
107 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
108 """
109 self.model_name = model_name
110 if groq_client is not None:
111 assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
112 assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
113 self.client = groq_client
114 elif http_client is not None: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 self.client = AsyncGroq(api_key=api_key, http_client=http_client)
116 else:
117 self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
119 async def agent_model(
120 self,
121 *,
122 function_tools: list[ToolDefinition],
123 allow_text_result: bool,
124 result_tools: list[ToolDefinition],
125 ) -> AgentModel:
126 check_allow_model_requests()
127 tools = [self._map_tool_definition(r) for r in function_tools]
128 if result_tools:
129 tools += [self._map_tool_definition(r) for r in result_tools]
130 return GroqAgentModel(
131 self.client,
132 self.model_name,
133 allow_text_result,
134 tools,
135 )
137 def name(self) -> str:
138 return f'groq:{self.model_name}'
140 @staticmethod
141 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
142 return {
143 'type': 'function',
144 'function': {
145 'name': f.name,
146 'description': f.description,
147 'parameters': f.parameters_json_schema,
148 },
149 }
152@dataclass
153class GroqAgentModel(AgentModel):
154 """Implementation of `AgentModel` for Groq models."""
156 client: AsyncGroq
157 model_name: str
158 allow_text_result: bool
159 tools: list[chat.ChatCompletionToolParam]
161 async def request(
162 self, messages: list[ModelMessage], model_settings: ModelSettings | None
163 ) -> tuple[ModelResponse, usage.Usage]:
164 response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
165 return self._process_response(response), _map_usage(response)
167 @asynccontextmanager
168 async def request_stream(
169 self, messages: list[ModelMessage], model_settings: ModelSettings | None
170 ) -> AsyncIterator[StreamedResponse]:
171 response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
172 async with response:
173 yield await self._process_streamed_response(response)
175 @overload
176 async def _completions_create(
177 self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
178 ) -> AsyncStream[ChatCompletionChunk]:
179 pass
181 @overload
182 async def _completions_create(
183 self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
184 ) -> chat.ChatCompletion:
185 pass
187 async def _completions_create(
188 self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
189 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
190 # standalone function to make it easier to override
191 if not self.tools:
192 tool_choice: Literal['none', 'required', 'auto'] | None = None
193 elif not self.allow_text_result:
194 tool_choice = 'required'
195 else:
196 tool_choice = 'auto'
198 groq_messages = list(chain(*(self._map_message(m) for m in messages)))
200 return await self.client.chat.completions.create(
201 model=str(self.model_name),
202 messages=groq_messages,
203 n=1,
204 parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
205 tools=self.tools or NOT_GIVEN,
206 tool_choice=tool_choice or NOT_GIVEN,
207 stream=stream,
208 max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
209 temperature=model_settings.get('temperature', NOT_GIVEN),
210 top_p=model_settings.get('top_p', NOT_GIVEN),
211 timeout=model_settings.get('timeout', NOT_GIVEN),
212 seed=model_settings.get('seed', NOT_GIVEN),
213 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
214 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
215 logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
216 )
218 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
219 """Process a non-streamed response, and prepare a message to return."""
220 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
221 choice = response.choices[0]
222 items: list[ModelResponsePart] = []
223 if choice.message.content is not None:
224 items.append(TextPart(content=choice.message.content))
225 if choice.message.tool_calls is not None:
226 for c in choice.message.tool_calls:
227 items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
228 return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
230 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
231 """Process a streamed response, and prepare a streaming response to return."""
232 peekable_response = _utils.PeekableAsyncStream(response)
233 first_chunk = await peekable_response.peek()
234 if isinstance(first_chunk, _utils.Unset): 234 ↛ 235line 234 didn't jump to line 235 because the condition on line 234 was never true
235 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
237 return GroqStreamedResponse(
238 _response=peekable_response,
239 _model_name=self.model_name,
240 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
241 )
243 @classmethod
244 def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
245 """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
246 if isinstance(message, ModelRequest):
247 yield from cls._map_user_message(message)
248 elif isinstance(message, ModelResponse):
249 texts: list[str] = []
250 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
251 for item in message.parts:
252 if isinstance(item, TextPart):
253 texts.append(item.content)
254 elif isinstance(item, ToolCallPart):
255 tool_calls.append(_map_tool_call(item))
256 else:
257 assert_never(item)
258 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
259 if texts:
260 # Note: model responses from this model should only have one text item, so the following
261 # shouldn't merge multiple texts into one unless you switch models between runs:
262 message_param['content'] = '\n\n'.join(texts)
263 if tool_calls:
264 message_param['tool_calls'] = tool_calls
265 yield message_param
266 else:
267 assert_never(message)
269 @classmethod
270 def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
271 for part in message.parts:
272 if isinstance(part, SystemPromptPart):
273 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
274 elif isinstance(part, UserPromptPart):
275 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
276 elif isinstance(part, ToolReturnPart):
277 yield chat.ChatCompletionToolMessageParam(
278 role='tool',
279 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
280 content=part.model_response_str(),
281 )
282 elif isinstance(part, RetryPromptPart): 282 ↛ 271line 282 didn't jump to line 271 because the condition on line 282 was always true
283 if part.tool_name is None: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true
284 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
285 else:
286 yield chat.ChatCompletionToolMessageParam(
287 role='tool',
288 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
289 content=part.model_response(),
290 )
293@dataclass
294class GroqStreamedResponse(StreamedResponse):
295 """Implementation of `StreamedResponse` for Groq models."""
297 _response: AsyncIterable[ChatCompletionChunk]
298 _timestamp: datetime
300 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
301 async for chunk in self._response:
302 self._usage += _map_usage(chunk)
304 try:
305 choice = chunk.choices[0]
306 except IndexError:
307 continue
309 # Handle the text part of the response
310 content = choice.delta.content
311 if content is not None:
312 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
314 # Handle the tool calls
315 for dtc in choice.delta.tool_calls or []:
316 maybe_event = self._parts_manager.handle_tool_call_delta(
317 vendor_part_id=dtc.index,
318 tool_name=dtc.function and dtc.function.name,
319 args=dtc.function and dtc.function.arguments,
320 tool_call_id=dtc.id,
321 )
322 if maybe_event is not None:
323 yield maybe_event
325 def timestamp(self) -> datetime:
326 return self._timestamp
329def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
330 return chat.ChatCompletionMessageToolCallParam(
331 id=_guard_tool_call_id(t=t, model_source='Groq'),
332 type='function',
333 function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
334 )
337def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
338 response_usage = None
339 if isinstance(completion, ChatCompletion):
340 response_usage = completion.usage
341 elif completion.x_groq is not None: 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true
342 response_usage = completion.x_groq.usage
344 if response_usage is None:
345 return usage.Usage()
347 return usage.Usage(
348 request_tokens=response_usage.prompt_tokens,
349 response_tokens=response_usage.completion_tokens,
350 total_tokens=response_usage.total_tokens,
351 )