diff --git a/py/src/braintrust/env.py b/py/src/braintrust/env.py index a4933c3e..60ab6e75 100644 --- a/py/src/braintrust/env.py +++ b/py/src/braintrust/env.py @@ -30,6 +30,14 @@ def parse_int(value: str) -> int | None: return None +def parse_positive_int(value: str) -> int | None: + """Parse a non-negative integer from a string.""" + result = parse_int(value) + if result is None or result < 0: + return None + return result + + def parse_bool(value: str) -> bool | None: """Parse common boolean environment variable values. @@ -56,6 +64,7 @@ def parse_string(value: str) -> str | None: class EnvParser(Enum): FLOAT = (parse_float,) INT = (parse_int,) + POSITIVE_INT = (parse_positive_int,) BOOL = (parse_bool,) STRING = (parse_string,) @@ -84,7 +93,7 @@ class BraintrustEnv: SYNC_FLUSH = EnvVar("BRAINTRUST_SYNC_FLUSH", EnvParser.BOOL) MAX_REQUEST_SIZE = EnvVar("BRAINTRUST_MAX_REQUEST_SIZE", EnvParser.INT) DEFAULT_BATCH_SIZE = EnvVar("BRAINTRUST_DEFAULT_BATCH_SIZE", EnvParser.INT) - NUM_RETRIES = EnvVar("BRAINTRUST_NUM_RETRIES", EnvParser.INT) + NUM_RETRIES = EnvVar("BRAINTRUST_NUM_RETRIES", EnvParser.POSITIVE_INT) QUEUE_SIZE = EnvVar("BRAINTRUST_QUEUE_SIZE", EnvParser.INT) QUEUE_DROP_LOGGING_PERIOD = EnvVar("BRAINTRUST_QUEUE_DROP_LOGGING_PERIOD", EnvParser.FLOAT) FAILED_PUBLISH_PAYLOADS_DIR = EnvVar("BRAINTRUST_FAILED_PUBLISH_PAYLOADS_DIR", EnvParser.STRING) diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index a908050d..333f2cba 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -695,6 +695,31 @@ def __init__( self.default_timeout_secs = default_timeout_secs super().__init__(*args, **kwargs) + @staticmethod + def _parse_retry_after(retry_after: str | None) -> float | None: + if not retry_after: + return None + + try: + return Retry().parse_retry_after(retry_after) + except urllib3.exceptions.InvalidHeader: + return None + + def _retry_sleep_seconds(self, num_prev_retries: int, retry_after: str | None = None) -> float: + parsed_retry_after = self._parse_retry_after(retry_after) + if parsed_retry_after is not None: + return parsed_retry_after + + # Emulates the sleeping logic in the backoff_factor of urllib3 Retry + return self.backoff_factor * (2**num_prev_retries) + + def _sleep_before_retry(self, *, reason: str, num_prev_retries: int, retry_after: str | None = None) -> int: + sleep_s = self._retry_sleep_seconds(num_prev_retries, retry_after) + print(f"Retrying request after {reason}", file=sys.stderr) + print("Sleeping for", sleep_s, "seconds", file=sys.stderr) + time.sleep(sleep_s) + return num_prev_retries + 1 + def send(self, *args, **kwargs): # Apply default timeout if none provided to prevent indefinite hangs if kwargs.get("timeout") is None: @@ -702,31 +727,44 @@ def send(self, *args, **kwargs): num_prev_retries = 0 while True: + retry_reason = None + retry_after = None + retry_error = None + response = None try: response = super().send(*args, **kwargs) # Fully-download the content to ensure we catch any errors from # downloading. if not response.is_redirect and response.content: pass - return response + if response.status_code != 429: + return response + retry_reason = "HTTP 429" + retry_after = response.headers.get("Retry-After") except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e: - if num_prev_retries < self.base_num_retries: - if isinstance(e, requests.exceptions.ReadTimeout): - # Clear all connection pools to discard stale connections. This - # fixes hangs caused by NAT gateways silently dropping idle TCP - # connections (e.g., Azure's ~4 min timeout). close() calls - # PoolManager.clear() which is thread-safe: in-flight requests - # keep their checked-out connections, and new requests create - # fresh pools on demand. - self.close() - # Emulates the sleeping logic in the backoff_factor of urllib3 Retry - sleep_s = self.backoff_factor * (2**num_prev_retries) - print("Retrying request after error:", e, file=sys.stderr) - print("Sleeping for", sleep_s, "seconds", file=sys.stderr) - time.sleep(sleep_s) - num_prev_retries += 1 - else: - raise e + if isinstance(e, requests.exceptions.ReadTimeout): + # Clear all connection pools to discard stale connections. This + # fixes hangs caused by NAT gateways silently dropping idle TCP + # connections (e.g., Azure's ~4 min timeout). close() calls + # PoolManager.clear() which is thread-safe: in-flight requests + # keep their checked-out connections, and new requests create + # fresh pools on demand. + self.close() + retry_reason = f"error: {e}" + retry_after = None + retry_error = e + + if num_prev_retries >= self.base_num_retries: + if retry_error: + raise retry_error + return response + + if response is not None: + response.close() + + num_prev_retries = self._sleep_before_retry( + reason=retry_reason, num_prev_retries=num_prev_retries, retry_after=retry_after + ) class HTTPConnection: @@ -747,8 +785,9 @@ def ping(self) -> bool: def make_long_lived(self) -> None: if not self.adapter: timeout_secs = BraintrustEnv.HTTP_TIMEOUT.get(60.0) + num_retries = BraintrustEnv.NUM_RETRIES.get(3) self.adapter = RetryRequestExceptionsAdapter( - base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs + base_num_retries=num_retries, backoff_factor=0.5, default_timeout_secs=timeout_secs ) self._reset() @@ -996,9 +1035,6 @@ def pop(self): return batch -BACKGROUND_LOGGER_BASE_SLEEP_TIME_S = 1.0 - - # We should only have one instance of this object in # 'BraintrustState._global_bg_logger'. Be careful about spawning multiple # instances of this class, because concurrent _BackgroundLoggers will not log to @@ -1016,7 +1052,6 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): self.sync_flush = BraintrustEnv.SYNC_FLUSH.get(False) self._max_request_size_override = BraintrustEnv.MAX_REQUEST_SIZE.get(None) self.default_batch_size = BraintrustEnv.DEFAULT_BATCH_SIZE.get(100) - self.num_tries = BraintrustEnv.NUM_RETRIES.get(2) + 1 queue_maxsize = BraintrustEnv.QUEUE_SIZE.get(None) self.queue_maxsize = DEFAULT_QUEUE_SIZE if queue_maxsize is None else queue_maxsize self.queue_drop_logging_period = BraintrustEnv.QUEUE_DROP_LOGGING_PERIOD.get(60.0) @@ -1190,60 +1225,45 @@ def flush(self, batch_size: int | None = None): def _unwrap_lazy_values( self, wrapped_items: Sequence[LazyValue[dict[str, Any]]] ) -> tuple[list[dict[str, Any]], list["BaseAttachment"]]: - for i in range(self.num_tries): - try: - unwrapped_items = [item.get() for item in wrapped_items] - merged_items = merge_row_batch(unwrapped_items) - - # Apply masking after merging but before sending to backend - if self.masking_function: - for item_idx in range(len(merged_items)): - item = merged_items[item_idx] - masked_item = item.copy() - - # Only mask specific fields if they exist - for field in REDACTION_FIELDS: - if field in item: - masked_value = _apply_masking_to_field(self.masking_function, item[field], field) - if isinstance(masked_value, _MaskingError): - # Drop the field and add error message - if field in masked_item: - del masked_item[field] - if "error" in masked_item: - masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}" - else: - masked_item["error"] = masked_value.error_msg - else: - masked_item[field] = masked_value + try: + unwrapped_items = [item.get() for item in wrapped_items] + merged_items = merge_row_batch(unwrapped_items) - merged_items[item_idx] = masked_item + # Apply masking after merging but before sending to backend + if self.masking_function: + for item_idx in range(len(merged_items)): + item = merged_items[item_idx] + masked_item = item.copy() - attachments: list["BaseAttachment"] = [] - for item in merged_items: - _extract_attachments(item, attachments) + # Only mask specific fields if they exist + for field in REDACTION_FIELDS: + if field in item: + masked_value = _apply_masking_to_field(self.masking_function, item[field], field) + if isinstance(masked_value, _MaskingError): + # Drop the field and add error message + if field in masked_item: + del masked_item[field] + if "error" in masked_item: + masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}" + else: + masked_item["error"] = masked_value.error_msg + else: + masked_item[field] = masked_value - return merged_items, attachments - except Exception as e: - errmsg = "Encountered error when constructing records to flush" - is_retrying = i + 1 < self.num_tries - if is_retrying: - errmsg += ". Retrying" + merged_items[item_idx] = masked_item - if not is_retrying and self.sync_flush: - raise Exception(errmsg) from e - else: - print(errmsg, file=self.outfile) - traceback.print_exc(file=self.outfile) - if is_retrying: - sleep_time_s = BACKGROUND_LOGGER_BASE_SLEEP_TIME_S * (2**i) - print(f"Sleeping for {sleep_time_s}s", file=self.outfile) - time.sleep(sleep_time_s) - - print( - f"Failed to construct log records to flush after {self.num_tries} attempts. Dropping batch", - file=self.outfile, - ) - return [], [] + attachments: list["BaseAttachment"] = [] + for item in merged_items: + _extract_attachments(item, attachments) + + return merged_items, attachments + except Exception as e: + errmsg = "Encountered error when constructing records to flush" + if self.sync_flush: + raise Exception(errmsg) from e + print(errmsg, file=self.outfile) + traceback.print_exc(file=self.outfile) + return [], [] def _request_logs3_overflow_upload( self, conn: HTTPConnection, payload_size_bytes: int, rows: list[dict[str, Any]] @@ -1329,54 +1349,38 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz if use_overflow else None ) - for i in range(self.num_tries): - start_time = time.time() - resp = None - error = None - try: - if overflow_rows: - if overflow_upload is None: - current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows) - self._upload_logs3_overflow_payload(current_upload, dataStr) - overflow_upload = current_upload - resp = conn.post( - "/logs3", - json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes), - ) - else: - resp = conn.post("/logs3", data=dataStr.encode("utf-8")) - except Exception as e: - error = e - if error is None and resp is not None and resp.ok: - if overflow_rows: - self._overflow_upload_count += 1 - return - has_response = error is None and resp is not None - is_413 = has_response and resp.status_code == 413 - resp_errmsg = f"{resp.status_code}: {resp.text}" if has_response else str(error) - - should_retry = i + 1 < self.num_tries and not is_413 - - if not should_retry and self.failed_publish_payloads_dir: - _HTTPBackgroundLogger._write_payload_to_dir( - payload_dir=self.failed_publish_payloads_dir, payload=dataStr + start_time = time.time() + resp = None + error = None + try: + if overflow_rows: + if overflow_upload is None: + current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows) + self._upload_logs3_overflow_payload(current_upload, dataStr) + overflow_upload = current_upload + resp = conn.post( + "/logs3", + json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes), ) - self._log_failed_payloads_dir() - - retrying_text = " Retrying" if should_retry else "" - errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}.{retrying_text} Error: {resp_errmsg}" - if not should_retry and self.sync_flush: - raise Exception(errmsg) - print(errmsg, file=self.outfile) + else: + resp = conn.post("/logs3", data=dataStr.encode("utf-8")) + except Exception as e: + error = e + if error is None and resp is not None and resp.ok: + if overflow_rows: + self._overflow_upload_count += 1 + return + has_response = error is None and resp is not None + resp_errmsg = f"{resp.status_code}: {resp.text}" if has_response else str(error) - if is_413: - return - if should_retry: - sleep_time_s = BACKGROUND_LOGGER_BASE_SLEEP_TIME_S * (2**i) - print(f"Sleeping for {sleep_time_s}s", file=self.outfile) - time.sleep(sleep_time_s) + if self.failed_publish_payloads_dir: + _HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.failed_publish_payloads_dir, payload=dataStr) + self._log_failed_payloads_dir() - print(f"log request failed after {self.num_tries} retries. Dropping batch", file=self.outfile) + errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}. Error: {resp_errmsg}" + if self.sync_flush: + raise Exception(errmsg) + print(errmsg, file=self.outfile) def _dump_dropped_events(self, wrapped_items): publish_payloads_dir = [x for x in [self.all_publish_payloads_dir, self.failed_publish_payloads_dir] if x] diff --git a/py/src/braintrust/test_env.py b/py/src/braintrust/test_env.py index 7d983902..68691c6d 100644 --- a/py/src/braintrust/test_env.py +++ b/py/src/braintrust/test_env.py @@ -1,4 +1,4 @@ -from .env import BraintrustEnv, EnvParser, EnvVar, parse_bool, parse_float, parse_int, parse_string +from .env import BraintrustEnv, EnvParser, EnvVar, parse_bool, parse_float, parse_int, parse_positive_int, parse_string class TestEnvParsers: @@ -16,6 +16,14 @@ def test_parse_int(self): assert parse_int("1.2") is None assert parse_int("not_an_int") is None + def test_parse_positive_int(self): + assert parse_positive_int("123") == 123 + assert parse_positive_int("0") == 0 + assert parse_positive_int("-5") is None + assert parse_positive_int("") is None + assert parse_positive_int("1.2") is None + assert parse_positive_int("not_an_int") is None + def test_parse_bool(self): for value in ("true", "True", "1", "yes", "y", "on"): assert parse_bool(value) is True @@ -59,6 +67,10 @@ def test_centralized_env_definitions_are_lazy(self, monkeypatch): monkeypatch.setenv("BRAINTRUST_HTTP_TIMEOUT", "0.2") assert BraintrustEnv.HTTP_TIMEOUT.get(60.0) == 0.2 + def test_num_retries_uses_default_for_negative_values(self, monkeypatch): + monkeypatch.setenv("BRAINTRUST_NUM_RETRIES", "-1") + assert BraintrustEnv.NUM_RETRIES.get(3) == 3 + def test_otel_compat_uses_shared_bool_parser(self, monkeypatch): for value in ("true", "1", "yes"): monkeypatch.setenv("BRAINTRUST_OTEL_COMPAT", value) diff --git a/py/src/braintrust/test_http.py b/py/src/braintrust/test_http.py index b9ede8d8..01ca6ec1 100644 --- a/py/src/braintrust/test_http.py +++ b/py/src/braintrust/test_http.py @@ -69,6 +69,38 @@ def do_GET(self): self.do_POST() +class RateLimitHandler(http.server.BaseHTTPRequestHandler): + """HTTP handler that returns 429 once before succeeding.""" + + request_count = 0 + retry_after = "2" + + def log_message(self, format, *args): + pass + + def do_POST(self): + RateLimitHandler.request_count += 1 + + if RateLimitHandler.request_count == 1: + body = b'{"error": "rate limited"}' + self.send_response(429) + self.send_header("Retry-After", RateLimitHandler.retry_after) + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + return + + body = b'{"status": "ok"}' + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def do_GET(self): + self.do_POST() + + @pytest.fixture def hanging_server(): """Fixture that creates a server that HANGS on first request (simulates stale NAT).""" @@ -109,6 +141,26 @@ def closing_server(): server.server_close() +@pytest.fixture +def rate_limit_server(): + """Fixture that creates a server that returns 429 then 200.""" + RateLimitHandler.request_count = 0 + RateLimitHandler.retry_after = "2" + + server = socketserver.ThreadingTCPServer(("127.0.0.1", 0), RateLimitHandler) + server.daemon_threads = True + port = server.server_address[1] + + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + yield f"http://127.0.0.1:{port}" + + server.shutdown() + server.server_close() + + class TestRetryRequestExceptionsAdapter: """Tests for RetryRequestExceptionsAdapter timeout and retry behavior.""" @@ -173,6 +225,58 @@ def test_adapter_resets_pool_on_timeout(self, hanging_server): assert elapsed < 10.0, f"Request took too long: {elapsed:.2f}s" assert HangingConnectionHandler.request_count >= 2 + def test_adapter_retries_429_using_retry_after_header(self, rate_limit_server, monkeypatch): + """429 retries should honor the server-provided Retry-After delay.""" + sleeps = [] + monkeypatch.setattr(time, "sleep", sleeps.append) + + adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.05) + session = requests.Session() + session.mount("http://", adapter) + + resp = session.post(f"{rate_limit_server}/test", json={"hello": "world"}) + + assert resp.status_code == 200 + assert RateLimitHandler.request_count == 2 + assert sleeps == [2] + + def test_adapter_retries_429_using_backoff_when_retry_after_invalid(self, rate_limit_server, monkeypatch): + """429 retries should fall back to exponential backoff without a valid Retry-After header.""" + sleeps = [] + monkeypatch.setattr(time, "sleep", sleeps.append) + RateLimitHandler.retry_after = "not-a-delay" + + adapter = RetryRequestExceptionsAdapter(base_num_retries=3, backoff_factor=0.05) + session = requests.Session() + session.mount("http://", adapter) + + resp = session.post(f"{rate_limit_server}/test", json={"hello": "world"}) + + assert resp.status_code == 200 + assert RateLimitHandler.request_count == 2 + assert sleeps == [0.05] + + def test_adapter_returns_final_429_without_closing_response(self, rate_limit_server, monkeypatch): + """When no 429 retries remain, return the response without closing it.""" + closed_responses = [] + original_close = requests.models.Response.close + + def record_close(response): + closed_responses.append(response) + original_close(response) + + monkeypatch.setattr(requests.models.Response, "close", record_close) + + adapter = RetryRequestExceptionsAdapter(base_num_retries=0, backoff_factor=0.05) + session = requests.Session() + session.mount("http://", adapter) + + resp = session.post(f"{rate_limit_server}/test", json={"hello": "world"}) + + assert resp.status_code == 429 + assert RateLimitHandler.request_count == 1 + assert closed_responses == [] + class TestHTTPConnection: """Tests for HTTPConnection timeout configuration.""" @@ -218,6 +322,34 @@ def test_env_var_configures_timeout(self): finally: del os.environ["BRAINTRUST_HTTP_TIMEOUT"] + def test_make_long_lived_uses_default_num_retries(self): + """HTTPConnection.make_long_lived() should retry requests 3 times by default.""" + conn = HTTPConnection("http://localhost:8080") + conn.make_long_lived() + + assert hasattr(conn.adapter, "base_num_retries") + assert conn.adapter.base_num_retries == 3 + + def test_env_var_configures_num_retries(self, monkeypatch): + """BRAINTRUST_NUM_RETRIES env var configures HTTP adapter retries via make_long_lived().""" + monkeypatch.setenv("BRAINTRUST_NUM_RETRIES", "4") + + conn = HTTPConnection("http://localhost:8080") + conn.make_long_lived() + + assert hasattr(conn.adapter, "base_num_retries") + assert conn.adapter.base_num_retries == 4 + + def test_env_var_uses_default_for_negative_num_retries(self, monkeypatch): + """Negative BRAINTRUST_NUM_RETRIES values are invalid and fall back to the default.""" + monkeypatch.setenv("BRAINTRUST_NUM_RETRIES", "-1") + + conn = HTTPConnection("http://localhost:8080") + conn.make_long_lived() + + assert hasattr(conn.adapter, "base_num_retries") + assert conn.adapter.base_num_retries == 3 + class TestAdapterCloseAndReuse: """Tests verifying that adapter.close() allows subsequent requests. diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 3fc5c833..9b1d9d2e 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -192,8 +192,8 @@ def test_init_with_saved_parameters_attaches_reference(self): class TestHTTPBackgroundLoggerLogs3(TestCase): - def test_submit_logs_request_413_skips_retries(self) -> None: - """Any 413 while publishing ``/logs3`` cannot succeed on retry with the same payload. + def test_submit_logs_request_does_not_retry(self) -> None: + """HTTP transport handles retries; background log submission attempts each batch once. ``sync_flush`` controls whether the terminal failure raises instead of printing. """ @@ -226,7 +226,6 @@ def test_submit_logs_request_413_skips_retries(self) -> None: mock_conn.post.return_value = mock_resp bg = _HTTPBackgroundLogger(LazyValue(lambda: mock_conn, use_mutex=False)) - bg.num_tries = 5 bg.sync_flush = sync_flush bg.failed_publish_payloads_dir = "/tmp/failed-payloads"