Coverage for tests/models/test_groq.py: 99.09%
211 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import json
4import os
5from collections.abc import Sequence
6from dataclasses import dataclass
7from datetime import datetime, timezone
8from functools import cached_property
9from typing import Any, Literal, Union, cast
10from unittest.mock import patch
12import httpx
13import pytest
14from inline_snapshot import snapshot
15from typing_extensions import TypedDict
17from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
18from pydantic_ai.messages import (
19 BinaryContent,
20 ImageUrl,
21 ModelRequest,
22 ModelResponse,
23 RetryPromptPart,
24 SystemPromptPart,
25 TextPart,
26 ToolCallPart,
27 ToolReturnPart,
28 UserPromptPart,
29)
30from pydantic_ai.usage import Usage
32from ..conftest import IsNow, IsStr, raise_if_exception, try_import
33from .mock_async_stream import MockAsyncStream
35with try_import() as imports_successful:
36 from groq import APIStatusError, AsyncGroq
37 from groq.types import chat
38 from groq.types.chat.chat_completion import Choice
39 from groq.types.chat.chat_completion_chunk import (
40 Choice as ChunkChoice,
41 ChoiceDelta,
42 ChoiceDeltaToolCall,
43 ChoiceDeltaToolCallFunction,
44 )
45 from groq.types.chat.chat_completion_message import ChatCompletionMessage
46 from groq.types.chat.chat_completion_message_tool_call import Function
47 from groq.types.completion_usage import CompletionUsage
49 from pydantic_ai.models.groq import GroqModel
50 from pydantic_ai.providers.groq import GroqProvider
52 # note: we use Union here so that casting works with Python 3.9
53 MockChatCompletion = Union[chat.ChatCompletion, Exception]
54 MockChatCompletionChunk = Union[chat.ChatCompletionChunk, Exception]
56pytestmark = [
57 pytest.mark.skipif(not imports_successful(), reason='groq not installed'),
58 pytest.mark.anyio,
59]
62def test_init():
63 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(api_key='foobar'))
64 assert m.client.api_key == 'foobar'
65 assert m.model_name == 'llama-3.3-70b-versatile'
66 assert m.system == 'groq'
67 assert m.base_url == 'https://api.groq.com'
70@dataclass
71class MockGroq:
72 completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None
73 stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None
74 index: int = 0
76 @cached_property
77 def chat(self) -> Any:
78 chat_completions = type('Completions', (), {'create': self.chat_completions_create})
79 return type('Chat', (), {'completions': chat_completions})
81 @classmethod
82 def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncGroq:
83 return cast(AsyncGroq, cls(completions=completions))
85 @classmethod
86 def create_mock_stream(
87 cls,
88 stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]],
89 ) -> AsyncGroq:
90 return cast(AsyncGroq, cls(stream=stream))
92 async def chat_completions_create(
93 self, *_args: Any, stream: bool = False, **_kwargs: Any
94 ) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]:
95 if stream:
96 assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
97 if isinstance(self.stream[0], Sequence): 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true
98 response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream[self.index])))
99 else:
100 response = MockAsyncStream(iter(cast(list[MockChatCompletionChunk], self.stream)))
101 else:
102 assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided'
103 if isinstance(self.completions, Sequence):
104 raise_if_exception(self.completions[self.index])
105 response = cast(chat.ChatCompletion, self.completions[self.index])
106 else:
107 raise_if_exception(self.completions)
108 response = cast(chat.ChatCompletion, self.completions)
109 self.index += 1
110 return response
113def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion:
114 return chat.ChatCompletion(
115 id='123',
116 choices=[Choice(finish_reason='stop', index=0, message=message)],
117 created=1704067200, # 2024-01-01
118 model='llama-3.3-70b-versatile-123',
119 object='chat.completion',
120 usage=usage,
121 )
124async def test_request_simple_success(allow_model_requests: None):
125 c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
126 mock_client = MockGroq.create_mock(c)
127 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
128 agent = Agent(m)
130 result = await agent.run('hello')
131 assert result.data == 'world'
132 assert result.usage() == snapshot(Usage(requests=1))
134 # reset the index so we get the same response again
135 mock_client.index = 0 # type: ignore
137 result = await agent.run('hello', message_history=result.new_messages())
138 assert result.data == 'world'
139 assert result.usage() == snapshot(Usage(requests=1))
140 assert result.all_messages() == snapshot(
141 [
142 ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
143 ModelResponse(
144 parts=[TextPart(content='world')],
145 model_name='llama-3.3-70b-versatile-123',
146 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
147 ),
148 ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
149 ModelResponse(
150 parts=[TextPart(content='world')],
151 model_name='llama-3.3-70b-versatile-123',
152 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
153 ),
154 ]
155 )
158async def test_request_simple_usage(allow_model_requests: None):
159 c = completion_message(
160 ChatCompletionMessage(content='world', role='assistant'),
161 usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
162 )
163 mock_client = MockGroq.create_mock(c)
164 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
165 agent = Agent(m)
167 result = await agent.run('Hello')
168 assert result.data == 'world'
171async def test_request_structured_response(allow_model_requests: None):
172 c = completion_message(
173 ChatCompletionMessage(
174 content=None,
175 role='assistant',
176 tool_calls=[
177 chat.ChatCompletionMessageToolCall(
178 id='123',
179 function=Function(arguments='{"response": [1, 2, 123]}', name='final_result'),
180 type='function',
181 )
182 ],
183 )
184 )
185 mock_client = MockGroq.create_mock(c)
186 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
187 agent = Agent(m, result_type=list[int])
189 result = await agent.run('Hello')
190 assert result.data == [1, 2, 123]
191 assert result.all_messages() == snapshot(
192 [
193 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
194 ModelResponse(
195 parts=[
196 ToolCallPart(
197 tool_name='final_result',
198 args='{"response": [1, 2, 123]}',
199 tool_call_id='123',
200 )
201 ],
202 model_name='llama-3.3-70b-versatile-123',
203 timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
204 ),
205 ModelRequest(
206 parts=[
207 ToolReturnPart(
208 tool_name='final_result',
209 content='Final result processed.',
210 tool_call_id='123',
211 timestamp=IsNow(tz=timezone.utc),
212 )
213 ]
214 ),
215 ]
216 )
219async def test_request_tool_call(allow_model_requests: None):
220 responses = [
221 completion_message(
222 ChatCompletionMessage(
223 content=None,
224 role='assistant',
225 tool_calls=[
226 chat.ChatCompletionMessageToolCall(
227 id='1',
228 function=Function(arguments='{"loc_name": "San Fransisco"}', name='get_location'),
229 type='function',
230 )
231 ],
232 ),
233 usage=CompletionUsage(
234 completion_tokens=1,
235 prompt_tokens=2,
236 total_tokens=3,
237 ),
238 ),
239 completion_message(
240 ChatCompletionMessage(
241 content=None,
242 role='assistant',
243 tool_calls=[
244 chat.ChatCompletionMessageToolCall(
245 id='2',
246 function=Function(arguments='{"loc_name": "London"}', name='get_location'),
247 type='function',
248 )
249 ],
250 ),
251 usage=CompletionUsage(
252 completion_tokens=2,
253 prompt_tokens=3,
254 total_tokens=6,
255 ),
256 ),
257 completion_message(ChatCompletionMessage(content='final response', role='assistant')),
258 ]
259 mock_client = MockGroq.create_mock(responses)
260 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
261 agent = Agent(m, system_prompt='this is the system prompt')
263 @agent.tool_plain
264 async def get_location(loc_name: str) -> str:
265 if loc_name == 'London':
266 return json.dumps({'lat': 51, 'lng': 0})
267 else:
268 raise ModelRetry('Wrong location, please try again')
270 result = await agent.run('Hello')
271 assert result.data == 'final response'
272 assert result.all_messages() == snapshot(
273 [
274 ModelRequest(
275 parts=[
276 SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)),
277 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
278 ]
279 ),
280 ModelResponse(
281 parts=[
282 ToolCallPart(
283 tool_name='get_location',
284 args='{"loc_name": "San Fransisco"}',
285 tool_call_id='1',
286 )
287 ],
288 model_name='llama-3.3-70b-versatile-123',
289 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
290 ),
291 ModelRequest(
292 parts=[
293 RetryPromptPart(
294 tool_name='get_location',
295 content='Wrong location, please try again',
296 tool_call_id='1',
297 timestamp=IsNow(tz=timezone.utc),
298 )
299 ]
300 ),
301 ModelResponse(
302 parts=[
303 ToolCallPart(
304 tool_name='get_location',
305 args='{"loc_name": "London"}',
306 tool_call_id='2',
307 )
308 ],
309 model_name='llama-3.3-70b-versatile-123',
310 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
311 ),
312 ModelRequest(
313 parts=[
314 ToolReturnPart(
315 tool_name='get_location',
316 content='{"lat": 51, "lng": 0}',
317 tool_call_id='2',
318 timestamp=IsNow(tz=timezone.utc),
319 )
320 ]
321 ),
322 ModelResponse(
323 parts=[TextPart(content='final response')],
324 model_name='llama-3.3-70b-versatile-123',
325 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
326 ),
327 ]
328 )
331FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']
334def chunk(delta: list[ChoiceDelta], finish_reason: FinishReason | None = None) -> chat.ChatCompletionChunk:
335 return chat.ChatCompletionChunk(
336 id='x',
337 choices=[
338 ChunkChoice(index=index, delta=delta, finish_reason=finish_reason) for index, delta in enumerate(delta)
339 ],
340 created=1704067200, # 2024-01-01
341 x_groq=None,
342 model='llama-3.3-70b-versatile',
343 object='chat.completion.chunk',
344 usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
345 )
348def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.ChatCompletionChunk:
349 return chunk([ChoiceDelta(content=text, role='assistant')], finish_reason=finish_reason)
352async def test_stream_text(allow_model_requests: None):
353 stream = text_chunk('hello '), text_chunk('world'), chunk([])
354 mock_client = MockGroq.create_mock_stream(stream)
355 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
356 agent = Agent(m)
358 async with agent.run_stream('') as result:
359 assert not result.is_complete
360 assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world'])
361 assert result.is_complete
364async def test_stream_text_finish_reason(allow_model_requests: None):
365 stream = text_chunk('hello '), text_chunk('world'), text_chunk('.', finish_reason='stop')
366 mock_client = MockGroq.create_mock_stream(stream)
367 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
368 agent = Agent(m)
370 async with agent.run_stream('') as result:
371 assert not result.is_complete
372 assert [c async for c in result.stream(debounce_by=None)] == snapshot(
373 ['hello ', 'hello world', 'hello world.', 'hello world.']
374 )
375 assert result.is_complete
378def struc_chunk(
379 tool_name: str | None, tool_arguments: str | None, finish_reason: FinishReason | None = None
380) -> chat.ChatCompletionChunk:
381 return chunk(
382 [
383 ChoiceDelta(
384 tool_calls=[
385 ChoiceDeltaToolCall(
386 index=0, function=ChoiceDeltaToolCallFunction(name=tool_name, arguments=tool_arguments)
387 )
388 ]
389 ),
390 ],
391 finish_reason=finish_reason,
392 )
395class MyTypedDict(TypedDict, total=False):
396 first: str
397 second: str
400async def test_stream_structured(allow_model_requests: None):
401 stream = (
402 chunk([ChoiceDelta()]),
403 chunk([ChoiceDelta(tool_calls=[])]),
404 chunk([ChoiceDelta(tool_calls=[ChoiceDeltaToolCall(index=0, function=None)])]),
405 chunk([ChoiceDelta(tool_calls=[ChoiceDeltaToolCall(index=0, function=None)])]),
406 struc_chunk('final_result', None),
407 chunk([ChoiceDelta(tool_calls=[ChoiceDeltaToolCall(index=0, function=None)])]),
408 struc_chunk(None, '{"first": "One'),
409 struc_chunk(None, '", "second": "Two"'),
410 struc_chunk(None, '}'),
411 chunk([]),
412 )
413 mock_client = MockGroq.create_mock_stream(stream)
414 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
415 agent = Agent(m, result_type=MyTypedDict)
417 async with agent.run_stream('') as result:
418 assert not result.is_complete
419 assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot(
420 [
421 {'first': 'One'},
422 {'first': 'One', 'second': 'Two'},
423 {'first': 'One', 'second': 'Two'},
424 {'first': 'One', 'second': 'Two'},
425 ]
426 )
427 assert result.is_complete
429 assert result.usage() == snapshot(Usage(requests=1))
430 assert result.all_messages() == snapshot(
431 [
432 ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]),
433 ModelResponse(
434 parts=[
435 ToolCallPart(
436 tool_name='final_result',
437 args='{"first": "One", "second": "Two"}',
438 tool_call_id=IsStr(),
439 )
440 ],
441 model_name='llama-3.3-70b-versatile',
442 timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
443 ),
444 ModelRequest(
445 parts=[
446 ToolReturnPart(
447 tool_name='final_result',
448 content='Final result processed.',
449 tool_call_id=IsStr(),
450 timestamp=IsNow(tz=timezone.utc),
451 )
452 ]
453 ),
454 ]
455 )
458async def test_stream_structured_finish_reason(allow_model_requests: None):
459 stream = (
460 struc_chunk('final_result', None),
461 struc_chunk(None, '{"first": "One'),
462 struc_chunk(None, '", "second": "Two"'),
463 struc_chunk(None, '}'),
464 struc_chunk(None, None, finish_reason='stop'),
465 )
466 mock_client = MockGroq.create_mock_stream(stream)
467 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
468 agent = Agent(m, result_type=MyTypedDict)
470 async with agent.run_stream('') as result:
471 assert not result.is_complete
472 assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot(
473 [
474 {'first': 'One'},
475 {'first': 'One', 'second': 'Two'},
476 {'first': 'One', 'second': 'Two'},
477 {'first': 'One', 'second': 'Two'},
478 {'first': 'One', 'second': 'Two'},
479 ]
480 )
481 assert result.is_complete
484async def test_no_content(allow_model_requests: None):
485 stream = chunk([ChoiceDelta()]), chunk([ChoiceDelta()])
486 mock_client = MockGroq.create_mock_stream(stream)
487 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
488 agent = Agent(m, result_type=MyTypedDict)
490 with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
491 async with agent.run_stream(''):
492 pass # pragma: no cover
495async def test_no_delta(allow_model_requests: None):
496 stream = chunk([]), text_chunk('hello '), text_chunk('world')
497 mock_client = MockGroq.create_mock_stream(stream)
498 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
499 agent = Agent(m)
501 async with agent.run_stream('') as result:
502 assert not result.is_complete
503 assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world'])
504 assert result.is_complete
507@pytest.mark.vcr()
508async def test_image_url_input(allow_model_requests: None, groq_api_key: str):
509 m = GroqModel('llama-3.2-11b-vision-preview', provider=GroqProvider(api_key=groq_api_key))
510 agent = Agent(m)
512 result = await agent.run(
513 [
514 'What is the name of this fruit?',
515 ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'),
516 ]
517 )
518 assert result.data == snapshot("""\
519The image you provided appears to be a potato. It is a root vegetable that belongs to the nightshade family. Potatoes are a popular and versatile crop, widely cultivated and consumed around the world.
521**Characteristics and Uses:**
523Potatoes are known for their starchy, slightly sweet flavor and soft, white interior. They come in various shapes, sizes, and colors including white, yellow, red, and purple. Some popular types of potatoes include:
525* Russet potatoes (also known as Idaho potatoes)
526* Red potatoes
527* Yukon gold potatoes
528* Sweet potatoes
530Potatoes are a versatile food that can be prepared in many different ways, such as baked, mashed, boiled, fried, or used in soups and stews. They are an excellent source of dietary fiber, potassium, and several key vitamins and minerals.\
531""")
534@pytest.mark.parametrize('media_type', ['audio/wav', 'audio/mpeg'])
535async def test_audio_as_binary_content_input(allow_model_requests: None, media_type: str):
536 c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
537 mock_client = MockGroq.create_mock(c)
538 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
539 agent = Agent(m)
541 base64_content = b'//uQZ'
543 with pytest.raises(RuntimeError, match='Only images are supported for binary content in Groq.'):
544 await agent.run(['hello', BinaryContent(data=base64_content, media_type=media_type)])
547@pytest.mark.vcr()
548async def test_image_as_binary_content_input(
549 allow_model_requests: None, groq_api_key: str, image_content: BinaryContent
550) -> None:
551 m = GroqModel('llama-3.2-11b-vision-preview', provider=GroqProvider(api_key=groq_api_key))
552 agent = Agent(m)
554 result = await agent.run(['What is the name of this fruit?', image_content])
555 assert result.data == snapshot(
556 "This is a kiwi, also known as a Chinese gooseberry. It's a small, green fruit with a hairy, brown skin and a bright green, juicy flesh inside. Kiwis are native to China and are often eaten raw, either on their own or added to salads, smoothies, and desserts. They're also a good source of vitamin C, vitamin K, and other nutrients."
557 )
560def test_model_status_error(allow_model_requests: None) -> None:
561 mock_client = MockGroq.create_mock(
562 APIStatusError(
563 'test error',
564 response=httpx.Response(status_code=500, request=httpx.Request('POST', 'https://example.com/v1')),
565 body={'error': 'test error'},
566 )
567 )
568 m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
569 agent = Agent(m)
570 with pytest.raises(ModelHTTPError) as exc_info:
571 agent.run_sync('hello')
572 assert str(exc_info.value) == snapshot(
573 "status_code: 500, model_name: llama-3.3-70b-versatile, body: {'error': 'test error'}"
574 )
577async def test_init_with_provider():
578 provider = GroqProvider(api_key='api-key')
579 model = GroqModel('llama3-8b-8192', provider=provider)
580 assert model.model_name == 'llama3-8b-8192'
581 assert model.client == provider.client
584async def test_init_with_provider_string():
585 with patch.dict(os.environ, {'GROQ_API_KEY': 'env-api-key'}, clear=False):
586 model = GroqModel('llama3-8b-8192', provider='groq')
587 assert model.model_name == 'llama3-8b-8192'
588 assert model.client is not None