Coverage for pydantic_ai_slim/pydantic_ai/models/groq.py: 95.54%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-30 19:21 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import AsyncIterable, AsyncIterator, Iterable 

4from contextlib import asynccontextmanager 

5from dataclasses import dataclass, field 

6from datetime import datetime, timezone 

7from itertools import chain 

8from typing import Literal, cast, overload 

9 

10from httpx import AsyncClient as AsyncHTTPClient 

11from typing_extensions import assert_never 

12 

13from .. import UnexpectedModelBehavior, _utils, usage 

14from .._utils import guard_tool_call_id as _guard_tool_call_id 

15from ..messages import ( 

16 ModelMessage, 

17 ModelRequest, 

18 ModelResponse, 

19 ModelResponsePart, 

20 ModelResponseStreamEvent, 

21 RetryPromptPart, 

22 SystemPromptPart, 

23 TextPart, 

24 ToolCallPart, 

25 ToolReturnPart, 

26 UserPromptPart, 

27) 

28from ..settings import ModelSettings 

29from ..tools import ToolDefinition 

30from . import ( 

31 AgentModel, 

32 Model, 

33 StreamedResponse, 

34 cached_async_http_client, 

35 check_allow_model_requests, 

36) 

37 

38try: 

39 from groq import NOT_GIVEN, AsyncGroq, AsyncStream 

40 from groq.types import chat 

41 from groq.types.chat import ChatCompletion, ChatCompletionChunk 

42except ImportError as _import_error: 

43 raise ImportError( 

44 'Please install `groq` to use the Groq model, ' 

45 "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`" 

46 ) from _import_error 

47 

48GroqModelName = Literal[ 

49 'llama-3.3-70b-versatile', 

50 'llama-3.3-70b-specdec', 

51 'llama-3.1-8b-instant', 

52 'llama-3.2-1b-preview', 

53 'llama-3.2-3b-preview', 

54 'llama-3.2-11b-vision-preview', 

55 'llama-3.2-90b-vision-preview', 

56 'llama3-70b-8192', 

57 'llama3-8b-8192', 

58 'mixtral-8x7b-32768', 

59 'gemma2-9b-it', 

60] 

61"""Named Groq models. 

62 

63See [the Groq docs](https://console.groq.com/docs/models) for a full list. 

64""" 

65 

66 

67class GroqModelSettings(ModelSettings): 

68 """Settings used for a Groq model request.""" 

69 

70 # This class is a placeholder for any future groq-specific settings 

71 

72 

73@dataclass(init=False) 

74class GroqModel(Model): 

75 """A model that uses the Groq API. 

76 

77 Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API. 

78 

79 Apart from `__init__`, all methods are private or match those of the base class. 

80 """ 

81 

82 model_name: GroqModelName 

83 client: AsyncGroq = field(repr=False) 

84 

85 def __init__( 

86 self, 

87 model_name: GroqModelName, 

88 *, 

89 api_key: str | None = None, 

90 groq_client: AsyncGroq | None = None, 

91 http_client: AsyncHTTPClient | None = None, 

92 ): 

93 """Initialize a Groq model. 

94 

95 Args: 

96 model_name: The name of the Groq model to use. List of model names available 

97 [here](https://console.groq.com/docs/models). 

98 api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable 

99 will be used if available. 

100 groq_client: An existing 

101 [`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage) 

102 client to use, if provided, `api_key` and `http_client` must be `None`. 

103 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

104 """ 

105 self.model_name = model_name 

106 if groq_client is not None: 

107 assert http_client is None, 'Cannot provide both `groq_client` and `http_client`' 

108 assert api_key is None, 'Cannot provide both `groq_client` and `api_key`' 

109 self.client = groq_client 

110 elif http_client is not None: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true

111 self.client = AsyncGroq(api_key=api_key, http_client=http_client) 

112 else: 

113 self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client()) 

114 

115 async def agent_model( 

116 self, 

117 *, 

118 function_tools: list[ToolDefinition], 

119 allow_text_result: bool, 

120 result_tools: list[ToolDefinition], 

121 ) -> AgentModel: 

122 check_allow_model_requests() 

123 tools = [self._map_tool_definition(r) for r in function_tools] 

124 if result_tools: 

125 tools += [self._map_tool_definition(r) for r in result_tools] 

126 return GroqAgentModel( 

127 self.client, 

128 self.model_name, 

129 allow_text_result, 

130 tools, 

131 ) 

132 

133 def name(self) -> str: 

134 return f'groq:{self.model_name}' 

135 

136 @staticmethod 

137 def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: 

138 return { 

139 'type': 'function', 

140 'function': { 

141 'name': f.name, 

142 'description': f.description, 

143 'parameters': f.parameters_json_schema, 

144 }, 

145 } 

146 

147 

148@dataclass 

149class GroqAgentModel(AgentModel): 

150 """Implementation of `AgentModel` for Groq models.""" 

151 

152 client: AsyncGroq 

153 model_name: str 

154 allow_text_result: bool 

155 tools: list[chat.ChatCompletionToolParam] 

156 

157 async def request( 

158 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

159 ) -> tuple[ModelResponse, usage.Usage]: 

160 response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {})) 

161 return self._process_response(response), _map_usage(response) 

162 

163 @asynccontextmanager 

164 async def request_stream( 

165 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

166 ) -> AsyncIterator[StreamedResponse]: 

167 response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {})) 

168 async with response: 

169 yield await self._process_streamed_response(response) 

170 

171 @overload 

172 async def _completions_create( 

173 self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings 

174 ) -> AsyncStream[ChatCompletionChunk]: 

175 pass 

176 

177 @overload 

178 async def _completions_create( 

179 self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings 

180 ) -> chat.ChatCompletion: 

181 pass 

182 

183 async def _completions_create( 

184 self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings 

185 ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: 

186 # standalone function to make it easier to override 

187 if not self.tools: 

188 tool_choice: Literal['none', 'required', 'auto'] | None = None 

189 elif not self.allow_text_result: 

190 tool_choice = 'required' 

191 else: 

192 tool_choice = 'auto' 

193 

194 groq_messages = list(chain(*(self._map_message(m) for m in messages))) 

195 

196 return await self.client.chat.completions.create( 

197 model=str(self.model_name), 

198 messages=groq_messages, 

199 n=1, 

200 parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), 

201 tools=self.tools or NOT_GIVEN, 

202 tool_choice=tool_choice or NOT_GIVEN, 

203 stream=stream, 

204 max_tokens=model_settings.get('max_tokens', NOT_GIVEN), 

205 temperature=model_settings.get('temperature', NOT_GIVEN), 

206 top_p=model_settings.get('top_p', NOT_GIVEN), 

207 timeout=model_settings.get('timeout', NOT_GIVEN), 

208 seed=model_settings.get('seed', NOT_GIVEN), 

209 presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), 

210 frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), 

211 logit_bias=model_settings.get('logit_bias', NOT_GIVEN), 

212 ) 

213 

214 def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: 

215 """Process a non-streamed response, and prepare a message to return.""" 

216 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) 

217 choice = response.choices[0] 

218 items: list[ModelResponsePart] = [] 

219 if choice.message.content is not None: 

220 items.append(TextPart(content=choice.message.content)) 

221 if choice.message.tool_calls is not None: 

222 for c in choice.message.tool_calls: 

223 items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) 

224 return ModelResponse(items, model_name=self.model_name, timestamp=timestamp) 

225 

226 async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse: 

227 """Process a streamed response, and prepare a streaming response to return.""" 

228 peekable_response = _utils.PeekableAsyncStream(response) 

229 first_chunk = await peekable_response.peek() 

230 if isinstance(first_chunk, _utils.Unset): 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true

231 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

232 

233 return GroqStreamedResponse( 

234 _response=peekable_response, 

235 _model_name=self.model_name, 

236 _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), 

237 ) 

238 

239 @classmethod 

240 def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: 

241 """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" 

242 if isinstance(message, ModelRequest): 

243 yield from cls._map_user_message(message) 

244 elif isinstance(message, ModelResponse): 

245 texts: list[str] = [] 

246 tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] 

247 for item in message.parts: 

248 if isinstance(item, TextPart): 

249 texts.append(item.content) 

250 elif isinstance(item, ToolCallPart): 

251 tool_calls.append(_map_tool_call(item)) 

252 else: 

253 assert_never(item) 

254 message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') 

255 if texts: 

256 # Note: model responses from this model should only have one text item, so the following 

257 # shouldn't merge multiple texts into one unless you switch models between runs: 

258 message_param['content'] = '\n\n'.join(texts) 

259 if tool_calls: 

260 message_param['tool_calls'] = tool_calls 

261 yield message_param 

262 else: 

263 assert_never(message) 

264 

265 @classmethod 

266 def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]: 

267 for part in message.parts: 

268 if isinstance(part, SystemPromptPart): 

269 yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) 

270 elif isinstance(part, UserPromptPart): 

271 yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) 

272 elif isinstance(part, ToolReturnPart): 

273 yield chat.ChatCompletionToolMessageParam( 

274 role='tool', 

275 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'), 

276 content=part.model_response_str(), 

277 ) 

278 elif isinstance(part, RetryPromptPart): 278 ↛ 267line 278 didn't jump to line 267 because the condition on line 278 was always true

279 if part.tool_name is None: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true

280 yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) 

281 else: 

282 yield chat.ChatCompletionToolMessageParam( 

283 role='tool', 

284 tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'), 

285 content=part.model_response(), 

286 ) 

287 

288 

289@dataclass 

290class GroqStreamedResponse(StreamedResponse): 

291 """Implementation of `StreamedResponse` for Groq models.""" 

292 

293 _response: AsyncIterable[ChatCompletionChunk] 

294 _timestamp: datetime 

295 

296 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

297 async for chunk in self._response: 

298 self._usage += _map_usage(chunk) 

299 

300 try: 

301 choice = chunk.choices[0] 

302 except IndexError: 

303 continue 

304 

305 # Handle the text part of the response 

306 content = choice.delta.content 

307 if content is not None: 

308 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) 

309 

310 # Handle the tool calls 

311 for dtc in choice.delta.tool_calls or []: 

312 maybe_event = self._parts_manager.handle_tool_call_delta( 

313 vendor_part_id=dtc.index, 

314 tool_name=dtc.function and dtc.function.name, 

315 args=dtc.function and dtc.function.arguments, 

316 tool_call_id=dtc.id, 

317 ) 

318 if maybe_event is not None: 

319 yield maybe_event 

320 

321 def timestamp(self) -> datetime: 

322 return self._timestamp 

323 

324 

325def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: 

326 return chat.ChatCompletionMessageToolCallParam( 

327 id=_guard_tool_call_id(t=t, model_source='Groq'), 

328 type='function', 

329 function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, 

330 ) 

331 

332 

333def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage: 

334 response_usage = None 

335 if isinstance(completion, ChatCompletion): 

336 response_usage = completion.usage 

337 elif completion.x_groq is not None: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true

338 response_usage = completion.x_groq.usage 

339 

340 if response_usage is None: 

341 return usage.Usage() 

342 

343 return usage.Usage( 

344 request_tokens=response_usage.prompt_tokens, 

345 response_tokens=response_usage.completion_tokens, 

346 total_tokens=response_usage.total_tokens, 

347 )