From 1b76012f82a04b781f6f67ab1aeeeb0b2eeab599 Mon Sep 17 00:00:00 2001 From: Sebastien Corbin Date: Sun, 20 Nov 2022 19:19:26 +0100 Subject: [PATCH] #214 Use Django 2.0 execute_wrapper() --- project/tests/test_db.py | 26 +++---- project/tests/test_execute_sql.py | 120 ------------------------------ silk/apps.py | 8 ++ silk/middleware.py | 17 +---- silk/sql.py | 103 ++++++++++++------------- 5 files changed, 73 insertions(+), 201 deletions(-) delete mode 100644 project/tests/test_execute_sql.py diff --git a/project/tests/test_db.py b/project/tests/test_db.py index 3157725d..cb0eb67d 100644 --- a/project/tests/test_db.py +++ b/project/tests/test_db.py @@ -2,6 +2,7 @@ Test profiling of DB queries without mocking, to catch possible incompatibility """ + from django.shortcuts import reverse from django.test import Client, TestCase @@ -20,24 +21,28 @@ def setUpClass(cls): BlindFactory.create_batch(size=5) SilkyConfig().SILKY_META = False + def setUp(self): + DataCollector().clear() + def test_profile_request_to_db(self): DataCollector().configure(Request(reverse('example_app:index'))) with silk_profile(name='test_profile'): resp = self.client.get(reverse('example_app:index')) - DataCollector().profiles.values() - assert len(resp.context['blinds']) == 5 + self.assertEqual(len(DataCollector().queries), 1, [q['query'] for q in DataCollector().queries.values()]) + self.assertEqual(len(resp.context['blinds']), 5) def test_profile_request_to_db_with_constraints(self): DataCollector().configure(Request(reverse('example_app:create'))) resp = self.client.post(reverse('example_app:create'), {'name': 'Foo'}) + self.assertEqual(len(DataCollector().queries), 2) + self.assertTrue(list(DataCollector().queries.values())[-1]['query'].startswith('INSERT')) self.assertEqual(resp.status_code, 302) class TestAnalyzeQueries(TestCase): - @classmethod def setUpClass(cls): super().setUpClass() @@ -48,7 +53,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): super().tearDownClass() - SilkyConfig().SILKLY_ANALYZE_QUERIES = False + SilkyConfig().SILKY_ANALYZE_QUERIES = False def test_analyze_queries(self): DataCollector().configure(Request(reverse('example_app:index'))) @@ -59,16 +64,3 @@ def test_analyze_queries(self): DataCollector().profiles.values() assert len(resp.context['blinds']) == 5 - - -class TestAnalyzeQueriesExplainParams(TestAnalyzeQueries): - - @classmethod - def setUpClass(cls): - super().setUpClass() - SilkyConfig().SILKY_EXPLAIN_FLAGS = {'verbose': True} - - @classmethod - def tearDownClass(cls): - super().tearDownClass() - SilkyConfig().SILKY_EXPLAIN_FLAGS = None diff --git a/project/tests/test_execute_sql.py b/project/tests/test_execute_sql.py deleted file mode 100644 index 7e9c5b20..00000000 --- a/project/tests/test_execute_sql.py +++ /dev/null @@ -1,120 +0,0 @@ -from unittest.mock import Mock, NonCallableMagicMock, NonCallableMock, patch - -from django.test import TestCase - -from silk.collector import DataCollector -from silk.models import Request, SQLQuery -from silk.sql import execute_sql - -from .util import delete_all_models - - -def mock_sql(): - mock_sql_query = Mock(spec_set=['_execute_sql', 'query', 'as_sql', 'connection']) - mock_sql_query._execute_sql = Mock() - mock_sql_query.query = NonCallableMock(spec_set=['model']) - mock_sql_query.query.model = Mock() - query_string = 'SELECT * from table_name' - mock_sql_query.as_sql = Mock(return_value=(query_string, ())) - - mock_sql_query.connection = NonCallableMock( - spec_set=['cursor', 'features', 'ops'], - cursor=Mock( - spec_set=['__call__'], - return_value=NonCallableMagicMock(spec_set=['__enter__', '__exit__', 'execute']) - ), - features=NonCallableMock( - spec_set=['supports_explaining_query_execution'], - supports_explaining_query_execution=True - ), - ops=NonCallableMock(spec_set=['explain_query_prefix']), - ) - - return mock_sql_query, query_string - - -def call_execute_sql(cls, request): - DataCollector().configure(request=request) - delete_all_models(SQLQuery) - cls.mock_sql, cls.query_string = mock_sql() - kwargs = { - 'one': 1, - 'two': 2 - } - cls.args = [1, 2] - cls.kwargs = kwargs - execute_sql(cls.mock_sql, *cls.args, **cls.kwargs) - - -class TestCallNoRequest(TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - call_execute_sql(cls, None) - - def test_called(self): - self.mock_sql._execute_sql.assert_called_once_with(*self.args, **self.kwargs) - - def test_count(self): - self.assertEqual(0, len(DataCollector().queries)) - - -class TestCallRequest(TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - call_execute_sql(cls, Request()) - - def test_called(self): - self.mock_sql._execute_sql.assert_called_once_with(*self.args, **self.kwargs) - - def test_count(self): - self.assertEqual(1, len(DataCollector().queries)) - - def test_query(self): - query = list(DataCollector().queries.values())[0] - self.assertEqual(query['query'], self.query_string) - - -class TestCallSilky(TestCase): - def test_no_effect(self): - DataCollector().configure() - sql, _ = mock_sql() - sql.query.model = NonCallableMagicMock(spec_set=['__module__']) - sql.query.model.__module__ = 'silk.models' - # No SQLQuery models should be created for silk requests for obvious reasons - with patch('silk.sql.DataCollector', return_value=Mock()) as mock_DataCollector: - execute_sql(sql) - self.assertFalse(mock_DataCollector().register_query.call_count) - - -class TestCollectorInteraction(TestCase): - def _query(self): - try: - query = list(DataCollector().queries.values())[0] - except IndexError: - self.fail('No queries created') - return query - - def test_request(self): - DataCollector().configure(request=Request.objects.create(path='/path/to/somewhere')) - sql, _ = mock_sql() - execute_sql(sql) - query = self._query() - self.assertEqual(query['request'], DataCollector().request) - - def test_registration(self): - DataCollector().configure(request=Request.objects.create(path='/path/to/somewhere')) - sql, _ = mock_sql() - execute_sql(sql) - query = self._query() - self.assertIn(query, DataCollector().queries.values()) - - def test_explain(self): - DataCollector().configure(request=Request.objects.create(path='/path/to/somewhere')) - sql, qs = mock_sql() - prefix = "EXPLAIN" - mock_cursor = sql.connection.cursor.return_value.__enter__.return_value - sql.connection.ops.explain_query_prefix.return_value = prefix - execute_sql(sql) - mock_cursor.execute.assert_called_once_with(f"{prefix} {qs}", ()) diff --git a/silk/apps.py b/silk/apps.py index 57828bea..023621f8 100644 --- a/silk/apps.py +++ b/silk/apps.py @@ -1,6 +1,14 @@ from django.apps import AppConfig +from django.db import connection + +from silk.sql import SilkQueryWrapper class SilkAppConfig(AppConfig): default_auto_field = "django.db.models.AutoField" name = "silk" + + def ready(self): + # Add wrapper to db connection + if not any(isinstance(wrapper, SilkQueryWrapper) for wrapper in connection.execute_wrappers): + connection.execute_wrappers.append(SilkQueryWrapper()) diff --git a/silk/middleware.py b/silk/middleware.py index 2bbc1049..376ce1c1 100644 --- a/silk/middleware.py +++ b/silk/middleware.py @@ -2,7 +2,6 @@ import random from django.db import DatabaseError, transaction -from django.db.models.sql.compiler import SQLCompiler from django.urls import NoReverseMatch, reverse from django.utils import timezone @@ -11,7 +10,6 @@ from silk.model_factory import RequestModelFactory, ResponseModelFactory from silk.profiling import dynamic from silk.profiling.profiler import silk_meta_profiler -from silk.sql import execute_sql Logger = logging.getLogger('silk.middleware') @@ -85,15 +83,11 @@ def _apply_dynamic_mappings(self): name = conf.get('name') if module and function: if start_line and end_line: # Dynamic context manager - dynamic.inject_context_manager_func(module=module, - func=function, - start_line=start_line, - end_line=end_line, - name=name) + dynamic.inject_context_manager_func( + module=module, func=function, start_line=start_line, end_line=end_line, name=name + ) else: # Dynamic decorator - dynamic.profile_function_or_method(module=module, - func=function, - name=name) + dynamic.profile_function_or_method(module=module, func=function, name=name) else: raise KeyError('Invalid dynamic mapping %s' % conf) @@ -107,9 +101,6 @@ def process_request(self, request): Logger.debug('process_request') request.silk_is_intercepted = True self._apply_dynamic_mappings() - if not hasattr(SQLCompiler, '_execute_sql'): - SQLCompiler._execute_sql = SQLCompiler.execute_sql - SQLCompiler.execute_sql = execute_sql silky_config = SilkyConfig() diff --git a/silk/sql.py b/silk/sql.py index ff3fbe4a..881c1d8a 100644 --- a/silk/sql.py +++ b/silk/sql.py @@ -1,26 +1,16 @@ import logging import traceback -from django.core.exceptions import EmptyResultSet +from django.apps import apps +from django.db import connection from django.utils import timezone from django.utils.encoding import force_str -from silk.collector import DataCollector from silk.config import SilkyConfig Logger = logging.getLogger('silk.sql') -def _should_wrap(sql_query): - if not DataCollector().request: - return False - - for ignore_str in SilkyConfig().SILKY_IGNORE_QUERIES: - if ignore_str in sql_query: - return False - return True - - def _unpack_explanation(result): for row in result: if not isinstance(row, str): @@ -34,16 +24,14 @@ def _explain_query(connection, q, params): if SilkyConfig().SILKY_ANALYZE_QUERIES: # Work around some DB engines not supporting analyze option try: - prefix = connection.ops.explain_query_prefix( - analyze=True, **(SilkyConfig().SILKY_EXPLAIN_FLAGS or {}) - ) + prefix = connection.ops.explain_query_prefix(analyze=True, **(SilkyConfig().SILKY_EXPLAIN_FLAGS or {})) except ValueError as error: error_str = str(error) if error_str.startswith("Unknown options:"): Logger.warning( - "Database does not support analyzing queries with provided params. %s." + "Database does not support analyzing queries with provided params. %s. " "SILKY_ANALYZE_QUERIES option will be ignored", - error_str + error_str, ) prefix = connection.ops.explain_query_prefix() else: @@ -61,40 +49,53 @@ def _explain_query(connection, q, params): return None -def execute_sql(self, *args, **kwargs): - """wrapper around real execute_sql in order to extract information""" +class SilkQueryWrapper: + def __init__(self): + # Local import to prevent messing app.ready() + from silk.collector import DataCollector - try: - q, params = self.as_sql() - if not q: - raise EmptyResultSet - except EmptyResultSet: - try: - result_type = args[0] - except IndexError: - result_type = kwargs.get('result_type', 'multi') - if result_type == 'multi': - return iter([]) - else: - return - tb = ''.join(reversed(traceback.format_stack())) - sql_query = q % tuple(force_str(param) for param in params) - if _should_wrap(sql_query): - query_dict = { - 'query': sql_query, - 'start_time': timezone.now(), - 'traceback': tb - } + self.data_collector = DataCollector() + self.silk_model_table_names = [model._meta.db_table for model in apps.get_app_config('silk').get_models()] + + def __call__(self, execute, sql, params, many, context): + sql_query = sql % tuple(force_str(param) for param in params) if params else sql + query_dict = None + if self._should_wrap(sql_query): + tb = ''.join(reversed(traceback.format_stack())) + query_dict = {'query': sql_query, 'start_time': timezone.now(), 'traceback': tb} try: - return self._execute_sql(*args, **kwargs) + return execute(sql, params, many, context) finally: - query_dict['end_time'] = timezone.now() - request = DataCollector().request - if request: - query_dict['request'] = request - if getattr(self.query.model, '__module__', '') != 'silk.models': - query_dict['analysis'] = _explain_query(self.connection, q, params) - DataCollector().register_query(query_dict) - else: - DataCollector().register_silk_query(query_dict) - return self._execute_sql(*args, **kwargs) + if query_dict: + query_dict['end_time'] = timezone.now() + request = self.data_collector.request + if request: + query_dict['request'] = request + if not any(table_name in sql_query for table_name in self.silk_model_table_names): + query_dict['analysis'] = _explain_query(connection, sql, params) + self.data_collector.register_query(query_dict) + else: + self.data_collector.register_silk_query(query_dict) + + def _should_wrap(self, sql_query): + # Must have a request ongoing + if not self.data_collector.request: + return False + + # Must not try to explain 'EXPLAIN' queries or transaction stuff + if any( + sql_query.startswith(keyword) + for keyword in [ + 'SAVEPOINT', + 'RELEASE SAVEPOINT', + 'ROLLBACK TO SAVEPOINT', + 'PRAGMA', + connection.ops.explain_query_prefix(), + ] + ): + return False + + for ignore_str in SilkyConfig().SILKY_IGNORE_QUERIES: + if ignore_str in sql_query: + return False + return True