Coverage for pydantic_ai_slim/pydantic_ai/providers/google_vertex.py: 96.04%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import functools 

4from collections.abc import AsyncGenerator, Mapping 

5from pathlib import Path 

6from typing import Literal, overload 

7 

8import anyio.to_thread 

9import httpx 

10 

11from pydantic_ai.exceptions import UserError 

12from pydantic_ai.models import cached_async_http_client 

13from pydantic_ai.providers import Provider 

14 

15try: 

16 import google.auth 

17 from google.auth.credentials import Credentials as BaseCredentials 

18 from google.auth.transport.requests import Request 

19 from google.oauth2.service_account import Credentials as ServiceAccountCredentials 

20except ImportError as _import_error: 

21 raise ImportError( 

22 'Please install the `google-auth` package to use the Google Vertex AI provider, ' 

23 'you can use the `vertexai` optional group — `pip install "pydantic-ai-slim[vertexai]"`' 

24 ) from _import_error 

25 

26 

27__all__ = ('GoogleVertexProvider',) 

28 

29 

30class GoogleVertexProvider(Provider[httpx.AsyncClient]): 

31 """Provider for Vertex AI API.""" 

32 

33 @property 

34 def name(self) -> str: 

35 return 'google-vertex' 

36 

37 @property 

38 def base_url(self) -> str: 

39 return ( 

40 f'https://{self.region}-aiplatform.googleapis.com/v1' 

41 f'/projects/{self.project_id}' 

42 f'/locations/{self.region}' 

43 f'/publishers/{self.model_publisher}/models/' 

44 ) 

45 

46 @property 

47 def client(self) -> httpx.AsyncClient: 

48 return self._client 

49 

50 @overload 

51 def __init__( 

52 self, 

53 *, 

54 service_account_file: Path | str | None = None, 

55 project_id: str | None = None, 

56 region: VertexAiRegion = 'us-central1', 

57 model_publisher: str = 'google', 

58 http_client: httpx.AsyncClient | None = None, 

59 ) -> None: ... 

60 

61 @overload 

62 def __init__( 

63 self, 

64 *, 

65 service_account_info: Mapping[str, str] | None = None, 

66 project_id: str | None = None, 

67 region: VertexAiRegion = 'us-central1', 

68 model_publisher: str = 'google', 

69 http_client: httpx.AsyncClient | None = None, 

70 ) -> None: ... 

71 

72 def __init__( 

73 self, 

74 *, 

75 service_account_file: Path | str | None = None, 

76 service_account_info: Mapping[str, str] | None = None, 

77 project_id: str | None = None, 

78 region: VertexAiRegion = 'us-central1', 

79 model_publisher: str = 'google', 

80 http_client: httpx.AsyncClient | None = None, 

81 ) -> None: 

82 """Create a new Vertex AI provider. 

83 

84 Args: 

85 service_account_file: Path to a service account file. 

86 If not provided, the service_account_info or default environment credentials will be used. 

87 service_account_info: The loaded service_account_file contents. 

88 If not provided, the service_account_file or default environment credentials will be used. 

89 project_id: The project ID to use, if not provided it will be taken from the credentials. 

90 region: The region to make requests to. 

91 model_publisher: The model publisher to use, I couldn't find a good list of available publishers, 

92 and from trial and error it seems non-google models don't work with the `generateContent` and 

93 `streamGenerateContent` functions, hence only `google` is currently supported. 

94 Please create an issue or PR if you know how to use other publishers. 

95 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

96 """ 

97 if service_account_file and service_account_info: 

98 raise ValueError('Only one of `service_account_file` or `service_account_info` can be provided.') 

99 

100 self._client = http_client or cached_async_http_client(provider='google-vertex') 

101 self.service_account_file = service_account_file 

102 self.service_account_info = service_account_info 

103 self.project_id = project_id 

104 self.region = region 

105 self.model_publisher = model_publisher 

106 

107 self._client.auth = _VertexAIAuth(service_account_file, service_account_info, project_id, region) 

108 self._client.base_url = self.base_url 

109 

110 

111class _VertexAIAuth(httpx.Auth): 

112 """Auth class for Vertex AI API.""" 

113 

114 credentials: BaseCredentials | ServiceAccountCredentials | None 

115 

116 def __init__( 

117 self, 

118 service_account_file: Path | str | None = None, 

119 service_account_info: Mapping[str, str] | None = None, 

120 project_id: str | None = None, 

121 region: VertexAiRegion = 'us-central1', 

122 ) -> None: 

123 self.service_account_file = service_account_file 

124 self.service_account_info = service_account_info 

125 self.project_id = project_id 

126 self.region = region 

127 

128 self.credentials = None 

129 

130 async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: 

131 if self.credentials is None: 131 ↛ 133line 131 didn't jump to line 133 because the condition on line 131 was always true

132 self.credentials = await self._get_credentials() 

133 if self.credentials.token is None: # type: ignore[reportUnknownMemberType] 

134 await self._refresh_token() 

135 request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType] 

136 # NOTE: This workaround is in place because we might get the project_id from the credentials. 

137 request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}')) 

138 response = yield request 

139 

140 if response.status_code == 401: 

141 await self._refresh_token() 

142 request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType] 

143 yield request 

144 

145 async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials: 

146 if self.service_account_file is not None: 

147 creds = await _creds_from_file(self.service_account_file) 

148 assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType] 

149 creds_project_id: str | None = creds.project_id 

150 creds_source = 'service account file' 

151 elif self.service_account_info is not None: 

152 creds = await _creds_from_info(self.service_account_info) 

153 assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType] 

154 creds_project_id: str | None = creds.project_id 

155 creds_source = 'service account info' 

156 else: 

157 creds, creds_project_id = await _async_google_auth() 

158 creds_source = '`google.auth.default()`' 

159 

160 if self.project_id is None: 160 ↛ 164line 160 didn't jump to line 164 because the condition on line 160 was always true

161 if creds_project_id is None: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 raise UserError(f'No project_id provided and none found in {creds_source}') 

163 self.project_id = creds_project_id 

164 return creds 

165 

166 async def _refresh_token(self) -> str: # pragma: no cover 

167 assert self.credentials is not None 

168 await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType] 

169 assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType] 

170 return self.credentials.token 

171 

172 

173async def _async_google_auth() -> tuple[BaseCredentials, str | None]: 

174 return await anyio.to_thread.run_sync(google.auth.default, ['https://www.googleapis.com/auth/cloud-platform']) # type: ignore 

175 

176 

177async def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials: 

178 service_account_credentials_from_file = functools.partial( 

179 ServiceAccountCredentials.from_service_account_file, # type: ignore[reportUnknownMemberType] 

180 scopes=['https://www.googleapis.com/auth/cloud-platform'], 

181 ) 

182 return await anyio.to_thread.run_sync(service_account_credentials_from_file, str(service_account_file)) 

183 

184 

185async def _creds_from_info(service_account_info: Mapping[str, str]) -> ServiceAccountCredentials: 

186 service_account_credentials_from_string = functools.partial( 

187 ServiceAccountCredentials.from_service_account_info, # type: ignore[reportUnknownMemberType] 

188 scopes=['https://www.googleapis.com/auth/cloud-platform'], 

189 ) 

190 return await anyio.to_thread.run_sync(service_account_credentials_from_string, service_account_info) 

191 

192 

193VertexAiRegion = Literal[ 

194 'asia-east1', 

195 'asia-east2', 

196 'asia-northeast1', 

197 'asia-northeast3', 

198 'asia-south1', 

199 'asia-southeast1', 

200 'australia-southeast1', 

201 'europe-central2', 

202 'europe-north1', 

203 'europe-southwest1', 

204 'europe-west1', 

205 'europe-west2', 

206 'europe-west3', 

207 'europe-west4', 

208 'europe-west6', 

209 'europe-west8', 

210 'europe-west9', 

211 'me-central1', 

212 'me-central2', 

213 'me-west1', 

214 'northamerica-northeast1', 

215 'southamerica-east1', 

216 'us-central1', 

217 'us-east1', 

218 'us-east4', 

219 'us-east5', 

220 'us-south1', 

221 'us-west1', 

222 'us-west4', 

223] 

224"""Regions available for Vertex AI. 

225 

226More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations). 

227"""