Coverage for requests_tracker/sql/sql_tracker.py: 95%

131 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 22:19 +0000

1import contextlib 

2import datetime 

3import json 

4import re 

5import types 

6from contextvars import ContextVar 

7from decimal import Decimal 

8from time import time 

9from typing import ( 

10 TYPE_CHECKING, 

11 Any, 

12 Callable, 

13 Dict, 

14 List, 

15 Mapping, 

16 Optional, 

17 Sequence, 

18 Tuple, 

19 Union, 

20) 

21from uuid import UUID 

22 

23from django.db.backends.base.base import BaseDatabaseWrapper 

24from django.db.backends.utils import CursorWrapper 

25from django.utils.encoding import force_str 

26 

27from requests_tracker import settings as dr_settings 

28from requests_tracker.sql.dataclasses import SQLQueryInfo 

29from requests_tracker.sql.sql_collector import SQLCollector 

30from requests_tracker.stack_trace import get_stack_trace 

31 

32try: 

33 from psycopg2._json import Json as PostgresJson 

34 from psycopg2.extensions import STATUS_IN_TRANSACTION 

35except ImportError: 

36 PostgresJson = None # type: ignore 

37 STATUS_IN_TRANSACTION = None # type: ignore 

38 

39if TYPE_CHECKING: 

40 DecodeReturn = Union[ 

41 List[Union["DecodeReturn", str]], Dict[str, Union["DecodeReturn", str]], str 

42 ] 

43 

44SQLType = Union[ 

45 None, 

46 bool, 

47 int, 

48 float, 

49 Decimal, 

50 str, 

51 bytes, 

52 datetime.date, 

53 datetime.datetime, 

54 UUID, 

55 Tuple[Any, ...], 

56 List[Any], 

57] 

58 

59ExecuteParameters = Optional[Union[Sequence[SQLType], Mapping[str, SQLType]]] 

60ExecuteParametersOrSequence = Union[ExecuteParameters, Sequence[ExecuteParameters]] 

61 

62QuoteParamsReturn = Optional[Union[Dict[str, str], List[str]]] 

63 

64_local: ContextVar["SQLTracker"] = ContextVar("current_sql_tracker") 

65 

66 

67class SQLTrackerMeta(type): 

68 @property 

69 def current(cls) -> "SQLTracker": 

70 """Returns the current instance of the sql tracker""" 

71 current_sql_tracker = _local.get(None) 

72 

73 if current_sql_tracker is None: 

74 current_sql_tracker = SQLTracker() 

75 _local.set(current_sql_tracker) 

76 

77 return current_sql_tracker 

78 

79 

80class SQLTracker(metaclass=SQLTrackerMeta): 

81 _old_sql_trackers: List["SQLTracker"] 

82 _sql_collector: Optional[SQLCollector] 

83 database_wrapper: Optional[BaseDatabaseWrapper] 

84 

85 def __init__(self, sql_collector: Optional[SQLCollector] = None) -> None: 

86 self._old_sql_trackers = [] 

87 self._sql_collector = sql_collector 

88 self.database_wrapper = None 

89 

90 def __enter__(self) -> "SQLTracker": 

91 self._old_sql_trackers.append(SQLTracker.current) 

92 _local.set(self) 

93 return self 

94 

95 def __exit__( 

96 self, 

97 exc_type: Optional[type], 

98 exc_value: Optional[BaseException], 

99 tb: Optional[types.TracebackType], 

100 ) -> None: 

101 old = self._old_sql_trackers.pop() 

102 _local.set(old) 

103 

104 def set_database_wrapper(self, database_wrapper: BaseDatabaseWrapper) -> None: 

105 self.database_wrapper = database_wrapper 

106 

107 @staticmethod 

108 def _quote_expr(element: Any) -> str: 

109 if isinstance(element, str): 

110 return f"""'{element.replace("'", "''")}'""" 

111 else: 

112 return repr(element) 

113 

114 def _quote_params(self, params: ExecuteParametersOrSequence) -> QuoteParamsReturn: 

115 if params is None: 

116 return params 

117 if isinstance(params, dict): 

118 return {key: self._quote_expr(value) for key, value in params.items()} 

119 return [self._quote_expr(p) for p in params] 

120 

121 def _get_raw_sql( 

122 self, 

123 cursor_self: CursorWrapper, 

124 sql: str, 

125 params: ExecuteParametersOrSequence, 

126 many: bool, 

127 vendor: str, 

128 ) -> str: 

129 # This is a hacky way to get the parameters correct for sqlite in executemany 

130 if ( 

131 vendor == "sqlite" 

132 and many 

133 and isinstance(params, (tuple, list)) 

134 and isinstance(params[0], (tuple, list)) 

135 ): 

136 final_params: List[str] = [] 

137 part_to_replace = re.search(r"(\(\s*%s.*\))", sql)[1] # type: ignore 

138 sql = sql.replace( 

139 part_to_replace, ", ".join(part_to_replace for _ in params) 

140 ) 

141 for current_params in zip(*params): # noqa: B905 

142 final_params.extend(current_params) 

143 

144 return self.database_wrapper.ops.last_executed_query( # type: ignore 

145 cursor_self, 

146 sql, 

147 self._quote_params(final_params), 

148 ) 

149 

150 return self.database_wrapper.ops.last_executed_query( # type: ignore 

151 cursor_self, 

152 sql, 

153 self._quote_params(params), 

154 ) 

155 

156 def _decode(self, param: ExecuteParametersOrSequence) -> "DecodeReturn": 

157 if PostgresJson is not None and isinstance(param, PostgresJson): 

158 return param.dumps(param.adapted) 

159 

160 # If a sequence type, decode each element separately 

161 if isinstance(param, (tuple, list)): 

162 return [self._decode(element) for element in param] 

163 

164 # If a dictionary type, decode each value separately 

165 if isinstance(param, dict): 

166 return {key: self._decode(value) for key, value in param.items()} 

167 

168 # make sure datetime, date and time are converted to string by force_str 

169 CONVERT_TYPES = (datetime.datetime, datetime.date, datetime.time) 

170 try: 

171 return force_str(param, strings_only=not isinstance(param, CONVERT_TYPES)) 

172 except UnicodeDecodeError: 

173 return "(encoded string)" 

174 

175 @staticmethod 

176 def _get_postgres_isolation_level(conn: Any) -> Any: 

177 """ 

178 If an erroneous query was ran on the connection, it might 

179 be in a state where checking isolation_level raises an exception. 

180 """ 

181 try: 

182 try: 

183 return conn.isolation_level 

184 except conn.InternalError: 

185 return "unknown" 

186 except AttributeError: 

187 return "unknown" 

188 

189 def _get_postgres_transaction_id( 

190 self, 

191 conn: Any, 

192 initial_conn_status: int, 

193 alias: str, 

194 ) -> Optional[str]: 

195 """ 

196 PostgreSQL does not expose any sort of transaction ID, so it is necessary to 

197 generate synthetic transaction IDs here. If the connection was not in a 

198 transaction when the query started, and was after the query finished, a new 

199 transaction definitely started, so get a new transaction ID from 

200 logger.new_transaction_id(). If the query was in a transaction both before and 

201 after executing, make the assumption that it is the same transaction and get the 

202 current transaction ID from logger.current_transaction_id(). 

203 There is an edge case where Django can start a transaction before the first 

204 query executes, so in that case logger.current_transaction_id() will generate 

205 a new transaction ID since one does not already exist. 

206 """ 

207 final_conn_status = conn.status 

208 if ( 

209 final_conn_status == STATUS_IN_TRANSACTION 

210 and self._sql_collector is not None 

211 ): 

212 if initial_conn_status == STATUS_IN_TRANSACTION: 

213 return self._sql_collector.current_transaction_id(alias) 

214 else: 

215 return self._sql_collector.new_transaction_id(alias) 

216 return None 

217 

218 def record( 

219 self, 

220 method: Callable[[CursorWrapper, str, Any], Any], 

221 cursor_self: CursorWrapper, 

222 sql: str, 

223 params: ExecuteParametersOrSequence, 

224 many: bool = False, 

225 ) -> Any: # sourcery skip: remove-unnecessary-cast 

226 # If we're not tracking SQL, just call the original method 

227 if self._sql_collector is None: 

228 return method(cursor_self, sql, params) 

229 

230 if self.database_wrapper is None: 

231 raise RuntimeError("SQLTracker not correctly initialized") 

232 

233 alias = self.database_wrapper.alias 

234 vendor = self.database_wrapper.vendor 

235 

236 if vendor == "postgresql": 

237 # The underlying DB connection (as opposed to Django's wrapper) 

238 db_version_string_match = re.match( 

239 r"(\d+\.\d+\.\d+)", 

240 self.database_wrapper.Database.__version__, # type: ignore 

241 ) 

242 db_version = ( 

243 tuple(map(int, db_version_string_match.group(1).split("."))) 

244 if db_version_string_match 

245 else (0, 0, 0) 

246 ) 

247 

248 conn = self.database_wrapper.connection 

249 pgconn = ( 

250 conn 

251 if db_version < (3, 0, 0) 

252 else self.database_wrapper.connection.pgconn 

253 ) 

254 initial_conn_status = pgconn.status 

255 

256 start_time = time() 

257 try: 

258 return method(cursor_self, sql, params) 

259 finally: 

260 stop_time = time() 

261 duration = (stop_time - start_time) * 1000 

262 _params = "" 

263 with contextlib.suppress(TypeError): 

264 _params = json.dumps(self._decode(params)) 

265 # Sql might be an object (such as psycopg Composed). 

266 # For logging purposes, make sure it's str. 

267 sql = str(sql) 

268 

269 sql_query_info = SQLQueryInfo( 

270 vendor=vendor, 

271 alias=alias, 

272 sql=sql, 

273 duration=duration, 

274 raw_sql=self._get_raw_sql(cursor_self, sql, params, many, vendor), 

275 params=_params, 

276 raw_params=params, 

277 stacktrace=get_stack_trace(skip=2), 

278 start_time=start_time, 

279 stop_time=stop_time, 

280 is_slow=duration > dr_settings.get_config()["SQL_WARNING_THRESHOLD"], 

281 is_select=sql.lower().strip().startswith("select"), 

282 ) 

283 

284 if vendor == "postgresql": 

285 sql_query_info.trans_id = self._get_postgres_transaction_id( 

286 conn=pgconn, 

287 initial_conn_status=initial_conn_status, 

288 alias=alias, 

289 ) 

290 try: 

291 sql_query_info.trans_status = pgconn.get_transaction_status() 

292 except AttributeError: 

293 sql_query_info.trans_status = pgconn.transaction_status 

294 sql_query_info.iso_level = self._get_postgres_isolation_level(pgconn) 

295 

296 self._sql_collector.record(sql_query_info) 

297 

298 

299GLOBAL_SQL_TRACKER = SQLTracker() 

300_local.set(GLOBAL_SQL_TRACKER)