diff --git a/contrib/opencensus-ext-django/opencensus/ext/django/middleware.py b/contrib/opencensus-ext-django/opencensus/ext/django/middleware.py index 68e1e0b39..6ba6a80a9 100644 --- a/contrib/opencensus-ext-django/opencensus/ext/django/middleware.py +++ b/contrib/opencensus-ext-django/opencensus/ext/django/middleware.py @@ -13,11 +13,14 @@ # limitations under the License. """Django middleware helper to capture and trace a request.""" +import django import logging import six import django.conf +from django.db import connection from django.utils.deprecation import MiddlewareMixin +from google.rpc import code_pb2 from opencensus.common import configuration from opencensus.trace import attributes_helper @@ -25,6 +28,7 @@ from opencensus.trace import print_exporter from opencensus.trace import samplers from opencensus.trace import span as span_module +from opencensus.trace import status as status_module from opencensus.trace import tracer as tracer_module from opencensus.trace import utils from opencensus.trace.propagation import trace_context_http_header_format @@ -99,6 +103,37 @@ def _set_django_attributes(span, request): span.add_attribute('django.user.name', str(user_name)) +def _trace_db_call(execute, sql, params, many, context): + tracer = _get_current_tracer() + if not tracer: + return execute(sql, params, many, context) + + vendor = context['connection'].vendor + alias = context['connection'].alias + + span = tracer.start_span() + span.name = '{}.query'.format(vendor) + span.span_kind = span_module.SpanKind.CLIENT + + tracer.add_attribute_to_current_span('component', vendor) + tracer.add_attribute_to_current_span('db.instance', alias) + tracer.add_attribute_to_current_span('db.statement', sql) + tracer.add_attribute_to_current_span('db.type', 'sql') + + try: + result = execute(sql, params, many, context) + except Exception: # pragma: NO COVER + status = status_module.Status( + code=code_pb2.UNKNOWN, message='DB error' + ) + span.set_status(status) + raise + else: + return result + finally: + tracer.end_span() + + class OpencensusMiddleware(MiddlewareMixin): """Saves the request in thread local""" @@ -126,6 +161,9 @@ def __init__(self, get_response=None): self.blacklist_hostnames = settings.get(BLACKLIST_HOSTNAMES, None) + if django.VERSION >= (2,): # pragma: NO COVER + connection.execute_wrappers.append(_trace_db_call) + def process_request(self, request): """Called on each request, before Django decides which view to execute. diff --git a/contrib/opencensus-ext-django/tests/test_django_db_middleware.py b/contrib/opencensus-ext-django/tests/test_django_db_middleware.py new file mode 100644 index 000000000..18bf385c9 --- /dev/null +++ b/contrib/opencensus-ext-django/tests/test_django_db_middleware.py @@ -0,0 +1,89 @@ +# Copyright 2017, OpenCensus Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import namedtuple + +import django +import mock +import pytest +from django.test.utils import teardown_test_environment + +from opencensus.trace import execution_context + + +class TestOpencensusDatabaseMiddleware(unittest.TestCase): + def setUp(self): + from django.conf import settings as django_settings + from django.test.utils import setup_test_environment + + if not django_settings.configured: + django_settings.configure() + setup_test_environment() + + def tearDown(self): + execution_context.clear() + teardown_test_environment() + + def test_process_request(self): + if django.VERSION < (2, 0): + pytest.skip("Wrong version of Django") + + from opencensus.ext.django import middleware + + sql = "SELECT * FROM users" + + MockConnection = namedtuple('Connection', ('vendor', 'alias')) + connection = MockConnection('mysql', 'default') + + mock_execute = mock.Mock() + mock_execute.return_value = "Mock result" + + middleware.OpencensusMiddleware() + + patch_no_tracer = mock.patch( + 'opencensus.ext.django.middleware._get_current_tracer', + return_value=None) + with patch_no_tracer: + result = middleware._trace_db_call( + mock_execute, sql, params=[], many=False, + context={'connection': connection}) + self.assertEqual(result, "Mock result") + + mock_tracer = mock.Mock() + mock_tracer.return_value = mock_tracer + patch = mock.patch( + 'opencensus.ext.django.middleware._get_current_tracer', + return_value=mock_tracer) + with patch: + result = middleware._trace_db_call( + mock_execute, sql, params=[], many=False, + context={'connection': connection}) + + (mock_sql, mock_params, mock_many, + mock_context) = mock_execute.call_args[0] + + self.assertEqual(mock_sql, sql) + self.assertEqual(mock_params, []) + self.assertEqual(mock_many, False) + self.assertEqual(mock_context, {'connection': connection}) + self.assertEqual(result, "Mock result") + + result = middleware._trace_db_call( + mock_execute, sql, params=[], many=True, + context={'connection': connection}) + + (mock_sql, mock_params, mock_many, + mock_context) = mock_execute.call_args[0] + self.assertEqual(mock_many, True)