Coverage for tests/providers/test_google_vertex.py: 98.36%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import json 

4from dataclasses import dataclass 

5from pathlib import Path 

6from unittest.mock import patch 

7 

8import httpx 

9import pytest 

10from inline_snapshot import snapshot 

11 

12from ..conftest import try_import 

13 

14with try_import() as imports_successful: 

15 from google.auth.transport.requests import Request 

16 

17 from pydantic_ai.providers.google_vertex import GoogleVertexProvider 

18 

19pytestmark = [ 

20 pytest.mark.skipif(not imports_successful(), reason='google-genai not installed'), 

21 pytest.mark.anyio(), 

22] 

23 

24 

25@pytest.fixture() 

26def http_client(): 

27 async def handler(request: httpx.Request): 

28 if ( 

29 request.url.path 

30 == '/v1/projects/my-project-id/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent' 

31 ): 

32 return httpx.Response(200, json={'content': 'success'}) 

33 raise NotImplementedError(f'Unexpected request: {request.url!r}') # pragma: no cover 

34 

35 return httpx.AsyncClient(transport=httpx.MockTransport(handler=handler)) 

36 

37 

38def test_google_vertex_provider(allow_model_requests: None) -> None: 

39 provider = GoogleVertexProvider() 

40 assert provider.name == 'google-vertex' 

41 assert provider.base_url == snapshot( 

42 'https://us-central1-aiplatform.googleapis.com/v1/projects/None/locations/us-central1/publishers/google/models/' 

43 ) 

44 assert isinstance(provider.client, httpx.AsyncClient) 

45 

46 

47@dataclass 

48class NoOpCredentials: 

49 token = 'my-token' 

50 

51 def refresh(self, request: Request): ... 51 ↛ exitline 51 didn't return from function 'refresh' because

52 

53 

54@patch('pydantic_ai.providers.google_vertex.google.auth.default', return_value=(NoOpCredentials(), 'my-project-id')) 

55async def test_google_vertex_provider_auth(allow_model_requests: None, http_client: httpx.AsyncClient): 

56 provider = GoogleVertexProvider(http_client=http_client) 

57 await provider.client.post('/gemini-1.0-pro:generateContent') 

58 assert provider.region == 'us-central1' 

59 assert getattr(provider.client.auth, 'project_id') == 'my-project-id' 

60 

61 

62async def mock_refresh_token(): 

63 return 'my-token' 

64 

65 

66async def test_google_vertex_provider_service_account_file( 

67 monkeypatch: pytest.MonkeyPatch, tmp_path: Path, allow_model_requests: None 

68): 

69 service_account_path = tmp_path / 'service_account.json' 

70 save_service_account(service_account_path, 'my-project-id') 

71 

72 provider = GoogleVertexProvider(service_account_file=service_account_path) 

73 monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token) 

74 await provider.client.post('/gemini-1.0-pro:generateContent') 

75 assert provider.region == 'us-central1' 

76 assert getattr(provider.client.auth, 'project_id') == 'my-project-id' 

77 

78 

79async def test_google_vertex_provider_service_account_file_info( 

80 monkeypatch: pytest.MonkeyPatch, allow_model_requests: None 

81): 

82 account_info = prepare_service_account_contents('my-project-id') 

83 

84 provider = GoogleVertexProvider(service_account_info=account_info) 

85 monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token) 

86 await provider.client.post('/gemini-1.0-pro:generateContent') 

87 assert provider.region == 'us-central1' 

88 assert getattr(provider.client.auth, 'project_id') == 'my-project-id' 

89 

90 

91async def test_google_vertex_provider_service_account_xor(allow_model_requests: None): 

92 with pytest.raises( 

93 ValueError, match='Only one of `service_account_file` or `service_account_info` can be provided' 

94 ): 

95 GoogleVertexProvider( # type: ignore[reportCallIssue] 

96 service_account_file='path/to/service-account.json', 

97 service_account_info=prepare_service_account_contents('my-project-id'), 

98 ) 

99 

100 

101def prepare_service_account_contents(project_id: str) -> dict[str, str]: 

102 return { 

103 'type': 'service_account', 

104 'project_id': project_id, 

105 'private_key_id': 'abc', 

106 # this is just a random private key I created with `openssl genpke ...`, it doesn't do anything 

107 'private_key': ( 

108 '-----BEGIN PRIVATE KEY-----\n' 

109 'MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMFrZYX4gZ20qv88\n' 

110 'jD0QCswXgcxgP7Ta06G47QEFprDVcv4WMUBDJVAKofzVcYyhsasWsOSxcpA8LIi9\n' 

111 '/VS2Otf8CmIK6nPBCD17Qgt8/IQYXOS4U2EBh0yjo0HQ4vFpkqium4lLWxrAZohA\n' 

112 '8r82clV08iLRUW3J+xvN23iPHyVDAgMBAAECgYBScRJe3iNxMvbHv+kOhe30O/jJ\n' 

113 'QiUlUzhtcEMk8mGwceqHvrHTcEtRKJcPC3NQvALcp9lSQQhRzjQ1PLXkC6BcfKFd\n' 

114 '03q5tVPmJiqsHbSyUyHWzdlHP42xWpl/RmX/DfRKGhPOvufZpSTzkmKWtN+7osHu\n' 

115 '7eiMpg2EDswCvOgf0QJBAPXLYwHbZLaM2KEMDgJSse5ZTE/0VMf+5vSTGUmHkr9c\n' 

116 'Wx2G1i258kc/JgsXInPbq4BnK9hd0Xj2T5cmEmQtm4UCQQDJc02DFnPnjPnnDUwg\n' 

117 'BPhrCyW+rnBGUVjehveu4XgbGx7l3wsbORTaKdCX3HIKUupgfFwFcDlMUzUy6fPO\n' 

118 'IuQnAkA8FhVE/fIX4kSO0hiWnsqafr/2B7+2CG1DOraC0B6ioxwvEqhHE17T5e8R\n' 

119 '5PzqH7hEMnR4dy7fCC+avpbeYHvVAkA5W58iR+5Qa49r/hlCtKeWsuHYXQqSuu62\n' 

120 'zW8QWBo+fYZapRsgcSxCwc0msBm4XstlFYON+NoXpUlsabiFZOHZAkEA8Ffq3xoU\n' 

121 'y0eYGy3MEzxx96F+tkl59lfkwHKWchWZJ95vAKWJaHx9WFxSWiJofbRna8Iim6pY\n' 

122 'BootYWyTCfjjwA==\n' 

123 '-----END PRIVATE KEY-----\n' 

124 ), 

125 'client_email': 'testing-pydantic-ai@pydantic-ai.iam.gserviceaccount.com', 

126 'client_id': '123', 

127 'auth_uri': 'https://accounts.google.com/o/oauth2/auth', 

128 'token_uri': 'https://oauth2.googleapis.com/token', 

129 'auth_provider_x509_cert_url': 'https://www.googleapis.com/oauth2/v1/certs', 

130 'client_x509_cert_url': 'https://www.googleapis.com/...', 

131 'universe_domain': 'googleapis.com', 

132 } 

133 

134 

135def save_service_account(service_account_path: Path, project_id: str) -> None: 

136 service_account = prepare_service_account_contents(project_id) 

137 

138 service_account_path.write_text(json.dumps(service_account, indent=2))