Coverage for tests/providers/test_google_vertex.py: 98.36%
59 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import json
4from dataclasses import dataclass
5from pathlib import Path
6from unittest.mock import patch
8import httpx
9import pytest
10from inline_snapshot import snapshot
12from ..conftest import try_import
14with try_import() as imports_successful:
15 from google.auth.transport.requests import Request
17 from pydantic_ai.providers.google_vertex import GoogleVertexProvider
19pytestmark = [
20 pytest.mark.skipif(not imports_successful(), reason='google-genai not installed'),
21 pytest.mark.anyio(),
22]
25@pytest.fixture()
26def http_client():
27 async def handler(request: httpx.Request):
28 if (
29 request.url.path
30 == '/v1/projects/my-project-id/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent'
31 ):
32 return httpx.Response(200, json={'content': 'success'})
33 raise NotImplementedError(f'Unexpected request: {request.url!r}') # pragma: no cover
35 return httpx.AsyncClient(transport=httpx.MockTransport(handler=handler))
38def test_google_vertex_provider(allow_model_requests: None) -> None:
39 provider = GoogleVertexProvider()
40 assert provider.name == 'google-vertex'
41 assert provider.base_url == snapshot(
42 'https://us-central1-aiplatform.googleapis.com/v1/projects/None/locations/us-central1/publishers/google/models/'
43 )
44 assert isinstance(provider.client, httpx.AsyncClient)
47@dataclass
48class NoOpCredentials:
49 token = 'my-token'
51 def refresh(self, request: Request): ... 51 ↛ exitline 51 didn't return from function 'refresh' because
54@patch('pydantic_ai.providers.google_vertex.google.auth.default', return_value=(NoOpCredentials(), 'my-project-id'))
55async def test_google_vertex_provider_auth(allow_model_requests: None, http_client: httpx.AsyncClient):
56 provider = GoogleVertexProvider(http_client=http_client)
57 await provider.client.post('/gemini-1.0-pro:generateContent')
58 assert provider.region == 'us-central1'
59 assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
62async def mock_refresh_token():
63 return 'my-token'
66async def test_google_vertex_provider_service_account_file(
67 monkeypatch: pytest.MonkeyPatch, tmp_path: Path, allow_model_requests: None
68):
69 service_account_path = tmp_path / 'service_account.json'
70 save_service_account(service_account_path, 'my-project-id')
72 provider = GoogleVertexProvider(service_account_file=service_account_path)
73 monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token)
74 await provider.client.post('/gemini-1.0-pro:generateContent')
75 assert provider.region == 'us-central1'
76 assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
79async def test_google_vertex_provider_service_account_file_info(
80 monkeypatch: pytest.MonkeyPatch, allow_model_requests: None
81):
82 account_info = prepare_service_account_contents('my-project-id')
84 provider = GoogleVertexProvider(service_account_info=account_info)
85 monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token)
86 await provider.client.post('/gemini-1.0-pro:generateContent')
87 assert provider.region == 'us-central1'
88 assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
91async def test_google_vertex_provider_service_account_xor(allow_model_requests: None):
92 with pytest.raises(
93 ValueError, match='Only one of `service_account_file` or `service_account_info` can be provided'
94 ):
95 GoogleVertexProvider( # type: ignore[reportCallIssue]
96 service_account_file='path/to/service-account.json',
97 service_account_info=prepare_service_account_contents('my-project-id'),
98 )
101def prepare_service_account_contents(project_id: str) -> dict[str, str]:
102 return {
103 'type': 'service_account',
104 'project_id': project_id,
105 'private_key_id': 'abc',
106 # this is just a random private key I created with `openssl genpke ...`, it doesn't do anything
107 'private_key': (
108 '-----BEGIN PRIVATE KEY-----\n'
109 'MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMFrZYX4gZ20qv88\n'
110 'jD0QCswXgcxgP7Ta06G47QEFprDVcv4WMUBDJVAKofzVcYyhsasWsOSxcpA8LIi9\n'
111 '/VS2Otf8CmIK6nPBCD17Qgt8/IQYXOS4U2EBh0yjo0HQ4vFpkqium4lLWxrAZohA\n'
112 '8r82clV08iLRUW3J+xvN23iPHyVDAgMBAAECgYBScRJe3iNxMvbHv+kOhe30O/jJ\n'
113 'QiUlUzhtcEMk8mGwceqHvrHTcEtRKJcPC3NQvALcp9lSQQhRzjQ1PLXkC6BcfKFd\n'
114 '03q5tVPmJiqsHbSyUyHWzdlHP42xWpl/RmX/DfRKGhPOvufZpSTzkmKWtN+7osHu\n'
115 '7eiMpg2EDswCvOgf0QJBAPXLYwHbZLaM2KEMDgJSse5ZTE/0VMf+5vSTGUmHkr9c\n'
116 'Wx2G1i258kc/JgsXInPbq4BnK9hd0Xj2T5cmEmQtm4UCQQDJc02DFnPnjPnnDUwg\n'
117 'BPhrCyW+rnBGUVjehveu4XgbGx7l3wsbORTaKdCX3HIKUupgfFwFcDlMUzUy6fPO\n'
118 'IuQnAkA8FhVE/fIX4kSO0hiWnsqafr/2B7+2CG1DOraC0B6ioxwvEqhHE17T5e8R\n'
119 '5PzqH7hEMnR4dy7fCC+avpbeYHvVAkA5W58iR+5Qa49r/hlCtKeWsuHYXQqSuu62\n'
120 'zW8QWBo+fYZapRsgcSxCwc0msBm4XstlFYON+NoXpUlsabiFZOHZAkEA8Ffq3xoU\n'
121 'y0eYGy3MEzxx96F+tkl59lfkwHKWchWZJ95vAKWJaHx9WFxSWiJofbRna8Iim6pY\n'
122 'BootYWyTCfjjwA==\n'
123 '-----END PRIVATE KEY-----\n'
124 ),
125 'client_email': 'testing-pydantic-ai@pydantic-ai.iam.gserviceaccount.com',
126 'client_id': '123',
127 'auth_uri': 'https://accounts.google.com/o/oauth2/auth',
128 'token_uri': 'https://oauth2.googleapis.com/token',
129 'auth_provider_x509_cert_url': 'https://www.googleapis.com/oauth2/v1/certs',
130 'client_x509_cert_url': 'https://www.googleapis.com/...',
131 'universe_domain': 'googleapis.com',
132 }
135def save_service_account(service_account_path: Path, project_id: str) -> None:
136 service_account = prepare_service_account_contents(project_id)
138 service_account_path.write_text(json.dumps(service_account, indent=2))