From 9bbb31a28f7c684fb08116c6d70dcbff5f07e811 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Thu, 28 May 2026 14:54:25 -0400 Subject: [PATCH] fix: Parse Retry-After header for 429s We remove retry logic from the background logger, and move it into the http adapter layer. This means we only retry on meaningful errors, like network issues or 429s. If we hit 429s, we now also correctly parse the Retry-After header to determine what the exponential backoff should be. --- py/src/braintrust/env.py | 11 +- py/src/braintrust/logger.py | 240 ++++++++++++++++--------------- py/src/braintrust/test_env.py | 14 +- py/src/braintrust/test_http.py | 132 +++++++++++++++++ py/src/braintrust/test_logger.py | 5 +- 5 files changed, 279 insertions(+), 123 deletions(-) diff --git a/py/src/braintrust/env.py b/py/src/braintrust/env.py index a4933c3e..60ab6e75 100644 --- a/py/src/braintrust/env.py +++ b/py/src/braintrust/env.py @@ -30,6 +30,14 @@ def parse_int(value: str) -> int | None: return None +def parse_positive_int(value: str) -> int | None: + """Parse a non-negative integer from a string.""" + result = parse_int(value) + if result is None or result < 0: + return None + return result + + def parse_bool(value: str) -> bool | None: """Parse common boolean environment variable values. @@ -56,6 +64,7 @@ def parse_string(value: str) -> str | None: class EnvParser(Enum): FLOAT = (parse_float,) INT = (parse_int,) + POSITIVE_INT = (parse_positive_int,) BOOL = (parse_bool,) STRING = (parse_string,) @@ -84,7 +93,7 @@ class BraintrustEnv: SYNC_FLUSH = EnvVar("BRAINTRUST_SYNC_FLUSH", EnvParser.BOOL) MAX_REQUEST_SIZE = EnvVar("BRAINTRUST_MAX_REQUEST_SIZE", EnvParser.INT) DEFAULT_BATCH_SIZE = EnvVar("BRAINTRUST_DEFAULT_BATCH_SIZE", EnvParser.INT) - NUM_RETRIES = EnvVar("BRAINTRUST_NUM_RETRIES", EnvParser.INT) + NUM_RETRIES = EnvVar("BRAINTRUST_NUM_RETRIES", EnvParser.POSITIVE_INT) QUEUE_SIZE = EnvVar("BRAINTRUST_QUEUE_SIZE", EnvParser.INT) QUEUE_DROP_LOGGING_PERIOD = EnvVar("BRAINTRUST_QUEUE_DROP_LOGGING_PERIOD", EnvParser.FLOAT) FAILED_PUBLISH_PAYLOADS_DIR = EnvVar("BRAINTRUST_FAILED_PUBLISH_PAYLOADS_DIR", EnvParser.STRING) diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index a908050d..333f2cba 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -695,6 +695,31 @@ def __init__( self.default_timeout_secs = default_timeout_secs super().__init__(*args, **kwargs) + @staticmethod + def _parse_retry_after(retry_after: str | None) -> float | None: + if not retry_after: + return None + + try: + return Retry().parse_retry_after(retry_after) + except urllib3.exceptions.InvalidHeader: + return None + + def _retry_sleep_seconds(self, num_prev_retries: int, retry_after: str | None = None) -> float: + parsed_retry_after = self._parse_retry_after(retry_after) + if parsed_retry_after is not None: + return parsed_retry_after + + # Emulates the sleeping logic in the backoff_factor of urllib3 Retry + return self.backoff_factor * (2**num_prev_retries) + + def _sleep_before_retry(self, *, reason: str, num_prev_retries: int, retry_after: str | None = None) -> int: + sleep_s = self._retry_sleep_seconds(num_prev_retries, retry_after) + print(f"Retrying request after {reason}", file=sys.stderr) + print("Sleeping for", sleep_s, "seconds", file=sys.stderr) + time.sleep(sleep_s) + return num_prev_retries + 1 + def send(self, *args, **kwargs): # Apply default timeout if none provided to prevent indefinite hangs if kwargs.get("timeout") is None: @@ -702,31 +727,44 @@ def send(self, *args, **kwargs): num_prev_retries = 0 while True: + retry_reason = None + retry_after = None + retry_error = None + response = None try: response = super().send(*args, **kwargs) # Fully-download the content to ensure we catch any errors from # downloading. if not response.is_redirect and response.content: pass - return response + if response.status_code != 429: + return response + retry_reason = "HTTP 429" + retry_after = response.headers.get("Retry-After") except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e: - if num_prev_retries < self.base_num_retries: - if isinstance(e, requests.exceptions.ReadTimeout): - # Clear all connection pools to discard stale connections. This - # fixes hangs caused by NAT gateways silently dropping idle TCP - # connections (e.g., Azure's ~4 min timeout). close() calls - # PoolManager.clear() which is thread-safe: in-flight requests - # keep their checked-out connections, and new requests create - # fresh pools on demand. - self.close() - # Emulates the sleeping logic in the backoff_factor of urllib3 Retry - sleep_s = self.backoff_factor * (2**num_prev_retries) - print("Retrying request after error:", e, file=sys.stderr) - print("Sleeping for", sleep_s, "seconds", file=sys.stderr) - time.sleep(sleep_s) - num_prev_retries += 1 - else: - raise e + if isinstance(e, requests.exceptions.ReadTimeout): + # Clear all connection pools to discard stale connections. This + # fixes hangs caused by NAT gateways silently dropping idle TCP + # connections (e.g., Azure's ~4 min timeout). close() calls + # PoolManager.clear() which is thread-safe: in-flight requests + # keep their checked-out connections, and new requests create + # fresh pools on demand. + self.close() + retry_reason = f"error: {e}" + retry_after = None + retry_error = e + + if num_prev_retries >= self.base_num_retries: + if retry_error: + raise retry_error + return response + + if response is not None: + response.close() + + num_prev_retries = self._sleep_before_retry( + reason=retry_reason, num_prev_retries=num_prev_retries, retry_after=retry_after + ) class HTTPConnection: @@ -747,8 +785,9 @@ def ping(self) -> bool: def make_long_lived(self) -> None: if not self.adapter: timeout_secs = BraintrustEnv.HTTP_TIMEOUT.get(60.0) + num_retries = BraintrustEnv.NUM_RETRIES.get(3) self.adapter = RetryRequestExceptionsAdapter( - base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs + base_num_retries=num_retries, backoff_factor=0.5, default_timeout_secs=timeout_secs ) self._reset() @@ -996,9 +1035,6 @@ def pop(self): return batch -BACKGROUND_LOGGER_BASE_SLEEP_TIME_S = 1.0 - - # We should only have one instance of this object in # 'BraintrustState._global_bg_logger'. Be careful about spawning multiple # instances of this class, because concurrent _BackgroundLoggers will not log to @@ -1016,7 +1052,6 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): self.sync_flush = BraintrustEnv.SYNC_FLUSH.get(False) self._max_request_size_override = BraintrustEnv.MAX_REQUEST_SIZE.get(None) self.default_batch_size = BraintrustEnv.DEFAULT_BATCH_SIZE.get(100) - self.num_tries = BraintrustEnv.NUM_RETRIES.get(2) + 1 queue_maxsize = BraintrustEnv.QUEUE_SIZE.get(None) self.queue_maxsize = DEFAULT_QUEUE_SIZE if queue_maxsize is None else queue_maxsize self.queue_drop_logging_period = BraintrustEnv.QUEUE_DROP_LOGGING_PERIOD.get(60.0) @@ -1190,60 +1225,45 @@ def flush(self, batch_size: int | None = None): def _unwrap_lazy_values( self, wrapped_items: Sequence[LazyValue[dict[str, Any]]] ) -> tuple[list[dict[str, Any]], list["BaseAttachment"]]: - for i in range(self.num_tries): - try: - unwrapped_items = [item.get() for item in wrapped_items] - merged_items = merge_row_batch(unwrapped_items) - - # Apply masking after merging but before sending to backend - if self.masking_function: - for item_idx in range(len(merged_items)): - item = merged_items[item_idx] - masked_item = item.copy() - - # Only mask specific fields if they exist - for field in REDACTION_FIELDS: - if field in item: - masked_value = _apply_masking_to_field(self.masking_function, item[field], field) - if isinstance(masked_value, _MaskingError): - # Drop the field and add error message - if field in masked_item: - del masked_item[field] - if "error" in masked_item: - masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}" - else: - masked_item["error"] = masked_value.error_msg - else: - masked_item[field] = masked_value + try: + unwrapped_items = [item.get() for item in wrapped_items] + merged_items = merge_row_batch(unwrapped_items) - merged_items[item_idx] = masked_item + # Apply masking after merging but before sending to backend + if self.masking_function: + for item_idx in range(len(merged_items)): + item = merged_items[item_idx] + masked_item = item.copy() - attachments: list["BaseAttachment"] = [] - for item in merged_items: - _extract_attachments(item, attachments) + # Only mask specific fields if they exist + for field in REDACTION_FIELDS: + if field in item: + masked_value = _apply_masking_to_field(self.masking_function, item[field], field) + if isinstance(masked_value, _MaskingError): + # Drop the field and add error message + if field in masked_item: + del masked_item[field] + if "error" in masked_item: + masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}" + else: + masked_item["error"] = masked_value.error_msg + else: + masked_item[field] = masked_value - return merged_items, attachments - except Exception as e: - errmsg = "Encountered error when constructing records to flush" - is_retrying = i + 1 < self.num_tries - if is_retrying: - errmsg += ". Retrying" + merged_items[item_idx] = masked_item - if not is_retrying and self.sync_flush: - raise Exception(errmsg) from e - else: - print(errmsg, file=self.outfile) - traceback.print_exc(file=self.outfile) - if is_retrying: - sleep_time_s = BACKGROUND_LOGGER_BASE_SLEEP_TIME_S * (2**i) - print(f"Sleeping for {sleep_time_s}s", file=self.outfile) - time.sleep(sleep_time_s) - - print( - f"Failed to construct log records to flush after {self.num_tries} attempts. Dropping batch", - file=self.outfile, - ) - return [], [] + attachments: list["BaseAttachment"] = [] + for item in merged_items: + _extract_attachments(item, attachments) + + return merged_items, attachments + except Exception as e: + errmsg = "Encountered error when constructing records to flush" + if self.sync_flush: + raise Exception(errmsg) from e + print(errmsg, file=self.outfile) + traceback.print_exc(file=self.outfile) + return [], [] def _request_logs3_overflow_upload( self, conn: HTTPConnection, payload_size_bytes: int, rows: list[dict[str, Any]] @@ -1329,54 +1349,38 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz if use_overflow else None ) - for i in range(self.num_tries): - start_time = time.time() - resp = None - error = None - try: - if overflow_rows: - if overflow_upload is None: - current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows) - self._upload_logs3_overflow_payload(current_upload, dataStr) - overflow_upload = current_upload - resp = conn.post( - "/logs3", - json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes), - ) - else: - resp = conn.post("/logs3", data=dataStr.encode("utf-8")) - except Exception as e: - error = e - if error is None and resp is not None and resp.ok: - if overflow_rows: - self._overflow_upload_count += 1 - return - has_response = error is None and resp is not None - is_413 = has_response and resp.status_code == 413 - resp_errmsg = f"{resp.status_code}: {resp.text}" if has_response else str(error) - - should_retry = i + 1 < self.num_tries and not is_413 - - if not should_retry and self.failed_publish_payloads_dir: - _HTTPBackgroundLogger._write_payload_to_dir( - payload_dir=self.failed_publish_payloads_dir, payload=dataStr + start_time = time.time() + resp = None + error = None + try: + if overflow_rows: + if overflow_upload is None: + current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows) + self._upload_logs3_overflow_payload(current_upload, dataStr) + overflow_upload = current_upload + resp = conn.post( + "/logs3", + json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes), ) - self._log_failed_payloads_dir() - - retrying_text = " Retrying" if should_retry else "" - errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}.{retrying_text} Error: {resp_errmsg}" - if not should_retry and self.sync_flush: - raise Exception(errmsg) - print(errmsg, file=self.outfile) + else: + resp = conn.post("/logs3", data=dataStr.encode("utf-8")) + except Exception as e: + error = e + if error is None and resp is not None and resp.ok: + if overflow_rows: + self._overflow_upload_count += 1 + return + has_response = error is None and resp is not None + resp_errmsg = f"{resp.status_code}: {resp.text}" if has_response else str(error) - if is_413: - return - if should_retry: - sleep_time_s = BACKGROUND_LOGGER_BASE_SLEEP_TIME_S * (2**i) - print(f"Sleeping for {sleep_time_s}s", file=self.outfile) - time.sleep(sleep_time_s) + if self.failed_publish_payloads_dir: + _HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.failed_publish_payloads_dir, payload=dataStr) + self._log_failed_payloads_dir() - print(f"log request failed after {self.num_tries} retries. Dropping batch", file=self.outfile) + errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}. Error: {resp_errmsg}" + if self.sync_flush: + raise Exception(errmsg) + print(errmsg, file=self.outfile) def _dump_dropped_events(self, wrapped_items): publish_payloads_dir = [x for x in [self.all_publish_payloads_dir, self.failed_publish_payloads_dir] if x] diff --git a/py/src/braintrust/test_env.py b/py/src/braintrust/test_env.py index 7d983902..68691c6d 100644 --- a/py/src/braintrust/test_env.py +++ b/py/src/braintrust/test_env.py @@ -1,4 +1,4 @@ -from .env import BraintrustEnv, EnvParser, EnvVar, parse_bool, parse_float, parse_int, parse_string +from .env import BraintrustEnv, EnvParser, EnvVar, parse_bool, parse_float, parse_int, parse_positive_int, parse_string class TestEnvParsers: @@ -16,6 +16,14 @@ def test_parse_int(self): assert parse_int("1.2") is None assert parse_int("not_an_int") is None + def test_parse_positive_int(self): + assert parse_positive_int("123") == 123 + assert parse_positive_int("0") == 0 + assert parse_positive_int("-5") is None + assert parse_positive_int("") is None + assert parse_positive_int("1.2") is None + assert parse_positive_int("not_an_int") is None + def test_parse_bool(self): for value in ("true", "True", "1", "yes", "y", "on"): assert parse_bool(value) is True @@ -59,6 +67,10 @@ def test_centralized_env_definitions_are_lazy(self, monkeypatch): monkeypatch.setenv("BRAINTRUST_HTTP_TIMEOUT", "0.2") assert BraintrustEnv.HTTP_TIMEOUT.get(60.0) == 0.2 + def test_num_retries_uses_default_for_negative_values(self, monkeypatch): + monkeypatch.setenv("BRAINTRUST_NUM_RETRIES", "-1") + assert BraintrustEnv.NUM_RETRIES.get(3) == 3 + def test_otel_compat_uses_shared_bool_parser(self, monkeypatch): for value in ("true", "1", "yes"): monkeypatch.setenv("BRAINTRUST_OTEL_COMPAT", value) diff --git a/py/src/braintrust/test_http.py b/py/src/braintrust/test_http.py index b9ede8d8..01ca6ec1 100644 --- a/py/src/braintrust/test_http.py +++ b/py/src/braintrust/test_http.py @@ -69,6 +69,38 @@ def do_GET(self): self.do_POST() +class RateLimitHandler(http.server.BaseHTTPRequestHandler): + """HTTP handler that returns 429 once before succeeding.""" + + request_count = 0 + retry_after = "2" + + def log_message(self, format, *args): + pass + + def do_POST(self): + RateLimitHandler.request_count += 1 + + if RateLimitHandler.request_count == 1: + body = b'{"error": "rate limited"}' + self.send_response(429) + self.send_header("Retry-After", RateLimitHandler.retry_after) + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + return + + body = b'{"status": "ok"}' + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def do_GET(self): + self.do_POST() + + @pytest.fixture def hanging_server(): """Fixture that creates a server that HANGS on first request (simulates stale NAT).""" @@ -109,6 +141,26 @@ def closing_server(): server.server_close() +@pytest.fixture +def rate_limit_server(): + """Fixture that creates a server that returns 429 then 200.""" + RateLimitHandler.request_count = 0 + RateLimitHandler.retry_after = "2" + + server = socketserver.ThreadingTCPServer(("127.0.0.1", 0), RateLimitHandler) + server.daemon_threads = True + port = server.server_address[1] + + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + yield f"http://127.0.0.1:{port}" + + server.shutdown() + server.server_close() + + class TestRetryRequestExceptionsAdapter: """Tests for RetryRequestExceptionsAdapter timeout and retry behavior.""" @@ -173,6 +225,58 @@ def test_adapter_resets_pool_on_timeout(self, hanging_server): assert elapsed < 10.0, f"Request took too long: {elapsed:.2f}s" assert HangingConnectionHandler.request_count >= 2 + def test_adapter_retries_429_using_retry_after_header(self, rate_limit_server, monkeypatch): + """429 retries should honor the server-provided Retry-After delay.""" + sleeps = [] + monkeypatch.setattr(time, "sleep", sleeps.append) + + adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.05) + session = requests.Session() + session.mount("http://", adapter) + + resp = session.post(f"{rate_limit_server}/test", json={"hello": "world"}) + + assert resp.status_code == 200 + assert RateLimitHandler.request_count == 2 + assert sleeps == [2] + + def test_adapter_retries_429_using_backoff_when_retry_after_invalid(self, rate_limit_server, monkeypatch): + """429 retries should fall back to exponential backoff without a valid Retry-After header.""" + sleeps = [] + monkeypatch.setattr(time, "sleep", sleeps.append) + RateLimitHandler.retry_after = "not-a-delay" + + adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.05) + session = requests.Session() + session.mount("http://", adapter) + + resp = session.post(f"{rate_limit_server}/test", json={"hello": "world"}) + + assert resp.status_code == 200 + assert RateLimitHandler.request_count == 2 + assert sleeps == [0.05] + + def test_adapter_returns_final_429_without_closing_response(self, rate_limit_server, monkeypatch): + """When no 429 retries remain, return the response without closing it.""" + closed_responses = [] + original_close = requests.models.Response.close + + def record_close(response): + closed_responses.append(response) + original_close(response) + + monkeypatch.setattr(requests.models.Response, "close", record_close) + + adapter = RetryRequestExceptionsAdapter(base_num_retries=0, backoff_factor=0.05) + session = requests.Session() + session.mount("http://", adapter) + + resp = session.post(f"{rate_limit_server}/test", json={"hello": "world"}) + + assert resp.status_code == 429 + assert RateLimitHandler.request_count == 1 + assert closed_responses == [] + class TestHTTPConnection: """Tests for HTTPConnection timeout configuration.""" @@ -218,6 +322,34 @@ def test_env_var_configures_timeout(self): finally: del os.environ["BRAINTRUST_HTTP_TIMEOUT"] + def test_make_long_lived_uses_default_num_retries(self): + """HTTPConnection.make_long_lived() should retry requests 3 times by default.""" + conn = HTTPConnection("http://localhost:8080") + conn.make_long_lived() + + assert hasattr(conn.adapter, "base_num_retries") + assert conn.adapter.base_num_retries == 3 + + def test_env_var_configures_num_retries(self, monkeypatch): + """BRAINTRUST_NUM_RETRIES env var configures HTTP adapter retries via make_long_lived().""" + monkeypatch.setenv("BRAINTRUST_NUM_RETRIES", "4") + + conn = HTTPConnection("http://localhost:8080") + conn.make_long_lived() + + assert hasattr(conn.adapter, "base_num_retries") + assert conn.adapter.base_num_retries == 4 + + def test_env_var_uses_default_for_negative_num_retries(self, monkeypatch): + """Negative BRAINTRUST_NUM_RETRIES values are invalid and fall back to the default.""" + monkeypatch.setenv("BRAINTRUST_NUM_RETRIES", "-1") + + conn = HTTPConnection("http://localhost:8080") + conn.make_long_lived() + + assert hasattr(conn.adapter, "base_num_retries") + assert conn.adapter.base_num_retries == 3 + class TestAdapterCloseAndReuse: """Tests verifying that adapter.close() allows subsequent requests. diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 3fc5c833..9b1d9d2e 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -192,8 +192,8 @@ def test_init_with_saved_parameters_attaches_reference(self): class TestHTTPBackgroundLoggerLogs3(TestCase): - def test_submit_logs_request_413_skips_retries(self) -> None: - """Any 413 while publishing ``/logs3`` cannot succeed on retry with the same payload. + def test_submit_logs_request_does_not_retry(self) -> None: + """HTTP transport handles retries; background log submission attempts each batch once. ``sync_flush`` controls whether the terminal failure raises instead of printing. """ @@ -226,7 +226,6 @@ def test_submit_logs_request_413_skips_retries(self) -> None: mock_conn.post.return_value = mock_resp bg = _HTTPBackgroundLogger(LazyValue(lambda: mock_conn, use_mutex=False)) - bg.num_tries = 5 bg.sync_flush = sync_flush bg.failed_publish_payloads_dir = "/tmp/failed-payloads"