Coverage for pydantic_ai_slim/pydantic_ai/models/gemini.py: 92.75%
423 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import base64
4import re
5from collections.abc import AsyncIterator, Sequence
6from contextlib import asynccontextmanager
7from copy import deepcopy
8from dataclasses import dataclass, field
9from datetime import datetime
10from typing import Annotated, Any, Literal, Protocol, Union, cast
11from uuid import uuid4
13import httpx
14import pydantic
15from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
16from typing_extensions import NotRequired, TypedDict, assert_never
18from pydantic_ai.providers import Provider, infer_provider
20from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
21from ..messages import (
22 AudioUrl,
23 BinaryContent,
24 DocumentUrl,
25 ImageUrl,
26 ModelMessage,
27 ModelRequest,
28 ModelResponse,
29 ModelResponsePart,
30 ModelResponseStreamEvent,
31 RetryPromptPart,
32 SystemPromptPart,
33 TextPart,
34 ToolCallPart,
35 ToolReturnPart,
36 UserPromptPart,
37)
38from ..settings import ModelSettings
39from ..tools import ToolDefinition
40from . import (
41 Model,
42 ModelRequestParameters,
43 StreamedResponse,
44 cached_async_http_client,
45 check_allow_model_requests,
46 get_user_agent,
47)
49LatestGeminiModelNames = Literal[
50 'gemini-1.5-flash',
51 'gemini-1.5-flash-8b',
52 'gemini-1.5-pro',
53 'gemini-1.0-pro',
54 'gemini-2.0-flash-exp',
55 'gemini-2.0-flash-thinking-exp-01-21',
56 'gemini-exp-1206',
57 'gemini-2.0-flash',
58 'gemini-2.0-flash-lite-preview-02-05',
59 'gemini-2.0-pro-exp-02-05',
60]
61"""Latest Gemini models."""
63GeminiModelName = Union[str, LatestGeminiModelNames]
64"""Possible Gemini model names.
66Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
67allow any name in the type hints.
68See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
69"""
72class GeminiModelSettings(ModelSettings):
73 """Settings used for a Gemini model request.
75 ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
76 """
78 gemini_safety_settings: list[GeminiSafetySettings]
81@dataclass(init=False)
82class GeminiModel(Model):
83 """A model that uses Gemini via `generativelanguage.googleapis.com` API.
85 This is implemented from scratch rather than using a dedicated SDK, good API documentation is
86 available [here](https://ai.google.dev/api).
88 Apart from `__init__`, all methods are private or match those of the base class.
89 """
91 client: httpx.AsyncClient = field(repr=False)
93 _model_name: GeminiModelName = field(repr=False)
94 _provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
95 _auth: AuthProtocol | None = field(repr=False)
96 _url: str | None = field(repr=False)
97 _system: str = field(default='gemini', repr=False)
99 def __init__(
100 self,
101 model_name: GeminiModelName,
102 *,
103 provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
104 ):
105 """Initialize a Gemini model.
107 Args:
108 model_name: The name of the model to use.
109 provider: The provider to use for authentication and API access. Can be either the string
110 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
111 If not provided, a new provider will be created using the other parameters.
112 """
113 self._model_name = model_name
114 self._provider = provider
116 if isinstance(provider, str):
117 provider = infer_provider(provider)
118 self._system = provider.name
119 self.client = provider.client
120 self._url = str(self.client.base_url)
122 @property
123 def base_url(self) -> str:
124 assert self._url is not None, 'URL not initialized'
125 return self._url
127 async def request(
128 self,
129 messages: list[ModelMessage],
130 model_settings: ModelSettings | None,
131 model_request_parameters: ModelRequestParameters,
132 ) -> tuple[ModelResponse, usage.Usage]:
133 check_allow_model_requests()
134 async with self._make_request(
135 messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
136 ) as http_response:
137 response = _gemini_response_ta.validate_json(await http_response.aread())
138 return self._process_response(response), _metadata_as_usage(response)
140 @asynccontextmanager
141 async def request_stream(
142 self,
143 messages: list[ModelMessage],
144 model_settings: ModelSettings | None,
145 model_request_parameters: ModelRequestParameters,
146 ) -> AsyncIterator[StreamedResponse]:
147 check_allow_model_requests()
148 async with self._make_request(
149 messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
150 ) as http_response:
151 yield await self._process_streamed_response(http_response)
153 @property
154 def model_name(self) -> GeminiModelName:
155 """The model name."""
156 return self._model_name
158 @property
159 def system(self) -> str:
160 """The system / model provider."""
161 return self._system
163 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
164 tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
165 if model_request_parameters.result_tools:
166 tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools]
167 return _GeminiTools(function_declarations=tools) if tools else None
169 def _get_tool_config(
170 self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
171 ) -> _GeminiToolConfig | None:
172 if model_request_parameters.allow_text_result:
173 return None
174 elif tools: 174 ↛ 177line 174 didn't jump to line 177 because the condition on line 174 was always true
175 return _tool_config([t['name'] for t in tools['function_declarations']])
176 else:
177 return _tool_config([])
179 @asynccontextmanager
180 async def _make_request(
181 self,
182 messages: list[ModelMessage],
183 streamed: bool,
184 model_settings: GeminiModelSettings,
185 model_request_parameters: ModelRequestParameters,
186 ) -> AsyncIterator[HTTPResponse]:
187 tools = self._get_tools(model_request_parameters)
188 tool_config = self._get_tool_config(model_request_parameters, tools)
189 sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
191 request_data = _GeminiRequest(contents=contents)
192 if sys_prompt_parts:
193 request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
194 if tools is not None:
195 request_data['tools'] = tools
196 if tool_config is not None:
197 request_data['tool_config'] = tool_config
199 generation_config: _GeminiGenerationConfig = {}
200 if model_settings:
201 if (max_tokens := model_settings.get('max_tokens')) is not None:
202 generation_config['max_output_tokens'] = max_tokens
203 if (temperature := model_settings.get('temperature')) is not None:
204 generation_config['temperature'] = temperature
205 if (top_p := model_settings.get('top_p')) is not None:
206 generation_config['top_p'] = top_p
207 if (presence_penalty := model_settings.get('presence_penalty')) is not None:
208 generation_config['presence_penalty'] = presence_penalty
209 if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
210 generation_config['frequency_penalty'] = frequency_penalty
211 if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []: 211 ↛ 213line 211 didn't jump to line 213 because the condition on line 211 was always true
212 request_data['safety_settings'] = gemini_safety_settings
213 if generation_config:
214 request_data['generation_config'] = generation_config
216 headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
217 url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
219 request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
220 async with self.client.stream(
221 'POST',
222 url,
223 content=request_json,
224 headers=headers,
225 timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
226 ) as r:
227 if (status_code := r.status_code) != 200:
228 await r.aread()
229 if status_code >= 400: 229 ↛ 231line 229 didn't jump to line 231 because the condition on line 229 was always true
230 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text)
231 raise UnexpectedModelBehavior(f'Unexpected response from gemini {status_code}', r.text)
232 yield r
234 def _process_response(self, response: _GeminiResponse) -> ModelResponse:
235 if len(response['candidates']) != 1: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true
236 raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
237 if 'content' not in response['candidates'][0]:
238 if response['candidates'][0].get('finish_reason') == 'SAFETY': 238 ↛ 241line 238 didn't jump to line 241 because the condition on line 238 was always true
239 raise UnexpectedModelBehavior('Safety settings triggered', str(response))
240 else:
241 raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
242 parts = response['candidates'][0]['content']['parts']
243 return _process_response_from_parts(parts, model_name=response.get('model_version', self._model_name))
245 async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
246 """Process a streamed response, and prepare a streaming response to return."""
247 aiter_bytes = http_response.aiter_bytes()
248 start_response: _GeminiResponse | None = None
249 content = bytearray()
251 async for chunk in aiter_bytes:
252 content.extend(chunk)
253 responses = _gemini_streamed_response_ta.validate_json(
254 _ensure_decodeable(content),
255 experimental_allow_partial='trailing-strings',
256 )
257 if responses: 257 ↛ 251line 257 didn't jump to line 251 because the condition on line 257 was always true
258 last = responses[-1]
259 if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'):
260 start_response = last
261 break
263 if start_response is None:
264 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
266 return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
268 @classmethod
269 async def _message_to_gemini_content(
270 cls, messages: list[ModelMessage]
271 ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
272 sys_prompt_parts: list[_GeminiTextPart] = []
273 contents: list[_GeminiContent] = []
274 for m in messages:
275 if isinstance(m, ModelRequest):
276 message_parts: list[_GeminiPartUnion] = []
278 for part in m.parts:
279 if isinstance(part, SystemPromptPart):
280 sys_prompt_parts.append(_GeminiTextPart(text=part.content))
281 elif isinstance(part, UserPromptPart):
282 message_parts.extend(await cls._map_user_prompt(part))
283 elif isinstance(part, ToolReturnPart):
284 message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
285 elif isinstance(part, RetryPromptPart):
286 if part.tool_name is None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 message_parts.append(_GeminiTextPart(text=part.model_response()))
288 else:
289 response = {'call_error': part.model_response()}
290 message_parts.append(_response_part_from_response(part.tool_name, response))
291 else:
292 assert_never(part)
294 if message_parts: 294 ↛ 274line 294 didn't jump to line 274 because the condition on line 294 was always true
295 contents.append(_GeminiContent(role='user', parts=message_parts))
296 elif isinstance(m, ModelResponse):
297 contents.append(_content_model_response(m))
298 else:
299 assert_never(m)
301 return sys_prompt_parts, contents
303 @staticmethod
304 async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
305 if isinstance(part.content, str):
306 return [{'text': part.content}]
307 else:
308 content: list[_GeminiPartUnion] = []
309 for item in part.content:
310 if isinstance(item, str):
311 content.append({'text': item})
312 elif isinstance(item, BinaryContent):
313 base64_encoded = base64.b64encode(item.data).decode('utf-8')
314 content.append(
315 _GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
316 )
317 elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl)):
318 client = cached_async_http_client()
319 response = await client.get(item.url, follow_redirects=True)
320 response.raise_for_status()
321 mime_type = response.headers['Content-Type'].split(';')[0]
322 inline_data = _GeminiInlineDataPart(
323 inline_data={'data': base64.b64encode(response.content).decode('utf-8'), 'mime_type': mime_type}
324 )
325 content.append(inline_data)
326 else:
327 assert_never(item)
328 return content
331class AuthProtocol(Protocol):
332 """Abstract definition for Gemini authentication."""
334 async def headers(self) -> dict[str, str]: ...
337@dataclass
338class ApiKeyAuth:
339 """Authentication using an API key for the `X-Goog-Api-Key` header."""
341 api_key: str
343 async def headers(self) -> dict[str, str]:
344 # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
345 return {'X-Goog-Api-Key': self.api_key}
348@dataclass
349class GeminiStreamedResponse(StreamedResponse):
350 """Implementation of `StreamedResponse` for the Gemini model."""
352 _model_name: GeminiModelName
353 _content: bytearray
354 _stream: AsyncIterator[bytes]
355 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
357 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
358 async for gemini_response in self._get_gemini_responses():
359 candidate = gemini_response['candidates'][0]
360 if 'content' not in candidate: 360 ↛ 361line 360 didn't jump to line 361 because the condition on line 360 was never true
361 raise UnexpectedModelBehavior('Streamed response has no content field')
362 gemini_part: _GeminiPartUnion
363 for gemini_part in candidate['content']['parts']:
364 if 'text' in gemini_part:
365 # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
366 # amongst the tool call deltas
367 yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
369 elif 'function_call' in gemini_part: 369 ↛ 383line 369 didn't jump to line 383 because the condition on line 369 was always true
370 # Here, we assume all function_call parts are complete and don't have deltas.
371 # We do this by assigning a unique randomly generated "vendor_part_id".
372 # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
373 # it would just be a bit more complicated. And we'd need to confirm the intended semantics.
374 maybe_event = self._parts_manager.handle_tool_call_delta(
375 vendor_part_id=uuid4(),
376 tool_name=gemini_part['function_call']['name'],
377 args=gemini_part['function_call']['args'],
378 tool_call_id=None,
379 )
380 if maybe_event is not None: 380 ↛ 363line 380 didn't jump to line 363 because the condition on line 380 was always true
381 yield maybe_event
382 else:
383 assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
385 async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
386 # This method exists to ensure we only yield completed items, so we don't need to worry about
387 # partial gemini responses, which would make everything more complicated
389 gemini_responses: list[_GeminiResponse] = []
390 current_gemini_response_index = 0
391 # Right now, there are some circumstances where we will have information that could be yielded sooner than it is
392 # But changing that would make things a lot more complicated.
393 async for chunk in self._stream:
394 self._content.extend(chunk)
396 gemini_responses = _gemini_streamed_response_ta.validate_json(
397 _ensure_decodeable(self._content),
398 experimental_allow_partial='trailing-strings',
399 )
401 # The idea: yield only up to the latest response, which might still be partial.
402 # Note that if the latest response is complete, we could yield it immediately, but there's not a good
403 # allow_partial API to determine if the last item in the list is complete.
404 responses_to_yield = gemini_responses[:-1]
405 for r in responses_to_yield[current_gemini_response_index:]:
406 current_gemini_response_index += 1
407 self._usage += _metadata_as_usage(r)
408 yield r
410 # Now yield the final response, which should be complete
411 if gemini_responses: 411 ↛ exitline 411 didn't return from function '_get_gemini_responses' because the condition on line 411 was always true
412 r = gemini_responses[-1]
413 self._usage += _metadata_as_usage(r)
414 yield r
416 @property
417 def model_name(self) -> GeminiModelName:
418 """Get the model name of the response."""
419 return self._model_name
421 @property
422 def timestamp(self) -> datetime:
423 """Get the timestamp of the response."""
424 return self._timestamp
427# We use typed dicts to define the Gemini API response schema
428# once Pydantic partial validation supports, dataclasses, we could revert to using them
429# TypeAdapters take care of validation and serialization
432@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
433class _GeminiRequest(TypedDict):
434 """Schema for an API request to the Gemini API.
436 See <https://ai.google.dev/api/generate-content#request-body> for API docs.
437 """
439 contents: list[_GeminiContent]
440 tools: NotRequired[_GeminiTools]
441 tool_config: NotRequired[_GeminiToolConfig]
442 safety_settings: NotRequired[list[GeminiSafetySettings]]
443 # we don't implement `generationConfig`, instead we use a named tool for the response
444 system_instruction: NotRequired[_GeminiTextContent]
445 """
446 Developer generated system instructions, see
447 <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
448 """
449 generation_config: NotRequired[_GeminiGenerationConfig]
452class GeminiSafetySettings(TypedDict):
453 """Safety settings options for Gemini model request.
455 See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions.
456 For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings).
457 """
459 category: Literal[
460 'HARM_CATEGORY_UNSPECIFIED',
461 'HARM_CATEGORY_HARASSMENT',
462 'HARM_CATEGORY_HATE_SPEECH',
463 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
464 'HARM_CATEGORY_DANGEROUS_CONTENT',
465 'HARM_CATEGORY_CIVIC_INTEGRITY',
466 ]
467 """
468 Safety settings category.
469 """
471 threshold: Literal[
472 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
473 'BLOCK_LOW_AND_ABOVE',
474 'BLOCK_MEDIUM_AND_ABOVE',
475 'BLOCK_ONLY_HIGH',
476 'BLOCK_NONE',
477 'OFF',
478 ]
479 """
480 Safety settings threshold.
481 """
484class _GeminiGenerationConfig(TypedDict, total=False):
485 """Schema for an API request to the Gemini API.
487 Note there are many additional fields available that have not been added yet.
489 See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
490 """
492 max_output_tokens: int
493 temperature: float
494 top_p: float
495 presence_penalty: float
496 frequency_penalty: float
499class _GeminiContent(TypedDict):
500 role: Literal['user', 'model']
501 parts: list[_GeminiPartUnion]
504def _content_model_response(m: ModelResponse) -> _GeminiContent:
505 parts: list[_GeminiPartUnion] = []
506 for item in m.parts:
507 if isinstance(item, ToolCallPart):
508 parts.append(_function_call_part_from_call(item))
509 elif isinstance(item, TextPart):
510 if item.content:
511 parts.append(_GeminiTextPart(text=item.content))
512 else:
513 assert_never(item)
514 return _GeminiContent(role='model', parts=parts)
517class _GeminiTextPart(TypedDict):
518 text: str
521class _GeminiInlineData(TypedDict):
522 data: str
523 mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
526class _GeminiInlineDataPart(TypedDict):
527 """See <https://ai.google.dev/api/caching#Blob>."""
529 inline_data: Annotated[_GeminiInlineData, pydantic.Field(alias='inlineData')]
532class _GeminiFileData(TypedDict):
533 """See <https://ai.google.dev/api/caching#FileData>."""
535 file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
536 mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
539class _GeminiFileDataPart(TypedDict):
540 file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
543class _GeminiFunctionCallPart(TypedDict):
544 function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
547def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
548 return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
551def _process_response_from_parts(
552 parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
553) -> ModelResponse:
554 items: list[ModelResponsePart] = []
555 for part in parts:
556 if 'text' in part:
557 items.append(TextPart(content=part['text']))
558 elif 'function_call' in part: 558 ↛ 560line 558 didn't jump to line 560 because the condition on line 558 was always true
559 items.append(ToolCallPart(tool_name=part['function_call']['name'], args=part['function_call']['args']))
560 elif 'function_response' in part:
561 raise UnexpectedModelBehavior(
562 f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
563 )
564 return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
567class _GeminiFunctionCall(TypedDict):
568 """See <https://ai.google.dev/api/caching#FunctionCall>."""
570 name: str
571 args: dict[str, Any]
574class _GeminiFunctionResponsePart(TypedDict):
575 function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
578def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
579 return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response))
582class _GeminiFunctionResponse(TypedDict):
583 """See <https://ai.google.dev/api/caching#FunctionResponse>."""
585 name: str
586 response: dict[str, Any]
589def _part_discriminator(v: Any) -> str:
590 if isinstance(v, dict): 590 ↛ 601line 590 didn't jump to line 601 because the condition on line 590 was always true
591 if 'text' in v:
592 return 'text'
593 elif 'inlineData' in v: 593 ↛ 594line 593 didn't jump to line 594 because the condition on line 593 was never true
594 return 'inline_data'
595 elif 'fileData' in v: 595 ↛ 596line 595 didn't jump to line 596 because the condition on line 595 was never true
596 return 'file_data'
597 elif 'functionCall' in v or 'function_call' in v:
598 return 'function_call'
599 elif 'functionResponse' in v or 'function_response' in v:
600 return 'function_response'
601 return 'text'
604# See <https://ai.google.dev/api/caching#Part>
605# we don't currently support other part types
606# TODO discriminator
607_GeminiPartUnion = Annotated[
608 Union[
609 Annotated[_GeminiTextPart, pydantic.Tag('text')],
610 Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
611 Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
612 Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
613 Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
614 ],
615 pydantic.Discriminator(_part_discriminator),
616]
619class _GeminiTextContent(TypedDict):
620 role: Literal['user', 'model']
621 parts: list[_GeminiTextPart]
624class _GeminiTools(TypedDict):
625 function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
628class _GeminiFunction(TypedDict):
629 name: str
630 description: str
631 parameters: NotRequired[dict[str, Any]]
632 """
633 ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema
634 <https://ai.google.dev/gemini-api/docs/function-calling#function_declarations>
635 and
636 <https://ai.google.dev/api/caching#FunctionDeclaration>
637 """
640def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
641 json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
642 f = _GeminiFunction(
643 name=tool.name,
644 description=tool.description,
645 )
646 if json_schema.get('properties'): 646 ↛ 648line 646 didn't jump to line 648 because the condition on line 646 was always true
647 f['parameters'] = json_schema
648 return f
651class _GeminiToolConfig(TypedDict):
652 function_calling_config: _GeminiFunctionCallingConfig
655def _tool_config(function_names: list[str]) -> _GeminiToolConfig:
656 return _GeminiToolConfig(
657 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names)
658 )
661class _GeminiFunctionCallingConfig(TypedDict):
662 mode: Literal['ANY', 'AUTO']
663 allowed_function_names: list[str]
666@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
667class _GeminiResponse(TypedDict):
668 """Schema for the response from the Gemini API.
670 See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>
671 and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse>
672 """
674 candidates: list[_GeminiCandidates]
675 # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
676 usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
677 prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
678 model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
681class _GeminiCandidates(TypedDict):
682 """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
684 content: NotRequired[_GeminiContent]
685 finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]]
686 """
687 See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
688 but let's wait until we see them and know what they mean to add them here.
689 """
690 avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
691 index: NotRequired[int]
692 safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
695class _GeminiUsageMetaData(TypedDict, total=False):
696 """See <https://ai.google.dev/api/generate-content#FinishReason>.
698 The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
699 """
701 prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
702 candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
703 total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
704 cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
707def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
708 metadata = response.get('usage_metadata')
709 if metadata is None: 709 ↛ 710line 709 didn't jump to line 710 because the condition on line 709 was never true
710 return usage.Usage()
711 details: dict[str, int] = {}
712 if cached_content_token_count := metadata.get('cached_content_token_count'): 712 ↛ 713line 712 didn't jump to line 713 because the condition on line 712 was never true
713 details['cached_content_token_count'] = cached_content_token_count
714 return usage.Usage(
715 request_tokens=metadata.get('prompt_token_count', 0),
716 response_tokens=metadata.get('candidates_token_count', 0),
717 total_tokens=metadata.get('total_token_count', 0),
718 details=details,
719 )
722class _GeminiSafetyRating(TypedDict):
723 """See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>."""
725 category: Literal[
726 'HARM_CATEGORY_HARASSMENT',
727 'HARM_CATEGORY_HATE_SPEECH',
728 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
729 'HARM_CATEGORY_DANGEROUS_CONTENT',
730 'HARM_CATEGORY_CIVIC_INTEGRITY',
731 ]
732 probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
733 blocked: NotRequired[bool]
736class _GeminiPromptFeedback(TypedDict):
737 """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
739 block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
740 safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
743_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
744_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
746# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
747_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
750class _GeminiJsonSchema:
751 """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
753 Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
754 a subset of OpenAPI v3.0.3.
756 Specifically:
757 * gemini doesn't allow the `title` keyword to be set
758 * gemini doesn't allow `$defs` — we need to inline the definitions where possible
759 """
761 def __init__(self, schema: _utils.ObjectJsonSchema):
762 self.schema = deepcopy(schema)
763 self.defs = self.schema.pop('$defs', {})
765 def simplify(self) -> dict[str, Any]:
766 self._simplify(self.schema, refs_stack=())
767 return self.schema
769 def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
770 schema.pop('title', None)
771 schema.pop('default', None)
772 if ref := schema.pop('$ref', None):
773 # noinspection PyTypeChecker
774 key = re.sub(r'^#/\$defs/', '', ref)
775 if key in refs_stack:
776 raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
777 refs_stack += (key,)
778 schema_def = self.defs[key]
779 self._simplify(schema_def, refs_stack)
780 schema.update(schema_def)
781 return
783 if any_of := schema.get('anyOf'):
784 for item_schema in any_of:
785 self._simplify(item_schema, refs_stack)
786 if len(any_of) == 2 and {'type': 'null'} in any_of: 786 ↛ 794line 786 didn't jump to line 794 because the condition on line 786 was always true
787 for item_schema in any_of: 787 ↛ 794line 787 didn't jump to line 794 because the loop on line 787 didn't complete
788 if item_schema != {'type': 'null'}: 788 ↛ 787line 788 didn't jump to line 787 because the condition on line 788 was always true
789 schema.clear()
790 schema.update(item_schema)
791 schema['nullable'] = True
792 return
794 type_ = schema.get('type')
796 if type_ == 'object':
797 self._object(schema, refs_stack)
798 elif type_ == 'array':
799 return self._array(schema, refs_stack)
800 elif type_ == 'string' and (fmt := schema.pop('format', None)):
801 description = schema.get('description')
802 if description:
803 schema['description'] = f'{description} (format: {fmt})'
804 else:
805 schema['description'] = f'Format: {fmt}'
807 def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
808 ad_props = schema.pop('additionalProperties', None)
809 if ad_props: 809 ↛ 810line 809 didn't jump to line 810 because the condition on line 809 was never true
810 raise UserError('Additional properties in JSON Schema are not supported by Gemini')
812 if properties := schema.get('properties'): # pragma: no branch
813 for value in properties.values():
814 self._simplify(value, refs_stack)
816 def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
817 if prefix_items := schema.get('prefixItems'):
818 # TODO I think this not is supported by Gemini, maybe we should raise an error?
819 for prefix_item in prefix_items:
820 self._simplify(prefix_item, refs_stack)
822 if items_schema := schema.get('items'): # pragma: no branch
823 self._simplify(items_schema, refs_stack)
826def _ensure_decodeable(content: bytearray) -> bytearray:
827 """Trim any invalid unicode point bytes off the end of a bytearray.
829 This is necessary before attempting to parse streaming JSON bytes.
831 This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
832 """
833 while True:
834 try:
835 content.decode()
836 except UnicodeDecodeError:
837 content = content[:-1] # this will definitely succeed before we run out of bytes
838 else:
839 return content