Coverage for tests/models/test_gemini.py: 99.70%
323 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1# pyright: reportPrivateUsage=false
2from __future__ import annotations as _annotations
4import datetime
5import json
6from collections.abc import AsyncIterator, Callable, Sequence
7from dataclasses import dataclass
8from datetime import timezone
10import httpx
11import pytest
12from inline_snapshot import snapshot
13from pydantic import BaseModel, Field
14from typing_extensions import Literal, TypeAlias
16from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, UserError
17from pydantic_ai.exceptions import ModelHTTPError
18from pydantic_ai.messages import (
19 BinaryContent,
20 DocumentUrl,
21 ImageUrl,
22 ModelRequest,
23 ModelResponse,
24 RetryPromptPart,
25 SystemPromptPart,
26 TextPart,
27 ToolCallPart,
28 ToolReturnPart,
29 UserPromptPart,
30)
31from pydantic_ai.models import ModelRequestParameters
32from pydantic_ai.models.gemini import (
33 GeminiModel,
34 GeminiModelSettings,
35 _content_model_response,
36 _gemini_response_ta,
37 _gemini_streamed_response_ta,
38 _GeminiCandidates,
39 _GeminiContent,
40 _GeminiFunction,
41 _GeminiFunctionCallingConfig,
42 _GeminiResponse,
43 _GeminiSafetyRating,
44 _GeminiToolConfig,
45 _GeminiTools,
46 _GeminiUsageMetaData,
47)
48from pydantic_ai.providers.google_gla import GoogleGLAProvider
49from pydantic_ai.result import Usage
50from pydantic_ai.tools import ToolDefinition
52from ..conftest import ClientWithHandler, IsNow, IsStr, TestEnv
54pytestmark = pytest.mark.anyio
57async def test_model_simple(allow_model_requests: None):
58 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
59 assert isinstance(m.client, httpx.AsyncClient)
60 assert m.model_name == 'gemini-1.5-flash'
61 assert 'x-goog-api-key' in m.client.headers
63 arc = ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[])
64 tools = m._get_tools(arc)
65 tool_config = m._get_tool_config(arc, tools)
66 assert tools is None
67 assert tool_config is None
70async def test_model_tools(allow_model_requests: None):
71 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
72 tools = [
73 ToolDefinition(
74 'foo',
75 'This is foo',
76 {'type': 'object', 'title': 'Foo', 'properties': {'bar': {'type': 'number', 'title': 'Bar'}}},
77 ),
78 ToolDefinition(
79 'apple',
80 'This is apple',
81 {
82 'type': 'object',
83 'properties': {
84 'banana': {'type': 'array', 'title': 'Banana', 'items': {'type': 'number', 'title': 'Bar'}}
85 },
86 },
87 ),
88 ]
89 result_tool = ToolDefinition(
90 'result',
91 'This is the tool for the final Result',
92 {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']},
93 )
95 arc = ModelRequestParameters(function_tools=tools, allow_text_result=True, result_tools=[result_tool])
96 tools = m._get_tools(arc)
97 tool_config = m._get_tool_config(arc, tools)
98 assert tools == snapshot(
99 _GeminiTools(
100 function_declarations=[
101 _GeminiFunction(
102 name='foo',
103 description='This is foo',
104 parameters={'type': 'object', 'properties': {'bar': {'type': 'number'}}},
105 ),
106 _GeminiFunction(
107 name='apple',
108 description='This is apple',
109 parameters={
110 'type': 'object',
111 'properties': {'banana': {'type': 'array', 'items': {'type': 'number'}}},
112 },
113 ),
114 _GeminiFunction(
115 name='result',
116 description='This is the tool for the final Result',
117 parameters={
118 'type': 'object',
119 'properties': {'spam': {'type': 'number'}},
120 'required': ['spam'],
121 },
122 ),
123 ]
124 )
125 )
126 assert tool_config is None
129async def test_require_response_tool(allow_model_requests: None):
130 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
131 result_tool = ToolDefinition(
132 'result',
133 'This is the tool for the final Result',
134 {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}},
135 )
136 arc = ModelRequestParameters(function_tools=[], allow_text_result=False, result_tools=[result_tool])
137 tools = m._get_tools(arc)
138 tool_config = m._get_tool_config(arc, tools)
139 assert tools == snapshot(
140 _GeminiTools(
141 function_declarations=[
142 _GeminiFunction(
143 name='result',
144 description='This is the tool for the final Result',
145 parameters={
146 'type': 'object',
147 'properties': {'spam': {'type': 'number'}},
148 },
149 ),
150 ]
151 )
152 )
153 assert tool_config == snapshot(
154 _GeminiToolConfig(
155 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=['result'])
156 )
157 )
160async def test_json_def_replaced(allow_model_requests: None):
161 class Location(BaseModel):
162 lat: float
163 lng: float = 1.1
165 class Locations(BaseModel):
166 locations: list[Location]
168 json_schema = Locations.model_json_schema()
169 assert json_schema == snapshot(
170 {
171 '$defs': {
172 'Location': {
173 'properties': {
174 'lat': {'title': 'Lat', 'type': 'number'},
175 'lng': {'default': 1.1, 'title': 'Lng', 'type': 'number'},
176 },
177 'required': ['lat'],
178 'title': 'Location',
179 'type': 'object',
180 }
181 },
182 'properties': {'locations': {'items': {'$ref': '#/$defs/Location'}, 'title': 'Locations', 'type': 'array'}},
183 'required': ['locations'],
184 'title': 'Locations',
185 'type': 'object',
186 }
187 )
189 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
190 result_tool = ToolDefinition(
191 'result',
192 'This is the tool for the final Result',
193 json_schema,
194 )
195 assert m._get_tools(
196 ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
197 ) == snapshot(
198 _GeminiTools(
199 function_declarations=[
200 _GeminiFunction(
201 name='result',
202 description='This is the tool for the final Result',
203 parameters={
204 'properties': {
205 'locations': {
206 'items': {
207 'properties': {
208 'lat': {'type': 'number'},
209 'lng': {'type': 'number'},
210 },
211 'required': ['lat'],
212 'type': 'object',
213 },
214 'type': 'array',
215 }
216 },
217 'required': ['locations'],
218 'type': 'object',
219 },
220 )
221 ]
222 )
223 )
226async def test_json_def_replaced_any_of(allow_model_requests: None):
227 class Location(BaseModel):
228 lat: float
229 lng: float
231 class Locations(BaseModel):
232 op_location: Location | None = None
234 json_schema = Locations.model_json_schema()
236 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
237 result_tool = ToolDefinition(
238 'result',
239 'This is the tool for the final Result',
240 json_schema,
241 )
242 assert m._get_tools(
243 ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
244 ) == snapshot(
245 _GeminiTools(
246 function_declarations=[
247 _GeminiFunction(
248 name='result',
249 description='This is the tool for the final Result',
250 parameters={
251 'properties': {
252 'op_location': {
253 'properties': {
254 'lat': {'type': 'number'},
255 'lng': {'type': 'number'},
256 },
257 'required': ['lat', 'lng'],
258 'nullable': True,
259 'type': 'object',
260 }
261 },
262 'type': 'object',
263 },
264 )
265 ]
266 )
267 )
270async def test_json_def_recursive(allow_model_requests: None):
271 class Location(BaseModel):
272 lat: float
273 lng: float
274 nested_locations: list[Location]
276 json_schema = Location.model_json_schema()
277 assert json_schema == snapshot(
278 {
279 '$defs': {
280 'Location': {
281 'properties': {
282 'lat': {'title': 'Lat', 'type': 'number'},
283 'lng': {'title': 'Lng', 'type': 'number'},
284 'nested_locations': {
285 'items': {'$ref': '#/$defs/Location'},
286 'title': 'Nested Locations',
287 'type': 'array',
288 },
289 },
290 'required': ['lat', 'lng', 'nested_locations'],
291 'title': 'Location',
292 'type': 'object',
293 }
294 },
295 '$ref': '#/$defs/Location',
296 }
297 )
299 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
300 result_tool = ToolDefinition(
301 'result',
302 'This is the tool for the final Result',
303 json_schema,
304 )
305 with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'):
306 m._get_tools(ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool]))
309async def test_json_def_date(allow_model_requests: None):
310 class FormattedStringFields(BaseModel):
311 d: datetime.date
312 dt: datetime.datetime
313 t: datetime.time = Field(description='')
314 td: datetime.timedelta = Field(description='my timedelta')
316 json_schema = FormattedStringFields.model_json_schema()
317 assert json_schema == snapshot(
318 {
319 'properties': {
320 'd': {'format': 'date', 'title': 'D', 'type': 'string'},
321 'dt': {'format': 'date-time', 'title': 'Dt', 'type': 'string'},
322 't': {'format': 'time', 'title': 'T', 'type': 'string', 'description': ''},
323 'td': {'format': 'duration', 'title': 'Td', 'type': 'string', 'description': 'my timedelta'},
324 },
325 'required': ['d', 'dt', 't', 'td'],
326 'title': 'FormattedStringFields',
327 'type': 'object',
328 }
329 )
331 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
332 result_tool = ToolDefinition(
333 'result',
334 'This is the tool for the final Result',
335 json_schema,
336 )
337 assert m._get_tools(
338 ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
339 ) == snapshot(
340 _GeminiTools(
341 function_declarations=[
342 _GeminiFunction(
343 description='This is the tool for the final Result',
344 name='result',
345 parameters={
346 'properties': {
347 'd': {'description': 'Format: date', 'type': 'string'},
348 'dt': {'description': 'Format: date-time', 'type': 'string'},
349 't': {'description': 'Format: time', 'type': 'string'},
350 'td': {'description': 'my timedelta (format: duration)', 'type': 'string'},
351 },
352 'required': ['d', 'dt', 't', 'td'],
353 'type': 'object',
354 },
355 )
356 ]
357 )
358 )
361@dataclass
362class AsyncByteStreamList(httpx.AsyncByteStream):
363 data: list[bytes]
365 async def __aiter__(self) -> AsyncIterator[bytes]:
366 for chunk in self.data:
367 yield chunk
370ResOrList: TypeAlias = '_GeminiResponse | httpx.AsyncByteStream | Sequence[_GeminiResponse | httpx.AsyncByteStream]'
371GetGeminiClient: TypeAlias = 'Callable[[ResOrList], httpx.AsyncClient]'
374@pytest.fixture
375async def get_gemini_client(
376 client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
377) -> GetGeminiClient:
378 env.set('GEMINI_API_KEY', 'via-env-var')
380 def create_client(response_or_list: ResOrList) -> httpx.AsyncClient:
381 index = 0
383 def handler(request: httpx.Request) -> httpx.Response:
384 nonlocal index
386 ua = request.headers.get('User-Agent')
387 assert isinstance(ua, str) and ua.startswith('pydantic-ai')
389 if isinstance(response_or_list, Sequence):
390 response = response_or_list[index]
391 index += 1
392 else:
393 response = response_or_list
395 if isinstance(response, httpx.AsyncByteStream):
396 content: bytes | None = None
397 stream: httpx.AsyncByteStream | None = response
398 else:
399 content = _gemini_response_ta.dump_json(response, by_alias=True)
400 stream = None
402 return httpx.Response(
403 200,
404 content=content,
405 stream=stream,
406 headers={'Content-Type': 'application/json'},
407 )
409 return client_with_handler(handler)
411 return create_client
414def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse:
415 candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[])
416 if finish_reason: # pragma: no cover
417 candidate['finish_reason'] = finish_reason
418 return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage(), model_version='gemini-1.5-flash-123')
421def example_usage() -> _GeminiUsageMetaData:
422 return _GeminiUsageMetaData(prompt_token_count=1, candidates_token_count=2, total_token_count=3)
425async def test_text_success(get_gemini_client: GetGeminiClient):
426 response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')])))
427 gemini_client = get_gemini_client(response)
428 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
429 agent = Agent(m)
431 result = await agent.run('Hello')
432 assert result.data == 'Hello world'
433 assert result.all_messages() == snapshot(
434 [
435 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
436 ModelResponse(
437 parts=[TextPart(content='Hello world')],
438 model_name='gemini-1.5-flash-123',
439 timestamp=IsNow(tz=timezone.utc),
440 ),
441 ]
442 )
443 assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
445 result = await agent.run('Hello', message_history=result.new_messages())
446 assert result.data == 'Hello world'
447 assert result.all_messages() == snapshot(
448 [
449 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
450 ModelResponse(
451 parts=[TextPart(content='Hello world')],
452 model_name='gemini-1.5-flash-123',
453 timestamp=IsNow(tz=timezone.utc),
454 ),
455 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
456 ModelResponse(
457 parts=[TextPart(content='Hello world')],
458 model_name='gemini-1.5-flash-123',
459 timestamp=IsNow(tz=timezone.utc),
460 ),
461 ]
462 )
465async def test_request_structured_response(get_gemini_client: GetGeminiClient):
466 response = gemini_response(
467 _content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]})]))
468 )
469 gemini_client = get_gemini_client(response)
470 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
471 agent = Agent(m, result_type=list[int])
473 result = await agent.run('Hello')
474 assert result.data == [1, 2, 123]
475 assert result.all_messages() == snapshot(
476 [
477 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
478 ModelResponse(
479 parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2, 123]}, tool_call_id=IsStr())],
480 model_name='gemini-1.5-flash-123',
481 timestamp=IsNow(tz=timezone.utc),
482 ),
483 ModelRequest(
484 parts=[
485 ToolReturnPart(
486 tool_name='final_result',
487 content='Final result processed.',
488 timestamp=IsNow(tz=timezone.utc),
489 tool_call_id=IsStr(),
490 )
491 ]
492 ),
493 ]
494 )
497async def test_request_tool_call(get_gemini_client: GetGeminiClient):
498 responses = [
499 gemini_response(
500 _content_model_response(ModelResponse(parts=[ToolCallPart('get_location', {'loc_name': 'San Fransisco'})]))
501 ),
502 gemini_response(
503 _content_model_response(
504 ModelResponse(
505 parts=[
506 ToolCallPart('get_location', {'loc_name': 'London'}),
507 ToolCallPart('get_location', {'loc_name': 'New York'}),
508 ]
509 )
510 )
511 ),
512 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('final response')]))),
513 ]
514 gemini_client = get_gemini_client(responses)
515 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
516 agent = Agent(m, system_prompt='this is the system prompt')
518 @agent.tool_plain
519 async def get_location(loc_name: str) -> str:
520 if loc_name == 'London':
521 return json.dumps({'lat': 51, 'lng': 0})
522 elif loc_name == 'New York':
523 return json.dumps({'lat': 41, 'lng': -74})
524 else:
525 raise ModelRetry('Wrong location, please try again')
527 result = await agent.run('Hello')
528 assert result.data == 'final response'
529 assert result.all_messages() == snapshot(
530 [
531 ModelRequest(
532 parts=[
533 SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)),
534 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
535 ]
536 ),
537 ModelResponse(
538 parts=[
539 ToolCallPart(tool_name='get_location', args={'loc_name': 'San Fransisco'}, tool_call_id=IsStr())
540 ],
541 model_name='gemini-1.5-flash-123',
542 timestamp=IsNow(tz=timezone.utc),
543 ),
544 ModelRequest(
545 parts=[
546 RetryPromptPart(
547 content='Wrong location, please try again',
548 tool_name='get_location',
549 tool_call_id=IsStr(),
550 timestamp=IsNow(tz=timezone.utc),
551 )
552 ]
553 ),
554 ModelResponse(
555 parts=[
556 ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()),
557 ToolCallPart(tool_name='get_location', args={'loc_name': 'New York'}, tool_call_id=IsStr()),
558 ],
559 model_name='gemini-1.5-flash-123',
560 timestamp=IsNow(tz=timezone.utc),
561 ),
562 ModelRequest(
563 parts=[
564 ToolReturnPart(
565 tool_name='get_location',
566 content='{"lat": 51, "lng": 0}',
567 timestamp=IsNow(tz=timezone.utc),
568 tool_call_id=IsStr(),
569 ),
570 ToolReturnPart(
571 tool_name='get_location',
572 content='{"lat": 41, "lng": -74}',
573 timestamp=IsNow(tz=timezone.utc),
574 tool_call_id=IsStr(),
575 ),
576 ]
577 ),
578 ModelResponse(
579 parts=[TextPart(content='final response')],
580 model_name='gemini-1.5-flash-123',
581 timestamp=IsNow(tz=timezone.utc),
582 ),
583 ]
584 )
585 assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9))
588async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None):
589 env.set('GEMINI_API_KEY', 'via-env-var')
591 def handler(_: httpx.Request):
592 return httpx.Response(401, content='invalid request')
594 gemini_client = client_with_handler(handler)
595 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
596 agent = Agent(m, system_prompt='this is the system prompt')
598 with pytest.raises(ModelHTTPError) as exc_info:
599 await agent.run('Hello')
601 assert str(exc_info.value) == snapshot('status_code: 401, model_name: gemini-1.5-flash, body: invalid request')
604async def test_stream_text(get_gemini_client: GetGeminiClient):
605 responses = [
606 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))),
607 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
608 ]
609 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
610 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
611 gemini_client = get_gemini_client(stream)
612 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
613 agent = Agent(m)
615 async with agent.run_stream('Hello') as result:
616 chunks = [chunk async for chunk in result.stream(debounce_by=None)]
617 assert chunks == snapshot(
618 [
619 'Hello ',
620 'Hello world',
621 # This last value is repeated due to the debounce_by=None combined with the need to emit
622 # a final empty chunk to signal the end of the stream
623 'Hello world',
624 ]
625 )
626 assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
628 async with agent.run_stream('Hello') as result:
629 chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)]
630 assert chunks == snapshot(['Hello ', 'world'])
631 assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
634async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
635 # Probably safe to remove this test once https://github.com/pydantic/pydantic-core/issues/1633 is resolved
636 responses = [
637 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('abc')]))),
638 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('€def')]))),
639 ]
640 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
642 for i in range(10, 1000):
643 parts = [json_data[:i], json_data[i:]]
644 try:
645 parts[0].decode()
646 except UnicodeDecodeError:
647 break
648 else: # pragma: no cover
649 assert False, 'failed to find a spot in payload that would break unicode parsing'
651 with pytest.raises(UnicodeDecodeError):
652 # Ensure the first part is _not_ valid unicode
653 parts[0].decode()
655 stream = AsyncByteStreamList(parts)
656 gemini_client = get_gemini_client(stream)
657 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
658 agent = Agent(m)
660 async with agent.run_stream('Hello') as result:
661 chunks = [chunk async for chunk in result.stream(debounce_by=None)]
662 assert chunks == snapshot(['abc', 'abc€def', 'abc€def'])
663 assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
666async def test_stream_text_no_data(get_gemini_client: GetGeminiClient):
667 responses = [_GeminiResponse(candidates=[], usage_metadata=example_usage())]
668 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
669 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
670 gemini_client = get_gemini_client(stream)
671 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
672 agent = Agent(m)
673 with pytest.raises(UnexpectedModelBehavior, match='Streamed response ended without con'):
674 async with agent.run_stream('Hello'):
675 pass
678async def test_stream_structured(get_gemini_client: GetGeminiClient):
679 responses = [
680 gemini_response(
681 _content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2]})])),
682 ),
683 ]
684 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
685 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
686 gemini_client = get_gemini_client(stream)
687 model = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
688 agent = Agent(model, result_type=tuple[int, int])
690 async with agent.run_stream('Hello') as result:
691 chunks = [chunk async for chunk in result.stream(debounce_by=None)]
692 assert chunks == snapshot([(1, 2), (1, 2)])
693 assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
696async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
697 first_responses = [
698 gemini_response(
699 _content_model_response(ModelResponse(parts=[ToolCallPart('foo', {'x': 'a'})])),
700 ),
701 gemini_response(
702 _content_model_response(ModelResponse(parts=[ToolCallPart('bar', {'y': 'b'})])),
703 ),
704 ]
705 d1 = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True)
706 first_stream = AsyncByteStreamList([d1[:100], d1[100:200], d1[200:300], d1[300:]])
708 second_responses = [
709 gemini_response(
710 _content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2]})])),
711 ),
712 ]
713 d2 = _gemini_streamed_response_ta.dump_json(second_responses, by_alias=True)
714 second_stream = AsyncByteStreamList([d2[:100], d2[100:]])
716 gemini_client = get_gemini_client([first_stream, second_stream])
717 model = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
718 agent = Agent(model, result_type=tuple[int, int])
719 tool_calls: list[str] = []
721 @agent.tool_plain
722 async def foo(x: str) -> str:
723 tool_calls.append(f'foo({x=!r})')
724 return x
726 @agent.tool_plain
727 async def bar(y: str) -> str:
728 tool_calls.append(f'bar({y=!r})')
729 return y
731 async with agent.run_stream('Hello') as result:
732 response = await result.get_data()
733 assert response == snapshot((1, 2))
734 assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9))
735 assert result.all_messages() == snapshot(
736 [
737 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
738 ModelResponse(
739 parts=[
740 ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()),
741 ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()),
742 ],
743 model_name='gemini-1.5-flash',
744 timestamp=IsNow(tz=timezone.utc),
745 ),
746 ModelRequest(
747 parts=[
748 ToolReturnPart(
749 tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr()
750 ),
751 ToolReturnPart(
752 tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr()
753 ),
754 ]
755 ),
756 ModelResponse(
757 parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())],
758 model_name='gemini-1.5-flash',
759 timestamp=IsNow(tz=timezone.utc),
760 ),
761 ModelRequest(
762 parts=[
763 ToolReturnPart(
764 tool_name='final_result',
765 content='Final result processed.',
766 timestamp=IsNow(tz=timezone.utc),
767 tool_call_id=IsStr(),
768 )
769 ]
770 ),
771 ]
772 )
773 assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"])
776async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
777 responses = [
778 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))),
779 gemini_response(
780 _GeminiContent(
781 role='model',
782 parts=[
783 {'text': 'foo'},
784 {'function_call': {'name': 'get_location', 'args': {'loc_name': 'San Fransisco'}}},
785 ],
786 )
787 ),
788 ]
789 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
790 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
791 gemini_client = get_gemini_client(stream)
792 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
793 agent = Agent(m)
795 @agent.tool_plain()
796 def get_location(loc_name: str) -> str:
797 return f'Location for {loc_name}'
799 async with agent.run_stream('Hello') as result:
800 data = await result.get_data()
802 assert data == 'Hello foo'
805async def test_empty_text_ignored():
806 content = _content_model_response(
807 ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]}), TextPart(content='xxx')])
808 )
809 # text included
810 assert content == snapshot(
811 {
812 'role': 'model',
813 'parts': [
814 {'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}},
815 {'text': 'xxx'},
816 ],
817 }
818 )
820 content = _content_model_response(
821 ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]}), TextPart(content='')])
822 )
823 # text skipped
824 assert content == snapshot(
825 {
826 'role': 'model',
827 'parts': [{'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}}],
828 }
829 )
832async def test_model_settings(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None) -> None:
833 def handler(request: httpx.Request) -> httpx.Response:
834 generation_config = json.loads(request.content)['generation_config']
835 assert generation_config == {
836 'max_output_tokens': 1,
837 'temperature': 0.1,
838 'top_p': 0.2,
839 'presence_penalty': 0.3,
840 'frequency_penalty': 0.4,
841 }
842 return httpx.Response(
843 200,
844 content=_gemini_response_ta.dump_json(
845 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
846 by_alias=True,
847 ),
848 headers={'Content-Type': 'application/json'},
849 )
851 gemini_client = client_with_handler(handler)
852 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock'))
853 agent = Agent(m)
855 result = await agent.run(
856 'hello',
857 model_settings={
858 'max_tokens': 1,
859 'temperature': 0.1,
860 'top_p': 0.2,
861 'presence_penalty': 0.3,
862 'frequency_penalty': 0.4,
863 },
864 )
865 assert result.data == 'world'
868def gemini_no_content_response(
869 safety_ratings: list[_GeminiSafetyRating], finish_reason: Literal['SAFETY'] | None = 'SAFETY'
870) -> _GeminiResponse:
871 candidate = _GeminiCandidates(safety_ratings=safety_ratings)
872 if finish_reason: 872 ↛ 874line 872 didn't jump to line 874 because the condition on line 872 was always true
873 candidate['finish_reason'] = finish_reason
874 return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
877async def test_safety_settings_unsafe(
878 client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
879) -> None:
880 try:
882 def handler(request: httpx.Request) -> httpx.Response:
883 safety_settings = json.loads(request.content)['safety_settings']
884 assert safety_settings == [
885 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
886 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
887 ]
889 return httpx.Response(
890 200,
891 content=_gemini_response_ta.dump_json(
892 gemini_no_content_response(
893 finish_reason='SAFETY',
894 safety_ratings=[
895 {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'MEDIUM', 'blocked': True}
896 ],
897 ),
898 by_alias=True,
899 ),
900 headers={'Content-Type': 'application/json'},
901 )
903 gemini_client = client_with_handler(handler)
905 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock'))
906 agent = Agent(m)
908 await agent.run(
909 'a request for something rude',
910 model_settings=GeminiModelSettings(
911 gemini_safety_settings=[
912 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
913 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
914 ]
915 ),
916 )
917 except UnexpectedModelBehavior as e:
918 assert repr(e) == "UnexpectedModelBehavior('Safety settings triggered')"
921async def test_safety_settings_safe(
922 client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
923) -> None:
924 def handler(request: httpx.Request) -> httpx.Response:
925 safety_settings = json.loads(request.content)['safety_settings']
926 assert safety_settings == [
927 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
928 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
929 ]
931 return httpx.Response(
932 200,
933 content=_gemini_response_ta.dump_json(
934 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
935 by_alias=True,
936 ),
937 headers={'Content-Type': 'application/json'},
938 )
940 gemini_client = client_with_handler(handler)
941 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock'))
942 agent = Agent(m)
944 result = await agent.run(
945 'hello',
946 model_settings=GeminiModelSettings(
947 gemini_safety_settings=[
948 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
949 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
950 ]
951 ),
952 )
953 assert result.data == 'world'
956@pytest.mark.vcr()
957async def test_image_as_binary_content_input(
958 allow_model_requests: None, gemini_api_key: str, image_content: BinaryContent
959) -> None:
960 m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key))
961 agent = Agent(m)
963 result = await agent.run(['What is the name of this fruit?', image_content])
964 assert result.data == snapshot('The fruit in the image is a kiwi.')
967@pytest.mark.vcr()
968async def test_image_url_input(allow_model_requests: None, gemini_api_key: str) -> None:
969 m = GeminiModel('gemini-2.0-flash-exp', provider=GoogleGLAProvider(api_key=gemini_api_key))
970 agent = Agent(m)
972 image_url = ImageUrl(url='https://goo.gle/instrument-img')
974 result = await agent.run(['What is the name of this fruit?', image_url])
975 assert result.data == snapshot("This is not a fruit; it's a pipe organ console.")
978@pytest.mark.vcr()
979async def test_document_url_input(allow_model_requests: None, gemini_api_key: str) -> None:
980 m = GeminiModel('gemini-2.0-flash-thinking-exp-01-21', provider=GoogleGLAProvider(api_key=gemini_api_key))
981 agent = Agent(m)
983 document_url = DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf')
985 result = await agent.run(['What is the main content on this document?', document_url])
986 assert result.data == snapshot('The main content of this document is that it is a **dummy PDF file**.')