Coverage for pydantic_ai_slim/pydantic_ai/models/groq.py: 95.54%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import AsyncIterable, AsyncIterator, Iterable 

4from contextlib import asynccontextmanager 

5from dataclasses import dataclass, field 

6from datetime import datetime, timezone 

7from itertools import chain 

8from typing import Literal, cast, overload 

9 

10from httpx import AsyncClient as AsyncHTTPClient 

11from typing_extensions import assert_never 

12 

13from .. import UnexpectedModelBehavior, _utils, usage 

14from .._utils import guard_tool_call_id as _guard_tool_call_id 

15from ..messages import ( 

16 ModelMessage, 

17 ModelRequest, 

18 ModelResponse, 

19 ModelResponsePart, 

20 ModelResponseStreamEvent, 

21 RetryPromptPart, 

22 SystemPromptPart, 

23 TextPart, 

24 ToolCallPart, 

25 ToolReturnPart, 

26 UserPromptPart, 

27) 

28from ..settings import ModelSettings 

29from ..tools import ToolDefinition 

30from . import ( 

31 AgentModel, 

32 Model, 

33 StreamedResponse, 

34 cached_async_http_client, 

35 check_allow_model_requests, 

36) 

37 

38try: 

39 from groq import NOT_GIVEN, AsyncGroq, AsyncStream 

40 from groq.types import chat 

41 from groq.types.chat import ChatCompletion, ChatCompletionChunk 

42except ImportError as _import_error: 

43 raise ImportError( 

44 'Please install `groq` to use the Groq model, ' 

45 "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`" 

46 ) from _import_error 

47 

48GroqModelName = Literal[ 

49 'llama-3.3-70b-versatile', 

50 'llama-3.1-70b-versatile', 

51 'llama3-groq-70b-8192-tool-use-preview', 

52 'llama3-groq-8b-8192-tool-use-preview', 

53 'llama-3.1-70b-specdec', 

54 'llama-3.1-8b-instant', 

55 'llama-3.2-1b-preview', 

56 'llama-3.2-3b-preview', 

57 'llama-3.2-11b-vision-preview', 

58 'llama-3.2-90b-vision-preview', 

59 'llama3-70b-8192', 

60 'llama3-8b-8192', 

61 'mixtral-8x7b-32768', 

62 'gemma2-9b-it', 

63 'gemma-7b-it', 

64] 

65"""Named Groq models. 

66 

67See [the Groq docs](https://console.groq.com/docs/models) for a full list. 

68""" 

69 

70 

71class GroqModelSettings(ModelSettings): 

72 """Settings used for a Groq model request.""" 

73 

74 # This class is a placeholder for any future groq-specific settings 

75 

76 

77@dataclass(init=False) 

78class GroqModel(Model): 

79 """A model that uses the Groq API. 

80 

81 Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API. 

82 

83 Apart from `__init__`, all methods are private or match those of the base class. 

84 """ 

85 

86 model_name: GroqModelName 

87 client: AsyncGroq = field(repr=False) 

88 

89 def __init__( 

90 self, 

91 model_name: GroqModelName, 

92 *, 

93 api_key: str | None = None, 

94 groq_client: AsyncGroq | None = None, 

95 http_client: AsyncHTTPClient | None = None, 

96 ): 

97 """Initialize a Groq model. 

98 

99 Args: 

100 model_name: The name of the Groq model to use. List of model names available 

101 [here](https://console.groq.com/docs/models). 

102 api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable 

103 will be used if available. 

104 groq_client: An existing 

105 [`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage) 

106 client to use, if provided, `api_key` and `http_client` must be `None`. 

107 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

108 """ 

109 self.model_name = model_name 

110 if groq_client is not None: 

111 assert http_client is None, 'Cannot provide both `groq_client` and `http_client`' 

112 assert api_key is None, 'Cannot provide both `groq_client` and `api_key`' 

113 self.client = groq_client 

114 elif http_client is not None: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 self.client = AsyncGroq(api_key=api_key, http_client=http_client) 

116 else: 

117 self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client()) 

118 

119 async def agent_model( 

120 self, 

121 *, 

122 function_tools: list[ToolDefinition], 

123 allow_text_result: bool, 

124 result_tools: list[ToolDefinition], 

125 ) -> AgentModel: 

126 check_allow_model_requests() 

127 tools = [self._map_tool_definition(r) for r in function_tools] 

128 if result_tools: 

129 tools += [self._map_tool_definition(r) for r in result_tools] 

130 return GroqAgentModel( 

131 self.client, 

132 self.model_name, 

133 allow_text_result, 

134 tools, 

135 ) 

136 

137 def name(self) -> str: 

138 return f'groq:{self.model_name}' 

139 

140 @staticmethod 

141 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: 

142 return { 

143 'type': 'function', 

144 'function': { 

145 'name': f.name, 

146 'description': f.description, 

147 'parameters': f.parameters_json_schema, 

148 }, 

149 } 

150 

151 

152@dataclass 

153class GroqAgentModel(AgentModel): 

154 """Implementation of `AgentModel` for Groq models.""" 

155 

156 client: AsyncGroq 

157 model_name: str 

158 allow_text_result: bool 

159 tools: list[chat.ChatCompletionToolParam] 

160 

161 async def request( 

162 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

163 ) -> tuple[ModelResponse, usage.Usage]: 

164 response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {})) 

165 return self._process_response(response), _map_usage(response) 

166 

167 @asynccontextmanager 

168 async def request_stream( 

169 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

170 ) -> AsyncIterator[StreamedResponse]: 

171 response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {})) 

172 async with response: 

173 yield await self._process_streamed_response(response) 

174 

175 @overload 

176 async def _completions_create( 

177 self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings 

178 ) -> AsyncStream[ChatCompletionChunk]: 

179 pass 

180 

181 @overload 

182 async def _completions_create( 

183 self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings 

184 ) -> chat.ChatCompletion: 

185 pass 

186 

187 async def _completions_create( 

188 self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings 

189 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: 

190 # standalone function to make it easier to override 

191 if not self.tools: 

192 tool_choice: Literal['none', 'required', 'auto'] | None = None 

193 elif not self.allow_text_result: 

194 tool_choice = 'required' 

195 else: 

196 tool_choice = 'auto' 

197 

198 groq_messages = list(chain(*(self._map_message(m) for m in messages))) 

199 

200 return await self.client.chat.completions.create( 

201 model=str(self.model_name), 

202 messages=groq_messages, 

203 n=1, 

204 parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN), 

205 tools=self.tools or NOT_GIVEN, 

206 tool_choice=tool_choice or NOT_GIVEN, 

207 stream=stream, 

208 max_tokens=model_settings.get('max_tokens', NOT_GIVEN), 

209 temperature=model_settings.get('temperature', NOT_GIVEN), 

210 top_p=model_settings.get('top_p', NOT_GIVEN), 

211 timeout=model_settings.get('timeout', NOT_GIVEN), 

212 seed=model_settings.get('seed', NOT_GIVEN), 

213 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), 

214 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), 

215 logit_bias=model_settings.get('logit_bias', NOT_GIVEN), 

216 ) 

217 

218 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: 

219 """Process a non-streamed response, and prepare a message to return.""" 

220 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) 

221 choice = response.choices[0] 

222 items: list[ModelResponsePart] = [] 

223 if choice.message.content is not None: 

224 items.append(TextPart(content=choice.message.content)) 

225 if choice.message.tool_calls is not None: 

226 for c in choice.message.tool_calls: 

227 items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) 

228 return ModelResponse(items, model_name=self.model_name, timestamp=timestamp) 

229 

230 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse: 

231 """Process a streamed response, and prepare a streaming response to return.""" 

232 peekable_response = _utils.PeekableAsyncStream(response) 

233 first_chunk = await peekable_response.peek() 

234 if isinstance(first_chunk, _utils.Unset): 234 ↛ 235line 234 didn't jump to line 235 because the condition on line 234 was never true

235 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

236 

237 return GroqStreamedResponse( 

238 _response=peekable_response, 

239 _model_name=self.model_name, 

240 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), 

241 ) 

242 

243 @classmethod 

244 def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: 

245 """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" 

246 if isinstance(message, ModelRequest): 

247 yield from cls._map_user_message(message) 

248 elif isinstance(message, ModelResponse): 

249 texts: list[str] = [] 

250 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] 

251 for item in message.parts: 

252 if isinstance(item, TextPart): 

253 texts.append(item.content) 

254 elif isinstance(item, ToolCallPart): 

255 tool_calls.append(_map_tool_call(item)) 

256 else: 

257 assert_never(item) 

258 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') 

259 if texts: 

260 # Note: model responses from this model should only have one text item, so the following 

261 # shouldn't merge multiple texts into one unless you switch models between runs: 

262 message_param['content'] = '\n\n'.join(texts) 

263 if tool_calls: 

264 message_param['tool_calls'] = tool_calls 

265 yield message_param 

266 else: 

267 assert_never(message) 

268 

269 @classmethod 

270 def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]: 

271 for part in message.parts: 

272 if isinstance(part, SystemPromptPart): 

273 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) 

274 elif isinstance(part, UserPromptPart): 

275 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) 

276 elif isinstance(part, ToolReturnPart): 

277 yield chat.ChatCompletionToolMessageParam( 

278 role='tool', 

279 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'), 

280 content=part.model_response_str(), 

281 ) 

282 elif isinstance(part, RetryPromptPart): 282 ↛ 271line 282 didn't jump to line 271 because the condition on line 282 was always true

283 if part.tool_name is None: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true

284 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) 

285 else: 

286 yield chat.ChatCompletionToolMessageParam( 

287 role='tool', 

288 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'), 

289 content=part.model_response(), 

290 ) 

291 

292 

293@dataclass 

294class GroqStreamedResponse(StreamedResponse): 

295 """Implementation of `StreamedResponse` for Groq models.""" 

296 

297 _response: AsyncIterable[ChatCompletionChunk] 

298 _timestamp: datetime 

299 

300 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

301 async for chunk in self._response: 

302 self._usage += _map_usage(chunk) 

303 

304 try: 

305 choice = chunk.choices[0] 

306 except IndexError: 

307 continue 

308 

309 # Handle the text part of the response 

310 content = choice.delta.content 

311 if content is not None: 

312 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) 

313 

314 # Handle the tool calls 

315 for dtc in choice.delta.tool_calls or []: 

316 maybe_event = self._parts_manager.handle_tool_call_delta( 

317 vendor_part_id=dtc.index, 

318 tool_name=dtc.function and dtc.function.name, 

319 args=dtc.function and dtc.function.arguments, 

320 tool_call_id=dtc.id, 

321 ) 

322 if maybe_event is not None: 

323 yield maybe_event 

324 

325 def timestamp(self) -> datetime: 

326 return self._timestamp 

327 

328 

329def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: 

330 return chat.ChatCompletionMessageToolCallParam( 

331 id=_guard_tool_call_id(t=t, model_source='Groq'), 

332 type='function', 

333 function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, 

334 ) 

335 

336 

337def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage: 

338 response_usage = None 

339 if isinstance(completion, ChatCompletion): 

340 response_usage = completion.usage 

341 elif completion.x_groq is not None: 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true

342 response_usage = completion.x_groq.usage 

343 

344 if response_usage is None: 

345 return usage.Usage() 

346 

347 return usage.Usage( 

348 request_tokens=response_usage.prompt_tokens, 

349 response_tokens=response_usage.completion_tokens, 

350 total_tokens=response_usage.total_tokens, 

351 )