Coverage for fastagency/api/openapi/openapi.py: 88%
155 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
1import builtins 1ahfgeibcd
2import importlib 1ahfgeibcd
3import inspect 1ahfgeibcd
4import json 1ahfgeibcd
5import re 1ahfgeibcd
6import shutil 1ahfgeibcd
7import sys 1ahfgeibcd
8from collections.abc import Iterable, Iterator, Mapping 1ahfgeibcd
9from contextlib import contextmanager 1ahfgeibcd
10from functools import wraps 1ahfgeibcd
11from pathlib import Path 1ahfgeibcd
12from types import ModuleType 1ahfgeibcd
13from typing import ( 1ahfgeibcd
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 Literal,
18 Optional,
19 Union,
20)
22import fastapi 1ahfgeibcd
23import requests 1ahfgeibcd
24from datamodel_code_generator import DataModelType 1ahfgeibcd
25from fastapi_code_generator.__main__ import generate_code 1ahfgeibcd
26from pydantic_core import PydanticUndefined 1ahfgeibcd
28from fastagency.helpers import optional_temp_path 1ahfgeibcd
30from ...logging import get_logger 1ahfgeibcd
31from .fastapi_code_generator_helpers import patch_get_parameter_type 1ahfgeibcd
32from .security import BaseSecurity, BaseSecurityParameters 1ahfgeibcd
34if TYPE_CHECKING: 1ahfgeibcd
35 from autogen.agentchat import ConversableAgent
37__all__ = ["OpenAPI"] 1ahfgeibcd
39logger = get_logger(__name__) 1ahfgeibcd
42@contextmanager 1ahfgeibcd
43def add_to_builtins(new_globals: dict[str, Any]) -> Iterator[None]: 1ahfgeibcd
44 old_globals = {key: getattr(builtins, key, None) for key in new_globals} 1afgebcd
46 try: 1afgebcd
47 for key, value in new_globals.items(): 1afgebcd
48 setattr(builtins, key, value) # Inject new global 1afgebcd
49 yield 1afgebcd
50 finally:
51 for key, value in old_globals.items(): 1afgebcd
52 if value is None: 52 ↛ 55line 52 didn't jump to line 55 because the condition on line 52 was always true1afgebcd
53 delattr(builtins, key) # Remove added globals 1afgebcd
54 else:
55 setattr(builtins, key, value) # Restore original value
58class OpenAPI: 1ahfgeibcd
59 def __init__( 1ahfgeibcd
60 self, servers: list[dict[str, Any]], title: Optional[str] = None, **kwargs: Any
61 ) -> None:
62 """Proxy class to generate client from OpenAPI schema."""
63 self._servers = servers 1afgebcd
64 self._title = title 1afgebcd
65 self._kwargs = kwargs 1afgebcd
66 self._registered_funcs: list[Callable[..., Any]] = [] 1afgebcd
67 self._globals: dict[str, Any] = {} 1afgebcd
69 self._security: dict[str, list[BaseSecurity]] = {} 1afgebcd
70 self._security_params: dict[Optional[str], BaseSecurityParameters] = {} 1afgebcd
72 @staticmethod 1ahfgeibcd
73 def _convert_camel_case_within_braces_to_snake(text: str) -> str: 1ahfgeibcd
74 # Function to convert camel case to snake case
75 def camel_to_snake(match: re.Match[str]) -> str: 1afebcd
76 return re.sub(r"(?<!^)(?=[A-Z])", "_", match.group(1)).lower() 1afbcd
78 # Find all occurrences inside curly braces and apply camel_to_snake
79 result = re.sub( 1afebcd
80 r"\{([a-zA-Z0-9]+)\}", lambda m: "{" + camel_to_snake(m) + "}", text
81 )
83 return result 1afebcd
85 @staticmethod 1ahfgeibcd
86 def _get_params( 1ahfgeibcd
87 path: str, func: Callable[..., Any] 1ahfgeibcd
88 ) -> tuple[set[str], set[str], Optional[str], bool]: 1ahfgeibcd
89 sig = inspect.signature(func) 1afebcd
91 params_names = set(sig.parameters.keys()) 1afebcd
93 path_params = set(re.findall(r"\{(.+?)\}", path)) 1afebcd
94 if not path_params.issubset(params_names): 1afebcd
95 raise ValueError(f"Path params {path_params} not in {params_names}")
97 body = "body" if "body" in params_names else None 1afebcd
99 security = "security" in params_names 1afebcd
101 q_params = set(params_names) - path_params - {body} - {"security"} 1afebcd
103 return q_params, path_params, body, security 1afebcd
105 def _process_params( 1ahfgeibcd
106 self, path: str, func: Callable[[Any], Any], **kwargs: Any
107 ) -> tuple[str, dict[str, Any], dict[str, Any]]:
108 path = OpenAPI._convert_camel_case_within_braces_to_snake(path) 1afebcd
109 q_params, path_params, body, security = OpenAPI._get_params(path, func) 1afebcd
111 expanded_path = path.format(**{p: kwargs[p] for p in path_params}) 1afebcd
113 url = self._servers[0]["url"] + expanded_path 1afebcd
115 body_dict = {} 1afebcd
116 if body and body in kwargs: 1afebcd
117 body_value = kwargs[body] 1afbcd
118 if isinstance(body_value, dict): 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true1afbcd
119 body_dict = {"json": body_value}
120 elif hasattr(body_value, "model_dump"): 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true1afbcd
121 body_dict = {"json": body_value.model_dump()} 1afbcd
122 else:
123 body_dict = {"json": body_value.dict()}
125 body_dict["headers"] = {"Content-Type": "application/json"} 1afebcd
126 if security: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true1afebcd
127 q_params, body_dict = kwargs["security"].add_security(q_params, body_dict)
128 # body_dict["headers"][security] = kwargs["security"]
130 params = {k: v for k, v in kwargs.items() if k in q_params} 1afebcd
132 return url, params, body_dict 1afebcd
134 def set_security_params( 1ahfgeibcd
135 self, security_params: BaseSecurityParameters, name: Optional[str] = None
136 ) -> None:
137 if name is not None: 1agebcd
138 security = self._security.get(name) 1abcd
139 if security is None: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true1abcd
140 raise ValueError(f"Security is not set for '{name}'")
142 for match_security in security: 142 ↛ 146line 142 didn't jump to line 146 because the loop on line 142 didn't complete1abcd
143 if match_security.accept(security_params): 1abcd
144 break 1abcd
145 else:
146 raise ValueError(
147 f"Security parameters {security_params} do not match security {security}"
148 )
150 self._security_params[name] = security_params 1agebcd
152 def _get_matching_security( 1ahfgeibcd
153 self, security: list[BaseSecurity], security_params: BaseSecurityParameters
154 ) -> BaseSecurity:
155 # check if security matches security parameters
156 for match_security in security: 1aebcd
157 if match_security.accept(security_params): 1aebcd
158 return match_security 1aebcd
159 raise ValueError( 1abcd
160 f"Security parameters {security_params} does not match any given security {security}"
161 )
163 def _get_security_params( 1ahfgeibcd
164 self, name: str
165 ) -> tuple[Optional[BaseSecurityParameters], Optional[BaseSecurity]]:
166 # check if security is set for the method
167 security = self._security.get(name) 1aebcd
168 if not security: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true1aebcd
169 return None, None
171 security_params = self._security_params.get(name) 1aebcd
172 if security_params is None: 1aebcd
173 # check if default security parameters are set
174 security_params = self._security_params.get(None) 1aebcd
175 if security_params is None: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true1aebcd
176 raise ValueError(
177 f"Security parameters are not set for {name} and there are no default security parameters"
178 )
180 match_security = self._get_matching_security(security, security_params) 1aebcd
182 return security_params, match_security 1aebcd
184 def _request( 1ahfgeibcd
185 self,
186 method: Literal["put", "get", "post", "head", "delete", "patch"], 1ahfgeibcd
187 path: str, 1ahfgeibcd
188 description: Optional[str] = None, 1ahfgeibcd
189 security: Optional[list[BaseSecurity]] = None, 1ahfgeibcd
190 **kwargs: Any, 1ahfgeibcd
191 ) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
192 def decorator(func: Callable[..., Any]) -> Callable[..., dict[str, Any]]: 1afgebcd
193 name = func.__name__ 1afgebcd
195 if security is not None: 1afgebcd
196 self._security[name] = security 1agebcd
198 @wraps(func) 1afgebcd
199 def wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]: 1afgebcd
200 url, params, body_dict = self._process_params(path, func, **kwargs) 1afebcd
202 security = self._security.get(name) 1afebcd
203 if security is not None: 1afebcd
204 security_params, matched_security = self._get_security_params(name) 1aebcd
205 if security_params is None: 1aebcd
206 raise ValueError(
207 f"Security parameters are not set for '{name}'"
208 )
209 else:
210 security_params.apply(params, body_dict, matched_security) # type: ignore [arg-type] 1aebcd
212 response = getattr(requests, method)(url, params=params, **body_dict) 1afebcd
213 return response.json() # type: ignore [no-any-return] 1afebcd
215 wrapper._description = ( # type: ignore [attr-defined] 1afgebcd
216 description or func.__doc__.strip() 1afgebcd
217 if func.__doc__ is not None 1afgebcd
218 else None 1abcd
219 )
221 self._registered_funcs.append(wrapper) 1afgebcd
223 return wrapper 1afgebcd
225 return decorator # type: ignore [return-value] 1afgebcd
227 def put(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
228 return self._request("put", path, **kwargs) 1agbcd
230 def get(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
231 return self._request("get", path, **kwargs) 1afgebcd
233 def post(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
234 return self._request("post", path, **kwargs) 1afgbcd
236 def delete(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
237 return self._request("delete", path, **kwargs) 1agbcd
239 def head(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
240 return self._request("head", path, **kwargs) 1agbcd
242 def patch(self, path: str, **kwargs: Any) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
243 return self._request("patch", path, **kwargs) 1agbcd
245 @classmethod 1ahfgeibcd
246 def _get_template_dir(cls) -> Path: 1ahfgeibcd
247 path = Path(__file__).parents[3] / "templates" 1afgebcd
248 if not path.exists(): 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true1afgebcd
249 raise RuntimeError(f"Template directory {path.resolve()} not found.")
250 return path 1afgebcd
252 @classmethod 1ahfgeibcd
253 def generate_code( 1ahfgeibcd
254 cls,
255 input_text: str,
256 output_dir: Path,
257 disable_timestamp: bool = False,
258 custom_visitors: Optional[list[Path]] = None,
259 ) -> str:
260 if custom_visitors is None: 260 ↛ 262line 260 didn't jump to line 262 because the condition on line 260 was always true1afgebcd
261 custom_visitors = [] 1afgebcd
262 custom_visitors.append(Path(__file__).parent / "security_schema_visitor.py") 1afgebcd
264 with patch_get_parameter_type(): 1afgebcd
265 generate_code( 1afgebcd
266 input_name="openapi.json",
267 input_text=input_text,
268 encoding="utf-8",
269 output_dir=output_dir,
270 template_dir=cls._get_template_dir(),
271 disable_timestamp=disable_timestamp,
272 custom_visitors=custom_visitors,
273 output_model_type=DataModelType.PydanticV2BaseModel,
274 )
275 # Use unique file name for main.py
276 main_name = f"main_{output_dir.name}" 1afgebcd
277 main_path = output_dir / f"{main_name}.py" 1afgebcd
278 shutil.move(output_dir / "main.py", main_path) 1afgebcd
280 # Change "from models import" to "from models_unique_name import"
281 with main_path.open("r") as f: 1afgebcd
282 main_py_code = f.read() 1afgebcd
283 main_py_code = main_py_code.replace( 1afgebcd
284 "from .models import", f"from models_{output_dir.name} import"
285 )
287 with main_path.open("w") as f: 1afgebcd
288 f.write(main_py_code) 1afgebcd
290 # Use unique file name for models.py
291 models_name = f"models_{output_dir.name}" 1afgebcd
292 models_path = output_dir / f"{models_name}.py" 1afgebcd
293 shutil.move(output_dir / "models.py", models_path) 1afgebcd
295 return main_name 1afgebcd
297 def set_globals(self, main: ModuleType, suffix: str) -> None: 1ahfgeibcd
298 xs = {k: v for k, v in main.__dict__.items() if not k.startswith("__")} 1afgebcd
299 self._globals = { 1afgebcd
300 k: v
301 for k, v in xs.items()
302 if hasattr(v, "__module__")
303 and v.__module__ in [f"models_{suffix}", "typing"]
304 }
306 @classmethod 1ahfgeibcd
307 def create( 1ahfgeibcd
308 cls,
309 *,
310 openapi_json: Optional[str] = None,
311 openapi_url: Optional[str] = None,
312 client_source_path: Optional[str] = None,
313 servers: Optional[list[dict[str, Any]]] = None,
314 ) -> "OpenAPI":
315 if (openapi_json is None) == (openapi_url is None): 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true1afgebcd
316 raise ValueError("Either openapi_json or openapi_url should be provided")
318 if openapi_json is None and openapi_url is not None: 1afgebcd
319 with requests.get(openapi_url, timeout=10) as response: 1afgebcd
320 response.raise_for_status() 1afgebcd
321 openapi_json = response.text 1afgebcd
323 if servers: 1afgebcd
324 openapi_parsed = json.loads(openapi_json) # type: ignore [arg-type] 1agbcd
325 openapi_parsed["servers"] = servers 1agbcd
326 openapi_json = json.dumps(openapi_parsed) 1agbcd
328 with optional_temp_path(client_source_path) as td: 1afgebcd
329 suffix = td.name 1afgebcd
331 main_name = cls.generate_code( 1afgebcd
332 input_text=openapi_json, # type: ignore [arg-type]
333 output_dir=td,
334 )
335 # add td to sys.path
336 try: 1afgebcd
337 sys.path.append(str(td)) 1afgebcd
338 main = importlib.import_module(main_name, package=td.name) # nosemgrep 1afgebcd
339 finally:
340 sys.path.remove(str(td)) 1afgebcd
342 client: OpenAPI = main.app # type: ignore [attr-defined] 1afgebcd
343 client.set_globals(main, suffix=suffix) 1afgebcd
345 return client 1afgebcd
347 def _get_functions_to_register( 1ahfgeibcd
348 self,
349 functions: Optional[ 1ahfgeibcd
350 Iterable[Union[str, Mapping[str, Mapping[str, str]]]] 1ahfgeibcd
351 ] = None, 1bcd
352 ) -> dict[Callable[..., Any], dict[str, Union[str, None]]]: 1ahfgeibcd
353 if functions is None: 1afgebcd
354 return { 1afebcd
355 f: { 1afebcd
356 "name": None, 1afebcd
357 "description": f._description 1afebcd
358 if hasattr(f, "_description") 1afebcd
359 else None,
360 }
361 for f in self._registered_funcs 1afebcd
362 }
364 functions_with_name_desc: dict[str, dict[str, Union[str, None]]] = {} 1agbcd
366 for f in functions: 1agbcd
367 if isinstance(f, str): 1agbcd
368 functions_with_name_desc[f] = {"name": None, "description": None} 1agbcd
369 elif isinstance(f, dict):
370 functions_with_name_desc.update(
371 {
372 k: {
373 "name": v.get("name", None),
374 "description": v.get("description", None),
375 }
376 for k, v in f.items()
377 }
378 )
379 else:
380 raise ValueError(f"Invalid type {type(f)} for function {f}")
382 funcs_to_register: dict[Callable[..., Any], dict[str, Union[str, None]]] = { 1agbcd
383 f: functions_with_name_desc[f.__name__] 1agbcd
384 for f in self._registered_funcs 1agbcd
385 if f.__name__ in functions_with_name_desc 1agbcd
386 }
387 missing_functions = set(functions_with_name_desc.keys()) - { 1agbcd
388 f.__name__ for f in funcs_to_register 1agbcd
389 }
390 if missing_functions: 1agbcd
391 raise ValueError( 1abcd
392 f"Following functions {missing_functions} are not valid functions" 1abcd
393 )
395 return funcs_to_register 1agbcd
397 @staticmethod 1ahfgeibcd
398 def _remove_pydantic_undefined_from_tools( 1ahfgeibcd
399 tools: list[dict[str, Any]],
400 ) -> list[dict[str, Any]]:
401 for tool in tools: 1afgebcd
402 if "function" not in tool: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true1afgebcd
403 continue
405 function = tool["function"] 1afgebcd
406 if ( 406 ↛ 410line 406 didn't jump to line 410 because the condition on line 406 was never true1afge
407 "parameters" not in function
408 or "properties" not in function["parameters"]
409 ):
410 continue
412 required = function["parameters"].get("required", []) 1afgebcd
413 for param_name, param_value in function["parameters"]["properties"].items(): 1afgebcd
414 if "default" not in param_value: 1afgebcd
415 continue 1afgebcd
417 default = param_value.get("default") 1afbcd
418 if ( 1afbc
419 isinstance(default, (fastapi.params.Path, fastapi.params.Query))
420 and param_value["default"].default is PydanticUndefined
421 ):
422 param_value.pop("default") 1afbcd
423 # We removed the default value, so we need to add the parameter to the required list
424 if param_name not in required: 424 ↛ 413line 424 didn't jump to line 413 because the condition on line 424 was always true1afbcd
425 required.append(param_name) 1afbcd
427 return tools 1afgebcd
429 def _register_for_llm( 1ahfgeibcd
430 self,
431 agent: "ConversableAgent",
432 functions: Optional[
433 Iterable[Union[str, Mapping[str, Mapping[str, str]]]]
434 ] = None,
435 ) -> None:
436 funcs_to_register = self._get_functions_to_register(functions) 1afgebcd
438 with add_to_builtins( 1afgebcd
439 new_globals=self._globals,
440 ):
441 for f, v in funcs_to_register.items(): 1afgebcd
442 agent.register_for_llm(name=v["name"], description=v["description"])(f) 1afgebcd
444 agent.llm_config["tools"] = OpenAPI._remove_pydantic_undefined_from_tools( 1afgebcd
445 agent.llm_config["tools"]
446 )
448 def _register_for_execution( 1ahfgeibcd
449 self,
450 agent: "ConversableAgent",
451 functions: Optional[
452 Iterable[Union[str, Mapping[str, Mapping[str, str]]]]
453 ] = None,
454 ) -> None:
455 funcs_to_register = self._get_functions_to_register(functions) 1afgebcd
457 for f, v in funcs_to_register.items(): 1afgebcd
458 agent.register_for_execution(name=v["name"])(f) 1afgebcd
460 def get_functions(self) -> list[str]: 1ahfgeibcd
461 raise DeprecationWarning( 1abcd
462 "Use function_names property instead of get_functions method"
463 )
465 @property 1ahfgeibcd
466 def function_names(self) -> list[str]: 1ahfgeibcd
467 return [f.__name__ for f in self._registered_funcs] 1abcd
469 def get_function(self, name: str) -> Callable[..., dict[str, Any]]: 1ahfgeibcd
470 for f in self._registered_funcs: 1abcd
471 if f.__name__ == name: 1abcd
472 return f 1abcd
473 raise ValueError(f"Function {name} not found") 1abcd
475 def set_function(self, name: str, func: Callable[..., dict[str, Any]]) -> None: 1ahfgeibcd
476 for i, f in enumerate(self._registered_funcs): 1abcd
477 if f.__name__ == name: 1abcd
478 self._registered_funcs[i] = func 1abcd
479 return 1abcd
481 raise ValueError(f"Function {name} not found")
483 def inject_parameters(self, name: str, **kwargs: Any) -> None: 1ahfgeibcd
484 raise NotImplementedError("Injecting parameters is not implemented yet")
485 # for f in self._registered_funcs:
486 # if f.__name__ == name:
487 # return
489 # raise ValueError(f"Function {name} not found")