Coverage for pydantic_ai_slim/pydantic_ai/models/groq.py: 96.68%

177 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import base64 

4from collections.abc import AsyncIterable, AsyncIterator, Iterable 

5from contextlib import asynccontextmanager 

6from dataclasses import dataclass, field 

7from datetime import datetime, timezone 

8from itertools import chain 

9from typing import Literal, Union, cast, overload 

10 

11from typing_extensions import assert_never 

12 

13from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage 

14from .._utils import guard_tool_call_id as _guard_tool_call_id 

15from ..messages import ( 

16 BinaryContent, 

17 DocumentUrl, 

18 ImageUrl, 

19 ModelMessage, 

20 ModelRequest, 

21 ModelResponse, 

22 ModelResponsePart, 

23 ModelResponseStreamEvent, 

24 RetryPromptPart, 

25 SystemPromptPart, 

26 TextPart, 

27 ToolCallPart, 

28 ToolReturnPart, 

29 UserPromptPart, 

30) 

31from ..providers import Provider, infer_provider 

32from ..settings import ModelSettings 

33from ..tools import ToolDefinition 

34from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests 

35 

36try: 

37 from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream 

38 from groq.types import chat 

39 from groq.types.chat.chat_completion_content_part_image_param import ImageURL 

40except ImportError as _import_error: 

41 raise ImportError( 

42 'Please install `groq` to use the Groq model, ' 

43 'you can use the `groq` optional group — `pip install "pydantic-ai-slim[groq]"`' 

44 ) from _import_error 

45 

46 

47LatestGroqModelNames = Literal[ 

48 'llama-3.3-70b-versatile', 

49 'llama-3.3-70b-specdec', 

50 'llama-3.1-8b-instant', 

51 'llama-3.2-1b-preview', 

52 'llama-3.2-3b-preview', 

53 'llama-3.2-11b-vision-preview', 

54 'llama-3.2-90b-vision-preview', 

55 'llama3-70b-8192', 

56 'llama3-8b-8192', 

57 'mixtral-8x7b-32768', 

58 'gemma2-9b-it', 

59] 

60"""Latest Groq models.""" 

61 

62GroqModelName = Union[str, LatestGroqModelNames] 

63""" 

64Possible Groq model names. 

65 

66Since Groq supports a variety of date-stamped models, we explicitly list the latest models but 

67allow any name in the type hints. 

68See [the Groq docs](https://console.groq.com/docs/models) for a full list. 

69""" 

70 

71 

72class GroqModelSettings(ModelSettings): 

73 """Settings used for a Groq model request. 

74 

75 ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. 

76 """ 

77 

78 # This class is a placeholder for any future groq-specific settings 

79 

80 

81@dataclass(init=False) 

82class GroqModel(Model): 

83 """A model that uses the Groq API. 

84 

85 Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API. 

86 

87 Apart from `__init__`, all methods are private or match those of the base class. 

88 """ 

89 

90 client: AsyncGroq = field(repr=False) 

91 

92 _model_name: GroqModelName = field(repr=False) 

93 _system: str = field(default='groq', repr=False) 

94 

95 def __init__(self, model_name: GroqModelName, *, provider: Literal['groq'] | Provider[AsyncGroq] = 'groq'): 

96 """Initialize a Groq model. 

97 

98 Args: 

99 model_name: The name of the Groq model to use. List of model names available 

100 [here](https://console.groq.com/docs/models). 

101 provider: The provider to use for authentication and API access. Can be either the string 

102 'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be 

103 created using the other parameters. 

104 """ 

105 self._model_name = model_name 

106 

107 if isinstance(provider, str): 

108 provider = infer_provider(provider) 

109 self.client = provider.client 

110 

111 @property 

112 def base_url(self) -> str: 

113 return str(self.client.base_url) 

114 

115 async def request( 

116 self, 

117 messages: list[ModelMessage], 

118 model_settings: ModelSettings | None, 

119 model_request_parameters: ModelRequestParameters, 

120 ) -> tuple[ModelResponse, usage.Usage]: 

121 check_allow_model_requests() 

122 response = await self._completions_create( 

123 messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters 

124 ) 

125 return self._process_response(response), _map_usage(response) 

126 

127 @asynccontextmanager 

128 async def request_stream( 

129 self, 

130 messages: list[ModelMessage], 

131 model_settings: ModelSettings | None, 

132 model_request_parameters: ModelRequestParameters, 

133 ) -> AsyncIterator[StreamedResponse]: 

134 check_allow_model_requests() 

135 response = await self._completions_create( 

136 messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters 

137 ) 

138 async with response: 

139 yield await self._process_streamed_response(response) 

140 

141 @property 

142 def model_name(self) -> GroqModelName: 

143 """The model name.""" 

144 return self._model_name 

145 

146 @property 

147 def system(self) -> str: 

148 """The system / model provider.""" 

149 return self._system 

150 

151 @overload 

152 async def _completions_create( 

153 self, 

154 messages: list[ModelMessage], 

155 stream: Literal[True], 

156 model_settings: GroqModelSettings, 

157 model_request_parameters: ModelRequestParameters, 

158 ) -> AsyncStream[chat.ChatCompletionChunk]: 

159 pass 

160 

161 @overload 

162 async def _completions_create( 

163 self, 

164 messages: list[ModelMessage], 

165 stream: Literal[False], 

166 model_settings: GroqModelSettings, 

167 model_request_parameters: ModelRequestParameters, 

168 ) -> chat.ChatCompletion: 

169 pass 

170 

171 async def _completions_create( 

172 self, 

173 messages: list[ModelMessage], 

174 stream: bool, 

175 model_settings: GroqModelSettings, 

176 model_request_parameters: ModelRequestParameters, 

177 ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]: 

178 tools = self._get_tools(model_request_parameters) 

179 # standalone function to make it easier to override 

180 if not tools: 

181 tool_choice: Literal['none', 'required', 'auto'] | None = None 

182 elif not model_request_parameters.allow_text_result: 

183 tool_choice = 'required' 

184 else: 

185 tool_choice = 'auto' 

186 

187 groq_messages = list(chain(*(self._map_message(m) for m in messages))) 

188 

189 try: 

190 return await self.client.chat.completions.create( 

191 model=str(self._model_name), 

192 messages=groq_messages, 

193 n=1, 

194 parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), 

195 tools=tools or NOT_GIVEN, 

196 tool_choice=tool_choice or NOT_GIVEN, 

197 stream=stream, 

198 max_tokens=model_settings.get('max_tokens', NOT_GIVEN), 

199 temperature=model_settings.get('temperature', NOT_GIVEN), 

200 top_p=model_settings.get('top_p', NOT_GIVEN), 

201 timeout=model_settings.get('timeout', NOT_GIVEN), 

202 seed=model_settings.get('seed', NOT_GIVEN), 

203 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), 

204 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), 

205 logit_bias=model_settings.get('logit_bias', NOT_GIVEN), 

206 ) 

207 except APIStatusError as e: 

208 if (status_code := e.status_code) >= 400: 208 ↛ 210line 208 didn't jump to line 210 because the condition on line 208 was always true

209 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e 

210 raise 

211 

212 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: 

213 """Process a non-streamed response, and prepare a message to return.""" 

214 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) 

215 choice = response.choices[0] 

216 items: list[ModelResponsePart] = [] 

217 if choice.message.content is not None: 

218 items.append(TextPart(content=choice.message.content)) 

219 if choice.message.tool_calls is not None: 

220 for c in choice.message.tool_calls: 

221 items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) 

222 return ModelResponse(items, model_name=response.model, timestamp=timestamp) 

223 

224 async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse: 

225 """Process a streamed response, and prepare a streaming response to return.""" 

226 peekable_response = _utils.PeekableAsyncStream(response) 

227 first_chunk = await peekable_response.peek() 

228 if isinstance(first_chunk, _utils.Unset): 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true

229 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

230 

231 return GroqStreamedResponse( 

232 _response=peekable_response, 

233 _model_name=self._model_name, 

234 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), 

235 ) 

236 

237 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: 

238 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] 

239 if model_request_parameters.result_tools: 

240 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools] 

241 return tools 

242 

243 def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: 

244 """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" 

245 if isinstance(message, ModelRequest): 

246 yield from self._map_user_message(message) 

247 elif isinstance(message, ModelResponse): 

248 texts: list[str] = [] 

249 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] 

250 for item in message.parts: 

251 if isinstance(item, TextPart): 

252 texts.append(item.content) 

253 elif isinstance(item, ToolCallPart): 

254 tool_calls.append(self._map_tool_call(item)) 

255 else: 

256 assert_never(item) 

257 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') 

258 if texts: 

259 # Note: model responses from this model should only have one text item, so the following 

260 # shouldn't merge multiple texts into one unless you switch models between runs: 

261 message_param['content'] = '\n\n'.join(texts) 

262 if tool_calls: 

263 message_param['tool_calls'] = tool_calls 

264 yield message_param 

265 else: 

266 assert_never(message) 

267 

268 @staticmethod 

269 def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: 

270 return chat.ChatCompletionMessageToolCallParam( 

271 id=_guard_tool_call_id(t=t), 

272 type='function', 

273 function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, 

274 ) 

275 

276 @staticmethod 

277 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: 

278 return { 

279 'type': 'function', 

280 'function': { 

281 'name': f.name, 

282 'description': f.description, 

283 'parameters': f.parameters_json_schema, 

284 }, 

285 } 

286 

287 @classmethod 

288 def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]: 

289 for part in message.parts: 

290 if isinstance(part, SystemPromptPart): 

291 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) 

292 elif isinstance(part, UserPromptPart): 

293 yield cls._map_user_prompt(part) 

294 elif isinstance(part, ToolReturnPart): 

295 yield chat.ChatCompletionToolMessageParam( 

296 role='tool', 

297 tool_call_id=_guard_tool_call_id(t=part), 

298 content=part.model_response_str(), 

299 ) 

300 elif isinstance(part, RetryPromptPart): 300 ↛ 289line 300 didn't jump to line 289 because the condition on line 300 was always true

301 if part.tool_name is None: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true

302 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) 

303 else: 

304 yield chat.ChatCompletionToolMessageParam( 

305 role='tool', 

306 tool_call_id=_guard_tool_call_id(t=part), 

307 content=part.model_response(), 

308 ) 

309 

310 @staticmethod 

311 def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam: 

312 content: str | list[chat.ChatCompletionContentPartParam] 

313 if isinstance(part.content, str): 

314 content = part.content 

315 else: 

316 content = [] 

317 for item in part.content: 

318 if isinstance(item, str): 

319 content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text')) 

320 elif isinstance(item, ImageUrl): 

321 image_url = ImageURL(url=item.url) 

322 content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) 

323 elif isinstance(item, BinaryContent): 

324 base64_encoded = base64.b64encode(item.data).decode('utf-8') 

325 if item.is_image: 

326 image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') 

327 content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) 

328 else: 

329 raise RuntimeError('Only images are supported for binary content in Groq.') 

330 elif isinstance(item, DocumentUrl): # pragma: no cover 

331 raise RuntimeError('DocumentUrl is not supported in Groq.') 

332 else: # pragma: no cover 

333 raise RuntimeError(f'Unsupported content type: {type(item)}') 

334 

335 return chat.ChatCompletionUserMessageParam(role='user', content=content) 

336 

337 

338@dataclass 

339class GroqStreamedResponse(StreamedResponse): 

340 """Implementation of `StreamedResponse` for Groq models.""" 

341 

342 _model_name: GroqModelName 

343 _response: AsyncIterable[chat.ChatCompletionChunk] 

344 _timestamp: datetime 

345 

346 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

347 async for chunk in self._response: 

348 self._usage += _map_usage(chunk) 

349 

350 try: 

351 choice = chunk.choices[0] 

352 except IndexError: 

353 continue 

354 

355 # Handle the text part of the response 

356 content = choice.delta.content 

357 if content is not None: 

358 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) 

359 

360 # Handle the tool calls 

361 for dtc in choice.delta.tool_calls or []: 

362 maybe_event = self._parts_manager.handle_tool_call_delta( 

363 vendor_part_id=dtc.index, 

364 tool_name=dtc.function and dtc.function.name, 

365 args=dtc.function and dtc.function.arguments, 

366 tool_call_id=dtc.id, 

367 ) 

368 if maybe_event is not None: 

369 yield maybe_event 

370 

371 @property 

372 def model_name(self) -> GroqModelName: 

373 """Get the model name of the response.""" 

374 return self._model_name 

375 

376 @property 

377 def timestamp(self) -> datetime: 

378 """Get the timestamp of the response.""" 

379 return self._timestamp 

380 

381 

382def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage: 

383 response_usage = None 

384 if isinstance(completion, chat.ChatCompletion): 

385 response_usage = completion.usage 

386 elif completion.x_groq is not None: 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true

387 response_usage = completion.x_groq.usage 

388 

389 if response_usage is None: 

390 return usage.Usage() 

391 

392 return usage.Usage( 

393 request_tokens=response_usage.prompt_tokens, 

394 response_tokens=response_usage.completion_tokens, 

395 total_tokens=response_usage.total_tokens, 

396 )