Coverage for pydantic_ai_slim/pydantic_ai/models/gemini.py: 93.66%
355 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import os
4import re
5from collections.abc import AsyncIterator, Sequence
6from contextlib import asynccontextmanager
7from copy import deepcopy
8from dataclasses import dataclass, field
9from datetime import datetime
10from typing import Annotated, Any, Literal, Protocol, Union, cast
11from uuid import uuid4
13import pydantic
14from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15from typing_extensions import NotRequired, TypedDict, assert_never
17from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18from ..messages import (
19 ModelMessage,
20 ModelRequest,
21 ModelResponse,
22 ModelResponsePart,
23 ModelResponseStreamEvent,
24 RetryPromptPart,
25 SystemPromptPart,
26 TextPart,
27 ToolCallPart,
28 ToolReturnPart,
29 UserPromptPart,
30)
31from ..settings import ModelSettings
32from ..tools import ToolDefinition
33from . import (
34 AgentModel,
35 Model,
36 StreamedResponse,
37 cached_async_http_client,
38 check_allow_model_requests,
39 get_user_agent,
40)
42GeminiModelName = Literal[
43 'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp'
44]
45"""Named Gemini models.
47See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
48"""
51class GeminiModelSettings(ModelSettings):
52 """Settings used for a Gemini model request."""
54 # This class is a placeholder for any future gemini-specific settings
57@dataclass(init=False)
58class GeminiModel(Model):
59 """A model that uses Gemini via `generativelanguage.googleapis.com` API.
61 This is implemented from scratch rather than using a dedicated SDK, good API documentation is
62 available [here](https://ai.google.dev/api).
64 Apart from `__init__`, all methods are private or match those of the base class.
65 """
67 model_name: GeminiModelName
68 auth: AuthProtocol
69 http_client: AsyncHTTPClient
70 url: str
72 def __init__(
73 self,
74 model_name: GeminiModelName,
75 *,
76 api_key: str | None = None,
77 http_client: AsyncHTTPClient | None = None,
78 url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
79 ):
80 """Initialize a Gemini model.
82 Args:
83 model_name: The name of the model to use.
84 api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
85 will be used if available.
86 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
87 url_template: The URL template to use for making requests, you shouldn't need to change this,
88 docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
89 `model` is substituted with the model name, and `function` is added to the end of the URL.
90 """
91 self.model_name = model_name
92 if api_key is None:
93 if env_api_key := os.getenv('GEMINI_API_KEY'):
94 api_key = env_api_key
95 else:
96 raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
97 self.auth = ApiKeyAuth(api_key)
98 self.http_client = http_client or cached_async_http_client()
99 self.url = url_template.format(model=model_name)
101 async def agent_model(
102 self,
103 *,
104 function_tools: list[ToolDefinition],
105 allow_text_result: bool,
106 result_tools: list[ToolDefinition],
107 ) -> GeminiAgentModel:
108 check_allow_model_requests()
109 return GeminiAgentModel(
110 http_client=self.http_client,
111 model_name=self.model_name,
112 auth=self.auth,
113 url=self.url,
114 function_tools=function_tools,
115 allow_text_result=allow_text_result,
116 result_tools=result_tools,
117 )
119 def name(self) -> str:
120 return f'google-gla:{self.model_name}'
123class AuthProtocol(Protocol):
124 """Abstract definition for Gemini authentication."""
126 async def headers(self) -> dict[str, str]: ...
129@dataclass
130class ApiKeyAuth:
131 """Authentication using an API key for the `X-Goog-Api-Key` header."""
133 api_key: str
135 async def headers(self) -> dict[str, str]:
136 # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
137 return {'X-Goog-Api-Key': self.api_key}
140@dataclass(init=False)
141class GeminiAgentModel(AgentModel):
142 """Implementation of `AgentModel` for Gemini models."""
144 http_client: AsyncHTTPClient
145 model_name: GeminiModelName
146 auth: AuthProtocol
147 tools: _GeminiTools | None
148 tool_config: _GeminiToolConfig | None
149 url: str
151 def __init__(
152 self,
153 http_client: AsyncHTTPClient,
154 model_name: GeminiModelName,
155 auth: AuthProtocol,
156 url: str,
157 function_tools: list[ToolDefinition],
158 allow_text_result: bool,
159 result_tools: list[ToolDefinition],
160 ):
161 tools = [_function_from_abstract_tool(t) for t in function_tools]
162 if result_tools:
163 tools += [_function_from_abstract_tool(t) for t in result_tools]
165 if allow_text_result:
166 tool_config = None
167 else:
168 tool_config = _tool_config([t['name'] for t in tools])
170 self.http_client = http_client
171 self.model_name = model_name
172 self.auth = auth
173 self.tools = _GeminiTools(function_declarations=tools) if tools else None
174 self.tool_config = tool_config
175 self.url = url
177 async def request(
178 self, messages: list[ModelMessage], model_settings: ModelSettings | None
179 ) -> tuple[ModelResponse, usage.Usage]:
180 async with self._make_request(
181 messages, False, cast(GeminiModelSettings, model_settings or {})
182 ) as http_response:
183 response = _gemini_response_ta.validate_json(await http_response.aread())
184 return self._process_response(response), _metadata_as_usage(response)
186 @asynccontextmanager
187 async def request_stream(
188 self, messages: list[ModelMessage], model_settings: ModelSettings | None
189 ) -> AsyncIterator[StreamedResponse]:
190 async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
191 yield await self._process_streamed_response(http_response)
193 @asynccontextmanager
194 async def _make_request(
195 self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
196 ) -> AsyncIterator[HTTPResponse]:
197 sys_prompt_parts, contents = self._message_to_gemini_content(messages)
199 request_data = _GeminiRequest(contents=contents)
200 if sys_prompt_parts:
201 request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
202 if self.tools is not None:
203 request_data['tools'] = self.tools
204 if self.tool_config is not None:
205 request_data['tool_config'] = self.tool_config
207 generation_config: _GeminiGenerationConfig = {}
208 if model_settings:
209 if (max_tokens := model_settings.get('max_tokens')) is not None: 209 ↛ 211line 209 didn't jump to line 211 because the condition on line 209 was always true
210 generation_config['max_output_tokens'] = max_tokens
211 if (temperature := model_settings.get('temperature')) is not None: 211 ↛ 213line 211 didn't jump to line 213 because the condition on line 211 was always true
212 generation_config['temperature'] = temperature
213 if (top_p := model_settings.get('top_p')) is not None: 213 ↛ 215line 213 didn't jump to line 215 because the condition on line 213 was always true
214 generation_config['top_p'] = top_p
215 if (presence_penalty := model_settings.get('presence_penalty')) is not None: 215 ↛ 217line 215 didn't jump to line 217 because the condition on line 215 was always true
216 generation_config['presence_penalty'] = presence_penalty
217 if (frequency_penalty := model_settings.get('frequency_penalty')) is not None: 217 ↛ 219line 217 didn't jump to line 219 because the condition on line 217 was always true
218 generation_config['frequency_penalty'] = frequency_penalty
219 if generation_config:
220 request_data['generation_config'] = generation_config
222 url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
224 headers = {
225 'Content-Type': 'application/json',
226 'User-Agent': get_user_agent(),
227 **await self.auth.headers(),
228 }
230 request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
232 async with self.http_client.stream(
233 'POST',
234 url,
235 content=request_json,
236 headers=headers,
237 timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
238 ) as r:
239 if r.status_code != 200:
240 await r.aread()
241 raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
242 yield r
244 def _process_response(self, response: _GeminiResponse) -> ModelResponse:
245 if len(response['candidates']) != 1: 245 ↛ 246line 245 didn't jump to line 246 because the condition on line 245 was never true
246 raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
247 parts = response['candidates'][0]['content']['parts']
248 return _process_response_from_parts(parts, model_name=self.model_name)
250 async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
251 """Process a streamed response, and prepare a streaming response to return."""
252 aiter_bytes = http_response.aiter_bytes()
253 start_response: _GeminiResponse | None = None
254 content = bytearray()
256 async for chunk in aiter_bytes:
257 content.extend(chunk)
258 responses = _gemini_streamed_response_ta.validate_json(
259 content,
260 experimental_allow_partial='trailing-strings',
261 )
262 if responses: 262 ↛ 256line 262 didn't jump to line 256 because the condition on line 262 was always true
263 last = responses[-1]
264 if last['candidates'] and last['candidates'][0]['content']['parts']:
265 start_response = last
266 break
268 if start_response is None:
269 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
271 return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
273 @classmethod
274 def _message_to_gemini_content(
275 cls, messages: list[ModelMessage]
276 ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
277 sys_prompt_parts: list[_GeminiTextPart] = []
278 contents: list[_GeminiContent] = []
279 for m in messages:
280 if isinstance(m, ModelRequest):
281 message_parts: list[_GeminiPartUnion] = []
283 for part in m.parts:
284 if isinstance(part, SystemPromptPart):
285 sys_prompt_parts.append(_GeminiTextPart(text=part.content))
286 elif isinstance(part, UserPromptPart):
287 message_parts.append(_GeminiTextPart(text=part.content))
288 elif isinstance(part, ToolReturnPart):
289 message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
290 elif isinstance(part, RetryPromptPart):
291 if part.tool_name is None: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 message_parts.append(_GeminiTextPart(text=part.model_response()))
293 else:
294 response = {'call_error': part.model_response()}
295 message_parts.append(_response_part_from_response(part.tool_name, response))
296 else:
297 assert_never(part)
299 if message_parts: 299 ↛ 279line 299 didn't jump to line 279 because the condition on line 299 was always true
300 contents.append(_GeminiContent(role='user', parts=message_parts))
301 elif isinstance(m, ModelResponse):
302 contents.append(_content_model_response(m))
303 else:
304 assert_never(m)
306 return sys_prompt_parts, contents
309@dataclass
310class GeminiStreamedResponse(StreamedResponse):
311 """Implementation of `StreamedResponse` for the Gemini model."""
313 _content: bytearray
314 _stream: AsyncIterator[bytes]
315 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
317 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
318 async for gemini_response in self._get_gemini_responses():
319 candidate = gemini_response['candidates'][0]
320 gemini_part: _GeminiPartUnion
321 for gemini_part in candidate['content']['parts']:
322 if 'text' in gemini_part:
323 # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
324 # amongst the tool call deltas
325 yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
327 elif 'function_call' in gemini_part: 327 ↛ 341line 327 didn't jump to line 341 because the condition on line 327 was always true
328 # Here, we assume all function_call parts are complete and don't have deltas.
329 # We do this by assigning a unique randomly generated "vendor_part_id".
330 # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
331 # it would just be a bit more complicated. And we'd need to confirm the intended semantics.
332 maybe_event = self._parts_manager.handle_tool_call_delta(
333 vendor_part_id=uuid4(),
334 tool_name=gemini_part['function_call']['name'],
335 args=gemini_part['function_call']['args'],
336 tool_call_id=None,
337 )
338 if maybe_event is not None: 338 ↛ 321line 338 didn't jump to line 321 because the condition on line 338 was always true
339 yield maybe_event
340 else:
341 assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
343 async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
344 # This method exists to ensure we only yield completed items, so we don't need to worry about
345 # partial gemini responses, which would make everything more complicated
347 gemini_responses: list[_GeminiResponse] = []
348 current_gemini_response_index = 0
349 # Right now, there are some circumstances where we will have information that could be yielded sooner than it is
350 # But changing that would make things a lot more complicated.
351 async for chunk in self._stream:
352 self._content.extend(chunk)
354 gemini_responses = _gemini_streamed_response_ta.validate_json(
355 self._content,
356 experimental_allow_partial='trailing-strings',
357 )
359 # The idea: yield only up to the latest response, which might still be partial.
360 # Note that if the latest response is complete, we could yield it immediately, but there's not a good
361 # allow_partial API to determine if the last item in the list is complete.
362 responses_to_yield = gemini_responses[:-1]
363 for r in responses_to_yield[current_gemini_response_index:]:
364 current_gemini_response_index += 1
365 self._usage += _metadata_as_usage(r)
366 yield r
368 # Now yield the final response, which should be complete
369 if gemini_responses: 369 ↛ exitline 369 didn't return from function '_get_gemini_responses' because the condition on line 369 was always true
370 r = gemini_responses[-1]
371 self._usage += _metadata_as_usage(r)
372 yield r
374 def timestamp(self) -> datetime:
375 return self._timestamp
378# We use typed dicts to define the Gemini API response schema
379# once Pydantic partial validation supports, dataclasses, we could revert to using them
380# TypeAdapters take care of validation and serialization
383@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
384class _GeminiRequest(TypedDict):
385 """Schema for an API request to the Gemini API.
387 See <https://ai.google.dev/api/generate-content#request-body> for API docs.
388 """
390 contents: list[_GeminiContent]
391 tools: NotRequired[_GeminiTools]
392 tool_config: NotRequired[_GeminiToolConfig]
393 # we don't implement `generationConfig`, instead we use a named tool for the response
394 system_instruction: NotRequired[_GeminiTextContent]
395 """
396 Developer generated system instructions, see
397 <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
398 """
399 generation_config: NotRequired[_GeminiGenerationConfig]
402class _GeminiGenerationConfig(TypedDict, total=False):
403 """Schema for an API request to the Gemini API.
405 Note there are many additional fields available that have not been added yet.
407 See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
408 """
410 max_output_tokens: int
411 temperature: float
412 top_p: float
413 presence_penalty: float
414 frequency_penalty: float
417class _GeminiContent(TypedDict):
418 role: Literal['user', 'model']
419 parts: list[_GeminiPartUnion]
422def _content_model_response(m: ModelResponse) -> _GeminiContent:
423 parts: list[_GeminiPartUnion] = []
424 for item in m.parts:
425 if isinstance(item, ToolCallPart):
426 parts.append(_function_call_part_from_call(item))
427 elif isinstance(item, TextPart):
428 if item.content:
429 parts.append(_GeminiTextPart(text=item.content))
430 else:
431 assert_never(item)
432 return _GeminiContent(role='model', parts=parts)
435class _GeminiTextPart(TypedDict):
436 text: str
439class _GeminiFunctionCallPart(TypedDict):
440 function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
443def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
444 return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
447def _process_response_from_parts(
448 parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
449) -> ModelResponse:
450 items: list[ModelResponsePart] = []
451 for part in parts:
452 if 'text' in part:
453 items.append(TextPart(content=part['text']))
454 elif 'function_call' in part: 454 ↛ 461line 454 didn't jump to line 461 because the condition on line 454 was always true
455 items.append(
456 ToolCallPart(
457 tool_name=part['function_call']['name'],
458 args=part['function_call']['args'],
459 )
460 )
461 elif 'function_response' in part:
462 raise exceptions.UnexpectedModelBehavior(
463 f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
464 )
465 return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
468class _GeminiFunctionCall(TypedDict):
469 """See <https://ai.google.dev/api/caching#FunctionCall>."""
471 name: str
472 args: dict[str, Any]
475class _GeminiFunctionResponsePart(TypedDict):
476 function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
479def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
480 return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response))
483class _GeminiFunctionResponse(TypedDict):
484 """See <https://ai.google.dev/api/caching#FunctionResponse>."""
486 name: str
487 response: dict[str, Any]
490def _part_discriminator(v: Any) -> str:
491 if isinstance(v, dict): 491 ↛ 498line 491 didn't jump to line 498 because the condition on line 491 was always true
492 if 'text' in v:
493 return 'text'
494 elif 'functionCall' in v or 'function_call' in v:
495 return 'function_call'
496 elif 'functionResponse' in v or 'function_response' in v:
497 return 'function_response'
498 return 'text'
501# See <https://ai.google.dev/api/caching#Part>
502# we don't currently support other part types
503# TODO discriminator
504_GeminiPartUnion = Annotated[
505 Union[
506 Annotated[_GeminiTextPart, pydantic.Tag('text')],
507 Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
508 Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
509 ],
510 pydantic.Discriminator(_part_discriminator),
511]
514class _GeminiTextContent(TypedDict):
515 role: Literal['user', 'model']
516 parts: list[_GeminiTextPart]
519class _GeminiTools(TypedDict):
520 function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
523class _GeminiFunction(TypedDict):
524 name: str
525 description: str
526 parameters: NotRequired[dict[str, Any]]
527 """
528 ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema
529 <https://ai.google.dev/gemini-api/docs/function-calling#function_declarations>
530 and
531 <https://ai.google.dev/api/caching#FunctionDeclaration>
532 """
535def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
536 json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
537 f = _GeminiFunction(
538 name=tool.name,
539 description=tool.description,
540 )
541 if json_schema.get('properties'): 541 ↛ 543line 541 didn't jump to line 543 because the condition on line 541 was always true
542 f['parameters'] = json_schema
543 return f
546class _GeminiToolConfig(TypedDict):
547 function_calling_config: _GeminiFunctionCallingConfig
550def _tool_config(function_names: list[str]) -> _GeminiToolConfig:
551 return _GeminiToolConfig(
552 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names)
553 )
556class _GeminiFunctionCallingConfig(TypedDict):
557 mode: Literal['ANY', 'AUTO']
558 allowed_function_names: list[str]
561@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
562class _GeminiResponse(TypedDict):
563 """Schema for the response from the Gemini API.
565 See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>
566 and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse>
567 """
569 candidates: list[_GeminiCandidates]
570 # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
571 usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
572 prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
575class _GeminiCandidates(TypedDict):
576 """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
578 content: _GeminiContent
579 finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
580 """
581 See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
582 but let's wait until we see them and know what they mean to add them here.
583 """
584 avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
585 index: NotRequired[int]
586 safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
589class _GeminiUsageMetaData(TypedDict, total=False):
590 """See <https://ai.google.dev/api/generate-content#FinishReason>.
592 The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
593 """
595 prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
596 candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
597 total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
598 cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
601def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
602 metadata = response.get('usage_metadata')
603 if metadata is None: 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true
604 return usage.Usage()
605 details: dict[str, int] = {}
606 if cached_content_token_count := metadata.get('cached_content_token_count'): 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true
607 details['cached_content_token_count'] = cached_content_token_count
608 return usage.Usage(
609 request_tokens=metadata.get('prompt_token_count', 0),
610 response_tokens=metadata.get('candidates_token_count', 0),
611 total_tokens=metadata.get('total_token_count', 0),
612 details=details,
613 )
616class _GeminiSafetyRating(TypedDict):
617 """See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>."""
619 category: Literal[
620 'HARM_CATEGORY_HARASSMENT',
621 'HARM_CATEGORY_HATE_SPEECH',
622 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
623 'HARM_CATEGORY_DANGEROUS_CONTENT',
624 'HARM_CATEGORY_CIVIC_INTEGRITY',
625 ]
626 probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
629class _GeminiPromptFeedback(TypedDict):
630 """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
632 block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
633 safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
636_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
637_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
639# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
640_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
643class _GeminiJsonSchema:
644 """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
646 Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
647 a subset of OpenAPI v3.0.3.
649 Specifically:
650 * gemini doesn't allow the `title` keyword to be set
651 * gemini doesn't allow `$defs` — we need to inline the definitions where possible
652 """
654 def __init__(self, schema: _utils.ObjectJsonSchema):
655 self.schema = deepcopy(schema)
656 self.defs = self.schema.pop('$defs', {})
658 def simplify(self) -> dict[str, Any]:
659 self._simplify(self.schema, refs_stack=())
660 return self.schema
662 def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
663 schema.pop('title', None)
664 schema.pop('default', None)
665 if ref := schema.pop('$ref', None):
666 # noinspection PyTypeChecker
667 key = re.sub(r'^#/\$defs/', '', ref)
668 if key in refs_stack:
669 raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
670 refs_stack += (key,)
671 schema_def = self.defs[key]
672 self._simplify(schema_def, refs_stack)
673 schema.update(schema_def)
674 return
676 if any_of := schema.get('anyOf'):
677 for item_schema in any_of:
678 self._simplify(item_schema, refs_stack)
679 if len(any_of) == 2 and {'type': 'null'} in any_of: 679 ↛ 687line 679 didn't jump to line 687 because the condition on line 679 was always true
680 for item_schema in any_of: 680 ↛ 687line 680 didn't jump to line 687 because the loop on line 680 didn't complete
681 if item_schema != {'type': 'null'}: 681 ↛ 680line 681 didn't jump to line 680 because the condition on line 681 was always true
682 schema.clear()
683 schema.update(item_schema)
684 schema['nullable'] = True
685 return
687 type_ = schema.get('type')
689 if type_ == 'object':
690 self._object(schema, refs_stack)
691 elif type_ == 'array':
692 return self._array(schema, refs_stack)
693 elif type_ == 'string' and (fmt := schema.pop('format', None)):
694 description = schema.get('description')
695 if description:
696 schema['description'] = f'{description} (format: {fmt})'
697 else:
698 schema['description'] = f'Format: {fmt}'
700 def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
701 ad_props = schema.pop('additionalProperties', None)
702 if ad_props: 702 ↛ 703line 702 didn't jump to line 703 because the condition on line 702 was never true
703 raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini')
705 if properties := schema.get('properties'): # pragma: no branch
706 for value in properties.values():
707 self._simplify(value, refs_stack)
709 def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
710 if prefix_items := schema.get('prefixItems'):
711 # TODO I think this not is supported by Gemini, maybe we should raise an error?
712 for prefix_item in prefix_items:
713 self._simplify(prefix_item, refs_stack)
715 if items_schema := schema.get('items'): # pragma: no branch
716 self._simplify(items_schema, refs_stack)