Coverage for tests/models/test_gemini.py: 99.70%

323 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1# pyright: reportPrivateUsage=false 

2from __future__ import annotations as _annotations 

3 

4import datetime 

5import json 

6from collections.abc import AsyncIterator, Callable, Sequence 

7from dataclasses import dataclass 

8from datetime import timezone 

9 

10import httpx 

11import pytest 

12from inline_snapshot import snapshot 

13from pydantic import BaseModel, Field 

14from typing_extensions import Literal, TypeAlias 

15 

16from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, UserError 

17from pydantic_ai.exceptions import ModelHTTPError 

18from pydantic_ai.messages import ( 

19 BinaryContent, 

20 DocumentUrl, 

21 ImageUrl, 

22 ModelRequest, 

23 ModelResponse, 

24 RetryPromptPart, 

25 SystemPromptPart, 

26 TextPart, 

27 ToolCallPart, 

28 ToolReturnPart, 

29 UserPromptPart, 

30) 

31from pydantic_ai.models import ModelRequestParameters 

32from pydantic_ai.models.gemini import ( 

33 GeminiModel, 

34 GeminiModelSettings, 

35 _content_model_response, 

36 _gemini_response_ta, 

37 _gemini_streamed_response_ta, 

38 _GeminiCandidates, 

39 _GeminiContent, 

40 _GeminiFunction, 

41 _GeminiFunctionCallingConfig, 

42 _GeminiResponse, 

43 _GeminiSafetyRating, 

44 _GeminiToolConfig, 

45 _GeminiTools, 

46 _GeminiUsageMetaData, 

47) 

48from pydantic_ai.providers.google_gla import GoogleGLAProvider 

49from pydantic_ai.result import Usage 

50from pydantic_ai.tools import ToolDefinition 

51 

52from ..conftest import ClientWithHandler, IsNow, IsStr, TestEnv 

53 

54pytestmark = pytest.mark.anyio 

55 

56 

57async def test_model_simple(allow_model_requests: None): 

58 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) 

59 assert isinstance(m.client, httpx.AsyncClient) 

60 assert m.model_name == 'gemini-1.5-flash' 

61 assert 'x-goog-api-key' in m.client.headers 

62 

63 arc = ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[]) 

64 tools = m._get_tools(arc) 

65 tool_config = m._get_tool_config(arc, tools) 

66 assert tools is None 

67 assert tool_config is None 

68 

69 

70async def test_model_tools(allow_model_requests: None): 

71 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) 

72 tools = [ 

73 ToolDefinition( 

74 'foo', 

75 'This is foo', 

76 {'type': 'object', 'title': 'Foo', 'properties': {'bar': {'type': 'number', 'title': 'Bar'}}}, 

77 ), 

78 ToolDefinition( 

79 'apple', 

80 'This is apple', 

81 { 

82 'type': 'object', 

83 'properties': { 

84 'banana': {'type': 'array', 'title': 'Banana', 'items': {'type': 'number', 'title': 'Bar'}} 

85 }, 

86 }, 

87 ), 

88 ] 

89 result_tool = ToolDefinition( 

90 'result', 

91 'This is the tool for the final Result', 

92 {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']}, 

93 ) 

94 

95 arc = ModelRequestParameters(function_tools=tools, allow_text_result=True, result_tools=[result_tool]) 

96 tools = m._get_tools(arc) 

97 tool_config = m._get_tool_config(arc, tools) 

98 assert tools == snapshot( 

99 _GeminiTools( 

100 function_declarations=[ 

101 _GeminiFunction( 

102 name='foo', 

103 description='This is foo', 

104 parameters={'type': 'object', 'properties': {'bar': {'type': 'number'}}}, 

105 ), 

106 _GeminiFunction( 

107 name='apple', 

108 description='This is apple', 

109 parameters={ 

110 'type': 'object', 

111 'properties': {'banana': {'type': 'array', 'items': {'type': 'number'}}}, 

112 }, 

113 ), 

114 _GeminiFunction( 

115 name='result', 

116 description='This is the tool for the final Result', 

117 parameters={ 

118 'type': 'object', 

119 'properties': {'spam': {'type': 'number'}}, 

120 'required': ['spam'], 

121 }, 

122 ), 

123 ] 

124 ) 

125 ) 

126 assert tool_config is None 

127 

128 

129async def test_require_response_tool(allow_model_requests: None): 

130 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) 

131 result_tool = ToolDefinition( 

132 'result', 

133 'This is the tool for the final Result', 

134 {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, 

135 ) 

136 arc = ModelRequestParameters(function_tools=[], allow_text_result=False, result_tools=[result_tool]) 

137 tools = m._get_tools(arc) 

138 tool_config = m._get_tool_config(arc, tools) 

139 assert tools == snapshot( 

140 _GeminiTools( 

141 function_declarations=[ 

142 _GeminiFunction( 

143 name='result', 

144 description='This is the tool for the final Result', 

145 parameters={ 

146 'type': 'object', 

147 'properties': {'spam': {'type': 'number'}}, 

148 }, 

149 ), 

150 ] 

151 ) 

152 ) 

153 assert tool_config == snapshot( 

154 _GeminiToolConfig( 

155 function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=['result']) 

156 ) 

157 ) 

158 

159 

160async def test_json_def_replaced(allow_model_requests: None): 

161 class Location(BaseModel): 

162 lat: float 

163 lng: float = 1.1 

164 

165 class Locations(BaseModel): 

166 locations: list[Location] 

167 

168 json_schema = Locations.model_json_schema() 

169 assert json_schema == snapshot( 

170 { 

171 '$defs': { 

172 'Location': { 

173 'properties': { 

174 'lat': {'title': 'Lat', 'type': 'number'}, 

175 'lng': {'default': 1.1, 'title': 'Lng', 'type': 'number'}, 

176 }, 

177 'required': ['lat'], 

178 'title': 'Location', 

179 'type': 'object', 

180 } 

181 }, 

182 'properties': {'locations': {'items': {'$ref': '#/$defs/Location'}, 'title': 'Locations', 'type': 'array'}}, 

183 'required': ['locations'], 

184 'title': 'Locations', 

185 'type': 'object', 

186 } 

187 ) 

188 

189 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) 

190 result_tool = ToolDefinition( 

191 'result', 

192 'This is the tool for the final Result', 

193 json_schema, 

194 ) 

195 assert m._get_tools( 

196 ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool]) 

197 ) == snapshot( 

198 _GeminiTools( 

199 function_declarations=[ 

200 _GeminiFunction( 

201 name='result', 

202 description='This is the tool for the final Result', 

203 parameters={ 

204 'properties': { 

205 'locations': { 

206 'items': { 

207 'properties': { 

208 'lat': {'type': 'number'}, 

209 'lng': {'type': 'number'}, 

210 }, 

211 'required': ['lat'], 

212 'type': 'object', 

213 }, 

214 'type': 'array', 

215 } 

216 }, 

217 'required': ['locations'], 

218 'type': 'object', 

219 }, 

220 ) 

221 ] 

222 ) 

223 ) 

224 

225 

226async def test_json_def_replaced_any_of(allow_model_requests: None): 

227 class Location(BaseModel): 

228 lat: float 

229 lng: float 

230 

231 class Locations(BaseModel): 

232 op_location: Location | None = None 

233 

234 json_schema = Locations.model_json_schema() 

235 

236 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) 

237 result_tool = ToolDefinition( 

238 'result', 

239 'This is the tool for the final Result', 

240 json_schema, 

241 ) 

242 assert m._get_tools( 

243 ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool]) 

244 ) == snapshot( 

245 _GeminiTools( 

246 function_declarations=[ 

247 _GeminiFunction( 

248 name='result', 

249 description='This is the tool for the final Result', 

250 parameters={ 

251 'properties': { 

252 'op_location': { 

253 'properties': { 

254 'lat': {'type': 'number'}, 

255 'lng': {'type': 'number'}, 

256 }, 

257 'required': ['lat', 'lng'], 

258 'nullable': True, 

259 'type': 'object', 

260 } 

261 }, 

262 'type': 'object', 

263 }, 

264 ) 

265 ] 

266 ) 

267 ) 

268 

269 

270async def test_json_def_recursive(allow_model_requests: None): 

271 class Location(BaseModel): 

272 lat: float 

273 lng: float 

274 nested_locations: list[Location] 

275 

276 json_schema = Location.model_json_schema() 

277 assert json_schema == snapshot( 

278 { 

279 '$defs': { 

280 'Location': { 

281 'properties': { 

282 'lat': {'title': 'Lat', 'type': 'number'}, 

283 'lng': {'title': 'Lng', 'type': 'number'}, 

284 'nested_locations': { 

285 'items': {'$ref': '#/$defs/Location'}, 

286 'title': 'Nested Locations', 

287 'type': 'array', 

288 }, 

289 }, 

290 'required': ['lat', 'lng', 'nested_locations'], 

291 'title': 'Location', 

292 'type': 'object', 

293 } 

294 }, 

295 '$ref': '#/$defs/Location', 

296 } 

297 ) 

298 

299 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) 

300 result_tool = ToolDefinition( 

301 'result', 

302 'This is the tool for the final Result', 

303 json_schema, 

304 ) 

305 with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): 

306 m._get_tools(ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])) 

307 

308 

309async def test_json_def_date(allow_model_requests: None): 

310 class FormattedStringFields(BaseModel): 

311 d: datetime.date 

312 dt: datetime.datetime 

313 t: datetime.time = Field(description='') 

314 td: datetime.timedelta = Field(description='my timedelta') 

315 

316 json_schema = FormattedStringFields.model_json_schema() 

317 assert json_schema == snapshot( 

318 { 

319 'properties': { 

320 'd': {'format': 'date', 'title': 'D', 'type': 'string'}, 

321 'dt': {'format': 'date-time', 'title': 'Dt', 'type': 'string'}, 

322 't': {'format': 'time', 'title': 'T', 'type': 'string', 'description': ''}, 

323 'td': {'format': 'duration', 'title': 'Td', 'type': 'string', 'description': 'my timedelta'}, 

324 }, 

325 'required': ['d', 'dt', 't', 'td'], 

326 'title': 'FormattedStringFields', 

327 'type': 'object', 

328 } 

329 ) 

330 

331 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) 

332 result_tool = ToolDefinition( 

333 'result', 

334 'This is the tool for the final Result', 

335 json_schema, 

336 ) 

337 assert m._get_tools( 

338 ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool]) 

339 ) == snapshot( 

340 _GeminiTools( 

341 function_declarations=[ 

342 _GeminiFunction( 

343 description='This is the tool for the final Result', 

344 name='result', 

345 parameters={ 

346 'properties': { 

347 'd': {'description': 'Format: date', 'type': 'string'}, 

348 'dt': {'description': 'Format: date-time', 'type': 'string'}, 

349 't': {'description': 'Format: time', 'type': 'string'}, 

350 'td': {'description': 'my timedelta (format: duration)', 'type': 'string'}, 

351 }, 

352 'required': ['d', 'dt', 't', 'td'], 

353 'type': 'object', 

354 }, 

355 ) 

356 ] 

357 ) 

358 ) 

359 

360 

361@dataclass 

362class AsyncByteStreamList(httpx.AsyncByteStream): 

363 data: list[bytes] 

364 

365 async def __aiter__(self) -> AsyncIterator[bytes]: 

366 for chunk in self.data: 

367 yield chunk 

368 

369 

370ResOrList: TypeAlias = '_GeminiResponse | httpx.AsyncByteStream | Sequence[_GeminiResponse | httpx.AsyncByteStream]' 

371GetGeminiClient: TypeAlias = 'Callable[[ResOrList], httpx.AsyncClient]' 

372 

373 

374@pytest.fixture 

375async def get_gemini_client( 

376 client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None 

377) -> GetGeminiClient: 

378 env.set('GEMINI_API_KEY', 'via-env-var') 

379 

380 def create_client(response_or_list: ResOrList) -> httpx.AsyncClient: 

381 index = 0 

382 

383 def handler(request: httpx.Request) -> httpx.Response: 

384 nonlocal index 

385 

386 ua = request.headers.get('User-Agent') 

387 assert isinstance(ua, str) and ua.startswith('pydantic-ai') 

388 

389 if isinstance(response_or_list, Sequence): 

390 response = response_or_list[index] 

391 index += 1 

392 else: 

393 response = response_or_list 

394 

395 if isinstance(response, httpx.AsyncByteStream): 

396 content: bytes | None = None 

397 stream: httpx.AsyncByteStream | None = response 

398 else: 

399 content = _gemini_response_ta.dump_json(response, by_alias=True) 

400 stream = None 

401 

402 return httpx.Response( 

403 200, 

404 content=content, 

405 stream=stream, 

406 headers={'Content-Type': 'application/json'}, 

407 ) 

408 

409 return client_with_handler(handler) 

410 

411 return create_client 

412 

413 

414def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | None = 'STOP') -> _GeminiResponse: 

415 candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[]) 

416 if finish_reason: # pragma: no cover 

417 candidate['finish_reason'] = finish_reason 

418 return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage(), model_version='gemini-1.5-flash-123') 

419 

420 

421def example_usage() -> _GeminiUsageMetaData: 

422 return _GeminiUsageMetaData(prompt_token_count=1, candidates_token_count=2, total_token_count=3) 

423 

424 

425async def test_text_success(get_gemini_client: GetGeminiClient): 

426 response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')]))) 

427 gemini_client = get_gemini_client(response) 

428 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

429 agent = Agent(m) 

430 

431 result = await agent.run('Hello') 

432 assert result.data == 'Hello world' 

433 assert result.all_messages() == snapshot( 

434 [ 

435 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

436 ModelResponse( 

437 parts=[TextPart(content='Hello world')], 

438 model_name='gemini-1.5-flash-123', 

439 timestamp=IsNow(tz=timezone.utc), 

440 ), 

441 ] 

442 ) 

443 assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) 

444 

445 result = await agent.run('Hello', message_history=result.new_messages()) 

446 assert result.data == 'Hello world' 

447 assert result.all_messages() == snapshot( 

448 [ 

449 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

450 ModelResponse( 

451 parts=[TextPart(content='Hello world')], 

452 model_name='gemini-1.5-flash-123', 

453 timestamp=IsNow(tz=timezone.utc), 

454 ), 

455 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

456 ModelResponse( 

457 parts=[TextPart(content='Hello world')], 

458 model_name='gemini-1.5-flash-123', 

459 timestamp=IsNow(tz=timezone.utc), 

460 ), 

461 ] 

462 ) 

463 

464 

465async def test_request_structured_response(get_gemini_client: GetGeminiClient): 

466 response = gemini_response( 

467 _content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]})])) 

468 ) 

469 gemini_client = get_gemini_client(response) 

470 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

471 agent = Agent(m, result_type=list[int]) 

472 

473 result = await agent.run('Hello') 

474 assert result.data == [1, 2, 123] 

475 assert result.all_messages() == snapshot( 

476 [ 

477 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

478 ModelResponse( 

479 parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2, 123]}, tool_call_id=IsStr())], 

480 model_name='gemini-1.5-flash-123', 

481 timestamp=IsNow(tz=timezone.utc), 

482 ), 

483 ModelRequest( 

484 parts=[ 

485 ToolReturnPart( 

486 tool_name='final_result', 

487 content='Final result processed.', 

488 timestamp=IsNow(tz=timezone.utc), 

489 tool_call_id=IsStr(), 

490 ) 

491 ] 

492 ), 

493 ] 

494 ) 

495 

496 

497async def test_request_tool_call(get_gemini_client: GetGeminiClient): 

498 responses = [ 

499 gemini_response( 

500 _content_model_response(ModelResponse(parts=[ToolCallPart('get_location', {'loc_name': 'San Fransisco'})])) 

501 ), 

502 gemini_response( 

503 _content_model_response( 

504 ModelResponse( 

505 parts=[ 

506 ToolCallPart('get_location', {'loc_name': 'London'}), 

507 ToolCallPart('get_location', {'loc_name': 'New York'}), 

508 ] 

509 ) 

510 ) 

511 ), 

512 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('final response')]))), 

513 ] 

514 gemini_client = get_gemini_client(responses) 

515 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

516 agent = Agent(m, system_prompt='this is the system prompt') 

517 

518 @agent.tool_plain 

519 async def get_location(loc_name: str) -> str: 

520 if loc_name == 'London': 

521 return json.dumps({'lat': 51, 'lng': 0}) 

522 elif loc_name == 'New York': 

523 return json.dumps({'lat': 41, 'lng': -74}) 

524 else: 

525 raise ModelRetry('Wrong location, please try again') 

526 

527 result = await agent.run('Hello') 

528 assert result.data == 'final response' 

529 assert result.all_messages() == snapshot( 

530 [ 

531 ModelRequest( 

532 parts=[ 

533 SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)), 

534 UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), 

535 ] 

536 ), 

537 ModelResponse( 

538 parts=[ 

539 ToolCallPart(tool_name='get_location', args={'loc_name': 'San Fransisco'}, tool_call_id=IsStr()) 

540 ], 

541 model_name='gemini-1.5-flash-123', 

542 timestamp=IsNow(tz=timezone.utc), 

543 ), 

544 ModelRequest( 

545 parts=[ 

546 RetryPromptPart( 

547 content='Wrong location, please try again', 

548 tool_name='get_location', 

549 tool_call_id=IsStr(), 

550 timestamp=IsNow(tz=timezone.utc), 

551 ) 

552 ] 

553 ), 

554 ModelResponse( 

555 parts=[ 

556 ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), 

557 ToolCallPart(tool_name='get_location', args={'loc_name': 'New York'}, tool_call_id=IsStr()), 

558 ], 

559 model_name='gemini-1.5-flash-123', 

560 timestamp=IsNow(tz=timezone.utc), 

561 ), 

562 ModelRequest( 

563 parts=[ 

564 ToolReturnPart( 

565 tool_name='get_location', 

566 content='{"lat": 51, "lng": 0}', 

567 timestamp=IsNow(tz=timezone.utc), 

568 tool_call_id=IsStr(), 

569 ), 

570 ToolReturnPart( 

571 tool_name='get_location', 

572 content='{"lat": 41, "lng": -74}', 

573 timestamp=IsNow(tz=timezone.utc), 

574 tool_call_id=IsStr(), 

575 ), 

576 ] 

577 ), 

578 ModelResponse( 

579 parts=[TextPart(content='final response')], 

580 model_name='gemini-1.5-flash-123', 

581 timestamp=IsNow(tz=timezone.utc), 

582 ), 

583 ] 

584 ) 

585 assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9)) 

586 

587 

588async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): 

589 env.set('GEMINI_API_KEY', 'via-env-var') 

590 

591 def handler(_: httpx.Request): 

592 return httpx.Response(401, content='invalid request') 

593 

594 gemini_client = client_with_handler(handler) 

595 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

596 agent = Agent(m, system_prompt='this is the system prompt') 

597 

598 with pytest.raises(ModelHTTPError) as exc_info: 

599 await agent.run('Hello') 

600 

601 assert str(exc_info.value) == snapshot('status_code: 401, model_name: gemini-1.5-flash, body: invalid request') 

602 

603 

604async def test_stream_text(get_gemini_client: GetGeminiClient): 

605 responses = [ 

606 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))), 

607 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))), 

608 ] 

609 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True) 

610 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]]) 

611 gemini_client = get_gemini_client(stream) 

612 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

613 agent = Agent(m) 

614 

615 async with agent.run_stream('Hello') as result: 

616 chunks = [chunk async for chunk in result.stream(debounce_by=None)] 

617 assert chunks == snapshot( 

618 [ 

619 'Hello ', 

620 'Hello world', 

621 # This last value is repeated due to the debounce_by=None combined with the need to emit 

622 # a final empty chunk to signal the end of the stream 

623 'Hello world', 

624 ] 

625 ) 

626 assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) 

627 

628 async with agent.run_stream('Hello') as result: 

629 chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] 

630 assert chunks == snapshot(['Hello ', 'world']) 

631 assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) 

632 

633 

634async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): 

635 # Probably safe to remove this test once https://github.com/pydantic/pydantic-core/issues/1633 is resolved 

636 responses = [ 

637 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('abc')]))), 

638 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('€def')]))), 

639 ] 

640 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True) 

641 

642 for i in range(10, 1000): 

643 parts = [json_data[:i], json_data[i:]] 

644 try: 

645 parts[0].decode() 

646 except UnicodeDecodeError: 

647 break 

648 else: # pragma: no cover 

649 assert False, 'failed to find a spot in payload that would break unicode parsing' 

650 

651 with pytest.raises(UnicodeDecodeError): 

652 # Ensure the first part is _not_ valid unicode 

653 parts[0].decode() 

654 

655 stream = AsyncByteStreamList(parts) 

656 gemini_client = get_gemini_client(stream) 

657 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

658 agent = Agent(m) 

659 

660 async with agent.run_stream('Hello') as result: 

661 chunks = [chunk async for chunk in result.stream(debounce_by=None)] 

662 assert chunks == snapshot(['abc', 'abc€def', 'abc€def']) 

663 assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) 

664 

665 

666async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): 

667 responses = [_GeminiResponse(candidates=[], usage_metadata=example_usage())] 

668 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True) 

669 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]]) 

670 gemini_client = get_gemini_client(stream) 

671 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

672 agent = Agent(m) 

673 with pytest.raises(UnexpectedModelBehavior, match='Streamed response ended without con'): 

674 async with agent.run_stream('Hello'): 

675 pass 

676 

677 

678async def test_stream_structured(get_gemini_client: GetGeminiClient): 

679 responses = [ 

680 gemini_response( 

681 _content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2]})])), 

682 ), 

683 ] 

684 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True) 

685 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]]) 

686 gemini_client = get_gemini_client(stream) 

687 model = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

688 agent = Agent(model, result_type=tuple[int, int]) 

689 

690 async with agent.run_stream('Hello') as result: 

691 chunks = [chunk async for chunk in result.stream(debounce_by=None)] 

692 assert chunks == snapshot([(1, 2), (1, 2)]) 

693 assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) 

694 

695 

696async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): 

697 first_responses = [ 

698 gemini_response( 

699 _content_model_response(ModelResponse(parts=[ToolCallPart('foo', {'x': 'a'})])), 

700 ), 

701 gemini_response( 

702 _content_model_response(ModelResponse(parts=[ToolCallPart('bar', {'y': 'b'})])), 

703 ), 

704 ] 

705 d1 = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True) 

706 first_stream = AsyncByteStreamList([d1[:100], d1[100:200], d1[200:300], d1[300:]]) 

707 

708 second_responses = [ 

709 gemini_response( 

710 _content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2]})])), 

711 ), 

712 ] 

713 d2 = _gemini_streamed_response_ta.dump_json(second_responses, by_alias=True) 

714 second_stream = AsyncByteStreamList([d2[:100], d2[100:]]) 

715 

716 gemini_client = get_gemini_client([first_stream, second_stream]) 

717 model = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

718 agent = Agent(model, result_type=tuple[int, int]) 

719 tool_calls: list[str] = [] 

720 

721 @agent.tool_plain 

722 async def foo(x: str) -> str: 

723 tool_calls.append(f'foo({x=!r})') 

724 return x 

725 

726 @agent.tool_plain 

727 async def bar(y: str) -> str: 

728 tool_calls.append(f'bar({y=!r})') 

729 return y 

730 

731 async with agent.run_stream('Hello') as result: 

732 response = await result.get_data() 

733 assert response == snapshot((1, 2)) 

734 assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9)) 

735 assert result.all_messages() == snapshot( 

736 [ 

737 ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), 

738 ModelResponse( 

739 parts=[ 

740 ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()), 

741 ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()), 

742 ], 

743 model_name='gemini-1.5-flash', 

744 timestamp=IsNow(tz=timezone.utc), 

745 ), 

746 ModelRequest( 

747 parts=[ 

748 ToolReturnPart( 

749 tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() 

750 ), 

751 ToolReturnPart( 

752 tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() 

753 ), 

754 ] 

755 ), 

756 ModelResponse( 

757 parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())], 

758 model_name='gemini-1.5-flash', 

759 timestamp=IsNow(tz=timezone.utc), 

760 ), 

761 ModelRequest( 

762 parts=[ 

763 ToolReturnPart( 

764 tool_name='final_result', 

765 content='Final result processed.', 

766 timestamp=IsNow(tz=timezone.utc), 

767 tool_call_id=IsStr(), 

768 ) 

769 ] 

770 ), 

771 ] 

772 ) 

773 assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"]) 

774 

775 

776async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): 

777 responses = [ 

778 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello ')]))), 

779 gemini_response( 

780 _GeminiContent( 

781 role='model', 

782 parts=[ 

783 {'text': 'foo'}, 

784 {'function_call': {'name': 'get_location', 'args': {'loc_name': 'San Fransisco'}}}, 

785 ], 

786 ) 

787 ), 

788 ] 

789 json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True) 

790 stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]]) 

791 gemini_client = get_gemini_client(stream) 

792 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client)) 

793 agent = Agent(m) 

794 

795 @agent.tool_plain() 

796 def get_location(loc_name: str) -> str: 

797 return f'Location for {loc_name}' 

798 

799 async with agent.run_stream('Hello') as result: 

800 data = await result.get_data() 

801 

802 assert data == 'Hello foo' 

803 

804 

805async def test_empty_text_ignored(): 

806 content = _content_model_response( 

807 ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]}), TextPart(content='xxx')]) 

808 ) 

809 # text included 

810 assert content == snapshot( 

811 { 

812 'role': 'model', 

813 'parts': [ 

814 {'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}}, 

815 {'text': 'xxx'}, 

816 ], 

817 } 

818 ) 

819 

820 content = _content_model_response( 

821 ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]}), TextPart(content='')]) 

822 ) 

823 # text skipped 

824 assert content == snapshot( 

825 { 

826 'role': 'model', 

827 'parts': [{'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}}], 

828 } 

829 ) 

830 

831 

832async def test_model_settings(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None) -> None: 

833 def handler(request: httpx.Request) -> httpx.Response: 

834 generation_config = json.loads(request.content)['generation_config'] 

835 assert generation_config == { 

836 'max_output_tokens': 1, 

837 'temperature': 0.1, 

838 'top_p': 0.2, 

839 'presence_penalty': 0.3, 

840 'frequency_penalty': 0.4, 

841 } 

842 return httpx.Response( 

843 200, 

844 content=_gemini_response_ta.dump_json( 

845 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))), 

846 by_alias=True, 

847 ), 

848 headers={'Content-Type': 'application/json'}, 

849 ) 

850 

851 gemini_client = client_with_handler(handler) 

852 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock')) 

853 agent = Agent(m) 

854 

855 result = await agent.run( 

856 'hello', 

857 model_settings={ 

858 'max_tokens': 1, 

859 'temperature': 0.1, 

860 'top_p': 0.2, 

861 'presence_penalty': 0.3, 

862 'frequency_penalty': 0.4, 

863 }, 

864 ) 

865 assert result.data == 'world' 

866 

867 

868def gemini_no_content_response( 

869 safety_ratings: list[_GeminiSafetyRating], finish_reason: Literal['SAFETY'] | None = 'SAFETY' 

870) -> _GeminiResponse: 

871 candidate = _GeminiCandidates(safety_ratings=safety_ratings) 

872 if finish_reason: 872 ↛ 874line 872 didn't jump to line 874 because the condition on line 872 was always true

873 candidate['finish_reason'] = finish_reason 

874 return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage()) 

875 

876 

877async def test_safety_settings_unsafe( 

878 client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None 

879) -> None: 

880 try: 

881 

882 def handler(request: httpx.Request) -> httpx.Response: 

883 safety_settings = json.loads(request.content)['safety_settings'] 

884 assert safety_settings == [ 

885 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

886 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

887 ] 

888 

889 return httpx.Response( 

890 200, 

891 content=_gemini_response_ta.dump_json( 

892 gemini_no_content_response( 

893 finish_reason='SAFETY', 

894 safety_ratings=[ 

895 {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'MEDIUM', 'blocked': True} 

896 ], 

897 ), 

898 by_alias=True, 

899 ), 

900 headers={'Content-Type': 'application/json'}, 

901 ) 

902 

903 gemini_client = client_with_handler(handler) 

904 

905 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock')) 

906 agent = Agent(m) 

907 

908 await agent.run( 

909 'a request for something rude', 

910 model_settings=GeminiModelSettings( 

911 gemini_safety_settings=[ 

912 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

913 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

914 ] 

915 ), 

916 ) 

917 except UnexpectedModelBehavior as e: 

918 assert repr(e) == "UnexpectedModelBehavior('Safety settings triggered')" 

919 

920 

921async def test_safety_settings_safe( 

922 client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None 

923) -> None: 

924 def handler(request: httpx.Request) -> httpx.Response: 

925 safety_settings = json.loads(request.content)['safety_settings'] 

926 assert safety_settings == [ 

927 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

928 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

929 ] 

930 

931 return httpx.Response( 

932 200, 

933 content=_gemini_response_ta.dump_json( 

934 gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))), 

935 by_alias=True, 

936 ), 

937 headers={'Content-Type': 'application/json'}, 

938 ) 

939 

940 gemini_client = client_with_handler(handler) 

941 m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock')) 

942 agent = Agent(m) 

943 

944 result = await agent.run( 

945 'hello', 

946 model_settings=GeminiModelSettings( 

947 gemini_safety_settings=[ 

948 {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

949 {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'}, 

950 ] 

951 ), 

952 ) 

953 assert result.data == 'world' 

954 

955 

956@pytest.mark.vcr() 

957async def test_image_as_binary_content_input( 

958 allow_model_requests: None, gemini_api_key: str, image_content: BinaryContent 

959) -> None: 

960 m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) 

961 agent = Agent(m) 

962 

963 result = await agent.run(['What is the name of this fruit?', image_content]) 

964 assert result.data == snapshot('The fruit in the image is a kiwi.') 

965 

966 

967@pytest.mark.vcr() 

968async def test_image_url_input(allow_model_requests: None, gemini_api_key: str) -> None: 

969 m = GeminiModel('gemini-2.0-flash-exp', provider=GoogleGLAProvider(api_key=gemini_api_key)) 

970 agent = Agent(m) 

971 

972 image_url = ImageUrl(url='https://goo.gle/instrument-img') 

973 

974 result = await agent.run(['What is the name of this fruit?', image_url]) 

975 assert result.data == snapshot("This is not a fruit; it's a pipe organ console.") 

976 

977 

978@pytest.mark.vcr() 

979async def test_document_url_input(allow_model_requests: None, gemini_api_key: str) -> None: 

980 m = GeminiModel('gemini-2.0-flash-thinking-exp-01-21', provider=GoogleGLAProvider(api_key=gemini_api_key)) 

981 agent = Agent(m) 

982 

983 document_url = DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf') 

984 

985 result = await agent.run(['What is the main content on this document?', document_url]) 

986 assert result.data == snapshot('The main content of this document is that it is a **dummy PDF file**.')