diff --git a/datadog/dogstatsd/base.py b/datadog/dogstatsd/base.py index f0091c19c..ffe3b678d 100644 --- a/datadog/dogstatsd/base.py +++ b/datadog/dogstatsd/base.py @@ -13,7 +13,7 @@ import socket import errno import time -from threading import Lock +from threading import Lock, RLock # Datadog libraries from datadog.dogstatsd.context import ( @@ -183,7 +183,7 @@ def __init__( :type telemetry_socket_path: string """ - self.lock = Lock() + self._socket_lock = Lock() # Check for deprecated option if max_buffer_size is not None: @@ -273,6 +273,8 @@ def __init__( self._telemetry_flush_interval = telemetry_min_flush_interval self._telemetry = not disable_telemetry + self._buffer_lock = RLock() + def disable_telemetry(self): self._telemetry = False @@ -311,7 +313,7 @@ def get_socket(self, telemetry=False): Note: connect the socket before assigning it to the class instance to avoid bad thread race conditions. """ - with self.lock: + with self._socket_lock: if telemetry and self._dedicated_telemetry_destination(): if not self.telemetry_socket: if self.telemetry_socket_path is not None: @@ -352,7 +354,12 @@ def open_buffer(self, max_buffer_size=None): >>> with DogStatsd() as batch: >>> batch.gauge("users.online", 123) >>> batch.gauge("active.connections", 1001) + + Note: This method must be called before close_buffer() matching invocation. """ + + self._buffer_lock.acquire() + if max_buffer_size is not None: log.warning("The parameter max_buffer_size is now deprecated and is not used anymore") self._current_buffer_total_size = 0 @@ -362,12 +369,22 @@ def open_buffer(self, max_buffer_size=None): def close_buffer(self): """ Flush the buffer and switch back to single metric packets. + + Note: This method must be called after a matching open_buffer() + invocation. """ - self._send = self._send_to_server - if self.buffer: - # Only send packets if there are packets to send - self._flush_buffer() + if not hasattr(self, 'buffer'): + raise BufferError('Cannot close buffer that was never opened') + + try: + self._send = self._send_to_server + + if self.buffer: + # Only send packets if there are packets to send + self._flush_buffer() + finally: + self._buffer_lock.release() def gauge( self, @@ -506,7 +523,7 @@ def close_socket(self): """ Closes connected socket if connected. """ - with self.lock: + with self._socket_lock: if self.socket: try: self.socket.close() diff --git a/tests/unit/dogstatsd/test_statsd.py b/tests/unit/dogstatsd/test_statsd.py index de0fdbbd7..914185fea 100644 --- a/tests/unit/dogstatsd/test_statsd.py +++ b/tests/unit/dogstatsd/test_statsd.py @@ -9,6 +9,7 @@ """ # Standard libraries from collections import deque +from threading import Thread import os import socket import errno @@ -666,13 +667,146 @@ def test_timed_start_stop_calls(self): self.assertEqual('timed_context.test', name) self.assert_almost_equal(500, float(value), 100) - def test_batched(self): + def test_batching(self): self.statsd.open_buffer() self.statsd.gauge('page.views', 123) self.statsd.timing('timer', 123) self.statsd.close_buffer() expected = "page.views:123|g\ntimer:123|ms" - self.assert_equal_telemetry(expected, self.recv(2), telemetry=telemetry_metrics(metrics=2, bytes_sent=len(expected))) + self.assert_equal_telemetry( + expected, + self.recv(2), + telemetry=telemetry_metrics(metrics=2, bytes_sent=len(expected)) + ) + + def test_batching_sequential(self): + self.statsd.open_buffer() + self.statsd.gauge('discarded.data', 123) + self.statsd.close_buffer() + + self.statsd.open_buffer() + self.statsd.gauge('page.views', 123) + self.statsd.timing('timer', 123) + self.statsd.close_buffer() + + expected1 = 'discarded.data:123|g' + expected_metrics1=telemetry_metrics(metrics=1, bytes_sent=len(expected1)) + self.assert_equal_telemetry( + expected1, + self.recv(2), + telemetry=expected_metrics1) + + + expected2 = "page.views:123|g\ntimer:123|ms" + self.assert_equal_telemetry( + expected2, + self.recv(2), + telemetry=telemetry_metrics( + metrics=2, + packets_sent=2, + bytes_sent=len(expected2 + expected_metrics1) + ) + ) + + def test_threaded_batching(self): + num_threads = 4 + threads = [] + + def batch_metrics(index, dsd): + time.sleep(0.3 * index) + + dsd.open_buffer() + + time.sleep(0.1) + dsd.gauge('page.%d.views' % index, 123) + + time.sleep(0.1) + dsd.timing('timer.%d' % index, 123) + + time.sleep(0.5) + dsd.close_buffer() + + for idx in range(num_threads): + threads.append(Thread(target=batch_metrics, args=(idx, self.statsd))) + + for thread in threads: + thread.start() + + for thread in threads: + if thread.is_alive(): + thread.join() + + # This is a bit of a tricky thing to test for - initially only our data packet is + # sent but then telemetry is flushed/reset and the subsequent metric xmit includes + # the telemetry data for the previous packet. The reason for 726 -> 727 increase is + # because packet #2 sends a three digit byte count ("726") that then increases the + # next metric size by 1 byte. + expected_xfer_metrics = [ + (33, 1), + (726, 2), + (727, 2), + (727, 2), + ] + + for idx in range(num_threads): + expected_message = "page.%d.views:123|g\ntimer.%d:123|ms" % (idx, idx) + bytes_sent, packets_sent = expected_xfer_metrics[idx] + + self.assert_equal_telemetry( + expected_message, + self.recv(2), + telemetry=telemetry_metrics( + metrics=2, + bytes_sent=bytes_sent, + packets_sent=packets_sent, + ) + ) + + def test_close_buffer_without_open(self): + dogstatsd = DogStatsd() + with self.assertRaises(BufferError): + dogstatsd.close_buffer() + + def test_threaded_close_buffer_without_open(self): + def batch_metrics(dsd): + time.sleep(0.3) + dsd.open_buffer() + + dsd.gauge('page.views', 123) + dsd.timing('timer', 123) + + time.sleep(0.5) + dsd.close_buffer() + + def close_async_buffer(self, dsd): + # Ensures that buffer is defined + dsd.open_buffer() + dsd.close_buffer() + + time.sleep(0.5) + with self.assertRaises(RuntimeError): + dsd.close_buffer() + + thread1 = Thread(target=batch_metrics, args=(self.statsd,)) + thread2 = Thread(target=close_async_buffer, args=(self, self.statsd,)) + + for thread in [thread1, thread2]: + thread.start() + + for thread in [thread1, thread2]: + if thread.is_alive(): + thread.join() + + expected_message = "page.views:123|g\ntimer:123|ms" + self.assert_equal_telemetry( + expected_message, + self.recv(2), + telemetry=telemetry_metrics( + metrics=2, + bytes_sent=29, + packets_sent=1, + ) + ) def test_telemetry(self): self.statsd.metrics_count = 1