Coverage for pydantic_ai_slim/pydantic_ai/models/gemini.py: 93.66%

355 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-30 19:21 +0000

1from __future__ import annotations as _annotations 

2 

3import os 

4import re 

5from collections.abc import AsyncIterator, Sequence 

6from contextlib import asynccontextmanager 

7from copy import deepcopy 

8from dataclasses import dataclass, field 

9from datetime import datetime 

10from typing import Annotated, Any, Literal, Protocol, Union, cast 

11from uuid import uuid4 

12 

13import pydantic 

14from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse 

15from typing_extensions import NotRequired, TypedDict, assert_never 

16 

17from .. import UnexpectedModelBehavior, _utils, exceptions, usage 

18from ..messages import ( 

19 ModelMessage, 

20 ModelRequest, 

21 ModelResponse, 

22 ModelResponsePart, 

23 ModelResponseStreamEvent, 

24 RetryPromptPart, 

25 SystemPromptPart, 

26 TextPart, 

27 ToolCallPart, 

28 ToolReturnPart, 

29 UserPromptPart, 

30) 

31from ..settings import ModelSettings 

32from ..tools import ToolDefinition 

33from . import ( 

34 AgentModel, 

35 Model, 

36 StreamedResponse, 

37 cached_async_http_client, 

38 check_allow_model_requests, 

39 get_user_agent, 

40) 

41 

42GeminiModelName = Literal[ 

43 'gemini-1.5-flash', 

44 'gemini-1.5-flash-8b', 

45 'gemini-1.5-pro', 

46 'gemini-1.0-pro', 

47 'gemini-2.0-flash-exp', 

48 'gemini-2.0-flash-thinking-exp-01-21', 

49 'gemini-exp-1206', 

50] 

51"""Named Gemini models. 

52 

53See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. 

54""" 

55 

56 

57class GeminiModelSettings(ModelSettings): 

58 """Settings used for a Gemini model request.""" 

59 

60 # This class is a placeholder for any future gemini-specific settings 

61 

62 

63@dataclass(init=False) 

64class GeminiModel(Model): 

65 """A model that uses Gemini via `generativelanguage.googleapis.com` API. 

66 

67 This is implemented from scratch rather than using a dedicated SDK, good API documentation is 

68 available [here](https://ai.google.dev/api). 

69 

70 Apart from `__init__`, all methods are private or match those of the base class. 

71 """ 

72 

73 model_name: GeminiModelName 

74 auth: AuthProtocol 

75 http_client: AsyncHTTPClient 

76 url: str 

77 

78 def __init__( 

79 self, 

80 model_name: GeminiModelName, 

81 *, 

82 api_key: str | None = None, 

83 http_client: AsyncHTTPClient | None = None, 

84 url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:', 

85 ): 

86 """Initialize a Gemini model. 

87 

88 Args: 

89 model_name: The name of the model to use. 

90 api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable 

91 will be used if available. 

92 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

93 url_template: The URL template to use for making requests, you shouldn't need to change this, 

94 docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request), 

95 `model` is substituted with the model name, and `function` is added to the end of the URL. 

96 """ 

97 self.model_name = model_name 

98 if api_key is None: 

99 if env_api_key := os.getenv('GEMINI_API_KEY'): 

100 api_key = env_api_key 

101 else: 

102 raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable') 

103 self.auth = ApiKeyAuth(api_key) 

104 self.http_client = http_client or cached_async_http_client() 

105 self.url = url_template.format(model=model_name) 

106 

107 async def agent_model( 

108 self, 

109 *, 

110 function_tools: list[ToolDefinition], 

111 allow_text_result: bool, 

112 result_tools: list[ToolDefinition], 

113 ) -> GeminiAgentModel: 

114 check_allow_model_requests() 

115 return GeminiAgentModel( 

116 http_client=self.http_client, 

117 model_name=self.model_name, 

118 auth=self.auth, 

119 url=self.url, 

120 function_tools=function_tools, 

121 allow_text_result=allow_text_result, 

122 result_tools=result_tools, 

123 ) 

124 

125 def name(self) -> str: 

126 return f'google-gla:{self.model_name}' 

127 

128 

129class AuthProtocol(Protocol): 

130 """Abstract definition for Gemini authentication.""" 

131 

132 async def headers(self) -> dict[str, str]: ... 

133 

134 

135@dataclass 

136class ApiKeyAuth: 

137 """Authentication using an API key for the `X-Goog-Api-Key` header.""" 

138 

139 api_key: str 

140 

141 async def headers(self) -> dict[str, str]: 

142 # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest 

143 return {'X-Goog-Api-Key': self.api_key} 

144 

145 

146@dataclass(init=False) 

147class GeminiAgentModel(AgentModel): 

148 """Implementation of `AgentModel` for Gemini models.""" 

149 

150 http_client: AsyncHTTPClient 

151 model_name: GeminiModelName 

152 auth: AuthProtocol 

153 tools: _GeminiTools | None 

154 tool_config: _GeminiToolConfig | None 

155 url: str 

156 

157 def __init__( 

158 self, 

159 http_client: AsyncHTTPClient, 

160 model_name: GeminiModelName, 

161 auth: AuthProtocol, 

162 url: str, 

163 function_tools: list[ToolDefinition], 

164 allow_text_result: bool, 

165 result_tools: list[ToolDefinition], 

166 ): 

167 tools = [_function_from_abstract_tool(t) for t in function_tools] 

168 if result_tools: 

169 tools += [_function_from_abstract_tool(t) for t in result_tools] 

170 

171 if allow_text_result: 

172 tool_config = None 

173 else: 

174 tool_config = _tool_config([t['name'] for t in tools]) 

175 

176 self.http_client = http_client 

177 self.model_name = model_name 

178 self.auth = auth 

179 self.tools = _GeminiTools(function_declarations=tools) if tools else None 

180 self.tool_config = tool_config 

181 self.url = url 

182 

183 async def request( 

184 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

185 ) -> tuple[ModelResponse, usage.Usage]: 

186 async with self._make_request( 

187 messages, False, cast(GeminiModelSettings, model_settings or {}) 

188 ) as http_response: 

189 response = _gemini_response_ta.validate_json(await http_response.aread()) 

190 return self._process_response(response), _metadata_as_usage(response) 

191 

192 @asynccontextmanager 

193 async def request_stream( 

194 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

195 ) -> AsyncIterator[StreamedResponse]: 

196 async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response: 

197 yield await self._process_streamed_response(http_response) 

198 

199 @asynccontextmanager 

200 async def _make_request( 

201 self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings 

202 ) -> AsyncIterator[HTTPResponse]: 

203 sys_prompt_parts, contents = self._message_to_gemini_content(messages) 

204 

205 request_data = _GeminiRequest(contents=contents) 

206 if sys_prompt_parts: 

207 request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) 

208 if self.tools is not None: 

209 request_data['tools'] = self.tools 

210 if self.tool_config is not None: 

211 request_data['tool_config'] = self.tool_config 

212 

213 generation_config: _GeminiGenerationConfig = {} 

214 if model_settings: 

215 if (max_tokens := model_settings.get('max_tokens')) is not None: 215 ↛ 217line 215 didn't jump to line 217 because the condition on line 215 was always true

216 generation_config['max_output_tokens'] = max_tokens 

217 if (temperature := model_settings.get('temperature')) is not None: 217 ↛ 219line 217 didn't jump to line 219 because the condition on line 217 was always true

218 generation_config['temperature'] = temperature 

219 if (top_p := model_settings.get('top_p')) is not None: 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true

220 generation_config['top_p'] = top_p 

221 if (presence_penalty := model_settings.get('presence_penalty')) is not None: 221 ↛ 223line 221 didn't jump to line 223 because the condition on line 221 was always true

222 generation_config['presence_penalty'] = presence_penalty 

223 if (frequency_penalty := model_settings.get('frequency_penalty')) is not None: 223 ↛ 225line 223 didn't jump to line 225 because the condition on line 223 was always true

224 generation_config['frequency_penalty'] = frequency_penalty 

225 if generation_config: 

226 request_data['generation_config'] = generation_config 

227 

228 url = self.url + ('streamGenerateContent' if streamed else 'generateContent') 

229 

230 headers = { 

231 'Content-Type': 'application/json', 

232 'User-Agent': get_user_agent(), 

233 **await self.auth.headers(), 

234 } 

235 

236 request_json = _gemini_request_ta.dump_json(request_data, by_alias=True) 

237 

238 async with self.http_client.stream( 

239 'POST', 

240 url, 

241 content=request_json, 

242 headers=headers, 

243 timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), 

244 ) as r: 

245 if r.status_code != 200: 

246 await r.aread() 

247 raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text) 

248 yield r 

249 

250 def _process_response(self, response: _GeminiResponse) -> ModelResponse: 

251 if len(response['candidates']) != 1: 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true

252 raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') 

253 parts = response['candidates'][0]['content']['parts'] 

254 return _process_response_from_parts(parts, model_name=self.model_name) 

255 

256 async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: 

257 """Process a streamed response, and prepare a streaming response to return.""" 

258 aiter_bytes = http_response.aiter_bytes() 

259 start_response: _GeminiResponse | None = None 

260 content = bytearray() 

261 

262 async for chunk in aiter_bytes: 

263 content.extend(chunk) 

264 responses = _gemini_streamed_response_ta.validate_json( 

265 content, 

266 experimental_allow_partial='trailing-strings', 

267 ) 

268 if responses: 268 ↛ 262line 268 didn't jump to line 262 because the condition on line 268 was always true

269 last = responses[-1] 

270 if last['candidates'] and last['candidates'][0]['content']['parts']: 

271 start_response = last 

272 break 

273 

274 if start_response is None: 

275 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

276 

277 return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes) 

278 

279 @classmethod 

280 def _message_to_gemini_content( 

281 cls, messages: list[ModelMessage] 

282 ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]: 

283 sys_prompt_parts: list[_GeminiTextPart] = [] 

284 contents: list[_GeminiContent] = [] 

285 for m in messages: 

286 if isinstance(m, ModelRequest): 

287 message_parts: list[_GeminiPartUnion] = [] 

288 

289 for part in m.parts: 

290 if isinstance(part, SystemPromptPart): 

291 sys_prompt_parts.append(_GeminiTextPart(text=part.content)) 

292 elif isinstance(part, UserPromptPart): 

293 message_parts.append(_GeminiTextPart(text=part.content)) 

294 elif isinstance(part, ToolReturnPart): 

295 message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object())) 

296 elif isinstance(part, RetryPromptPart): 

297 if part.tool_name is None: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true

298 message_parts.append(_GeminiTextPart(text=part.model_response())) 

299 else: 

300 response = {'call_error': part.model_response()} 

301 message_parts.append(_response_part_from_response(part.tool_name, response)) 

302 else: 

303 assert_never(part) 

304 

305 if message_parts: 305 ↛ 285line 305 didn't jump to line 285 because the condition on line 305 was always true

306 contents.append(_GeminiContent(role='user', parts=message_parts)) 

307 elif isinstance(m, ModelResponse): 

308 contents.append(_content_model_response(m)) 

309 else: 

310 assert_never(m) 

311 

312 return sys_prompt_parts, contents 

313 

314 

315@dataclass 

316class GeminiStreamedResponse(StreamedResponse): 

317 """Implementation of `StreamedResponse` for the Gemini model.""" 

318 

319 _content: bytearray 

320 _stream: AsyncIterator[bytes] 

321 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) 

322 

323 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

324 async for gemini_response in self._get_gemini_responses(): 

325 candidate = gemini_response['candidates'][0] 

326 gemini_part: _GeminiPartUnion 

327 for gemini_part in candidate['content']['parts']: 

328 if 'text' in gemini_part: 

329 # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled 

330 # amongst the tool call deltas 

331 yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text']) 

332 

333 elif 'function_call' in gemini_part: 333 ↛ 347line 333 didn't jump to line 347 because the condition on line 333 was always true

334 # Here, we assume all function_call parts are complete and don't have deltas. 

335 # We do this by assigning a unique randomly generated "vendor_part_id". 

336 # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly 

337 # it would just be a bit more complicated. And we'd need to confirm the intended semantics. 

338 maybe_event = self._parts_manager.handle_tool_call_delta( 

339 vendor_part_id=uuid4(), 

340 tool_name=gemini_part['function_call']['name'], 

341 args=gemini_part['function_call']['args'], 

342 tool_call_id=None, 

343 ) 

344 if maybe_event is not None: 344 ↛ 327line 344 didn't jump to line 327 because the condition on line 344 was always true

345 yield maybe_event 

346 else: 

347 assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}' 

348 

349 async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: 

350 # This method exists to ensure we only yield completed items, so we don't need to worry about 

351 # partial gemini responses, which would make everything more complicated 

352 

353 gemini_responses: list[_GeminiResponse] = [] 

354 current_gemini_response_index = 0 

355 # Right now, there are some circumstances where we will have information that could be yielded sooner than it is 

356 # But changing that would make things a lot more complicated. 

357 async for chunk in self._stream: 

358 self._content.extend(chunk) 

359 

360 gemini_responses = _gemini_streamed_response_ta.validate_json( 

361 self._content, 

362 experimental_allow_partial='trailing-strings', 

363 ) 

364 

365 # The idea: yield only up to the latest response, which might still be partial. 

366 # Note that if the latest response is complete, we could yield it immediately, but there's not a good 

367 # allow_partial API to determine if the last item in the list is complete. 

368 responses_to_yield = gemini_responses[:-1] 

369 for r in responses_to_yield[current_gemini_response_index:]: 

370 current_gemini_response_index += 1 

371 self._usage += _metadata_as_usage(r) 

372 yield r 

373 

374 # Now yield the final response, which should be complete 

375 if gemini_responses: 375 ↛ exitline 375 didn't return from function '_get_gemini_responses' because the condition on line 375 was always true

376 r = gemini_responses[-1] 

377 self._usage += _metadata_as_usage(r) 

378 yield r 

379 

380 def timestamp(self) -> datetime: 

381 return self._timestamp 

382 

383 

384# We use typed dicts to define the Gemini API response schema 

385# once Pydantic partial validation supports, dataclasses, we could revert to using them 

386# TypeAdapters take care of validation and serialization 

387 

388 

389@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) 

390class _GeminiRequest(TypedDict): 

391 """Schema for an API request to the Gemini API. 

392 

393 See <https://ai.google.dev/api/generate-content#request-body> for API docs. 

394 """ 

395 

396 contents: list[_GeminiContent] 

397 tools: NotRequired[_GeminiTools] 

398 tool_config: NotRequired[_GeminiToolConfig] 

399 # we don't implement `generationConfig`, instead we use a named tool for the response 

400 system_instruction: NotRequired[_GeminiTextContent] 

401 """ 

402 Developer generated system instructions, see 

403 <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest> 

404 """ 

405 generation_config: NotRequired[_GeminiGenerationConfig] 

406 

407 

408class _GeminiGenerationConfig(TypedDict, total=False): 

409 """Schema for an API request to the Gemini API. 

410 

411 Note there are many additional fields available that have not been added yet. 

412 

413 See <https://ai.google.dev/api/generate-content#generationconfig> for API docs. 

414 """ 

415 

416 max_output_tokens: int 

417 temperature: float 

418 top_p: float 

419 presence_penalty: float 

420 frequency_penalty: float 

421 

422 

423class _GeminiContent(TypedDict): 

424 role: Literal['user', 'model'] 

425 parts: list[_GeminiPartUnion] 

426 

427 

428def _content_model_response(m: ModelResponse) -> _GeminiContent: 

429 parts: list[_GeminiPartUnion] = [] 

430 for item in m.parts: 

431 if isinstance(item, ToolCallPart): 

432 parts.append(_function_call_part_from_call(item)) 

433 elif isinstance(item, TextPart): 

434 if item.content: 

435 parts.append(_GeminiTextPart(text=item.content)) 

436 else: 

437 assert_never(item) 

438 return _GeminiContent(role='model', parts=parts) 

439 

440 

441class _GeminiTextPart(TypedDict): 

442 text: str 

443 

444 

445class _GeminiFunctionCallPart(TypedDict): 

446 function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')] 

447 

448 

449def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart: 

450 return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict())) 

451 

452 

453def _process_response_from_parts( 

454 parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None 

455) -> ModelResponse: 

456 items: list[ModelResponsePart] = [] 

457 for part in parts: 

458 if 'text' in part: 

459 items.append(TextPart(content=part['text'])) 

460 elif 'function_call' in part: 460 ↛ 467line 460 didn't jump to line 467 because the condition on line 460 was always true

461 items.append( 

462 ToolCallPart( 

463 tool_name=part['function_call']['name'], 

464 args=part['function_call']['args'], 

465 ) 

466 ) 

467 elif 'function_response' in part: 

468 raise exceptions.UnexpectedModelBehavior( 

469 f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' 

470 ) 

471 return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc()) 

472 

473 

474class _GeminiFunctionCall(TypedDict): 

475 """See <https://ai.google.dev/api/caching#FunctionCall>.""" 

476 

477 name: str 

478 args: dict[str, Any] 

479 

480 

481class _GeminiFunctionResponsePart(TypedDict): 

482 function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')] 

483 

484 

485def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart: 

486 return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response)) 

487 

488 

489class _GeminiFunctionResponse(TypedDict): 

490 """See <https://ai.google.dev/api/caching#FunctionResponse>.""" 

491 

492 name: str 

493 response: dict[str, Any] 

494 

495 

496def _part_discriminator(v: Any) -> str: 

497 if isinstance(v, dict): 497 ↛ 504line 497 didn't jump to line 504 because the condition on line 497 was always true

498 if 'text' in v: 

499 return 'text' 

500 elif 'functionCall' in v or 'function_call' in v: 

501 return 'function_call' 

502 elif 'functionResponse' in v or 'function_response' in v: 

503 return 'function_response' 

504 return 'text' 

505 

506 

507# See <https://ai.google.dev/api/caching#Part> 

508# we don't currently support other part types 

509# TODO discriminator 

510_GeminiPartUnion = Annotated[ 

511 Union[ 

512 Annotated[_GeminiTextPart, pydantic.Tag('text')], 

513 Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')], 

514 Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')], 

515 ], 

516 pydantic.Discriminator(_part_discriminator), 

517] 

518 

519 

520class _GeminiTextContent(TypedDict): 

521 role: Literal['user', 'model'] 

522 parts: list[_GeminiTextPart] 

523 

524 

525class _GeminiTools(TypedDict): 

526 function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]] 

527 

528 

529class _GeminiFunction(TypedDict): 

530 name: str 

531 description: str 

532 parameters: NotRequired[dict[str, Any]] 

533 """ 

534 ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema 

535 <https://ai.google.dev/gemini-api/docs/function-calling#function_declarations> 

536 and 

537 <https://ai.google.dev/api/caching#FunctionDeclaration> 

538 """ 

539 

540 

541def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction: 

542 json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify() 

543 f = _GeminiFunction( 

544 name=tool.name, 

545 description=tool.description, 

546 ) 

547 if json_schema.get('properties'): 547 ↛ 549line 547 didn't jump to line 549 because the condition on line 547 was always true

548 f['parameters'] = json_schema 

549 return f 

550 

551 

552class _GeminiToolConfig(TypedDict): 

553 function_calling_config: _GeminiFunctionCallingConfig 

554 

555 

556def _tool_config(function_names: list[str]) -> _GeminiToolConfig: 

557 return _GeminiToolConfig( 

558 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names) 

559 ) 

560 

561 

562class _GeminiFunctionCallingConfig(TypedDict): 

563 mode: Literal['ANY', 'AUTO'] 

564 allowed_function_names: list[str] 

565 

566 

567@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) 

568class _GeminiResponse(TypedDict): 

569 """Schema for the response from the Gemini API. 

570 

571 See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse> 

572 and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse> 

573 """ 

574 

575 candidates: list[_GeminiCandidates] 

576 # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response 

577 usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] 

578 prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] 

579 

580 

581class _GeminiCandidates(TypedDict): 

582 """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>.""" 

583 

584 content: _GeminiContent 

585 finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]] 

586 """ 

587 See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible, 

588 but let's wait until we see them and know what they mean to add them here. 

589 """ 

590 avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]] 

591 index: NotRequired[int] 

592 safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]] 

593 

594 

595class _GeminiUsageMetaData(TypedDict, total=False): 

596 """See <https://ai.google.dev/api/generate-content#FinishReason>. 

597 

598 The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. 

599 """ 

600 

601 prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')] 

602 candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] 

603 total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] 

604 cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] 

605 

606 

607def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: 

608 metadata = response.get('usage_metadata') 

609 if metadata is None: 609 ↛ 610line 609 didn't jump to line 610 because the condition on line 609 was never true

610 return usage.Usage() 

611 details: dict[str, int] = {} 

612 if cached_content_token_count := metadata.get('cached_content_token_count'): 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true

613 details['cached_content_token_count'] = cached_content_token_count 

614 return usage.Usage( 

615 request_tokens=metadata.get('prompt_token_count', 0), 

616 response_tokens=metadata.get('candidates_token_count', 0), 

617 total_tokens=metadata.get('total_token_count', 0), 

618 details=details, 

619 ) 

620 

621 

622class _GeminiSafetyRating(TypedDict): 

623 """See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>.""" 

624 

625 category: Literal[ 

626 'HARM_CATEGORY_HARASSMENT', 

627 'HARM_CATEGORY_HATE_SPEECH', 

628 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 

629 'HARM_CATEGORY_DANGEROUS_CONTENT', 

630 'HARM_CATEGORY_CIVIC_INTEGRITY', 

631 ] 

632 probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH'] 

633 

634 

635class _GeminiPromptFeedback(TypedDict): 

636 """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>.""" 

637 

638 block_reason: Annotated[str, pydantic.Field(alias='blockReason')] 

639 safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')] 

640 

641 

642_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest) 

643_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse) 

644 

645# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent 

646_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True)) 

647 

648 

649class _GeminiJsonSchema: 

650 """Transforms the JSON Schema from Pydantic to be suitable for Gemini. 

651 

652 Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations) 

653 a subset of OpenAPI v3.0.3. 

654 

655 Specifically: 

656 * gemini doesn't allow the `title` keyword to be set 

657 * gemini doesn't allow `$defs` — we need to inline the definitions where possible 

658 """ 

659 

660 def __init__(self, schema: _utils.ObjectJsonSchema): 

661 self.schema = deepcopy(schema) 

662 self.defs = self.schema.pop('$defs', {}) 

663 

664 def simplify(self) -> dict[str, Any]: 

665 self._simplify(self.schema, refs_stack=()) 

666 return self.schema 

667 

668 def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

669 schema.pop('title', None) 

670 schema.pop('default', None) 

671 if ref := schema.pop('$ref', None): 

672 # noinspection PyTypeChecker 

673 key = re.sub(r'^#/\$defs/', '', ref) 

674 if key in refs_stack: 

675 raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini') 

676 refs_stack += (key,) 

677 schema_def = self.defs[key] 

678 self._simplify(schema_def, refs_stack) 

679 schema.update(schema_def) 

680 return 

681 

682 if any_of := schema.get('anyOf'): 

683 for item_schema in any_of: 

684 self._simplify(item_schema, refs_stack) 

685 if len(any_of) == 2 and {'type': 'null'} in any_of: 685 ↛ 693line 685 didn't jump to line 693 because the condition on line 685 was always true

686 for item_schema in any_of: 686 ↛ 693line 686 didn't jump to line 693 because the loop on line 686 didn't complete

687 if item_schema != {'type': 'null'}: 687 ↛ 686line 687 didn't jump to line 686 because the condition on line 687 was always true

688 schema.clear() 

689 schema.update(item_schema) 

690 schema['nullable'] = True 

691 return 

692 

693 type_ = schema.get('type') 

694 

695 if type_ == 'object': 

696 self._object(schema, refs_stack) 

697 elif type_ == 'array': 

698 return self._array(schema, refs_stack) 

699 elif type_ == 'string' and (fmt := schema.pop('format', None)): 

700 description = schema.get('description') 

701 if description: 

702 schema['description'] = f'{description} (format: {fmt})' 

703 else: 

704 schema['description'] = f'Format: {fmt}' 

705 

706 def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

707 ad_props = schema.pop('additionalProperties', None) 

708 if ad_props: 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true

709 raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini') 

710 

711 if properties := schema.get('properties'): # pragma: no branch 

712 for value in properties.values(): 

713 self._simplify(value, refs_stack) 

714 

715 def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None: 

716 if prefix_items := schema.get('prefixItems'): 

717 # TODO I think this not is supported by Gemini, maybe we should raise an error? 

718 for prefix_item in prefix_items: 

719 self._simplify(prefix_item, refs_stack) 

720 

721 if items_schema := schema.get('items'): # pragma: no branch 

722 self._simplify(items_schema, refs_stack)