Coverage for fastagency/api/openapi/security.py: 91%

139 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-19 12:16 +0000

1import base64 1aghefibcd

2import logging 1aghefibcd

3from typing import Any, ClassVar, Literal, Optional, Protocol 1aghefibcd

4 

5import requests 1aghefibcd

6from pydantic import BaseModel, model_validator 1aghefibcd

7from typing_extensions import TypeAlias 1aghefibcd

8 

9# Get the logger 

10logger = logging.getLogger(__name__) 1aghefibcd

11logger.setLevel(logging.DEBUG) 1aghefibcd

12 

13BaseSecurityType: TypeAlias = type["BaseSecurity"] 1aghefibcd

14 

15 

16class BaseSecurity(BaseModel): 1aghefibcd

17 """Base class for security classes.""" 

18 

19 type: ClassVar[ 1aghefibcd

20 Literal["apiKey", "http", "mutualTLS", "oauth2", "openIdConnect", "unsupported"] 

21 ] 

22 in_value: ClassVar[ 1aghefibcd

23 Literal["header", "query", "cookie", "bearer", "basic", "tls", "unsupported"] 

24 ] 

25 name: str 1aghefibcd

26 

27 @model_validator(mode="after") # type: ignore[misc] 1aghefibcd

28 def __post_init__( 1aghefibcd

29 self, 

30 ) -> None: # dataclasses uses __post_init__ instead of model_validator 

31 """Validate the in_value based on the type.""" 

32 valid_in_values = { 1aefbcd

33 "apiKey": ["header", "query", "cookie"], 

34 "http": ["bearer", "basic"], 

35 "oauth2": ["bearer"], 

36 "openIdConnect": ["bearer"], 

37 "mutualTLS": ["tls"], 

38 "unsupported": ["unsupported"], 

39 } 

40 if self.in_value not in valid_in_values[self.type]: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true1aefbcd

41 raise ValueError( 

42 f"Invalid in_value '{self.in_value}' for type '{self.type}'" 

43 ) 

44 

45 def accept(self, security_params: "BaseSecurityParameters") -> bool: 1aghefibcd

46 return isinstance(self, security_params.get_security_class()) 1afbcd

47 

48 @classmethod 1aghefibcd

49 def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: 1aghefibcd

50 return cls.type == type and cls.in_value == schema_parameters.get("in") 1aefbcd

51 

52 @classmethod 1aghefibcd

53 def get_security_class( 1aghefibcd

54 cls, type: str, schema_parameters: dict[str, Any] 

55 ) -> BaseSecurityType: 

56 sub_classes = cls.__subclasses__() 1aefbcd

57 

58 for sub_class in sub_classes: 1aefbcd

59 if sub_class.is_supported(type, schema_parameters): 1aefbcd

60 return sub_class 1aefbcd

61 

62 logger.error( 1aebcd

63 f"Unsupported type '{type}' and schema_parameters '{schema_parameters}' combination" 

64 ) 

65 return UnsuportedSecurityStub 1aebcd

66 

67 @classmethod 1aghefibcd

68 def get_security_parameters(cls, schema_parameters: dict[str, Any]) -> str: 1aghefibcd

69 return f'{cls.__name__}(name="{schema_parameters.get("name")}")' 1aefbcd

70 

71 

72class BaseSecurityParameters(Protocol): 1aghefibcd

73 """Base class for security parameters.""" 

74 

75 def apply( 1aghefibcd

76 self, 

77 q_params: dict[str, Any], 1aghefibcd

78 body_dict: dict[str, Any], 1aghefibcd

79 security: BaseSecurity, 1aghefibcd

80 ) -> None: ... 1aghefibcd

81 

82 def get_security_class(self) -> type[BaseSecurity]: ... 1aghefibcd

83 

84 

85class UnsuportedSecurityStub(BaseSecurity): 1aghefibcd

86 """Unsupported security stub class.""" 

87 

88 type: ClassVar[Literal["unsupported"]] = "unsupported" 1aghefibcd

89 in_value: ClassVar[Literal["unsupported"]] = "unsupported" 1aghefibcd

90 

91 @classmethod 1aghefibcd

92 def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: 1aghefibcd

93 return False 1aefbcd

94 

95 def accept(self, security_params: "BaseSecurityParameters") -> bool: 1aghefibcd

96 if isinstance(self, security_params.get_security_class()): 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was always true1abcd

97 raise RuntimeError("Trying to set UnsuportedSecurityStub params") 1abcd

98 return False 

99 

100 class Parameters(BaseModel): # BaseSecurityParameters 1aghefibcd

101 """API Key Header security parameters class.""" 

102 

103 def apply( 1aghefibcd

104 self, 

105 q_params: dict[str, Any], 

106 body_dict: dict[str, Any], 

107 security: BaseSecurity, 

108 ) -> None: 

109 pass 

110 

111 def get_security_class(self) -> type[BaseSecurity]: 1aghefibcd

112 return UnsuportedSecurityStub 1abcd

113 

114 

115class APIKeyHeader(BaseSecurity): 1aghefibcd

116 """API Key Header security class.""" 

117 

118 type: ClassVar[Literal["apiKey"]] = "apiKey" 1aghefibcd

119 in_value: ClassVar[Literal["header"]] = "header" 1aghefibcd

120 

121 class Parameters(BaseModel): # BaseSecurityParameters 1aghefibcd

122 """API Key Header security parameters class.""" 

123 

124 value: str 1aghefibcd

125 

126 def apply( 1aghefibcd

127 self, 

128 q_params: dict[str, Any], 

129 body_dict: dict[str, Any], 

130 security: BaseSecurity, 

131 ) -> None: 

132 api_key_header: APIKeyHeader = security # type: ignore[assignment] 1afbcd

133 

134 if "headers" not in body_dict: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true1afbcd

135 body_dict["headers"] = {} 

136 

137 body_dict["headers"][api_key_header.name] = self.value 1afbcd

138 

139 def get_security_class(self) -> type[BaseSecurity]: 1aghefibcd

140 return APIKeyHeader 1afbcd

141 

142 

143class APIKeyQuery(BaseSecurity): 1aghefibcd

144 """API Key Query security class.""" 

145 

146 type: ClassVar[Literal["apiKey"]] = "apiKey" 1aghefibcd

147 in_value: ClassVar[Literal["query"]] = "query" 1aghefibcd

148 

149 @classmethod 1aghefibcd

150 def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: 1aghefibcd

151 return ( 1aebcd

152 super().is_supported(type, schema_parameters) 

153 and "name" in schema_parameters 

154 ) 

155 

156 class Parameters(BaseModel): # BaseSecurityParameters 1aghefibcd

157 """API Key Query security parameters class.""" 

158 

159 value: str 1aghefibcd

160 

161 def apply( 1aghefibcd

162 self, 

163 q_params: dict[str, Any], 

164 body_dict: dict[str, Any], 

165 security: BaseSecurity, 

166 ) -> None: 

167 api_key_query: APIKeyQuery = security # type: ignore[assignment] 1abcd

168 

169 q_params[api_key_query.name] = self.value 1abcd

170 

171 def get_security_class(self) -> type[BaseSecurity]: 1aghefibcd

172 return APIKeyQuery 1abcd

173 

174 

175class APIKeyCookie(BaseSecurity): 1aghefibcd

176 """API Key Cookie security class.""" 

177 

178 type: ClassVar[Literal["apiKey"]] = "apiKey" 1aghefibcd

179 in_value: ClassVar[Literal["cookie"]] = "cookie" 1aghefibcd

180 

181 class Parameters(BaseModel): # BaseSecurityParameters 1aghefibcd

182 """API Key Cookie security parameters class.""" 

183 

184 value: str 1aghefibcd

185 

186 def apply( 1aghefibcd

187 self, 

188 q_params: dict[str, Any], 

189 body_dict: dict[str, Any], 

190 security: BaseSecurity, 

191 ) -> None: 

192 api_key_cookie: APIKeyCookie = security # type: ignore[assignment] 1abcd

193 

194 if "cookies" not in body_dict: 194 ↛ 197line 194 didn't jump to line 197 because the condition on line 194 was always true1abcd

195 body_dict["cookies"] = {} 1abcd

196 

197 body_dict["cookies"][api_key_cookie.name] = self.value 1abcd

198 

199 def get_security_class(self) -> type[BaseSecurity]: 1aghefibcd

200 return APIKeyCookie 1abcd

201 

202 

203class HTTPBearer(BaseSecurity): 1aghefibcd

204 """HTTP Bearer security class.""" 

205 

206 type: ClassVar[Literal["http"]] = "http" 1aghefibcd

207 in_value: ClassVar[Literal["bearer"]] = "bearer" 1aghefibcd

208 

209 @classmethod 1aghefibcd

210 def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: 1aghefibcd

211 return cls.type == type and cls.in_value == schema_parameters.get("scheme") 1aebcd

212 

213 class Parameters(BaseModel): # BaseSecurityParameters 1aghefibcd

214 """HTTP Bearer security parameters class.""" 

215 

216 value: str 1aghefibcd

217 

218 def apply( 1aghefibcd

219 self, 

220 q_params: dict[str, Any], 

221 body_dict: dict[str, Any], 

222 security: BaseSecurity, 

223 ) -> None: 

224 if "headers" not in body_dict: 224 ↛ 225line 224 didn't jump to line 225 because the condition on line 224 was never true1abcd

225 body_dict["headers"] = {} 

226 

227 body_dict["headers"]["Authorization"] = f"Bearer {self.value}" 1abcd

228 

229 def get_security_class(self) -> type[BaseSecurity]: 1aghefibcd

230 return HTTPBearer 1abcd

231 

232 

233class HTTPBasic(BaseSecurity): 1aghefibcd

234 """HTTP Bearer security class.""" 

235 

236 type: ClassVar[Literal["http"]] = "http" 1aghefibcd

237 in_value: ClassVar[Literal["basic"]] = "basic" 1aghefibcd

238 

239 @classmethod 1aghefibcd

240 def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: 1aghefibcd

241 return cls.type == type and cls.in_value == schema_parameters.get("scheme") 1aebcd

242 

243 class Parameters(BaseModel): # BaseSecurityParameters 1aghefibcd

244 """HTTP Basic security parameters class.""" 

245 

246 username: str 1aghefibcd

247 password: str 1aghefibcd

248 

249 def apply( 1aghefibcd

250 self, 

251 q_params: dict[str, Any], 

252 body_dict: dict[str, Any], 

253 security: BaseSecurity, 

254 ) -> None: 

255 if "headers" not in body_dict: 1abcd

256 body_dict["headers"] = {} 

257 

258 credentials = f"{self.username}:{self.password}" 1abcd

259 encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode( 1abcd

260 "utf-8" 

261 ) 

262 

263 body_dict["headers"]["Authorization"] = f"Basic {encoded_credentials}" 1abcd

264 

265 def get_security_class(self) -> type[BaseSecurity]: 1aghefibcd

266 return HTTPBasic 1abcd

267 

268 

269class OAuth2PasswordBearer(BaseSecurity): 1aghefibcd

270 """OAuth2 Password Bearer security class.""" 

271 

272 type: ClassVar[Literal["oauth2"]] = "oauth2" 1aghefibcd

273 in_value: ClassVar[Literal["bearer"]] = "bearer" 1aghefibcd

274 token_url: str 1aghefibcd

275 

276 @classmethod 1aghefibcd

277 def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: 1aghefibcd

278 return type == cls.type and "password" in schema_parameters.get("flows", {}) 1aebcd

279 

280 @classmethod 1aghefibcd

281 def get_security_parameters(cls, schema_parameters: dict[str, Any]) -> str: 1aghefibcd

282 name = schema_parameters.get("name") 1abcd

283 token_url = f"{schema_parameters.get('server_url')}/{schema_parameters['flows']['password']['tokenUrl']}" 1abcd

284 return f'{cls.__name__}(name="{name}", token_url="{token_url}")' 1abcd

285 

286 class Parameters(BaseModel): # BaseSecurityParameters 1aghefibcd

287 """OAuth2 Password Bearer security class.""" 

288 

289 username: Optional[str] = None 1aghefibcd

290 password: Optional[str] = None 1aghefibcd

291 bearer_token: Optional[str] = None 1aghefibcd

292 token_url: Optional[str] = None 1aghefibcd

293 

294 @model_validator(mode="before") 1aghefibcd

295 def check_credentials(cls, values: dict[str, Any]) -> Any: # noqa 1aghefibcd

296 username = values.get("username") 1abcd

297 password = values.get("password") 1abcd

298 bearer_token = values.get("bearer_token") 1abcd

299 

300 if not bearer_token and (not username or not password): 1abcd

301 # If bearer_token is not provided, both username and password must be defined 

302 raise ValueError( 

303 "Both username and password are required if bearer_token is not provided." 

304 ) 

305 

306 return values 1abcd

307 

308 def get_token(self, token_url: str) -> str: 1aghefibcd

309 # Get the token 

310 request = requests.post( 1abcd

311 token_url, 

312 data={ 

313 "username": self.username, 

314 "password": self.password, 

315 }, 

316 timeout=5, 

317 ) 

318 request.raise_for_status() 1abcd

319 return request.json()["access_token"] # type: ignore 1abcd

320 

321 def apply( 1aghefibcd

322 self, 

323 q_params: dict[str, Any], 

324 body_dict: dict[str, Any], 

325 security: BaseSecurity, 

326 ) -> None: 

327 if not self.bearer_token: 1abcd

328 if security.token_url is None: # type: ignore 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true1abcd

329 raise ValueError("Token URL is not defined") 

330 self.bearer_token = self.get_token(security.token_url) # type: ignore 1abcd

331 

332 if "headers" not in body_dict: 332 ↛ 333line 332 didn't jump to line 333 because the condition on line 332 was never true1abcd

333 body_dict["headers"] = {} 

334 

335 body_dict["headers"]["Authorization"] = f"Bearer {self.bearer_token}" 1abcd

336 

337 def get_security_class(self) -> type[BaseSecurity]: 1aghefibcd

338 return OAuth2PasswordBearer 1abcd