Coverage for tests/models/test_groq.py: 99.09%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import json 

4import os 

5from collections.abc import Sequence 

6from dataclasses import dataclass 

7from datetime import datetime, timezone 

8from functools import cached_property 

9from typing import Any, Literal, Union, cast 

10from unittest.mock import patch 

11 

12import httpx 

13import pytest 

14from inline_snapshot import snapshot 

15from typing_extensions import TypedDict 

16 

17from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior 

18from pydantic_ai.messages import ( 

19 BinaryContent, 

20 ImageUrl, 

21 ModelRequest, 

22 ModelResponse, 

23 RetryPromptPart, 

24 SystemPromptPart, 

25 TextPart, 

26 ToolCallPart, 

27 ToolReturnPart, 

28 UserPromptPart, 

29) 

30from pydantic_ai.usage import Usage 

31 

32from ..conftest import IsNow, IsStr, raise_if_exception, try_import 

33from .mock_async_stream import MockAsyncStream 

34 

35with try_import() as imports_successful: 

36 from groq import APIStatusError, AsyncGroq 

37 from groq.types import chat 

38 from groq.types.chat.chat_completion import Choice 

39 from groq.types.chat.chat_completion_chunk import ( 

40 Choice as ChunkChoice, 

41 ChoiceDelta, 

42 ChoiceDeltaToolCall, 

43 ChoiceDeltaToolCallFunction, 

44 ) 

45 from groq.types.chat.chat_completion_message import ChatCompletionMessage 

46 from groq.types.chat.chat_completion_message_tool_call import Function 

47 from groq.types.completion_usage import CompletionUsage 

48 

49 from pydantic_ai.models.groq import GroqModel 

50 from pydantic_ai.providers.groq import GroqProvider 

51 

52 # note: we use Union here so that casting works with Python 3.9 

53 MockChatCompletion = Union[chat.ChatCompletion, Exception] 

54 MockChatCompletionChunk = Union[chat.ChatCompletionChunk, Exception] 

55 

56pytestmark = [ 

57 pytest.mark.skipif(not imports_successful(), reason='groq not installed'), 

58 pytest.mark.anyio, 

59] 

60 

61 

62def test_init(): 

63 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(api_key='foobar')) 

64 assert m.client.api_key == 'foobar' 

65 assert m.model_name == 'llama-3.3-70b-versatile' 

66 assert m.system == 'groq' 

67 assert m.base_url == 'https://api.groq.com' 

68 

69 

70@dataclass 

71class MockGroq: 

72 completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None 

73 stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None 

74 index: int = 0 

75 

76 @cached_property 

77 def chat(self) -> Any: 

78 chat_completions = type('Completions', (), {'create': self.chat_completions_create}) 

79 return type('Chat', (), {'completions': chat_completions}) 

80 

81 @classmethod 

82 def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncGroq: 

83 return cast(AsyncGroq, cls(completions=completions)) 

84 

85 @classmethod 

86 def create_mock_stream( 

87 cls, 

88 stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]], 

89 ) -> AsyncGroq: 

90 return cast(AsyncGroq, cls(stream=stream)) 

91 

92 async def chat_completions_create( 

93 self, *_args: Any, stream: bool = False, **_kwargs: Any 

94 ) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]: 

95 if stream: 

96 assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' 

97 if isinstance(self.stream[0], Sequence): 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true

98 response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream[self.index]))) 

99 else: 

100 response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream))) 

101 else: 

102 assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided' 

103 if isinstance(self.completions, Sequence): 

104 raise_if_exception(self.completions[self.index]) 

105 response = cast(chat.ChatCompletion, self.completions[self.index]) 

106 else: 

107 raise_if_exception(self.completions) 

108 response = cast(chat.ChatCompletion, self.completions) 

109 self.index += 1 

110 return response 

111 

112 

113def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion: 

114 return chat.ChatCompletion( 

115 id='123', 

116 choices=[Choice(finish_reason='stop', index=0, message=message)], 

117 created=1704067200, # 2024-01-01 

118 model='llama-3.3-70b-versatile-123', 

119 object='chat.completion', 

120 usage=usage, 

121 ) 

122 

123 

124async def test_request_simple_success(allow_model_requests: None): 

125 c = completion_message(ChatCompletionMessage(content='world', role='assistant')) 

126 mock_client = MockGroq.create_mock(c) 

127 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

128 agent = Agent(m) 

129 

130 result = await agent.run('hello') 

131 assert result.data == 'world' 

132 assert result.usage() == snapshot(Usage(requests=1)) 

133 

134 # reset the index so we get the same response again 

135 mock_client.index = 0 # type: ignore 

136 

137 result = await agent.run('hello', message_history=result.new_messages()) 

138 assert result.data == 'world' 

139 assert result.usage() == snapshot(Usage(requests=1)) 

140 assert result.all_messages() == snapshot( 

141 [ 

142 ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), 

143 ModelResponse( 

144 parts=[TextPart(content='world')], 

145 model_name='llama-3.3-70b-versatile-123', 

146 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), 

147 ), 

148 ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), 

149 ModelResponse( 

150 parts=[TextPart(content='world')], 

151 model_name='llama-3.3-70b-versatile-123', 

152 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), 

153 ), 

154 ] 

155 ) 

156 

157 

158async def test_request_simple_usage(allow_model_requests: None): 

159 c = completion_message( 

160 ChatCompletionMessage(content='world', role='assistant'), 

161 usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), 

162 ) 

163 mock_client = MockGroq.create_mock(c) 

164 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

165 agent = Agent(m) 

166 

167 result = await agent.run('Hello') 

168 assert result.data == 'world' 

169 

170 

171async def test_request_structured_response(allow_model_requests: None): 

172 c = completion_message( 

173 ChatCompletionMessage( 

174 content=None, 

175 role='assistant', 

176 tool_calls=[ 

177 chat.ChatCompletionMessageToolCall( 

178 id='123', 

179 function=Function(arguments='{"response": [1, 2, 123]}', name='final_result'), 

180 type='function', 

181 ) 

182 ], 

183 ) 

184 ) 

185 mock_client = MockGroq.create_mock(c) 

186 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

187 agent = Agent(m, result_type=list[int]) 

188 

189 result = await agent.run('Hello') 

190 assert result.data == [1, 2, 123] 

191 assert result.all_messages() == snapshot( 

192 [ 

193 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

194 ModelResponse( 

195 parts=[ 

196 ToolCallPart( 

197 tool_name='final_result', 

198 args='{"response": [1, 2, 123]}', 

199 tool_call_id='123', 

200 ) 

201 ], 

202 model_name='llama-3.3-70b-versatile-123', 

203 timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), 

204 ), 

205 ModelRequest( 

206 parts=[ 

207 ToolReturnPart( 

208 tool_name='final_result', 

209 content='Final result processed.', 

210 tool_call_id='123', 

211 timestamp=IsNow(tz=timezone.utc), 

212 ) 

213 ] 

214 ), 

215 ] 

216 ) 

217 

218 

219async def test_request_tool_call(allow_model_requests: None): 

220 responses = [ 

221 completion_message( 

222 ChatCompletionMessage( 

223 content=None, 

224 role='assistant', 

225 tool_calls=[ 

226 chat.ChatCompletionMessageToolCall( 

227 id='1', 

228 function=Function(arguments='{"loc_name": "San Fransisco"}', name='get_location'), 

229 type='function', 

230 ) 

231 ], 

232 ), 

233 usage=CompletionUsage( 

234 completion_tokens=1, 

235 prompt_tokens=2, 

236 total_tokens=3, 

237 ), 

238 ), 

239 completion_message( 

240 ChatCompletionMessage( 

241 content=None, 

242 role='assistant', 

243 tool_calls=[ 

244 chat.ChatCompletionMessageToolCall( 

245 id='2', 

246 function=Function(arguments='{"loc_name": "London"}', name='get_location'), 

247 type='function', 

248 ) 

249 ], 

250 ), 

251 usage=CompletionUsage( 

252 completion_tokens=2, 

253 prompt_tokens=3, 

254 total_tokens=6, 

255 ), 

256 ), 

257 completion_message(ChatCompletionMessage(content='final response', role='assistant')), 

258 ] 

259 mock_client = MockGroq.create_mock(responses) 

260 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

261 agent = Agent(m, system_prompt='this is the system prompt') 

262 

263 @agent.tool_plain 

264 async def get_location(loc_name: str) -> str: 

265 if loc_name == 'London': 

266 return json.dumps({'lat': 51, 'lng': 0}) 

267 else: 

268 raise ModelRetry('Wrong location, please try again') 

269 

270 result = await agent.run('Hello') 

271 assert result.data == 'final response' 

272 assert result.all_messages() == snapshot( 

273 [ 

274 ModelRequest( 

275 parts=[ 

276 SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)), 

277 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

278 ] 

279 ), 

280 ModelResponse( 

281 parts=[ 

282 ToolCallPart( 

283 tool_name='get_location', 

284 args='{"loc_name": "San Fransisco"}', 

285 tool_call_id='1', 

286 ) 

287 ], 

288 model_name='llama-3.3-70b-versatile-123', 

289 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), 

290 ), 

291 ModelRequest( 

292 parts=[ 

293 RetryPromptPart( 

294 tool_name='get_location', 

295 content='Wrong location, please try again', 

296 tool_call_id='1', 

297 timestamp=IsNow(tz=timezone.utc), 

298 ) 

299 ] 

300 ), 

301 ModelResponse( 

302 parts=[ 

303 ToolCallPart( 

304 tool_name='get_location', 

305 args='{"loc_name": "London"}', 

306 tool_call_id='2', 

307 ) 

308 ], 

309 model_name='llama-3.3-70b-versatile-123', 

310 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), 

311 ), 

312 ModelRequest( 

313 parts=[ 

314 ToolReturnPart( 

315 tool_name='get_location', 

316 content='{"lat": 51, "lng": 0}', 

317 tool_call_id='2', 

318 timestamp=IsNow(tz=timezone.utc), 

319 ) 

320 ] 

321 ), 

322 ModelResponse( 

323 parts=[TextPart(content='final response')], 

324 model_name='llama-3.3-70b-versatile-123', 

325 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), 

326 ), 

327 ] 

328 ) 

329 

330 

331FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] 

332 

333 

334def chunk(delta: list[ChoiceDelta], finish_reason: FinishReason | None = None) -> chat.ChatCompletionChunk: 

335 return chat.ChatCompletionChunk( 

336 id='x', 

337 choices=[ 

338 ChunkChoice(index=index, delta=delta, finish_reason=finish_reason) for index, delta in enumerate(delta) 

339 ], 

340 created=1704067200, # 2024-01-01 

341 x_groq=None, 

342 model='llama-3.3-70b-versatile', 

343 object='chat.completion.chunk', 

344 usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), 

345 ) 

346 

347 

348def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.ChatCompletionChunk: 

349 return chunk([ChoiceDelta(content=text, role='assistant')], finish_reason=finish_reason) 

350 

351 

352async def test_stream_text(allow_model_requests: None): 

353 stream = text_chunk('hello '), text_chunk('world'), chunk([]) 

354 mock_client = MockGroq.create_mock_stream(stream) 

355 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

356 agent = Agent(m) 

357 

358 async with agent.run_stream('') as result: 

359 assert not result.is_complete 

360 assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world']) 

361 assert result.is_complete 

362 

363 

364async def test_stream_text_finish_reason(allow_model_requests: None): 

365 stream = text_chunk('hello '), text_chunk('world'), text_chunk('.', finish_reason='stop') 

366 mock_client = MockGroq.create_mock_stream(stream) 

367 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

368 agent = Agent(m) 

369 

370 async with agent.run_stream('') as result: 

371 assert not result.is_complete 

372 assert [c async for c in result.stream(debounce_by=None)] == snapshot( 

373 ['hello ', 'hello world', 'hello world.', 'hello world.'] 

374 ) 

375 assert result.is_complete 

376 

377 

378def struc_chunk( 

379 tool_name: str | None, tool_arguments: str | None, finish_reason: FinishReason | None = None 

380) -> chat.ChatCompletionChunk: 

381 return chunk( 

382 [ 

383 ChoiceDelta( 

384 tool_calls=[ 

385 ChoiceDeltaToolCall( 

386 index=0, function=ChoiceDeltaToolCallFunction(name=tool_name, arguments=tool_arguments) 

387 ) 

388 ] 

389 ), 

390 ], 

391 finish_reason=finish_reason, 

392 ) 

393 

394 

395class MyTypedDict(TypedDict, total=False): 

396 first: str 

397 second: str 

398 

399 

400async def test_stream_structured(allow_model_requests: None): 

401 stream = ( 

402 chunk([ChoiceDelta()]), 

403 chunk([ChoiceDelta(tool_calls=[])]), 

404 chunk([ChoiceDelta(tool_calls=[ChoiceDeltaToolCall(index=0, function=None)])]), 

405 chunk([ChoiceDelta(tool_calls=[ChoiceDeltaToolCall(index=0, function=None)])]), 

406 struc_chunk('final_result', None), 

407 chunk([ChoiceDelta(tool_calls=[ChoiceDeltaToolCall(index=0, function=None)])]), 

408 struc_chunk(None, '{"first": "One'), 

409 struc_chunk(None, '", "second": "Two"'), 

410 struc_chunk(None, '}'), 

411 chunk([]), 

412 ) 

413 mock_client = MockGroq.create_mock_stream(stream) 

414 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

415 agent = Agent(m, result_type=MyTypedDict) 

416 

417 async with agent.run_stream('') as result: 

418 assert not result.is_complete 

419 assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( 

420 [ 

421 {'first': 'One'}, 

422 {'first': 'One', 'second': 'Two'}, 

423 {'first': 'One', 'second': 'Two'}, 

424 {'first': 'One', 'second': 'Two'}, 

425 ] 

426 ) 

427 assert result.is_complete 

428 

429 assert result.usage() == snapshot(Usage(requests=1)) 

430 assert result.all_messages() == snapshot( 

431 [ 

432 ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), 

433 ModelResponse( 

434 parts=[ 

435 ToolCallPart( 

436 tool_name='final_result', 

437 args='{"first": "One", "second": "Two"}', 

438 tool_call_id=IsStr(), 

439 ) 

440 ], 

441 model_name='llama-3.3-70b-versatile', 

442 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), 

443 ), 

444 ModelRequest( 

445 parts=[ 

446 ToolReturnPart( 

447 tool_name='final_result', 

448 content='Final result processed.', 

449 tool_call_id=IsStr(), 

450 timestamp=IsNow(tz=timezone.utc), 

451 ) 

452 ] 

453 ), 

454 ] 

455 ) 

456 

457 

458async def test_stream_structured_finish_reason(allow_model_requests: None): 

459 stream = ( 

460 struc_chunk('final_result', None), 

461 struc_chunk(None, '{"first": "One'), 

462 struc_chunk(None, '", "second": "Two"'), 

463 struc_chunk(None, '}'), 

464 struc_chunk(None, None, finish_reason='stop'), 

465 ) 

466 mock_client = MockGroq.create_mock_stream(stream) 

467 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

468 agent = Agent(m, result_type=MyTypedDict) 

469 

470 async with agent.run_stream('') as result: 

471 assert not result.is_complete 

472 assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( 

473 [ 

474 {'first': 'One'}, 

475 {'first': 'One', 'second': 'Two'}, 

476 {'first': 'One', 'second': 'Two'}, 

477 {'first': 'One', 'second': 'Two'}, 

478 {'first': 'One', 'second': 'Two'}, 

479 ] 

480 ) 

481 assert result.is_complete 

482 

483 

484async def test_no_content(allow_model_requests: None): 

485 stream = chunk([ChoiceDelta()]), chunk([ChoiceDelta()]) 

486 mock_client = MockGroq.create_mock_stream(stream) 

487 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

488 agent = Agent(m, result_type=MyTypedDict) 

489 

490 with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): 

491 async with agent.run_stream(''): 

492 pass # pragma: no cover 

493 

494 

495async def test_no_delta(allow_model_requests: None): 

496 stream = chunk([]), text_chunk('hello '), text_chunk('world') 

497 mock_client = MockGroq.create_mock_stream(stream) 

498 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

499 agent = Agent(m) 

500 

501 async with agent.run_stream('') as result: 

502 assert not result.is_complete 

503 assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world']) 

504 assert result.is_complete 

505 

506 

507@pytest.mark.vcr() 

508async def test_image_url_input(allow_model_requests: None, groq_api_key: str): 

509 m = GroqModel('llama-3.2-11b-vision-preview', provider=GroqProvider(api_key=groq_api_key)) 

510 agent = Agent(m) 

511 

512 result = await agent.run( 

513 [ 

514 'What is the name of this fruit?', 

515 ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'), 

516 ] 

517 ) 

518 assert result.data == snapshot("""\ 

519The image you provided appears to be a potato. It is a root vegetable that belongs to the nightshade family. Potatoes are a popular and versatile crop, widely cultivated and consumed around the world. 

520 

521**Characteristics and Uses:** 

522 

523Potatoes are known for their starchy, slightly sweet flavor and soft, white interior. They come in various shapes, sizes, and colors including white, yellow, red, and purple. Some popular types of potatoes include: 

524 

525* Russet potatoes (also known as Idaho potatoes) 

526* Red potatoes 

527* Yukon gold potatoes 

528* Sweet potatoes 

529 

530Potatoes are a versatile food that can be prepared in many different ways, such as baked, mashed, boiled, fried, or used in soups and stews. They are an excellent source of dietary fiber, potassium, and several key vitamins and minerals.\ 

531""") 

532 

533 

534@pytest.mark.parametrize('media_type', ['audio/wav', 'audio/mpeg']) 

535async def test_audio_as_binary_content_input(allow_model_requests: None, media_type: str): 

536 c = completion_message(ChatCompletionMessage(content='world', role='assistant')) 

537 mock_client = MockGroq.create_mock(c) 

538 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

539 agent = Agent(m) 

540 

541 base64_content = b'//uQZ' 

542 

543 with pytest.raises(RuntimeError, match='Only images are supported for binary content in Groq.'): 

544 await agent.run(['hello', BinaryContent(data=base64_content, media_type=media_type)]) 

545 

546 

547@pytest.mark.vcr() 

548async def test_image_as_binary_content_input( 

549 allow_model_requests: None, groq_api_key: str, image_content: BinaryContent 

550) -> None: 

551 m = GroqModel('llama-3.2-11b-vision-preview', provider=GroqProvider(api_key=groq_api_key)) 

552 agent = Agent(m) 

553 

554 result = await agent.run(['What is the name of this fruit?', image_content]) 

555 assert result.data == snapshot( 

556 "This is a kiwi, also known as a Chinese gooseberry. It's a small, green fruit with a hairy, brown skin and a bright green, juicy flesh inside. Kiwis are native to China and are often eaten raw, either on their own or added to salads, smoothies, and desserts. They're also a good source of vitamin C, vitamin K, and other nutrients." 

557 ) 

558 

559 

560def test_model_status_error(allow_model_requests: None) -> None: 

561 mock_client = MockGroq.create_mock( 

562 APIStatusError( 

563 'test error', 

564 response=httpx.Response(status_code=500, request=httpx.Request('POST', 'https://example.com/v1')), 

565 body={'error': 'test error'}, 

566 ) 

567 ) 

568 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) 

569 agent = Agent(m) 

570 with pytest.raises(ModelHTTPError) as exc_info: 

571 agent.run_sync('hello') 

572 assert str(exc_info.value) == snapshot( 

573 "status_code: 500, model_name: llama-3.3-70b-versatile, body: {'error': 'test error'}" 

574 ) 

575 

576 

577async def test_init_with_provider(): 

578 provider = GroqProvider(api_key='api-key') 

579 model = GroqModel('llama3-8b-8192', provider=provider) 

580 assert model.model_name == 'llama3-8b-8192' 

581 assert model.client == provider.client 

582 

583 

584async def test_init_with_provider_string(): 

585 with patch.dict(os.environ, {'GROQ_API_KEY': 'env-api-key'}, clear=False): 

586 model = GroqModel('llama3-8b-8192', provider='groq') 

587 assert model.model_name == 'llama3-8b-8192' 

588 assert model.client is not None