Coverage for fastagency/api/openapi/openapi.py: 88%

155 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-19 12:16 +0000

1import builtins 1ahfgeibcd

2import importlib 1ahfgeibcd

3import inspect 1ahfgeibcd

4import json 1ahfgeibcd

5import re 1ahfgeibcd

6import shutil 1ahfgeibcd

7import sys 1ahfgeibcd

8from collections.abc import Iterable, Iterator, Mapping 1ahfgeibcd

9from contextlib import contextmanager 1ahfgeibcd

10from functools import wraps 1ahfgeibcd

11from pathlib import Path 1ahfgeibcd

12from types import ModuleType 1ahfgeibcd

13from typing import ( 1ahfgeibcd

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 Literal, 

18 Optional, 

19 Union, 

20) 

21 

22import fastapi 1ahfgeibcd

23import requests 1ahfgeibcd

24from datamodel_code_generator import DataModelType 1ahfgeibcd

25from fastapi_code_generator.__main__ import generate_code 1ahfgeibcd

26from pydantic_core import PydanticUndefined 1ahfgeibcd

27 

28from fastagency.helpers import optional_temp_path 1ahfgeibcd

29 

30from ...logging import get_logger 1ahfgeibcd

31from .fastapi_code_generator_helpers import patch_get_parameter_type 1ahfgeibcd

32from .security import BaseSecurity, BaseSecurityParameters 1ahfgeibcd

33 

34if TYPE_CHECKING: 1ahfgeibcd

35 from autogen.agentchat import ConversableAgent 

36 

37__all__ = ["OpenAPI"] 1ahfgeibcd

38 

39logger = get_logger(__name__) 1ahfgeibcd

40 

41 

42@contextmanager 1ahfgeibcd

43def add_to_builtins(new_globals: dict[str, Any]) -> Iterator[None]: 1ahfgeibcd

44 old_globals = {key: getattr(builtins, key, None) for key in new_globals} 1afgebcd

45 

46 try: 1afgebcd

47 for key, value in new_globals.items(): 1afgebcd

48 setattr(builtins, key, value) # Inject new global 1afgebcd

49 yield 1afgebcd

50 finally: 

51 for key, value in old_globals.items(): 1afgebcd

52 if value is None: 52 ↛ 55line 52 didn't jump to line 55 because the condition on line 52 was always true1afgebcd

53 delattr(builtins, key) # Remove added globals 1afgebcd

54 else: 

55 setattr(builtins, key, value) # Restore original value 

56 

57 

58class OpenAPI: 1ahfgeibcd

59 def __init__( 1ahfgeibcd

60 self, servers: list[dict[str, Any]], title: Optional[str] = None, **kwargs: Any 

61 ) -> None: 

62 """Proxy class to generate client from OpenAPI schema.""" 

63 self._servers = servers 1afgebcd

64 self._title = title 1afgebcd

65 self._kwargs = kwargs 1afgebcd

66 self._registered_funcs: list[Callable[..., Any]] = [] 1afgebcd

67 self._globals: dict[str, Any] = {} 1afgebcd

68 

69 self._security: dict[str, list[BaseSecurity]] = {} 1afgebcd

70 self._security_params: dict[Optional[str], BaseSecurityParameters] = {} 1afgebcd

71 

72 @staticmethod 1ahfgeibcd

73 def _convert_camel_case_within_braces_to_snake(text: str) -> str: 1ahfgeibcd

74 # Function to convert camel case to snake case 

75 def camel_to_snake(match: re.Match[str]) -> str: 1afebcd

76 return re.sub(r"(?<!^)(?=[A-Z])", "_", match.group(1)).lower() 1afbcd

77 

78 # Find all occurrences inside curly braces and apply camel_to_snake 

79 result = re.sub( 1afebcd

80 r"\{([a-zA-Z0-9]+)\}", lambda m: "{" + camel_to_snake(m) + "}", text 

81 ) 

82 

83 return result 1afebcd

84 

85 @staticmethod 1ahfgeibcd

86 def _get_params( 1ahfgeibcd

87 path: str, func: Callable[..., Any] 1ahfgeibcd

88 ) -> tuple[set[str], set[str], Optional[str], bool]: 1ahfgeibcd

89 sig = inspect.signature(func) 1afebcd

90 

91 params_names = set(sig.parameters.keys()) 1afebcd

92 

93 path_params = set(re.findall(r"\{(.+?)\}", path)) 1afebcd

94 if not path_params.issubset(params_names): 1afebcd

95 raise ValueError(f"Path params {path_params} not in {params_names}") 

96 

97 body = "body" if "body" in params_names else None 1afebcd

98 

99 security = "security" in params_names 1afebcd

100 

101 q_params = set(params_names) - path_params - {body} - {"security"} 1afebcd

102 

103 return q_params, path_params, body, security 1afebcd

104 

105 def _process_params( 1ahfgeibcd

106 self, path: str, func: Callable[[Any], Any], **kwargs: Any 

107 ) -> tuple[str, dict[str, Any], dict[str, Any]]: 

108 path = OpenAPI._convert_camel_case_within_braces_to_snake(path) 1afebcd

109 q_params, path_params, body, security = OpenAPI._get_params(path, func) 1afebcd

110 

111 expanded_path = path.format(**{p: kwargs[p] for p in path_params}) 1afebcd

112 

113 url = self._servers[0]["url"] + expanded_path 1afebcd

114 

115 body_dict = {} 1afebcd

116 if body and body in kwargs: 1afebcd

117 body_value = kwargs[body] 1afbcd

118 if isinstance(body_value, dict): 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true1afbcd

119 body_dict = {"json": body_value} 

120 elif hasattr(body_value, "model_dump"): 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true1afbcd

121 body_dict = {"json": body_value.model_dump()} 1afbcd

122 else: 

123 body_dict = {"json": body_value.dict()} 

124 

125 body_dict["headers"] = {"Content-Type": "application/json"} 1afebcd

126 if security: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true1afebcd

127 q_params, body_dict = kwargs["security"].add_security(q_params, body_dict) 

128 # body_dict["headers"][security] = kwargs["security"] 

129 

130 params = {k: v for k, v in kwargs.items() if k in q_params} 1afebcd

131 

132 return url, params, body_dict 1afebcd

133 

134 def set_security_params( 1ahfgeibcd

135 self, security_params: BaseSecurityParameters, name: Optional[str] = None 

136 ) -> None: 

137 if name is not None: 1agebcd

138 security = self._security.get(name) 1abcd

139 if security is None: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true1abcd

140 raise ValueError(f"Security is not set for '{name}'") 

141 

142 for match_security in security: 142 ↛ 146line 142 didn't jump to line 146 because the loop on line 142 didn't complete1abcd

143 if match_security.accept(security_params): 1abcd

144 break 1abcd

145 else: 

146 raise ValueError( 

147 f"Security parameters {security_params} do not match security {security}" 

148 ) 

149 

150 self._security_params[name] = security_params 1agebcd

151 

152 def _get_matching_security( 1ahfgeibcd

153 self, security: list[BaseSecurity], security_params: BaseSecurityParameters 

154 ) -> BaseSecurity: 

155 # check if security matches security parameters 

156 for match_security in security: 1aebcd

157 if match_security.accept(security_params): 1aebcd

158 return match_security 1aebcd

159 raise ValueError( 1abcd

160 f"Security parameters {security_params} does not match any given security {security}" 

161 ) 

162 

163 def _get_security_params( 1ahfgeibcd

164 self, name: str 

165 ) -> tuple[Optional[BaseSecurityParameters], Optional[BaseSecurity]]: 

166 # check if security is set for the method 

167 security = self._security.get(name) 1aebcd

168 if not security: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true1aebcd

169 return None, None 

170 

171 security_params = self._security_params.get(name) 1aebcd

172 if security_params is None: 1aebcd

173 # check if default security parameters are set 

174 security_params = self._security_params.get(None) 1aebcd

175 if security_params is None: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true1aebcd

176 raise ValueError( 

177 f"Security parameters are not set for {name} and there are no default security parameters" 

178 ) 

179 

180 match_security = self._get_matching_security(security, security_params) 1aebcd

181 

182 return security_params, match_security 1aebcd

183 

184 def _request( 1ahfgeibcd

185 self, 

186 method: Literal["put", "get", "post", "head", "delete", "patch"], 1ahfgeibcd

187 path: str, 1ahfgeibcd

188 description: Optional[str] = None, 1ahfgeibcd

189 security: Optional[list[BaseSecurity]] = None, 1ahfgeibcd

190 **kwargs: Any, 1ahfgeibcd

191 ) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

192 def decorator(func: Callable[..., Any]) -> Callable[..., dict[str, Any]]: 1afgebcd

193 name = func.__name__ 1afgebcd

194 

195 if security is not None: 1afgebcd

196 self._security[name] = security 1agebcd

197 

198 @wraps(func) 1afgebcd

199 def wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]: 1afgebcd

200 url, params, body_dict = self._process_params(path, func, **kwargs) 1afebcd

201 

202 security = self._security.get(name) 1afebcd

203 if security is not None: 1afebcd

204 security_params, matched_security = self._get_security_params(name) 1aebcd

205 if security_params is None: 1aebcd

206 raise ValueError( 

207 f"Security parameters are not set for '{name}'" 

208 ) 

209 else: 

210 security_params.apply(params, body_dict, matched_security) # type: ignore [arg-type] 1aebcd

211 

212 response = getattr(requests, method)(url, params=params, **body_dict) 1afebcd

213 return response.json() # type: ignore [no-any-return] 1afebcd

214 

215 wrapper._description = ( # type: ignore [attr-defined] 1afgebcd

216 description or func.__doc__.strip() 1afgebcd

217 if func.__doc__ is not None 1afgebcd

218 else None 1abcd

219 ) 

220 

221 self._registered_funcs.append(wrapper) 1afgebcd

222 

223 return wrapper 1afgebcd

224 

225 return decorator # type: ignore [return-value] 1afgebcd

226 

227 def put(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

228 return self._request("put", path, **kwargs) 1agbcd

229 

230 def get(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

231 return self._request("get", path, **kwargs) 1afgebcd

232 

233 def post(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

234 return self._request("post", path, **kwargs) 1afgbcd

235 

236 def delete(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

237 return self._request("delete", path, **kwargs) 1agbcd

238 

239 def head(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

240 return self._request("head", path, **kwargs) 1agbcd

241 

242 def patch(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

243 return self._request("patch", path, **kwargs) 1agbcd

244 

245 @classmethod 1ahfgeibcd

246 def _get_template_dir(cls) -> Path: 1ahfgeibcd

247 path = Path(__file__).parents[3] / "templates" 1afgebcd

248 if not path.exists(): 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true1afgebcd

249 raise RuntimeError(f"Template directory {path.resolve()} not found.") 

250 return path 1afgebcd

251 

252 @classmethod 1ahfgeibcd

253 def generate_code( 1ahfgeibcd

254 cls, 

255 input_text: str, 

256 output_dir: Path, 

257 disable_timestamp: bool = False, 

258 custom_visitors: Optional[list[Path]] = None, 

259 ) -> str: 

260 if custom_visitors is None: 260 ↛ 262line 260 didn't jump to line 262 because the condition on line 260 was always true1afgebcd

261 custom_visitors = [] 1afgebcd

262 custom_visitors.append(Path(__file__).parent / "security_schema_visitor.py") 1afgebcd

263 

264 with patch_get_parameter_type(): 1afgebcd

265 generate_code( 1afgebcd

266 input_name="openapi.json", 

267 input_text=input_text, 

268 encoding="utf-8", 

269 output_dir=output_dir, 

270 template_dir=cls._get_template_dir(), 

271 disable_timestamp=disable_timestamp, 

272 custom_visitors=custom_visitors, 

273 output_model_type=DataModelType.PydanticV2BaseModel, 

274 ) 

275 # Use unique file name for main.py 

276 main_name = f"main_{output_dir.name}" 1afgebcd

277 main_path = output_dir / f"{main_name}.py" 1afgebcd

278 shutil.move(output_dir / "main.py", main_path) 1afgebcd

279 

280 # Change "from models import" to "from models_unique_name import" 

281 with main_path.open("r") as f: 1afgebcd

282 main_py_code = f.read() 1afgebcd

283 main_py_code = main_py_code.replace( 1afgebcd

284 "from .models import", f"from models_{output_dir.name} import" 

285 ) 

286 

287 with main_path.open("w") as f: 1afgebcd

288 f.write(main_py_code) 1afgebcd

289 

290 # Use unique file name for models.py 

291 models_name = f"models_{output_dir.name}" 1afgebcd

292 models_path = output_dir / f"{models_name}.py" 1afgebcd

293 shutil.move(output_dir / "models.py", models_path) 1afgebcd

294 

295 return main_name 1afgebcd

296 

297 def set_globals(self, main: ModuleType, suffix: str) -> None: 1ahfgeibcd

298 xs = {k: v for k, v in main.__dict__.items() if not k.startswith("__")} 1afgebcd

299 self._globals = { 1afgebcd

300 k: v 

301 for k, v in xs.items() 

302 if hasattr(v, "__module__") 

303 and v.__module__ in [f"models_{suffix}", "typing"] 

304 } 

305 

306 @classmethod 1ahfgeibcd

307 def create( 1ahfgeibcd

308 cls, 

309 *, 

310 openapi_json: Optional[str] = None, 

311 openapi_url: Optional[str] = None, 

312 client_source_path: Optional[str] = None, 

313 servers: Optional[list[dict[str, Any]]] = None, 

314 ) -> "OpenAPI": 

315 if (openapi_json is None) == (openapi_url is None): 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true1afgebcd

316 raise ValueError("Either openapi_json or openapi_url should be provided") 

317 

318 if openapi_json is None and openapi_url is not None: 1afgebcd

319 with requests.get(openapi_url, timeout=10) as response: 1afgebcd

320 response.raise_for_status() 1afgebcd

321 openapi_json = response.text 1afgebcd

322 

323 if servers: 1afgebcd

324 openapi_parsed = json.loads(openapi_json) # type: ignore [arg-type] 1agbcd

325 openapi_parsed["servers"] = servers 1agbcd

326 openapi_json = json.dumps(openapi_parsed) 1agbcd

327 

328 with optional_temp_path(client_source_path) as td: 1afgebcd

329 suffix = td.name 1afgebcd

330 

331 main_name = cls.generate_code( 1afgebcd

332 input_text=openapi_json, # type: ignore [arg-type] 

333 output_dir=td, 

334 ) 

335 # add td to sys.path 

336 try: 1afgebcd

337 sys.path.append(str(td)) 1afgebcd

338 main = importlib.import_module(main_name, package=td.name) # nosemgrep 1afgebcd

339 finally: 

340 sys.path.remove(str(td)) 1afgebcd

341 

342 client: OpenAPI = main.app # type: ignore [attr-defined] 1afgebcd

343 client.set_globals(main, suffix=suffix) 1afgebcd

344 

345 return client 1afgebcd

346 

347 def _get_functions_to_register( 1ahfgeibcd

348 self, 

349 functions: Optional[ 1ahfgeibcd

350 Iterable[Union[str, Mapping[str, Mapping[str, str]]]] 1ahfgeibcd

351 ] = None, 1bcd

352 ) -> dict[Callable[..., Any], dict[str, Union[str, None]]]: 1ahfgeibcd

353 if functions is None: 1afgebcd

354 return { 1afebcd

355 f: { 1afebcd

356 "name": None, 1afebcd

357 "description": f._description 1afebcd

358 if hasattr(f, "_description") 1afebcd

359 else None, 

360 } 

361 for f in self._registered_funcs 1afebcd

362 } 

363 

364 functions_with_name_desc: dict[str, dict[str, Union[str, None]]] = {} 1agbcd

365 

366 for f in functions: 1agbcd

367 if isinstance(f, str): 1agbcd

368 functions_with_name_desc[f] = {"name": None, "description": None} 1agbcd

369 elif isinstance(f, dict): 

370 functions_with_name_desc.update( 

371 { 

372 k: { 

373 "name": v.get("name", None), 

374 "description": v.get("description", None), 

375 } 

376 for k, v in f.items() 

377 } 

378 ) 

379 else: 

380 raise ValueError(f"Invalid type {type(f)} for function {f}") 

381 

382 funcs_to_register: dict[Callable[..., Any], dict[str, Union[str, None]]] = { 1agbcd

383 f: functions_with_name_desc[f.__name__] 1agbcd

384 for f in self._registered_funcs 1agbcd

385 if f.__name__ in functions_with_name_desc 1agbcd

386 } 

387 missing_functions = set(functions_with_name_desc.keys()) - { 1agbcd

388 f.__name__ for f in funcs_to_register 1agbcd

389 } 

390 if missing_functions: 1agbcd

391 raise ValueError( 1abcd

392 f"Following functions {missing_functions} are not valid functions" 1abcd

393 ) 

394 

395 return funcs_to_register 1agbcd

396 

397 @staticmethod 1ahfgeibcd

398 def _remove_pydantic_undefined_from_tools( 1ahfgeibcd

399 tools: list[dict[str, Any]], 

400 ) -> list[dict[str, Any]]: 

401 for tool in tools: 1afgebcd

402 if "function" not in tool: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true1afgebcd

403 continue 

404 

405 function = tool["function"] 1afgebcd

406 if ( 406 ↛ 410line 406 didn't jump to line 410 because the condition on line 406 was never true1afge

407 "parameters" not in function 

408 or "properties" not in function["parameters"] 

409 ): 

410 continue 

411 

412 required = function["parameters"].get("required", []) 1afgebcd

413 for param_name, param_value in function["parameters"]["properties"].items(): 1afgebcd

414 if "default" not in param_value: 1afgebcd

415 continue 1afgebcd

416 

417 default = param_value.get("default") 1afbcd

418 if ( 1afbc

419 isinstance(default, (fastapi.params.Path, fastapi.params.Query)) 

420 and param_value["default"].default is PydanticUndefined 

421 ): 

422 param_value.pop("default") 1afbcd

423 # We removed the default value, so we need to add the parameter to the required list 

424 if param_name not in required: 424 ↛ 413line 424 didn't jump to line 413 because the condition on line 424 was always true1afbcd

425 required.append(param_name) 1afbcd

426 

427 return tools 1afgebcd

428 

429 def _register_for_llm( 1ahfgeibcd

430 self, 

431 agent: "ConversableAgent", 

432 functions: Optional[ 

433 Iterable[Union[str, Mapping[str, Mapping[str, str]]]] 

434 ] = None, 

435 ) -> None: 

436 funcs_to_register = self._get_functions_to_register(functions) 1afgebcd

437 

438 with add_to_builtins( 1afgebcd

439 new_globals=self._globals, 

440 ): 

441 for f, v in funcs_to_register.items(): 1afgebcd

442 agent.register_for_llm(name=v["name"], description=v["description"])(f) 1afgebcd

443 

444 agent.llm_config["tools"] = OpenAPI._remove_pydantic_undefined_from_tools( 1afgebcd

445 agent.llm_config["tools"] 

446 ) 

447 

448 def _register_for_execution( 1ahfgeibcd

449 self, 

450 agent: "ConversableAgent", 

451 functions: Optional[ 

452 Iterable[Union[str, Mapping[str, Mapping[str, str]]]] 

453 ] = None, 

454 ) -> None: 

455 funcs_to_register = self._get_functions_to_register(functions) 1afgebcd

456 

457 for f, v in funcs_to_register.items(): 1afgebcd

458 agent.register_for_execution(name=v["name"])(f) 1afgebcd

459 

460 def get_functions(self) -> list[str]: 1ahfgeibcd

461 raise DeprecationWarning( 1abcd

462 "Use function_names property instead of get_functions method" 

463 ) 

464 

465 @property 1ahfgeibcd

466 def function_names(self) -> list[str]: 1ahfgeibcd

467 return [f.__name__ for f in self._registered_funcs] 1abcd

468 

469 def get_function(self, name: str) -> Callable[..., dict[str, Any]]: 1ahfgeibcd

470 for f in self._registered_funcs: 1abcd

471 if f.__name__ == name: 1abcd

472 return f 1abcd

473 raise ValueError(f"Function {name} not found") 1abcd

474 

475 def set_function(self, name: str, func: Callable[..., dict[str, Any]]) -> None: 1ahfgeibcd

476 for i, f in enumerate(self._registered_funcs): 1abcd

477 if f.__name__ == name: 1abcd

478 self._registered_funcs[i] = func 1abcd

479 return 1abcd

480 

481 raise ValueError(f"Function {name} not found") 

482 

483 def inject_parameters(self, name: str, **kwargs: Any) -> None: 1ahfgeibcd

484 raise NotImplementedError("Injecting parameters is not implemented yet") 

485 # for f in self._registered_funcs: 

486 # if f.__name__ == name: 

487 # return 

488 

489 # raise ValueError(f"Function {name} not found")