Coverage for bbconf/config_parser/models.py: 88%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-17 04:01 +0000

1from pathlib import Path 

2from typing import Optional, Union 

3import logging 

4 

5from pydantic import BaseModel, ConfigDict, computed_field, field_validator 

6from yacman import load_yaml 

7 

8from bbconf.config_parser.const import ( 

9 DEFAULT_DB_DIALECT, 

10 DEFAULT_DB_DRIVER, 

11 DEFAULT_DB_NAME, 

12 DEFAULT_DB_PORT, 

13 DEFAULT_PEPHUB_NAME, 

14 DEFAULT_PEPHUB_NAMESPACE, 

15 DEFAULT_PEPHUB_TAG, 

16 DEFAULT_QDRANT_COLLECTION_NAME, 

17 DEFAULT_QDRANT_PORT, 

18 DEFAULT_REGION2_VEC_MODEL, 

19 DEFAULT_S3_BUCKET, 

20 DEFAULT_SERVER_HOST, 

21 DEFAULT_SERVER_PORT, 

22 DEFAULT_TEXT2VEC_MODEL, 

23 DEFAULT_VEC2VEC_MODEL, 

24) 

25 

26_LOGGER = logging.getLogger(__name__) 

27 

28 

29class ConfigDB(BaseModel): 

30 host: str 

31 port: int = DEFAULT_DB_PORT 

32 user: str 

33 password: str 

34 database: str = DEFAULT_DB_NAME 

35 dialect: str = DEFAULT_DB_DIALECT 

36 driver: Optional[str] = DEFAULT_DB_DRIVER 

37 

38 model_config = ConfigDict(extra="forbid") 

39 

40 @computed_field 

41 @property 

42 def url(self) -> str: 

43 """ 

44 The URL of the database. 

45 

46 :return str: The URL of the database. 

47 """ 

48 return f"{self.dialect}+{self.driver}://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" 

49 

50 

51class ConfigQdrant(BaseModel): 

52 host: str 

53 port: int = DEFAULT_QDRANT_PORT 

54 api_key: Optional[str] = None 

55 collection: str = DEFAULT_QDRANT_COLLECTION_NAME 

56 

57 

58class ConfigServer(BaseModel): 

59 host: str = DEFAULT_SERVER_HOST 

60 port: int = DEFAULT_SERVER_PORT 

61 

62 

63class ConfigPath(BaseModel): 

64 region2vec: str = DEFAULT_REGION2_VEC_MODEL 

65 vec2vec: str = DEFAULT_VEC2VEC_MODEL 

66 text2vec: str = DEFAULT_TEXT2VEC_MODEL 

67 

68 

69class AccessMethodsStruct(BaseModel): 

70 type: str 

71 description: str = None 

72 prefix: str 

73 

74 

75class AccessMethods(BaseModel): 

76 http: AccessMethodsStruct = None 

77 s3: AccessMethodsStruct = None 

78 local: AccessMethodsStruct = None 

79 

80 

81class ConfigS3(BaseModel): 

82 endpoint_url: Union[str, None] = None 

83 aws_access_key_id: Union[str, None] = None 

84 aws_secret_access_key: Union[str, None] = None 

85 bucket: Union[str, None] = DEFAULT_S3_BUCKET 

86 

87 @field_validator("aws_access_key_id", "aws_secret_access_key") 

88 def validate_aws_credentials(cls, value): 

89 # Do this if AWS credentials are not provided 

90 if value in [ 

91 "AWS_SECRET_ACCESS_KEY", 

92 "AWS_ACCESS_KEY_ID", 

93 "", 

94 "$AWS_ACCESS_KEY_ID", 

95 "$AWS_SECRET_ACCESS_KEY", 

96 ]: 

97 return None 

98 return value 

99 

100 @computed_field 

101 @property 

102 def modify_access(self) -> bool: 

103 """ 

104 If the AWS credentials are provided, set the modify access to True. (create = True) 

105 

106 :return str: The URL of the database. 

107 """ 

108 if self.aws_access_key_id and self.aws_secret_access_key: 

109 return True 

110 _LOGGER.warning( 

111 "AWS credentials are not provided. The S3 bucket will be read-only." 

112 ) 

113 return False 

114 

115 

116class ConfigPepHubClient(BaseModel): 

117 namespace: Union[str, None] = DEFAULT_PEPHUB_NAMESPACE 

118 name: Union[str, None] = DEFAULT_PEPHUB_NAME 

119 tag: Union[str, None] = DEFAULT_PEPHUB_TAG 

120 

121 

122class ConfigFile(BaseModel): 

123 database: ConfigDB 

124 qdrant: ConfigQdrant = None 

125 server: ConfigServer 

126 path: ConfigPath 

127 access_methods: AccessMethods = None 

128 s3: ConfigS3 = None 

129 phc: ConfigPepHubClient = None 

130 

131 model_config = ConfigDict(extra="allow") 

132 

133 @classmethod 

134 def from_yaml(cls, path: Path): 

135 """ 

136 Load the database configuration from a YAML file. 

137 

138 :param path: The path to the YAML file. 

139 

140 :returns: DatabaseConfig: The database configuration. 

141 """ 

142 return cls.model_validate(load_yaml(path.as_posix()))