Coverage for pydantic_ai_slim/pydantic_ai/providers/google_vertex.py: 96.04%
85 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import functools
4from collections.abc import AsyncGenerator, Mapping
5from pathlib import Path
6from typing import Literal, overload
8import anyio.to_thread
9import httpx
11from pydantic_ai.exceptions import UserError
12from pydantic_ai.models import cached_async_http_client
13from pydantic_ai.providers import Provider
15try:
16 import google.auth
17 from google.auth.credentials import Credentials as BaseCredentials
18 from google.auth.transport.requests import Request
19 from google.oauth2.service_account import Credentials as ServiceAccountCredentials
20except ImportError as _import_error:
21 raise ImportError(
22 'Please install the `google-auth` package to use the Google Vertex AI provider, '
23 'you can use the `vertexai` optional group — `pip install "pydantic-ai-slim[vertexai]"`'
24 ) from _import_error
27__all__ = ('GoogleVertexProvider',)
30class GoogleVertexProvider(Provider[httpx.AsyncClient]):
31 """Provider for Vertex AI API."""
33 @property
34 def name(self) -> str:
35 return 'google-vertex'
37 @property
38 def base_url(self) -> str:
39 return (
40 f'https://{self.region}-aiplatform.googleapis.com/v1'
41 f'/projects/{self.project_id}'
42 f'/locations/{self.region}'
43 f'/publishers/{self.model_publisher}/models/'
44 )
46 @property
47 def client(self) -> httpx.AsyncClient:
48 return self._client
50 @overload
51 def __init__(
52 self,
53 *,
54 service_account_file: Path | str | None = None,
55 project_id: str | None = None,
56 region: VertexAiRegion = 'us-central1',
57 model_publisher: str = 'google',
58 http_client: httpx.AsyncClient | None = None,
59 ) -> None: ...
61 @overload
62 def __init__(
63 self,
64 *,
65 service_account_info: Mapping[str, str] | None = None,
66 project_id: str | None = None,
67 region: VertexAiRegion = 'us-central1',
68 model_publisher: str = 'google',
69 http_client: httpx.AsyncClient | None = None,
70 ) -> None: ...
72 def __init__(
73 self,
74 *,
75 service_account_file: Path | str | None = None,
76 service_account_info: Mapping[str, str] | None = None,
77 project_id: str | None = None,
78 region: VertexAiRegion = 'us-central1',
79 model_publisher: str = 'google',
80 http_client: httpx.AsyncClient | None = None,
81 ) -> None:
82 """Create a new Vertex AI provider.
84 Args:
85 service_account_file: Path to a service account file.
86 If not provided, the service_account_info or default environment credentials will be used.
87 service_account_info: The loaded service_account_file contents.
88 If not provided, the service_account_file or default environment credentials will be used.
89 project_id: The project ID to use, if not provided it will be taken from the credentials.
90 region: The region to make requests to.
91 model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
92 and from trial and error it seems non-google models don't work with the `generateContent` and
93 `streamGenerateContent` functions, hence only `google` is currently supported.
94 Please create an issue or PR if you know how to use other publishers.
95 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
96 """
97 if service_account_file and service_account_info:
98 raise ValueError('Only one of `service_account_file` or `service_account_info` can be provided.')
100 self._client = http_client or cached_async_http_client(provider='google-vertex')
101 self.service_account_file = service_account_file
102 self.service_account_info = service_account_info
103 self.project_id = project_id
104 self.region = region
105 self.model_publisher = model_publisher
107 self._client.auth = _VertexAIAuth(service_account_file, service_account_info, project_id, region)
108 self._client.base_url = self.base_url
111class _VertexAIAuth(httpx.Auth):
112 """Auth class for Vertex AI API."""
114 credentials: BaseCredentials | ServiceAccountCredentials | None
116 def __init__(
117 self,
118 service_account_file: Path | str | None = None,
119 service_account_info: Mapping[str, str] | None = None,
120 project_id: str | None = None,
121 region: VertexAiRegion = 'us-central1',
122 ) -> None:
123 self.service_account_file = service_account_file
124 self.service_account_info = service_account_info
125 self.project_id = project_id
126 self.region = region
128 self.credentials = None
130 async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
131 if self.credentials is None: 131 ↛ 133line 131 didn't jump to line 133 because the condition on line 131 was always true
132 self.credentials = await self._get_credentials()
133 if self.credentials.token is None: # type: ignore[reportUnknownMemberType]
134 await self._refresh_token()
135 request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
136 # NOTE: This workaround is in place because we might get the project_id from the credentials.
137 request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
138 response = yield request
140 if response.status_code == 401:
141 await self._refresh_token()
142 request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
143 yield request
145 async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
146 if self.service_account_file is not None:
147 creds = await _creds_from_file(self.service_account_file)
148 assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType]
149 creds_project_id: str | None = creds.project_id
150 creds_source = 'service account file'
151 elif self.service_account_info is not None:
152 creds = await _creds_from_info(self.service_account_info)
153 assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType]
154 creds_project_id: str | None = creds.project_id
155 creds_source = 'service account info'
156 else:
157 creds, creds_project_id = await _async_google_auth()
158 creds_source = '`google.auth.default()`'
160 if self.project_id is None: 160 ↛ 164line 160 didn't jump to line 164 because the condition on line 160 was always true
161 if creds_project_id is None: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 raise UserError(f'No project_id provided and none found in {creds_source}')
163 self.project_id = creds_project_id
164 return creds
166 async def _refresh_token(self) -> str: # pragma: no cover
167 assert self.credentials is not None
168 await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
169 assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
170 return self.credentials.token
173async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
174 return await anyio.to_thread.run_sync(google.auth.default, ['https://www.googleapis.com/auth/cloud-platform']) # type: ignore
177async def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
178 service_account_credentials_from_file = functools.partial(
179 ServiceAccountCredentials.from_service_account_file, # type: ignore[reportUnknownMemberType]
180 scopes=['https://www.googleapis.com/auth/cloud-platform'],
181 )
182 return await anyio.to_thread.run_sync(service_account_credentials_from_file, str(service_account_file))
185async def _creds_from_info(service_account_info: Mapping[str, str]) -> ServiceAccountCredentials:
186 service_account_credentials_from_string = functools.partial(
187 ServiceAccountCredentials.from_service_account_info, # type: ignore[reportUnknownMemberType]
188 scopes=['https://www.googleapis.com/auth/cloud-platform'],
189 )
190 return await anyio.to_thread.run_sync(service_account_credentials_from_string, service_account_info)
193VertexAiRegion = Literal[
194 'asia-east1',
195 'asia-east2',
196 'asia-northeast1',
197 'asia-northeast3',
198 'asia-south1',
199 'asia-southeast1',
200 'australia-southeast1',
201 'europe-central2',
202 'europe-north1',
203 'europe-southwest1',
204 'europe-west1',
205 'europe-west2',
206 'europe-west3',
207 'europe-west4',
208 'europe-west6',
209 'europe-west8',
210 'europe-west9',
211 'me-central1',
212 'me-central2',
213 'me-west1',
214 'northamerica-northeast1',
215 'southamerica-east1',
216 'us-central1',
217 'us-east1',
218 'us-east4',
219 'us-east5',
220 'us-south1',
221 'us-west1',
222 'us-west4',
223]
224"""Regions available for Vertex AI.
226More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
227"""