Coverage for requests_tracker/sql/sql_tracker.py: 95%
131 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 22:19 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 22:19 +0000
1import contextlib
2import datetime
3import json
4import re
5import types
6from contextvars import ContextVar
7from decimal import Decimal
8from time import time
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 Callable,
13 Dict,
14 List,
15 Mapping,
16 Optional,
17 Sequence,
18 Tuple,
19 Union,
20)
21from uuid import UUID
23from django.db.backends.base.base import BaseDatabaseWrapper
24from django.db.backends.utils import CursorWrapper
25from django.utils.encoding import force_str
27from requests_tracker import settings as dr_settings
28from requests_tracker.sql.dataclasses import SQLQueryInfo
29from requests_tracker.sql.sql_collector import SQLCollector
30from requests_tracker.stack_trace import get_stack_trace
32try:
33 from psycopg2._json import Json as PostgresJson
34 from psycopg2.extensions import STATUS_IN_TRANSACTION
35except ImportError:
36 PostgresJson = None # type: ignore
37 STATUS_IN_TRANSACTION = None # type: ignore
39if TYPE_CHECKING:
40 DecodeReturn = Union[
41 List[Union["DecodeReturn", str]], Dict[str, Union["DecodeReturn", str]], str
42 ]
44SQLType = Union[
45 None,
46 bool,
47 int,
48 float,
49 Decimal,
50 str,
51 bytes,
52 datetime.date,
53 datetime.datetime,
54 UUID,
55 Tuple[Any, ...],
56 List[Any],
57]
59ExecuteParameters = Optional[Union[Sequence[SQLType], Mapping[str, SQLType]]]
60ExecuteParametersOrSequence = Union[ExecuteParameters, Sequence[ExecuteParameters]]
62QuoteParamsReturn = Optional[Union[Dict[str, str], List[str]]]
64_local: ContextVar["SQLTracker"] = ContextVar("current_sql_tracker")
67class SQLTrackerMeta(type):
68 @property
69 def current(cls) -> "SQLTracker":
70 """Returns the current instance of the sql tracker"""
71 current_sql_tracker = _local.get(None)
73 if current_sql_tracker is None:
74 current_sql_tracker = SQLTracker()
75 _local.set(current_sql_tracker)
77 return current_sql_tracker
80class SQLTracker(metaclass=SQLTrackerMeta):
81 _old_sql_trackers: List["SQLTracker"]
82 _sql_collector: Optional[SQLCollector]
83 database_wrapper: Optional[BaseDatabaseWrapper]
85 def __init__(self, sql_collector: Optional[SQLCollector] = None) -> None:
86 self._old_sql_trackers = []
87 self._sql_collector = sql_collector
88 self.database_wrapper = None
90 def __enter__(self) -> "SQLTracker":
91 self._old_sql_trackers.append(SQLTracker.current)
92 _local.set(self)
93 return self
95 def __exit__(
96 self,
97 exc_type: Optional[type],
98 exc_value: Optional[BaseException],
99 tb: Optional[types.TracebackType],
100 ) -> None:
101 old = self._old_sql_trackers.pop()
102 _local.set(old)
104 def set_database_wrapper(self, database_wrapper: BaseDatabaseWrapper) -> None:
105 self.database_wrapper = database_wrapper
107 @staticmethod
108 def _quote_expr(element: Any) -> str:
109 if isinstance(element, str):
110 return f"""'{element.replace("'", "''")}'"""
111 else:
112 return repr(element)
114 def _quote_params(self, params: ExecuteParametersOrSequence) -> QuoteParamsReturn:
115 if params is None:
116 return params
117 if isinstance(params, dict):
118 return {key: self._quote_expr(value) for key, value in params.items()}
119 return [self._quote_expr(p) for p in params]
121 def _get_raw_sql(
122 self,
123 cursor_self: CursorWrapper,
124 sql: str,
125 params: ExecuteParametersOrSequence,
126 many: bool,
127 vendor: str,
128 ) -> str:
129 # This is a hacky way to get the parameters correct for sqlite in executemany
130 if (
131 vendor == "sqlite"
132 and many
133 and isinstance(params, (tuple, list))
134 and isinstance(params[0], (tuple, list))
135 ):
136 final_params: List[str] = []
137 part_to_replace = re.search(r"(\(\s*%s.*\))", sql)[1] # type: ignore
138 sql = sql.replace(
139 part_to_replace, ", ".join(part_to_replace for _ in params)
140 )
141 for current_params in zip(*params): # noqa: B905
142 final_params.extend(current_params)
144 return self.database_wrapper.ops.last_executed_query( # type: ignore
145 cursor_self,
146 sql,
147 self._quote_params(final_params),
148 )
150 return self.database_wrapper.ops.last_executed_query( # type: ignore
151 cursor_self,
152 sql,
153 self._quote_params(params),
154 )
156 def _decode(self, param: ExecuteParametersOrSequence) -> "DecodeReturn":
157 if PostgresJson is not None and isinstance(param, PostgresJson):
158 return param.dumps(param.adapted)
160 # If a sequence type, decode each element separately
161 if isinstance(param, (tuple, list)):
162 return [self._decode(element) for element in param]
164 # If a dictionary type, decode each value separately
165 if isinstance(param, dict):
166 return {key: self._decode(value) for key, value in param.items()}
168 # make sure datetime, date and time are converted to string by force_str
169 CONVERT_TYPES = (datetime.datetime, datetime.date, datetime.time)
170 try:
171 return force_str(param, strings_only=not isinstance(param, CONVERT_TYPES))
172 except UnicodeDecodeError:
173 return "(encoded string)"
175 @staticmethod
176 def _get_postgres_isolation_level(conn: Any) -> Any:
177 """
178 If an erroneous query was ran on the connection, it might
179 be in a state where checking isolation_level raises an exception.
180 """
181 try:
182 try:
183 return conn.isolation_level
184 except conn.InternalError:
185 return "unknown"
186 except AttributeError:
187 return "unknown"
189 def _get_postgres_transaction_id(
190 self,
191 conn: Any,
192 initial_conn_status: int,
193 alias: str,
194 ) -> Optional[str]:
195 """
196 PostgreSQL does not expose any sort of transaction ID, so it is necessary to
197 generate synthetic transaction IDs here. If the connection was not in a
198 transaction when the query started, and was after the query finished, a new
199 transaction definitely started, so get a new transaction ID from
200 logger.new_transaction_id(). If the query was in a transaction both before and
201 after executing, make the assumption that it is the same transaction and get the
202 current transaction ID from logger.current_transaction_id().
203 There is an edge case where Django can start a transaction before the first
204 query executes, so in that case logger.current_transaction_id() will generate
205 a new transaction ID since one does not already exist.
206 """
207 final_conn_status = conn.status
208 if (
209 final_conn_status == STATUS_IN_TRANSACTION
210 and self._sql_collector is not None
211 ):
212 if initial_conn_status == STATUS_IN_TRANSACTION:
213 return self._sql_collector.current_transaction_id(alias)
214 else:
215 return self._sql_collector.new_transaction_id(alias)
216 return None
218 def record(
219 self,
220 method: Callable[[CursorWrapper, str, Any], Any],
221 cursor_self: CursorWrapper,
222 sql: str,
223 params: ExecuteParametersOrSequence,
224 many: bool = False,
225 ) -> Any: # sourcery skip: remove-unnecessary-cast
226 # If we're not tracking SQL, just call the original method
227 if self._sql_collector is None:
228 return method(cursor_self, sql, params)
230 if self.database_wrapper is None:
231 raise RuntimeError("SQLTracker not correctly initialized")
233 alias = self.database_wrapper.alias
234 vendor = self.database_wrapper.vendor
236 if vendor == "postgresql":
237 # The underlying DB connection (as opposed to Django's wrapper)
238 db_version_string_match = re.match(
239 r"(\d+\.\d+\.\d+)",
240 self.database_wrapper.Database.__version__, # type: ignore
241 )
242 db_version = (
243 tuple(map(int, db_version_string_match.group(1).split(".")))
244 if db_version_string_match
245 else (0, 0, 0)
246 )
248 conn = self.database_wrapper.connection
249 pgconn = (
250 conn
251 if db_version < (3, 0, 0)
252 else self.database_wrapper.connection.pgconn
253 )
254 initial_conn_status = pgconn.status
256 start_time = time()
257 try:
258 return method(cursor_self, sql, params)
259 finally:
260 stop_time = time()
261 duration = (stop_time - start_time) * 1000
262 _params = ""
263 with contextlib.suppress(TypeError):
264 _params = json.dumps(self._decode(params))
265 # Sql might be an object (such as psycopg Composed).
266 # For logging purposes, make sure it's str.
267 sql = str(sql)
269 sql_query_info = SQLQueryInfo(
270 vendor=vendor,
271 alias=alias,
272 sql=sql,
273 duration=duration,
274 raw_sql=self._get_raw_sql(cursor_self, sql, params, many, vendor),
275 params=_params,
276 raw_params=params,
277 stacktrace=get_stack_trace(skip=2),
278 start_time=start_time,
279 stop_time=stop_time,
280 is_slow=duration > dr_settings.get_config()["SQL_WARNING_THRESHOLD"],
281 is_select=sql.lower().strip().startswith("select"),
282 )
284 if vendor == "postgresql":
285 sql_query_info.trans_id = self._get_postgres_transaction_id(
286 conn=pgconn,
287 initial_conn_status=initial_conn_status,
288 alias=alias,
289 )
290 try:
291 sql_query_info.trans_status = pgconn.get_transaction_status()
292 except AttributeError:
293 sql_query_info.trans_status = pgconn.transaction_status
294 sql_query_info.iso_level = self._get_postgres_isolation_level(pgconn)
296 self._sql_collector.record(sql_query_info)
299GLOBAL_SQL_TRACKER = SQLTracker()
300_local.set(GLOBAL_SQL_TRACKER)