Coverage for pydantic_ai_slim/pydantic_ai/models/gemini.py: 93.66%

355 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3import os 

4import re 

5from collections.abc import AsyncIterator, Sequence 

6from contextlib import asynccontextmanager 

7from copy import deepcopy 

8from dataclasses import dataclass, field 

9from datetime import datetime 

10from typing import Annotated, Any, Literal, Protocol, Union, cast 

11from uuid import uuid4 

12 

13import pydantic 

14from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse 

15from typing_extensions import NotRequired, TypedDict, assert_never 

16 

17from .. import UnexpectedModelBehavior, _utils, exceptions, usage 

18from ..messages import ( 

19 ModelMessage, 

20 ModelRequest, 

21 ModelResponse, 

22 ModelResponsePart, 

23 ModelResponseStreamEvent, 

24 RetryPromptPart, 

25 SystemPromptPart, 

26 TextPart, 

27 ToolCallPart, 

28 ToolReturnPart, 

29 UserPromptPart, 

30) 

31from ..settings import ModelSettings 

32from ..tools import ToolDefinition 

33from . import ( 

34 AgentModel, 

35 Model, 

36 StreamedResponse, 

37 cached_async_http_client, 

38 check_allow_model_requests, 

39 get_user_agent, 

40) 

41 

42GeminiModelName = Literal[ 

43 'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp' 

44] 

45"""Named Gemini models. 

46 

47See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. 

48""" 

49 

50 

51class GeminiModelSettings(ModelSettings): 

52 """Settings used for a Gemini model request.""" 

53 

54 # This class is a placeholder for any future gemini-specific settings 

55 

56 

57@dataclass(init=False) 

58class GeminiModel(Model): 

59 """A model that uses Gemini via `generativelanguage.googleapis.com` API. 

60 

61 This is implemented from scratch rather than using a dedicated SDK, good API documentation is 

62 available [here](https://ai.google.dev/api). 

63 

64 Apart from `__init__`, all methods are private or match those of the base class. 

65 """ 

66 

67 model_name: GeminiModelName 

68 auth: AuthProtocol 

69 http_client: AsyncHTTPClient 

70 url: str 

71 

72 def __init__( 

73 self, 

74 model_name: GeminiModelName, 

75 *, 

76 api_key: str | None = None, 

77 http_client: AsyncHTTPClient | None = None, 

78 url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:', 

79 ): 

80 """Initialize a Gemini model. 

81 

82 Args: 

83 model_name: The name of the model to use. 

84 api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable 

85 will be used if available. 

86 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

87 url_template: The URL template to use for making requests, you shouldn't need to change this, 

88 docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request), 

89 `model` is substituted with the model name, and `function` is added to the end of the URL. 

90 """ 

91 self.model_name = model_name 

92 if api_key is None: 

93 if env_api_key := os.getenv('GEMINI_API_KEY'): 

94 api_key = env_api_key 

95 else: 

96 raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable') 

97 self.auth = ApiKeyAuth(api_key) 

98 self.http_client = http_client or cached_async_http_client() 

99 self.url = url_template.format(model=model_name) 

100 

101 async def agent_model( 

102 self, 

103 *, 

104 function_tools: list[ToolDefinition], 

105 allow_text_result: bool, 

106 result_tools: list[ToolDefinition], 

107 ) -> GeminiAgentModel: 

108 check_allow_model_requests() 

109 return GeminiAgentModel( 

110 http_client=self.http_client, 

111 model_name=self.model_name, 

112 auth=self.auth, 

113 url=self.url, 

114 function_tools=function_tools, 

115 allow_text_result=allow_text_result, 

116 result_tools=result_tools, 

117 ) 

118 

119 def name(self) -> str: 

120 return f'google-gla:{self.model_name}' 

121 

122 

123class AuthProtocol(Protocol): 

124 """Abstract definition for Gemini authentication.""" 

125 

126 async def headers(self) -> dict[str, str]: ... 

127 

128 

129@dataclass 

130class ApiKeyAuth: 

131 """Authentication using an API key for the `X-Goog-Api-Key` header.""" 

132 

133 api_key: str 

134 

135 async def headers(self) -> dict[str, str]: 

136 # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest 

137 return {'X-Goog-Api-Key': self.api_key} 

138 

139 

140@dataclass(init=False) 

141class GeminiAgentModel(AgentModel): 

142 """Implementation of `AgentModel` for Gemini models.""" 

143 

144 http_client: AsyncHTTPClient 

145 model_name: GeminiModelName 

146 auth: AuthProtocol 

147 tools: _GeminiTools | None 

148 tool_config: _GeminiToolConfig | None 

149 url: str 

150 

151 def __init__( 

152 self, 

153 http_client: AsyncHTTPClient, 

154 model_name: GeminiModelName, 

155 auth: AuthProtocol, 

156 url: str, 

157 function_tools: list[ToolDefinition], 

158 allow_text_result: bool, 

159 result_tools: list[ToolDefinition], 

160 ): 

161 tools = [_function_from_abstract_tool(t) for t in function_tools] 

162 if result_tools: 

163 tools += [_function_from_abstract_tool(t) for t in result_tools] 

164 

165 if allow_text_result: 

166 tool_config = None 

167 else: 

168 tool_config = _tool_config([t['name'] for t in tools]) 

169 

170 self.http_client = http_client 

171 self.model_name = model_name 

172 self.auth = auth 

173 self.tools = _GeminiTools(function_declarations=tools) if tools else None 

174 self.tool_config = tool_config 

175 self.url = url 

176 

177 async def request( 

178 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

179 ) -> tuple[ModelResponse, usage.Usage]: 

180 async with self._make_request( 

181 messages, False, cast(GeminiModelSettings, model_settings or {}) 

182 ) as http_response: 

183 response = _gemini_response_ta.validate_json(await http_response.aread()) 

184 return self._process_response(response), _metadata_as_usage(response) 

185 

186 @asynccontextmanager 

187 async def request_stream( 

188 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

189 ) -> AsyncIterator[StreamedResponse]: 

190 async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response: 

191 yield await self._process_streamed_response(http_response) 

192 

193 @asynccontextmanager 

194 async def _make_request( 

195 self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings 

196 ) -> AsyncIterator[HTTPResponse]: 

197 sys_prompt_parts, contents = self._message_to_gemini_content(messages) 

198 

199 request_data = _GeminiRequest(contents=contents) 

200 if sys_prompt_parts: 

201 request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) 

202 if self.tools is not None: 

203 request_data['tools'] = self.tools 

204 if self.tool_config is not None: 

205 request_data['tool_config'] = self.tool_config 

206 

207 generation_config: _GeminiGenerationConfig = {} 

208 if model_settings: 

209 if (max_tokens := model_settings.get('max_tokens')) is not None: 209 ↛ 211line 209 didn't jump to line 211 because the condition on line 209 was always true

210 generation_config['max_output_tokens'] = max_tokens 

211 if (temperature := model_settings.get('temperature')) is not None: 211 ↛ 213line 211 didn't jump to line 213 because the condition on line 211 was always true

212 generation_config['temperature'] = temperature 

213 if (top_p := model_settings.get('top_p')) is not None: 213 ↛ 215line 213 didn't jump to line 215 because the condition on line 213 was always true

214 generation_config['top_p'] = top_p 

215 if (presence_penalty := model_settings.get('presence_penalty')) is not None: 215 ↛ 217line 215 didn't jump to line 217 because the condition on line 215 was always true

216 generation_config['presence_penalty'] = presence_penalty 

217 if (frequency_penalty := model_settings.get('frequency_penalty')) is not None: 217 ↛ 219line 217 didn't jump to line 219 because the condition on line 217 was always true

218 generation_config['frequency_penalty'] = frequency_penalty 

219 if generation_config: 

220 request_data['generation_config'] = generation_config 

221 

222 url = self.url + ('streamGenerateContent' if streamed else 'generateContent') 

223 

224 headers = { 

225 'Content-Type': 'application/json', 

226 'User-Agent': get_user_agent(), 

227 **await self.auth.headers(), 

228 } 

229 

230 request_json = _gemini_request_ta.dump_json(request_data, by_alias=True) 

231 

232 async with self.http_client.stream( 

233 'POST', 

234 url, 

235 content=request_json, 

236 headers=headers, 

237 timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), 

238 ) as r: 

239 if r.status_code != 200: 

240 await r.aread() 

241 raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text) 

242 yield r 

243 

244 def _process_response(self, response: _GeminiResponse) -> ModelResponse: 

245 if len(response['candidates']) != 1: 245 ↛ 246line 245 didn't jump to line 246 because the condition on line 245 was never true

246 raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') 

247 parts = response['candidates'][0]['content']['parts'] 

248 return _process_response_from_parts(parts, model_name=self.model_name) 

249 

250 async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: 

251 """Process a streamed response, and prepare a streaming response to return.""" 

252 aiter_bytes = http_response.aiter_bytes() 

253 start_response: _GeminiResponse | None = None 

254 content = bytearray() 

255 

256 async for chunk in aiter_bytes: 

257 content.extend(chunk) 

258 responses = _gemini_streamed_response_ta.validate_json( 

259 content, 

260 experimental_allow_partial='trailing-strings', 

261 ) 

262 if responses: 262 ↛ 256line 262 didn't jump to line 256 because the condition on line 262 was always true

263 last = responses[-1] 

264 if last['candidates'] and last['candidates'][0]['content']['parts']: 

265 start_response = last 

266 break 

267 

268 if start_response is None: 

269 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

270 

271 return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes) 

272 

273 @classmethod 

274 def _message_to_gemini_content( 

275 cls, messages: list[ModelMessage] 

276 ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]: 

277 sys_prompt_parts: list[_GeminiTextPart] = [] 

278 contents: list[_GeminiContent] = [] 

279 for m in messages: 

280 if isinstance(m, ModelRequest): 

281 message_parts: list[_GeminiPartUnion] = [] 

282 

283 for part in m.parts: 

284 if isinstance(part, SystemPromptPart): 

285 sys_prompt_parts.append(_GeminiTextPart(text=part.content)) 

286 elif isinstance(part, UserPromptPart): 

287 message_parts.append(_GeminiTextPart(text=part.content)) 

288 elif isinstance(part, ToolReturnPart): 

289 message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object())) 

290 elif isinstance(part, RetryPromptPart): 

291 if part.tool_name is None: 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true

292 message_parts.append(_GeminiTextPart(text=part.model_response())) 

293 else: 

294 response = {'call_error': part.model_response()} 

295 message_parts.append(_response_part_from_response(part.tool_name, response)) 

296 else: 

297 assert_never(part) 

298 

299 if message_parts: 299 ↛ 279line 299 didn't jump to line 279 because the condition on line 299 was always true

300 contents.append(_GeminiContent(role='user', parts=message_parts)) 

301 elif isinstance(m, ModelResponse): 

302 contents.append(_content_model_response(m)) 

303 else: 

304 assert_never(m) 

305 

306 return sys_prompt_parts, contents 

307 

308 

309@dataclass 

310class GeminiStreamedResponse(StreamedResponse): 

311 """Implementation of `StreamedResponse` for the Gemini model.""" 

312 

313 _content: bytearray 

314 _stream: AsyncIterator[bytes] 

315 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) 

316 

317 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

318 async for gemini_response in self._get_gemini_responses(): 

319 candidate = gemini_response['candidates'][0] 

320 gemini_part: _GeminiPartUnion 

321 for gemini_part in candidate['content']['parts']: 

322 if 'text' in gemini_part: 

323 # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled 

324 # amongst the tool call deltas 

325 yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text']) 

326 

327 elif 'function_call' in gemini_part: 327 ↛ 341line 327 didn't jump to line 341 because the condition on line 327 was always true

328 # Here, we assume all function_call parts are complete and don't have deltas. 

329 # We do this by assigning a unique randomly generated "vendor_part_id". 

330 # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly 

331 # it would just be a bit more complicated. And we'd need to confirm the intended semantics. 

332 maybe_event = self._parts_manager.handle_tool_call_delta( 

333 vendor_part_id=uuid4(), 

334 tool_name=gemini_part['function_call']['name'], 

335 args=gemini_part['function_call']['args'], 

336 tool_call_id=None, 

337 ) 

338 if maybe_event is not None: 338 ↛ 321line 338 didn't jump to line 321 because the condition on line 338 was always true

339 yield maybe_event 

340 else: 

341 assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}' 

342 

343 async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: 

344 # This method exists to ensure we only yield completed items, so we don't need to worry about 

345 # partial gemini responses, which would make everything more complicated 

346 

347 gemini_responses: list[_GeminiResponse] = [] 

348 current_gemini_response_index = 0 

349 # Right now, there are some circumstances where we will have information that could be yielded sooner than it is 

350 # But changing that would make things a lot more complicated. 

351 async for chunk in self._stream: 

352 self._content.extend(chunk) 

353 

354 gemini_responses = _gemini_streamed_response_ta.validate_json( 

355 self._content, 

356 experimental_allow_partial='trailing-strings', 

357 ) 

358 

359 # The idea: yield only up to the latest response, which might still be partial. 

360 # Note that if the latest response is complete, we could yield it immediately, but there's not a good 

361 # allow_partial API to determine if the last item in the list is complete. 

362 responses_to_yield = gemini_responses[:-1] 

363 for r in responses_to_yield[current_gemini_response_index:]: 

364 current_gemini_response_index += 1 

365 self._usage += _metadata_as_usage(r) 

366 yield r 

367 

368 # Now yield the final response, which should be complete 

369 if gemini_responses: 369 ↛ exitline 369 didn't return from function '_get_gemini_responses' because the condition on line 369 was always true

370 r = gemini_responses[-1] 

371 self._usage += _metadata_as_usage(r) 

372 yield r 

373 

374 def timestamp(self) -> datetime: 

375 return self._timestamp 

376 

377 

378# We use typed dicts to define the Gemini API response schema 

379# once Pydantic partial validation supports, dataclasses, we could revert to using them 

380# TypeAdapters take care of validation and serialization 

381 

382 

383@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) 

384class _GeminiRequest(TypedDict): 

385 """Schema for an API request to the Gemini API. 

386 

387 See <https://ai.google.dev/api/generate-content#request-body> for API docs. 

388 """ 

389 

390 contents: list[_GeminiContent] 

391 tools: NotRequired[_GeminiTools] 

392 tool_config: NotRequired[_GeminiToolConfig] 

393 # we don't implement `generationConfig`, instead we use a named tool for the response 

394 system_instruction: NotRequired[_GeminiTextContent] 

395 """ 

396 Developer generated system instructions, see 

397 <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest> 

398 """ 

399 generation_config: NotRequired[_GeminiGenerationConfig] 

400 

401 

402class _GeminiGenerationConfig(TypedDict, total=False): 

403 """Schema for an API request to the Gemini API. 

404 

405 Note there are many additional fields available that have not been added yet. 

406 

407 See <https://ai.google.dev/api/generate-content#generationconfig> for API docs. 

408 """ 

409 

410 max_output_tokens: int 

411 temperature: float 

412 top_p: float 

413 presence_penalty: float 

414 frequency_penalty: float 

415 

416 

417class _GeminiContent(TypedDict): 

418 role: Literal['user', 'model'] 

419 parts: list[_GeminiPartUnion] 

420 

421 

422def _content_model_response(m: ModelResponse) -> _GeminiContent: 

423 parts: list[_GeminiPartUnion] = [] 

424 for item in m.parts: 

425 if isinstance(item, ToolCallPart): 

426 parts.append(_function_call_part_from_call(item)) 

427 elif isinstance(item, TextPart): 

428 if item.content: 

429 parts.append(_GeminiTextPart(text=item.content)) 

430 else: 

431 assert_never(item) 

432 return _GeminiContent(role='model', parts=parts) 

433 

434 

435class _GeminiTextPart(TypedDict): 

436 text: str 

437 

438 

439class _GeminiFunctionCallPart(TypedDict): 

440 function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')] 

441 

442 

443def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart: 

444 return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict())) 

445 

446 

447def _process_response_from_parts( 

448 parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None 

449) -> ModelResponse: 

450 items: list[ModelResponsePart] = [] 

451 for part in parts: 

452 if 'text' in part: 

453 items.append(TextPart(content=part['text'])) 

454 elif 'function_call' in part: 454 ↛ 461line 454 didn't jump to line 461 because the condition on line 454 was always true

455 items.append( 

456 ToolCallPart( 

457 tool_name=part['function_call']['name'], 

458 args=part['function_call']['args'], 

459 ) 

460 ) 

461 elif 'function_response' in part: 

462 raise exceptions.UnexpectedModelBehavior( 

463 f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' 

464 ) 

465 return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc()) 

466 

467 

468class _GeminiFunctionCall(TypedDict): 

469 """See <https://ai.google.dev/api/caching#FunctionCall>.""" 

470 

471 name: str 

472 args: dict[str, Any] 

473 

474 

475class _GeminiFunctionResponsePart(TypedDict): 

476 function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')] 

477 

478 

479def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart: 

480 return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response)) 

481 

482 

483class _GeminiFunctionResponse(TypedDict): 

484 """See <https://ai.google.dev/api/caching#FunctionResponse>.""" 

485 

486 name: str 

487 response: dict[str, Any] 

488 

489 

490def _part_discriminator(v: Any) -> str: 

491 if isinstance(v, dict): 491 ↛ 498line 491 didn't jump to line 498 because the condition on line 491 was always true

492 if 'text' in v: 

493 return 'text' 

494 elif 'functionCall' in v or 'function_call' in v: 

495 return 'function_call' 

496 elif 'functionResponse' in v or 'function_response' in v: 

497 return 'function_response' 

498 return 'text' 

499 

500 

501# See <https://ai.google.dev/api/caching#Part> 

502# we don't currently support other part types 

503# TODO discriminator 

504_GeminiPartUnion = Annotated[ 

505 Union[ 

506 Annotated[_GeminiTextPart, pydantic.Tag('text')], 

507 Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')], 

508 Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')], 

509 ], 

510 pydantic.Discriminator(_part_discriminator), 

511] 

512 

513 

514class _GeminiTextContent(TypedDict): 

515 role: Literal['user', 'model'] 

516 parts: list[_GeminiTextPart] 

517 

518 

519class _GeminiTools(TypedDict): 

520 function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]] 

521 

522 

523class _GeminiFunction(TypedDict): 

524 name: str 

525 description: str 

526 parameters: NotRequired[dict[str, Any]] 

527 """ 

528 ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema 

529 <https://ai.google.dev/gemini-api/docs/function-calling#function_declarations> 

530 and 

531 <https://ai.google.dev/api/caching#FunctionDeclaration> 

532 """ 

533 

534 

535def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction: 

536 json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify() 

537 f = _GeminiFunction( 

538 name=tool.name, 

539 description=tool.description, 

540 ) 

541 if json_schema.get('properties'): 541 ↛ 543line 541 didn't jump to line 543 because the condition on line 541 was always true

542 f['parameters'] = json_schema 

543 return f 

544 

545 

546class _GeminiToolConfig(TypedDict): 

547 function_calling_config: _GeminiFunctionCallingConfig 

548 

549 

550def _tool_config(function_names: list[str]) -> _GeminiToolConfig: 

551 return _GeminiToolConfig( 

552 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names) 

553 ) 

554 

555 

556class _GeminiFunctionCallingConfig(TypedDict): 

557 mode: Literal['ANY', 'AUTO'] 

558 allowed_function_names: list[str] 

559 

560 

561@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) 

562class _GeminiResponse(TypedDict): 

563 """Schema for the response from the Gemini API. 

564 

565 See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse> 

566 and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse> 

567 """ 

568 

569 candidates: list[_GeminiCandidates] 

570 # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response 

571 usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] 

572 prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] 

573 

574 

575class _GeminiCandidates(TypedDict): 

576 """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>.""" 

577 

578 content: _GeminiContent 

579 finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]] 

580 """ 

581 See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible, 

582 but let's wait until we see them and know what they mean to add them here. 

583 """ 

584 avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]] 

585 index: NotRequired[int] 

586 safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]] 

587 

588 

589class _GeminiUsageMetaData(TypedDict, total=False): 

590 """See <https://ai.google.dev/api/generate-content#FinishReason>. 

591 

592 The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. 

593 """ 

594 

595 prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')] 

596 candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] 

597 total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] 

598 cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] 

599 

600 

601def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: 

602 metadata = response.get('usage_metadata') 

603 if metadata is None: 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true

604 return usage.Usage() 

605 details: dict[str, int] = {} 

606 if cached_content_token_count := metadata.get('cached_content_token_count'): 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true

607 details['cached_content_token_count'] = cached_content_token_count 

608 return usage.Usage( 

609 request_tokens=metadata.get('prompt_token_count', 0), 

610 response_tokens=metadata.get('candidates_token_count', 0), 

611 total_tokens=metadata.get('total_token_count', 0), 

612 details=details, 

613 ) 

614 

615 

616class _GeminiSafetyRating(TypedDict): 

617 """See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>.""" 

618 

619 category: Literal[ 

620 'HARM_CATEGORY_HARASSMENT', 

621 'HARM_CATEGORY_HATE_SPEECH', 

622 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 

623 'HARM_CATEGORY_DANGEROUS_CONTENT', 

624 'HARM_CATEGORY_CIVIC_INTEGRITY', 

625 ] 

626 probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH'] 

627 

628 

629class _GeminiPromptFeedback(TypedDict): 

630 """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>.""" 

631 

632 block_reason: Annotated[str, pydantic.Field(alias='blockReason')] 

633 safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')] 

634 

635 

636_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest) 

637_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse) 

638 

639# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent 

640_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True)) 

641 

642 

643class _GeminiJsonSchema: 

644 """Transforms the JSON Schema from Pydantic to be suitable for Gemini. 

645 

646 Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations) 

647 a subset of OpenAPI v3.0.3. 

648 

649 Specifically: 

650 * gemini doesn't allow the `title` keyword to be set 

651 * gemini doesn't allow `$defs` — we need to inline the definitions where possible 

652 """ 

653 

654 def __init__(self, schema: _utils.ObjectJsonSchema): 

655 self.schema = deepcopy(schema) 

656 self.defs = self.schema.pop('$defs', {}) 

657 

658 def simplify(self) -> dict[str, Any]: 

659 self._simplify(self.schema, refs_stack=()) 

660 return self.schema 

661 

662 def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

663 schema.pop('title', None) 

664 schema.pop('default', None) 

665 if ref := schema.pop('$ref', None): 

666 # noinspection PyTypeChecker 

667 key = re.sub(r'^#/\$defs/', '', ref) 

668 if key in refs_stack: 

669 raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini') 

670 refs_stack += (key,) 

671 schema_def = self.defs[key] 

672 self._simplify(schema_def, refs_stack) 

673 schema.update(schema_def) 

674 return 

675 

676 if any_of := schema.get('anyOf'): 

677 for item_schema in any_of: 

678 self._simplify(item_schema, refs_stack) 

679 if len(any_of) == 2 and {'type': 'null'} in any_of: 679 ↛ 687line 679 didn't jump to line 687 because the condition on line 679 was always true

680 for item_schema in any_of: 680 ↛ 687line 680 didn't jump to line 687 because the loop on line 680 didn't complete

681 if item_schema != {'type': 'null'}: 681 ↛ 680line 681 didn't jump to line 680 because the condition on line 681 was always true

682 schema.clear() 

683 schema.update(item_schema) 

684 schema['nullable'] = True 

685 return 

686 

687 type_ = schema.get('type') 

688 

689 if type_ == 'object': 

690 self._object(schema, refs_stack) 

691 elif type_ == 'array': 

692 return self._array(schema, refs_stack) 

693 elif type_ == 'string' and (fmt := schema.pop('format', None)): 

694 description = schema.get('description') 

695 if description: 

696 schema['description'] = f'{description} (format: {fmt})' 

697 else: 

698 schema['description'] = f'Format: {fmt}' 

699 

700 def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

701 ad_props = schema.pop('additionalProperties', None) 

702 if ad_props: 702 ↛ 703line 702 didn't jump to line 703 because the condition on line 702 was never true

703 raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini') 

704 

705 if properties := schema.get('properties'): # pragma: no branch 

706 for value in properties.values(): 

707 self._simplify(value, refs_stack) 

708 

709 def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

710 if prefix_items := schema.get('prefixItems'): 

711 # TODO I think this not is supported by Gemini, maybe we should raise an error? 

712 for prefix_item in prefix_items: 

713 self._simplify(prefix_item, refs_stack) 

714 

715 if items_schema := schema.get('items'): # pragma: no branch 

716 self._simplify(items_schema, refs_stack)