Coverage for pydantic_ai_slim/pydantic_ai/models/gemini.py: 93.66%
355 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
1from __future__ import annotations as _annotations
3import os
4import re
5from collections.abc import AsyncIterator, Sequence
6from contextlib import asynccontextmanager
7from copy import deepcopy
8from dataclasses import dataclass, field
9from datetime import datetime
10from typing import Annotated, Any, Literal, Protocol, Union, cast
11from uuid import uuid4
13import pydantic
14from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15from typing_extensions import NotRequired, TypedDict, assert_never
17from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18from ..messages import (
19 ModelMessage,
20 ModelRequest,
21 ModelResponse,
22 ModelResponsePart,
23 ModelResponseStreamEvent,
24 RetryPromptPart,
25 SystemPromptPart,
26 TextPart,
27 ToolCallPart,
28 ToolReturnPart,
29 UserPromptPart,
30)
31from ..settings import ModelSettings
32from ..tools import ToolDefinition
33from . import (
34 AgentModel,
35 Model,
36 StreamedResponse,
37 cached_async_http_client,
38 check_allow_model_requests,
39 get_user_agent,
40)
42GeminiModelName = Literal[
43 'gemini-1.5-flash',
44 'gemini-1.5-flash-8b',
45 'gemini-1.5-pro',
46 'gemini-1.0-pro',
47 'gemini-2.0-flash-exp',
48 'gemini-2.0-flash-thinking-exp-01-21',
49 'gemini-exp-1206',
50]
51"""Named Gemini models.
53See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
54"""
57class GeminiModelSettings(ModelSettings):
58 """Settings used for a Gemini model request."""
60 # This class is a placeholder for any future gemini-specific settings
63@dataclass(init=False)
64class GeminiModel(Model):
65 """A model that uses Gemini via `generativelanguage.googleapis.com` API.
67 This is implemented from scratch rather than using a dedicated SDK, good API documentation is
68 available [here](https://ai.google.dev/api).
70 Apart from `__init__`, all methods are private or match those of the base class.
71 """
73 model_name: GeminiModelName
74 auth: AuthProtocol
75 http_client: AsyncHTTPClient
76 url: str
78 def __init__(
79 self,
80 model_name: GeminiModelName,
81 *,
82 api_key: str | None = None,
83 http_client: AsyncHTTPClient | None = None,
84 url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
85 ):
86 """Initialize a Gemini model.
88 Args:
89 model_name: The name of the model to use.
90 api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
91 will be used if available.
92 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
93 url_template: The URL template to use for making requests, you shouldn't need to change this,
94 docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
95 `model` is substituted with the model name, and `function` is added to the end of the URL.
96 """
97 self.model_name = model_name
98 if api_key is None:
99 if env_api_key := os.getenv('GEMINI_API_KEY'):
100 api_key = env_api_key
101 else:
102 raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
103 self.auth = ApiKeyAuth(api_key)
104 self.http_client = http_client or cached_async_http_client()
105 self.url = url_template.format(model=model_name)
107 async def agent_model(
108 self,
109 *,
110 function_tools: list[ToolDefinition],
111 allow_text_result: bool,
112 result_tools: list[ToolDefinition],
113 ) -> GeminiAgentModel:
114 check_allow_model_requests()
115 return GeminiAgentModel(
116 http_client=self.http_client,
117 model_name=self.model_name,
118 auth=self.auth,
119 url=self.url,
120 function_tools=function_tools,
121 allow_text_result=allow_text_result,
122 result_tools=result_tools,
123 )
125 def name(self) -> str:
126 return f'google-gla:{self.model_name}'
129class AuthProtocol(Protocol):
130 """Abstract definition for Gemini authentication."""
132 async def headers(self) -> dict[str, str]: ...
135@dataclass
136class ApiKeyAuth:
137 """Authentication using an API key for the `X-Goog-Api-Key` header."""
139 api_key: str
141 async def headers(self) -> dict[str, str]:
142 # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
143 return {'X-Goog-Api-Key': self.api_key}
146@dataclass(init=False)
147class GeminiAgentModel(AgentModel):
148 """Implementation of `AgentModel` for Gemini models."""
150 http_client: AsyncHTTPClient
151 model_name: GeminiModelName
152 auth: AuthProtocol
153 tools: _GeminiTools | None
154 tool_config: _GeminiToolConfig | None
155 url: str
157 def __init__(
158 self,
159 http_client: AsyncHTTPClient,
160 model_name: GeminiModelName,
161 auth: AuthProtocol,
162 url: str,
163 function_tools: list[ToolDefinition],
164 allow_text_result: bool,
165 result_tools: list[ToolDefinition],
166 ):
167 tools = [_function_from_abstract_tool(t) for t in function_tools]
168 if result_tools:
169 tools += [_function_from_abstract_tool(t) for t in result_tools]
171 if allow_text_result:
172 tool_config = None
173 else:
174 tool_config = _tool_config([t['name'] for t in tools])
176 self.http_client = http_client
177 self.model_name = model_name
178 self.auth = auth
179 self.tools = _GeminiTools(function_declarations=tools) if tools else None
180 self.tool_config = tool_config
181 self.url = url
183 async def request(
184 self, messages: list[ModelMessage], model_settings: ModelSettings | None
185 ) -> tuple[ModelResponse, usage.Usage]:
186 async with self._make_request(
187 messages, False, cast(GeminiModelSettings, model_settings or {})
188 ) as http_response:
189 response = _gemini_response_ta.validate_json(await http_response.aread())
190 return self._process_response(response), _metadata_as_usage(response)
192 @asynccontextmanager
193 async def request_stream(
194 self, messages: list[ModelMessage], model_settings: ModelSettings | None
195 ) -> AsyncIterator[StreamedResponse]:
196 async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
197 yield await self._process_streamed_response(http_response)
199 @asynccontextmanager
200 async def _make_request(
201 self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
202 ) -> AsyncIterator[HTTPResponse]:
203 sys_prompt_parts, contents = self._message_to_gemini_content(messages)
205 request_data = _GeminiRequest(contents=contents)
206 if sys_prompt_parts:
207 request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
208 if self.tools is not None:
209 request_data['tools'] = self.tools
210 if self.tool_config is not None:
211 request_data['tool_config'] = self.tool_config
213 generation_config: _GeminiGenerationConfig = {}
214 if model_settings:
215 if (max_tokens := model_settings.get('max_tokens')) is not None: 215 ↛ 217line 215 didn't jump to line 217 because the condition on line 215 was always true
216 generation_config['max_output_tokens'] = max_tokens
217 if (temperature := model_settings.get('temperature')) is not None: 217 ↛ 219line 217 didn't jump to line 219 because the condition on line 217 was always true
218 generation_config['temperature'] = temperature
219 if (top_p := model_settings.get('top_p')) is not None: 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true
220 generation_config['top_p'] = top_p
221 if (presence_penalty := model_settings.get('presence_penalty')) is not None: 221 ↛ 223line 221 didn't jump to line 223 because the condition on line 221 was always true
222 generation_config['presence_penalty'] = presence_penalty
223 if (frequency_penalty := model_settings.get('frequency_penalty')) is not None: 223 ↛ 225line 223 didn't jump to line 225 because the condition on line 223 was always true
224 generation_config['frequency_penalty'] = frequency_penalty
225 if generation_config:
226 request_data['generation_config'] = generation_config
228 url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
230 headers = {
231 'Content-Type': 'application/json',
232 'User-Agent': get_user_agent(),
233 **await self.auth.headers(),
234 }
236 request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
238 async with self.http_client.stream(
239 'POST',
240 url,
241 content=request_json,
242 headers=headers,
243 timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
244 ) as r:
245 if r.status_code != 200:
246 await r.aread()
247 raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
248 yield r
250 def _process_response(self, response: _GeminiResponse) -> ModelResponse:
251 if len(response['candidates']) != 1: 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true
252 raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
253 parts = response['candidates'][0]['content']['parts']
254 return _process_response_from_parts(parts, model_name=self.model_name)
256 async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
257 """Process a streamed response, and prepare a streaming response to return."""
258 aiter_bytes = http_response.aiter_bytes()
259 start_response: _GeminiResponse | None = None
260 content = bytearray()
262 async for chunk in aiter_bytes:
263 content.extend(chunk)
264 responses = _gemini_streamed_response_ta.validate_json(
265 content,
266 experimental_allow_partial='trailing-strings',
267 )
268 if responses: 268 ↛ 262line 268 didn't jump to line 262 because the condition on line 268 was always true
269 last = responses[-1]
270 if last['candidates'] and last['candidates'][0]['content']['parts']:
271 start_response = last
272 break
274 if start_response is None:
275 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
277 return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
279 @classmethod
280 def _message_to_gemini_content(
281 cls, messages: list[ModelMessage]
282 ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
283 sys_prompt_parts: list[_GeminiTextPart] = []
284 contents: list[_GeminiContent] = []
285 for m in messages:
286 if isinstance(m, ModelRequest):
287 message_parts: list[_GeminiPartUnion] = []
289 for part in m.parts:
290 if isinstance(part, SystemPromptPart):
291 sys_prompt_parts.append(_GeminiTextPart(text=part.content))
292 elif isinstance(part, UserPromptPart):
293 message_parts.append(_GeminiTextPart(text=part.content))
294 elif isinstance(part, ToolReturnPart):
295 message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
296 elif isinstance(part, RetryPromptPart):
297 if part.tool_name is None: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true
298 message_parts.append(_GeminiTextPart(text=part.model_response()))
299 else:
300 response = {'call_error': part.model_response()}
301 message_parts.append(_response_part_from_response(part.tool_name, response))
302 else:
303 assert_never(part)
305 if message_parts: 305 ↛ 285line 305 didn't jump to line 285 because the condition on line 305 was always true
306 contents.append(_GeminiContent(role='user', parts=message_parts))
307 elif isinstance(m, ModelResponse):
308 contents.append(_content_model_response(m))
309 else:
310 assert_never(m)
312 return sys_prompt_parts, contents
315@dataclass
316class GeminiStreamedResponse(StreamedResponse):
317 """Implementation of `StreamedResponse` for the Gemini model."""
319 _content: bytearray
320 _stream: AsyncIterator[bytes]
321 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
323 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
324 async for gemini_response in self._get_gemini_responses():
325 candidate = gemini_response['candidates'][0]
326 gemini_part: _GeminiPartUnion
327 for gemini_part in candidate['content']['parts']:
328 if 'text' in gemini_part:
329 # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
330 # amongst the tool call deltas
331 yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
333 elif 'function_call' in gemini_part: 333 ↛ 347line 333 didn't jump to line 347 because the condition on line 333 was always true
334 # Here, we assume all function_call parts are complete and don't have deltas.
335 # We do this by assigning a unique randomly generated "vendor_part_id".
336 # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
337 # it would just be a bit more complicated. And we'd need to confirm the intended semantics.
338 maybe_event = self._parts_manager.handle_tool_call_delta(
339 vendor_part_id=uuid4(),
340 tool_name=gemini_part['function_call']['name'],
341 args=gemini_part['function_call']['args'],
342 tool_call_id=None,
343 )
344 if maybe_event is not None: 344 ↛ 327line 344 didn't jump to line 327 because the condition on line 344 was always true
345 yield maybe_event
346 else:
347 assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
349 async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
350 # This method exists to ensure we only yield completed items, so we don't need to worry about
351 # partial gemini responses, which would make everything more complicated
353 gemini_responses: list[_GeminiResponse] = []
354 current_gemini_response_index = 0
355 # Right now, there are some circumstances where we will have information that could be yielded sooner than it is
356 # But changing that would make things a lot more complicated.
357 async for chunk in self._stream:
358 self._content.extend(chunk)
360 gemini_responses = _gemini_streamed_response_ta.validate_json(
361 self._content,
362 experimental_allow_partial='trailing-strings',
363 )
365 # The idea: yield only up to the latest response, which might still be partial.
366 # Note that if the latest response is complete, we could yield it immediately, but there's not a good
367 # allow_partial API to determine if the last item in the list is complete.
368 responses_to_yield = gemini_responses[:-1]
369 for r in responses_to_yield[current_gemini_response_index:]:
370 current_gemini_response_index += 1
371 self._usage += _metadata_as_usage(r)
372 yield r
374 # Now yield the final response, which should be complete
375 if gemini_responses: 375 ↛ exitline 375 didn't return from function '_get_gemini_responses' because the condition on line 375 was always true
376 r = gemini_responses[-1]
377 self._usage += _metadata_as_usage(r)
378 yield r
380 def timestamp(self) -> datetime:
381 return self._timestamp
384# We use typed dicts to define the Gemini API response schema
385# once Pydantic partial validation supports, dataclasses, we could revert to using them
386# TypeAdapters take care of validation and serialization
389@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
390class _GeminiRequest(TypedDict):
391 """Schema for an API request to the Gemini API.
393 See <https://ai.google.dev/api/generate-content#request-body> for API docs.
394 """
396 contents: list[_GeminiContent]
397 tools: NotRequired[_GeminiTools]
398 tool_config: NotRequired[_GeminiToolConfig]
399 # we don't implement `generationConfig`, instead we use a named tool for the response
400 system_instruction: NotRequired[_GeminiTextContent]
401 """
402 Developer generated system instructions, see
403 <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
404 """
405 generation_config: NotRequired[_GeminiGenerationConfig]
408class _GeminiGenerationConfig(TypedDict, total=False):
409 """Schema for an API request to the Gemini API.
411 Note there are many additional fields available that have not been added yet.
413 See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
414 """
416 max_output_tokens: int
417 temperature: float
418 top_p: float
419 presence_penalty: float
420 frequency_penalty: float
423class _GeminiContent(TypedDict):
424 role: Literal['user', 'model']
425 parts: list[_GeminiPartUnion]
428def _content_model_response(m: ModelResponse) -> _GeminiContent:
429 parts: list[_GeminiPartUnion] = []
430 for item in m.parts:
431 if isinstance(item, ToolCallPart):
432 parts.append(_function_call_part_from_call(item))
433 elif isinstance(item, TextPart):
434 if item.content:
435 parts.append(_GeminiTextPart(text=item.content))
436 else:
437 assert_never(item)
438 return _GeminiContent(role='model', parts=parts)
441class _GeminiTextPart(TypedDict):
442 text: str
445class _GeminiFunctionCallPart(TypedDict):
446 function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
449def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
450 return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
453def _process_response_from_parts(
454 parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
455) -> ModelResponse:
456 items: list[ModelResponsePart] = []
457 for part in parts:
458 if 'text' in part:
459 items.append(TextPart(content=part['text']))
460 elif 'function_call' in part: 460 ↛ 467line 460 didn't jump to line 467 because the condition on line 460 was always true
461 items.append(
462 ToolCallPart(
463 tool_name=part['function_call']['name'],
464 args=part['function_call']['args'],
465 )
466 )
467 elif 'function_response' in part:
468 raise exceptions.UnexpectedModelBehavior(
469 f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
470 )
471 return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
474class _GeminiFunctionCall(TypedDict):
475 """See <https://ai.google.dev/api/caching#FunctionCall>."""
477 name: str
478 args: dict[str, Any]
481class _GeminiFunctionResponsePart(TypedDict):
482 function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
485def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
486 return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response))
489class _GeminiFunctionResponse(TypedDict):
490 """See <https://ai.google.dev/api/caching#FunctionResponse>."""
492 name: str
493 response: dict[str, Any]
496def _part_discriminator(v: Any) -> str:
497 if isinstance(v, dict): 497 ↛ 504line 497 didn't jump to line 504 because the condition on line 497 was always true
498 if 'text' in v:
499 return 'text'
500 elif 'functionCall' in v or 'function_call' in v:
501 return 'function_call'
502 elif 'functionResponse' in v or 'function_response' in v:
503 return 'function_response'
504 return 'text'
507# See <https://ai.google.dev/api/caching#Part>
508# we don't currently support other part types
509# TODO discriminator
510_GeminiPartUnion = Annotated[
511 Union[
512 Annotated[_GeminiTextPart, pydantic.Tag('text')],
513 Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
514 Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
515 ],
516 pydantic.Discriminator(_part_discriminator),
517]
520class _GeminiTextContent(TypedDict):
521 role: Literal['user', 'model']
522 parts: list[_GeminiTextPart]
525class _GeminiTools(TypedDict):
526 function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
529class _GeminiFunction(TypedDict):
530 name: str
531 description: str
532 parameters: NotRequired[dict[str, Any]]
533 """
534 ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema
535 <https://ai.google.dev/gemini-api/docs/function-calling#function_declarations>
536 and
537 <https://ai.google.dev/api/caching#FunctionDeclaration>
538 """
541def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
542 json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
543 f = _GeminiFunction(
544 name=tool.name,
545 description=tool.description,
546 )
547 if json_schema.get('properties'): 547 ↛ 549line 547 didn't jump to line 549 because the condition on line 547 was always true
548 f['parameters'] = json_schema
549 return f
552class _GeminiToolConfig(TypedDict):
553 function_calling_config: _GeminiFunctionCallingConfig
556def _tool_config(function_names: list[str]) -> _GeminiToolConfig:
557 return _GeminiToolConfig(
558 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names)
559 )
562class _GeminiFunctionCallingConfig(TypedDict):
563 mode: Literal['ANY', 'AUTO']
564 allowed_function_names: list[str]
567@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
568class _GeminiResponse(TypedDict):
569 """Schema for the response from the Gemini API.
571 See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>
572 and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse>
573 """
575 candidates: list[_GeminiCandidates]
576 # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
577 usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
578 prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
581class _GeminiCandidates(TypedDict):
582 """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
584 content: _GeminiContent
585 finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
586 """
587 See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
588 but let's wait until we see them and know what they mean to add them here.
589 """
590 avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
591 index: NotRequired[int]
592 safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
595class _GeminiUsageMetaData(TypedDict, total=False):
596 """See <https://ai.google.dev/api/generate-content#FinishReason>.
598 The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
599 """
601 prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
602 candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
603 total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
604 cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
607def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
608 metadata = response.get('usage_metadata')
609 if metadata is None: 609 ↛ 610line 609 didn't jump to line 610 because the condition on line 609 was never true
610 return usage.Usage()
611 details: dict[str, int] = {}
612 if cached_content_token_count := metadata.get('cached_content_token_count'): 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true
613 details['cached_content_token_count'] = cached_content_token_count
614 return usage.Usage(
615 request_tokens=metadata.get('prompt_token_count', 0),
616 response_tokens=metadata.get('candidates_token_count', 0),
617 total_tokens=metadata.get('total_token_count', 0),
618 details=details,
619 )
622class _GeminiSafetyRating(TypedDict):
623 """See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>."""
625 category: Literal[
626 'HARM_CATEGORY_HARASSMENT',
627 'HARM_CATEGORY_HATE_SPEECH',
628 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
629 'HARM_CATEGORY_DANGEROUS_CONTENT',
630 'HARM_CATEGORY_CIVIC_INTEGRITY',
631 ]
632 probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
635class _GeminiPromptFeedback(TypedDict):
636 """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
638 block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
639 safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
642_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
643_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
645# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
646_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
649class _GeminiJsonSchema:
650 """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
652 Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
653 a subset of OpenAPI v3.0.3.
655 Specifically:
656 * gemini doesn't allow the `title` keyword to be set
657 * gemini doesn't allow `$defs` — we need to inline the definitions where possible
658 """
660 def __init__(self, schema: _utils.ObjectJsonSchema):
661 self.schema = deepcopy(schema)
662 self.defs = self.schema.pop('$defs', {})
664 def simplify(self) -> dict[str, Any]:
665 self._simplify(self.schema, refs_stack=())
666 return self.schema
668 def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
669 schema.pop('title', None)
670 schema.pop('default', None)
671 if ref := schema.pop('$ref', None):
672 # noinspection PyTypeChecker
673 key = re.sub(r'^#/\$defs/', '', ref)
674 if key in refs_stack:
675 raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
676 refs_stack += (key,)
677 schema_def = self.defs[key]
678 self._simplify(schema_def, refs_stack)
679 schema.update(schema_def)
680 return
682 if any_of := schema.get('anyOf'):
683 for item_schema in any_of:
684 self._simplify(item_schema, refs_stack)
685 if len(any_of) == 2 and {'type': 'null'} in any_of: 685 ↛ 693line 685 didn't jump to line 693 because the condition on line 685 was always true
686 for item_schema in any_of: 686 ↛ 693line 686 didn't jump to line 693 because the loop on line 686 didn't complete
687 if item_schema != {'type': 'null'}: 687 ↛ 686line 687 didn't jump to line 686 because the condition on line 687 was always true
688 schema.clear()
689 schema.update(item_schema)
690 schema['nullable'] = True
691 return
693 type_ = schema.get('type')
695 if type_ == 'object':
696 self._object(schema, refs_stack)
697 elif type_ == 'array':
698 return self._array(schema, refs_stack)
699 elif type_ == 'string' and (fmt := schema.pop('format', None)):
700 description = schema.get('description')
701 if description:
702 schema['description'] = f'{description} (format: {fmt})'
703 else:
704 schema['description'] = f'Format: {fmt}'
706 def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
707 ad_props = schema.pop('additionalProperties', None)
708 if ad_props: 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true
709 raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini')
711 if properties := schema.get('properties'): # pragma: no branch
712 for value in properties.values():
713 self._simplify(value, refs_stack)
715 def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
716 if prefix_items := schema.get('prefixItems'):
717 # TODO I think this not is supported by Gemini, maybe we should raise an error?
718 for prefix_item in prefix_items:
719 self._simplify(prefix_item, refs_stack)
721 if items_schema := schema.get('items'): # pragma: no branch
722 self._simplify(items_schema, refs_stack)