Coverage for pydantic_ai_slim/pydantic_ai/models/gemini.py: 92.75%

423 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import base64 

4import re 

5from collections.abc import AsyncIterator, Sequence 

6from contextlib import asynccontextmanager 

7from copy import deepcopy 

8from dataclasses import dataclass, field 

9from datetime import datetime 

10from typing import Annotated, Any, Literal, Protocol, Union, cast 

11from uuid import uuid4 

12 

13import httpx 

14import pydantic 

15from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse 

16from typing_extensions import NotRequired, TypedDict, assert_never 

17 

18from pydantic_ai.providers import Provider, infer_provider 

19 

20from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage 

21from ..messages import ( 

22 AudioUrl, 

23 BinaryContent, 

24 DocumentUrl, 

25 ImageUrl, 

26 ModelMessage, 

27 ModelRequest, 

28 ModelResponse, 

29 ModelResponsePart, 

30 ModelResponseStreamEvent, 

31 RetryPromptPart, 

32 SystemPromptPart, 

33 TextPart, 

34 ToolCallPart, 

35 ToolReturnPart, 

36 UserPromptPart, 

37) 

38from ..settings import ModelSettings 

39from ..tools import ToolDefinition 

40from . import ( 

41 Model, 

42 ModelRequestParameters, 

43 StreamedResponse, 

44 cached_async_http_client, 

45 check_allow_model_requests, 

46 get_user_agent, 

47) 

48 

49LatestGeminiModelNames = Literal[ 

50 'gemini-1.5-flash', 

51 'gemini-1.5-flash-8b', 

52 'gemini-1.5-pro', 

53 'gemini-1.0-pro', 

54 'gemini-2.0-flash-exp', 

55 'gemini-2.0-flash-thinking-exp-01-21', 

56 'gemini-exp-1206', 

57 'gemini-2.0-flash', 

58 'gemini-2.0-flash-lite-preview-02-05', 

59 'gemini-2.0-pro-exp-02-05', 

60] 

61"""Latest Gemini models.""" 

62 

63GeminiModelName = Union[str, LatestGeminiModelNames] 

64"""Possible Gemini model names. 

65 

66Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but 

67allow any name in the type hints. 

68See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. 

69""" 

70 

71 

72class GeminiModelSettings(ModelSettings): 

73 """Settings used for a Gemini model request. 

74 

75 ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. 

76 """ 

77 

78 gemini_safety_settings: list[GeminiSafetySettings] 

79 

80 

81@dataclass(init=False) 

82class GeminiModel(Model): 

83 """A model that uses Gemini via `generativelanguage.googleapis.com` API. 

84 

85 This is implemented from scratch rather than using a dedicated SDK, good API documentation is 

86 available [here](https://ai.google.dev/api). 

87 

88 Apart from `__init__`, all methods are private or match those of the base class. 

89 """ 

90 

91 client: httpx.AsyncClient = field(repr=False) 

92 

93 _model_name: GeminiModelName = field(repr=False) 

94 _provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False) 

95 _auth: AuthProtocol | None = field(repr=False) 

96 _url: str | None = field(repr=False) 

97 _system: str = field(default='gemini', repr=False) 

98 

99 def __init__( 

100 self, 

101 model_name: GeminiModelName, 

102 *, 

103 provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla', 

104 ): 

105 """Initialize a Gemini model. 

106 

107 Args: 

108 model_name: The name of the model to use. 

109 provider: The provider to use for authentication and API access. Can be either the string 

110 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`. 

111 If not provided, a new provider will be created using the other parameters. 

112 """ 

113 self._model_name = model_name 

114 self._provider = provider 

115 

116 if isinstance(provider, str): 

117 provider = infer_provider(provider) 

118 self._system = provider.name 

119 self.client = provider.client 

120 self._url = str(self.client.base_url) 

121 

122 @property 

123 def base_url(self) -> str: 

124 assert self._url is not None, 'URL not initialized' 

125 return self._url 

126 

127 async def request( 

128 self, 

129 messages: list[ModelMessage], 

130 model_settings: ModelSettings | None, 

131 model_request_parameters: ModelRequestParameters, 

132 ) -> tuple[ModelResponse, usage.Usage]: 

133 check_allow_model_requests() 

134 async with self._make_request( 

135 messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters 

136 ) as http_response: 

137 response = _gemini_response_ta.validate_json(await http_response.aread()) 

138 return self._process_response(response), _metadata_as_usage(response) 

139 

140 @asynccontextmanager 

141 async def request_stream( 

142 self, 

143 messages: list[ModelMessage], 

144 model_settings: ModelSettings | None, 

145 model_request_parameters: ModelRequestParameters, 

146 ) -> AsyncIterator[StreamedResponse]: 

147 check_allow_model_requests() 

148 async with self._make_request( 

149 messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters 

150 ) as http_response: 

151 yield await self._process_streamed_response(http_response) 

152 

153 @property 

154 def model_name(self) -> GeminiModelName: 

155 """The model name.""" 

156 return self._model_name 

157 

158 @property 

159 def system(self) -> str: 

160 """The system / model provider.""" 

161 return self._system 

162 

163 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None: 

164 tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools] 

165 if model_request_parameters.result_tools: 

166 tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools] 

167 return _GeminiTools(function_declarations=tools) if tools else None 

168 

169 def _get_tool_config( 

170 self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None 

171 ) -> _GeminiToolConfig | None: 

172 if model_request_parameters.allow_text_result: 

173 return None 

174 elif tools: 174 ↛ 177line 174 didn't jump to line 177 because the condition on line 174 was always true

175 return _tool_config([t['name'] for t in tools['function_declarations']]) 

176 else: 

177 return _tool_config([]) 

178 

179 @asynccontextmanager 

180 async def _make_request( 

181 self, 

182 messages: list[ModelMessage], 

183 streamed: bool, 

184 model_settings: GeminiModelSettings, 

185 model_request_parameters: ModelRequestParameters, 

186 ) -> AsyncIterator[HTTPResponse]: 

187 tools = self._get_tools(model_request_parameters) 

188 tool_config = self._get_tool_config(model_request_parameters, tools) 

189 sys_prompt_parts, contents = await self._message_to_gemini_content(messages) 

190 

191 request_data = _GeminiRequest(contents=contents) 

192 if sys_prompt_parts: 

193 request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) 

194 if tools is not None: 

195 request_data['tools'] = tools 

196 if tool_config is not None: 

197 request_data['tool_config'] = tool_config 

198 

199 generation_config: _GeminiGenerationConfig = {} 

200 if model_settings: 

201 if (max_tokens := model_settings.get('max_tokens')) is not None: 

202 generation_config['max_output_tokens'] = max_tokens 

203 if (temperature := model_settings.get('temperature')) is not None: 

204 generation_config['temperature'] = temperature 

205 if (top_p := model_settings.get('top_p')) is not None: 

206 generation_config['top_p'] = top_p 

207 if (presence_penalty := model_settings.get('presence_penalty')) is not None: 

208 generation_config['presence_penalty'] = presence_penalty 

209 if (frequency_penalty := model_settings.get('frequency_penalty')) is not None: 

210 generation_config['frequency_penalty'] = frequency_penalty 

211 if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []: 211 ↛ 213line 211 didn't jump to line 213 because the condition on line 211 was always true

212 request_data['safety_settings'] = gemini_safety_settings 

213 if generation_config: 

214 request_data['generation_config'] = generation_config 

215 

216 headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} 

217 url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}' 

218 

219 request_json = _gemini_request_ta.dump_json(request_data, by_alias=True) 

220 async with self.client.stream( 

221 'POST', 

222 url, 

223 content=request_json, 

224 headers=headers, 

225 timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), 

226 ) as r: 

227 if (status_code := r.status_code) != 200: 

228 await r.aread() 

229 if status_code >= 400: 229 ↛ 231line 229 didn't jump to line 231 because the condition on line 229 was always true

230 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) 

231 raise UnexpectedModelBehavior(f'Unexpected response from gemini {status_code}', r.text) 

232 yield r 

233 

234 def _process_response(self, response: _GeminiResponse) -> ModelResponse: 

235 if len(response['candidates']) != 1: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') 

237 if 'content' not in response['candidates'][0]: 

238 if response['candidates'][0].get('finish_reason') == 'SAFETY': 238 ↛ 241line 238 didn't jump to line 241 because the condition on line 238 was always true

239 raise UnexpectedModelBehavior('Safety settings triggered', str(response)) 

240 else: 

241 raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response)) 

242 parts = response['candidates'][0]['content']['parts'] 

243 return _process_response_from_parts(parts, model_name=response.get('model_version', self._model_name)) 

244 

245 async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: 

246 """Process a streamed response, and prepare a streaming response to return.""" 

247 aiter_bytes = http_response.aiter_bytes() 

248 start_response: _GeminiResponse | None = None 

249 content = bytearray() 

250 

251 async for chunk in aiter_bytes: 

252 content.extend(chunk) 

253 responses = _gemini_streamed_response_ta.validate_json( 

254 _ensure_decodeable(content), 

255 experimental_allow_partial='trailing-strings', 

256 ) 

257 if responses: 257 ↛ 251line 257 didn't jump to line 251 because the condition on line 257 was always true

258 last = responses[-1] 

259 if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'): 

260 start_response = last 

261 break 

262 

263 if start_response is None: 

264 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

265 

266 return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes) 

267 

268 @classmethod 

269 async def _message_to_gemini_content( 

270 cls, messages: list[ModelMessage] 

271 ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]: 

272 sys_prompt_parts: list[_GeminiTextPart] = [] 

273 contents: list[_GeminiContent] = [] 

274 for m in messages: 

275 if isinstance(m, ModelRequest): 

276 message_parts: list[_GeminiPartUnion] = [] 

277 

278 for part in m.parts: 

279 if isinstance(part, SystemPromptPart): 

280 sys_prompt_parts.append(_GeminiTextPart(text=part.content)) 

281 elif isinstance(part, UserPromptPart): 

282 message_parts.extend(await cls._map_user_prompt(part)) 

283 elif isinstance(part, ToolReturnPart): 

284 message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object())) 

285 elif isinstance(part, RetryPromptPart): 

286 if part.tool_name is None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true

287 message_parts.append(_GeminiTextPart(text=part.model_response())) 

288 else: 

289 response = {'call_error': part.model_response()} 

290 message_parts.append(_response_part_from_response(part.tool_name, response)) 

291 else: 

292 assert_never(part) 

293 

294 if message_parts: 294 ↛ 274line 294 didn't jump to line 274 because the condition on line 294 was always true

295 contents.append(_GeminiContent(role='user', parts=message_parts)) 

296 elif isinstance(m, ModelResponse): 

297 contents.append(_content_model_response(m)) 

298 else: 

299 assert_never(m) 

300 

301 return sys_prompt_parts, contents 

302 

303 @staticmethod 

304 async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]: 

305 if isinstance(part.content, str): 

306 return [{'text': part.content}] 

307 else: 

308 content: list[_GeminiPartUnion] = [] 

309 for item in part.content: 

310 if isinstance(item, str): 

311 content.append({'text': item}) 

312 elif isinstance(item, BinaryContent): 

313 base64_encoded = base64.b64encode(item.data).decode('utf-8') 

314 content.append( 

315 _GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type}) 

316 ) 

317 elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl)): 

318 client = cached_async_http_client() 

319 response = await client.get(item.url, follow_redirects=True) 

320 response.raise_for_status() 

321 mime_type = response.headers['Content-Type'].split(';')[0] 

322 inline_data = _GeminiInlineDataPart( 

323 inline_data={'data': base64.b64encode(response.content).decode('utf-8'), 'mime_type': mime_type} 

324 ) 

325 content.append(inline_data) 

326 else: 

327 assert_never(item) 

328 return content 

329 

330 

331class AuthProtocol(Protocol): 

332 """Abstract definition for Gemini authentication.""" 

333 

334 async def headers(self) -> dict[str, str]: ... 

335 

336 

337@dataclass 

338class ApiKeyAuth: 

339 """Authentication using an API key for the `X-Goog-Api-Key` header.""" 

340 

341 api_key: str 

342 

343 async def headers(self) -> dict[str, str]: 

344 # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest 

345 return {'X-Goog-Api-Key': self.api_key} 

346 

347 

348@dataclass 

349class GeminiStreamedResponse(StreamedResponse): 

350 """Implementation of `StreamedResponse` for the Gemini model.""" 

351 

352 _model_name: GeminiModelName 

353 _content: bytearray 

354 _stream: AsyncIterator[bytes] 

355 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) 

356 

357 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

358 async for gemini_response in self._get_gemini_responses(): 

359 candidate = gemini_response['candidates'][0] 

360 if 'content' not in candidate: 360 ↛ 361line 360 didn't jump to line 361 because the condition on line 360 was never true

361 raise UnexpectedModelBehavior('Streamed response has no content field') 

362 gemini_part: _GeminiPartUnion 

363 for gemini_part in candidate['content']['parts']: 

364 if 'text' in gemini_part: 

365 # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled 

366 # amongst the tool call deltas 

367 yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text']) 

368 

369 elif 'function_call' in gemini_part: 369 ↛ 383line 369 didn't jump to line 383 because the condition on line 369 was always true

370 # Here, we assume all function_call parts are complete and don't have deltas. 

371 # We do this by assigning a unique randomly generated "vendor_part_id". 

372 # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly 

373 # it would just be a bit more complicated. And we'd need to confirm the intended semantics. 

374 maybe_event = self._parts_manager.handle_tool_call_delta( 

375 vendor_part_id=uuid4(), 

376 tool_name=gemini_part['function_call']['name'], 

377 args=gemini_part['function_call']['args'], 

378 tool_call_id=None, 

379 ) 

380 if maybe_event is not None: 380 ↛ 363line 380 didn't jump to line 363 because the condition on line 380 was always true

381 yield maybe_event 

382 else: 

383 assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}' 

384 

385 async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: 

386 # This method exists to ensure we only yield completed items, so we don't need to worry about 

387 # partial gemini responses, which would make everything more complicated 

388 

389 gemini_responses: list[_GeminiResponse] = [] 

390 current_gemini_response_index = 0 

391 # Right now, there are some circumstances where we will have information that could be yielded sooner than it is 

392 # But changing that would make things a lot more complicated. 

393 async for chunk in self._stream: 

394 self._content.extend(chunk) 

395 

396 gemini_responses = _gemini_streamed_response_ta.validate_json( 

397 _ensure_decodeable(self._content), 

398 experimental_allow_partial='trailing-strings', 

399 ) 

400 

401 # The idea: yield only up to the latest response, which might still be partial. 

402 # Note that if the latest response is complete, we could yield it immediately, but there's not a good 

403 # allow_partial API to determine if the last item in the list is complete. 

404 responses_to_yield = gemini_responses[:-1] 

405 for r in responses_to_yield[current_gemini_response_index:]: 

406 current_gemini_response_index += 1 

407 self._usage += _metadata_as_usage(r) 

408 yield r 

409 

410 # Now yield the final response, which should be complete 

411 if gemini_responses: 411 ↛ exitline 411 didn't return from function '_get_gemini_responses' because the condition on line 411 was always true

412 r = gemini_responses[-1] 

413 self._usage += _metadata_as_usage(r) 

414 yield r 

415 

416 @property 

417 def model_name(self) -> GeminiModelName: 

418 """Get the model name of the response.""" 

419 return self._model_name 

420 

421 @property 

422 def timestamp(self) -> datetime: 

423 """Get the timestamp of the response.""" 

424 return self._timestamp 

425 

426 

427# We use typed dicts to define the Gemini API response schema 

428# once Pydantic partial validation supports, dataclasses, we could revert to using them 

429# TypeAdapters take care of validation and serialization 

430 

431 

432@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) 

433class _GeminiRequest(TypedDict): 

434 """Schema for an API request to the Gemini API. 

435 

436 See <https://ai.google.dev/api/generate-content#request-body> for API docs. 

437 """ 

438 

439 contents: list[_GeminiContent] 

440 tools: NotRequired[_GeminiTools] 

441 tool_config: NotRequired[_GeminiToolConfig] 

442 safety_settings: NotRequired[list[GeminiSafetySettings]] 

443 # we don't implement `generationConfig`, instead we use a named tool for the response 

444 system_instruction: NotRequired[_GeminiTextContent] 

445 """ 

446 Developer generated system instructions, see 

447 <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest> 

448 """ 

449 generation_config: NotRequired[_GeminiGenerationConfig] 

450 

451 

452class GeminiSafetySettings(TypedDict): 

453 """Safety settings options for Gemini model request. 

454 

455 See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions. 

456 For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings). 

457 """ 

458 

459 category: Literal[ 

460 'HARM_CATEGORY_UNSPECIFIED', 

461 'HARM_CATEGORY_HARASSMENT', 

462 'HARM_CATEGORY_HATE_SPEECH', 

463 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 

464 'HARM_CATEGORY_DANGEROUS_CONTENT', 

465 'HARM_CATEGORY_CIVIC_INTEGRITY', 

466 ] 

467 """ 

468 Safety settings category. 

469 """ 

470 

471 threshold: Literal[ 

472 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', 

473 'BLOCK_LOW_AND_ABOVE', 

474 'BLOCK_MEDIUM_AND_ABOVE', 

475 'BLOCK_ONLY_HIGH', 

476 'BLOCK_NONE', 

477 'OFF', 

478 ] 

479 """ 

480 Safety settings threshold. 

481 """ 

482 

483 

484class _GeminiGenerationConfig(TypedDict, total=False): 

485 """Schema for an API request to the Gemini API. 

486 

487 Note there are many additional fields available that have not been added yet. 

488 

489 See <https://ai.google.dev/api/generate-content#generationconfig> for API docs. 

490 """ 

491 

492 max_output_tokens: int 

493 temperature: float 

494 top_p: float 

495 presence_penalty: float 

496 frequency_penalty: float 

497 

498 

499class _GeminiContent(TypedDict): 

500 role: Literal['user', 'model'] 

501 parts: list[_GeminiPartUnion] 

502 

503 

504def _content_model_response(m: ModelResponse) -> _GeminiContent: 

505 parts: list[_GeminiPartUnion] = [] 

506 for item in m.parts: 

507 if isinstance(item, ToolCallPart): 

508 parts.append(_function_call_part_from_call(item)) 

509 elif isinstance(item, TextPart): 

510 if item.content: 

511 parts.append(_GeminiTextPart(text=item.content)) 

512 else: 

513 assert_never(item) 

514 return _GeminiContent(role='model', parts=parts) 

515 

516 

517class _GeminiTextPart(TypedDict): 

518 text: str 

519 

520 

521class _GeminiInlineData(TypedDict): 

522 data: str 

523 mime_type: Annotated[str, pydantic.Field(alias='mimeType')] 

524 

525 

526class _GeminiInlineDataPart(TypedDict): 

527 """See <https://ai.google.dev/api/caching#Blob>.""" 

528 

529 inline_data: Annotated[_GeminiInlineData, pydantic.Field(alias='inlineData')] 

530 

531 

532class _GeminiFileData(TypedDict): 

533 """See <https://ai.google.dev/api/caching#FileData>.""" 

534 

535 file_uri: Annotated[str, pydantic.Field(alias='fileUri')] 

536 mime_type: Annotated[str, pydantic.Field(alias='mimeType')] 

537 

538 

539class _GeminiFileDataPart(TypedDict): 

540 file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')] 

541 

542 

543class _GeminiFunctionCallPart(TypedDict): 

544 function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')] 

545 

546 

547def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart: 

548 return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict())) 

549 

550 

551def _process_response_from_parts( 

552 parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None 

553) -> ModelResponse: 

554 items: list[ModelResponsePart] = [] 

555 for part in parts: 

556 if 'text' in part: 

557 items.append(TextPart(content=part['text'])) 

558 elif 'function_call' in part: 558 ↛ 560line 558 didn't jump to line 560 because the condition on line 558 was always true

559 items.append(ToolCallPart(tool_name=part['function_call']['name'], args=part['function_call']['args'])) 

560 elif 'function_response' in part: 

561 raise UnexpectedModelBehavior( 

562 f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' 

563 ) 

564 return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc()) 

565 

566 

567class _GeminiFunctionCall(TypedDict): 

568 """See <https://ai.google.dev/api/caching#FunctionCall>.""" 

569 

570 name: str 

571 args: dict[str, Any] 

572 

573 

574class _GeminiFunctionResponsePart(TypedDict): 

575 function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')] 

576 

577 

578def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart: 

579 return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response)) 

580 

581 

582class _GeminiFunctionResponse(TypedDict): 

583 """See <https://ai.google.dev/api/caching#FunctionResponse>.""" 

584 

585 name: str 

586 response: dict[str, Any] 

587 

588 

589def _part_discriminator(v: Any) -> str: 

590 if isinstance(v, dict): 590 ↛ 601line 590 didn't jump to line 601 because the condition on line 590 was always true

591 if 'text' in v: 

592 return 'text' 

593 elif 'inlineData' in v: 593 ↛ 594line 593 didn't jump to line 594 because the condition on line 593 was never true

594 return 'inline_data' 

595 elif 'fileData' in v: 595 ↛ 596line 595 didn't jump to line 596 because the condition on line 595 was never true

596 return 'file_data' 

597 elif 'functionCall' in v or 'function_call' in v: 

598 return 'function_call' 

599 elif 'functionResponse' in v or 'function_response' in v: 

600 return 'function_response' 

601 return 'text' 

602 

603 

604# See <https://ai.google.dev/api/caching#Part> 

605# we don't currently support other part types 

606# TODO discriminator 

607_GeminiPartUnion = Annotated[ 

608 Union[ 

609 Annotated[_GeminiTextPart, pydantic.Tag('text')], 

610 Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')], 

611 Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')], 

612 Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')], 

613 Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')], 

614 ], 

615 pydantic.Discriminator(_part_discriminator), 

616] 

617 

618 

619class _GeminiTextContent(TypedDict): 

620 role: Literal['user', 'model'] 

621 parts: list[_GeminiTextPart] 

622 

623 

624class _GeminiTools(TypedDict): 

625 function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]] 

626 

627 

628class _GeminiFunction(TypedDict): 

629 name: str 

630 description: str 

631 parameters: NotRequired[dict[str, Any]] 

632 """ 

633 ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema 

634 <https://ai.google.dev/gemini-api/docs/function-calling#function_declarations> 

635 and 

636 <https://ai.google.dev/api/caching#FunctionDeclaration> 

637 """ 

638 

639 

640def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction: 

641 json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify() 

642 f = _GeminiFunction( 

643 name=tool.name, 

644 description=tool.description, 

645 ) 

646 if json_schema.get('properties'): 646 ↛ 648line 646 didn't jump to line 648 because the condition on line 646 was always true

647 f['parameters'] = json_schema 

648 return f 

649 

650 

651class _GeminiToolConfig(TypedDict): 

652 function_calling_config: _GeminiFunctionCallingConfig 

653 

654 

655def _tool_config(function_names: list[str]) -> _GeminiToolConfig: 

656 return _GeminiToolConfig( 

657 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names) 

658 ) 

659 

660 

661class _GeminiFunctionCallingConfig(TypedDict): 

662 mode: Literal['ANY', 'AUTO'] 

663 allowed_function_names: list[str] 

664 

665 

666@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) 

667class _GeminiResponse(TypedDict): 

668 """Schema for the response from the Gemini API. 

669 

670 See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse> 

671 and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse> 

672 """ 

673 

674 candidates: list[_GeminiCandidates] 

675 # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response 

676 usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] 

677 prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] 

678 model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]] 

679 

680 

681class _GeminiCandidates(TypedDict): 

682 """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>.""" 

683 

684 content: NotRequired[_GeminiContent] 

685 finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]] 

686 """ 

687 See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible, 

688 but let's wait until we see them and know what they mean to add them here. 

689 """ 

690 avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]] 

691 index: NotRequired[int] 

692 safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]] 

693 

694 

695class _GeminiUsageMetaData(TypedDict, total=False): 

696 """See <https://ai.google.dev/api/generate-content#FinishReason>. 

697 

698 The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. 

699 """ 

700 

701 prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')] 

702 candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] 

703 total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] 

704 cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] 

705 

706 

707def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: 

708 metadata = response.get('usage_metadata') 

709 if metadata is None: 709 ↛ 710line 709 didn't jump to line 710 because the condition on line 709 was never true

710 return usage.Usage() 

711 details: dict[str, int] = {} 

712 if cached_content_token_count := metadata.get('cached_content_token_count'): 712 ↛ 713line 712 didn't jump to line 713 because the condition on line 712 was never true

713 details['cached_content_token_count'] = cached_content_token_count 

714 return usage.Usage( 

715 request_tokens=metadata.get('prompt_token_count', 0), 

716 response_tokens=metadata.get('candidates_token_count', 0), 

717 total_tokens=metadata.get('total_token_count', 0), 

718 details=details, 

719 ) 

720 

721 

722class _GeminiSafetyRating(TypedDict): 

723 """See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>.""" 

724 

725 category: Literal[ 

726 'HARM_CATEGORY_HARASSMENT', 

727 'HARM_CATEGORY_HATE_SPEECH', 

728 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 

729 'HARM_CATEGORY_DANGEROUS_CONTENT', 

730 'HARM_CATEGORY_CIVIC_INTEGRITY', 

731 ] 

732 probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH'] 

733 blocked: NotRequired[bool] 

734 

735 

736class _GeminiPromptFeedback(TypedDict): 

737 """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>.""" 

738 

739 block_reason: Annotated[str, pydantic.Field(alias='blockReason')] 

740 safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')] 

741 

742 

743_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest) 

744_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse) 

745 

746# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent 

747_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True)) 

748 

749 

750class _GeminiJsonSchema: 

751 """Transforms the JSON Schema from Pydantic to be suitable for Gemini. 

752 

753 Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations) 

754 a subset of OpenAPI v3.0.3. 

755 

756 Specifically: 

757 * gemini doesn't allow the `title` keyword to be set 

758 * gemini doesn't allow `$defs` — we need to inline the definitions where possible 

759 """ 

760 

761 def __init__(self, schema: _utils.ObjectJsonSchema): 

762 self.schema = deepcopy(schema) 

763 self.defs = self.schema.pop('$defs', {}) 

764 

765 def simplify(self) -> dict[str, Any]: 

766 self._simplify(self.schema, refs_stack=()) 

767 return self.schema 

768 

769 def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

770 schema.pop('title', None) 

771 schema.pop('default', None) 

772 if ref := schema.pop('$ref', None): 

773 # noinspection PyTypeChecker 

774 key = re.sub(r'^#/\$defs/', '', ref) 

775 if key in refs_stack: 

776 raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini') 

777 refs_stack += (key,) 

778 schema_def = self.defs[key] 

779 self._simplify(schema_def, refs_stack) 

780 schema.update(schema_def) 

781 return 

782 

783 if any_of := schema.get('anyOf'): 

784 for item_schema in any_of: 

785 self._simplify(item_schema, refs_stack) 

786 if len(any_of) == 2 and {'type': 'null'} in any_of: 786 ↛ 794line 786 didn't jump to line 794 because the condition on line 786 was always true

787 for item_schema in any_of: 787 ↛ 794line 787 didn't jump to line 794 because the loop on line 787 didn't complete

788 if item_schema != {'type': 'null'}: 788 ↛ 787line 788 didn't jump to line 787 because the condition on line 788 was always true

789 schema.clear() 

790 schema.update(item_schema) 

791 schema['nullable'] = True 

792 return 

793 

794 type_ = schema.get('type') 

795 

796 if type_ == 'object': 

797 self._object(schema, refs_stack) 

798 elif type_ == 'array': 

799 return self._array(schema, refs_stack) 

800 elif type_ == 'string' and (fmt := schema.pop('format', None)): 

801 description = schema.get('description') 

802 if description: 

803 schema['description'] = f'{description} (format: {fmt})' 

804 else: 

805 schema['description'] = f'Format: {fmt}' 

806 

807 def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

808 ad_props = schema.pop('additionalProperties', None) 

809 if ad_props: 809 ↛ 810line 809 didn't jump to line 810 because the condition on line 809 was never true

810 raise UserError('Additional properties in JSON Schema are not supported by Gemini') 

811 

812 if properties := schema.get('properties'): # pragma: no branch 

813 for value in properties.values(): 

814 self._simplify(value, refs_stack) 

815 

816 def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

817 if prefix_items := schema.get('prefixItems'): 

818 # TODO I think this not is supported by Gemini, maybe we should raise an error? 

819 for prefix_item in prefix_items: 

820 self._simplify(prefix_item, refs_stack) 

821 

822 if items_schema := schema.get('items'): # pragma: no branch 

823 self._simplify(items_schema, refs_stack) 

824 

825 

826def _ensure_decodeable(content: bytearray) -> bytearray: 

827 """Trim any invalid unicode point bytes off the end of a bytearray. 

828 

829 This is necessary before attempting to parse streaming JSON bytes. 

830 

831 This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved 

832 """ 

833 while True: 

834 try: 

835 content.decode() 

836 except UnicodeDecodeError: 

837 content = content[:-1] # this will definitely succeed before we run out of bytes 

838 else: 

839 return content