Coverage for requests_tracker/sql/sql_hook.py: 76%

34 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 22:19 +0000

1from typing import Any, Sequence 

2 

3from django.db.backends.utils import CursorWrapper 

4 

5from requests_tracker.sql.sql_tracker import ExecuteParameters, SQLTracker 

6 

7 

8def set_database_wrapper_if_missing( 

9 sql_tracker: SQLTracker, 

10 cursor_wrapper: CursorWrapper, 

11) -> None: 

12 """ 

13 Because install_sql_hook might be called after the database connection is 

14 already established, and the overwritten connect function not have been called. 

15 

16 Therefor set the database wrapper if it has not been set before. 

17 """ 

18 if sql_tracker.database_wrapper is None: 

19 sql_tracker.set_database_wrapper(cursor_wrapper.db) 

20 

21 

22def install_sql_hook() -> None: 

23 from django.db.backends.base.base import BaseDatabaseWrapper 

24 from django.db.backends.utils import CursorWrapper 

25 

26 real_execute = CursorWrapper.execute 

27 real_executemany = CursorWrapper.executemany 

28 real_call_proc = CursorWrapper.callproc 

29 real_connect = BaseDatabaseWrapper.connect 

30 

31 def execute(self: CursorWrapper, sql: str, params: ExecuteParameters = None) -> Any: 

32 sql_tracker = SQLTracker.current 

33 set_database_wrapper_if_missing(sql_tracker, self) 

34 return sql_tracker.record( 

35 method=real_execute, 

36 cursor_self=self, 

37 sql=sql, 

38 params=params, 

39 ) 

40 

41 def executemany( 

42 self: CursorWrapper, 

43 sql: str, 

44 param_list: Sequence[ExecuteParameters], 

45 ) -> Any: 

46 sql_tracker = SQLTracker.current 

47 set_database_wrapper_if_missing(sql_tracker, self) 

48 return sql_tracker.record( 

49 method=real_executemany, 

50 cursor_self=self, 

51 sql=sql, 

52 params=param_list, 

53 many=True, 

54 ) 

55 

56 def callproc( 

57 self: CursorWrapper, 

58 procname: str, 

59 params: ExecuteParameters = None, 

60 ) -> Any: 

61 sql_tracker = SQLTracker.current 

62 set_database_wrapper_if_missing(sql_tracker, self) 

63 return sql_tracker.record( 

64 method=real_call_proc, 

65 cursor_self=self, 

66 sql=procname, 

67 params=params, 

68 ) 

69 

70 def connect(self: BaseDatabaseWrapper) -> Any: 

71 ret = real_connect(self) 

72 sql_tracker = SQLTracker.current 

73 sql_tracker.set_database_wrapper(self) 

74 

75 return ret 

76 

77 CursorWrapper.execute = execute # type: ignore 

78 CursorWrapper.executemany = executemany # type: ignore 

79 CursorWrapper.callproc = callproc # type: ignore 

80 BaseDatabaseWrapper.connect = connect # type: ignore