Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion py/src/braintrust/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def parse_int(value: str) -> int | None:
return None


def parse_positive_int(value: str) -> int | None:
"""Parse a non-negative integer from a string."""
result = parse_int(value)
if result is None or result < 0:
return None
return result


def parse_bool(value: str) -> bool | None:
"""Parse common boolean environment variable values.

Expand All @@ -56,6 +64,7 @@ def parse_string(value: str) -> str | None:
class EnvParser(Enum):
FLOAT = (parse_float,)
INT = (parse_int,)
POSITIVE_INT = (parse_positive_int,)
BOOL = (parse_bool,)
STRING = (parse_string,)

Expand Down Expand Up @@ -84,7 +93,7 @@ class BraintrustEnv:
SYNC_FLUSH = EnvVar("BRAINTRUST_SYNC_FLUSH", EnvParser.BOOL)
MAX_REQUEST_SIZE = EnvVar("BRAINTRUST_MAX_REQUEST_SIZE", EnvParser.INT)
DEFAULT_BATCH_SIZE = EnvVar("BRAINTRUST_DEFAULT_BATCH_SIZE", EnvParser.INT)
NUM_RETRIES = EnvVar("BRAINTRUST_NUM_RETRIES", EnvParser.INT)
NUM_RETRIES = EnvVar("BRAINTRUST_NUM_RETRIES", EnvParser.POSITIVE_INT)
QUEUE_SIZE = EnvVar("BRAINTRUST_QUEUE_SIZE", EnvParser.INT)
QUEUE_DROP_LOGGING_PERIOD = EnvVar("BRAINTRUST_QUEUE_DROP_LOGGING_PERIOD", EnvParser.FLOAT)
FAILED_PUBLISH_PAYLOADS_DIR = EnvVar("BRAINTRUST_FAILED_PUBLISH_PAYLOADS_DIR", EnvParser.STRING)
Expand Down
240 changes: 122 additions & 118 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,38 +695,76 @@ def __init__(
self.default_timeout_secs = default_timeout_secs
super().__init__(*args, **kwargs)

@staticmethod
def _parse_retry_after(retry_after: str | None) -> float | None:
if not retry_after:
return None

try:
return Retry().parse_retry_after(retry_after)
except urllib3.exceptions.InvalidHeader:
return None

def _retry_sleep_seconds(self, num_prev_retries: int, retry_after: str | None = None) -> float:
parsed_retry_after = self._parse_retry_after(retry_after)
if parsed_retry_after is not None:
return parsed_retry_after

# Emulates the sleeping logic in the backoff_factor of urllib3 Retry
return self.backoff_factor * (2**num_prev_retries)

def _sleep_before_retry(self, *, reason: str, num_prev_retries: int, retry_after: str | None = None) -> int:
sleep_s = self._retry_sleep_seconds(num_prev_retries, retry_after)
print(f"Retrying request after {reason}", file=sys.stderr)
print("Sleeping for", sleep_s, "seconds", file=sys.stderr)
time.sleep(sleep_s)
return num_prev_retries + 1

def send(self, *args, **kwargs):
# Apply default timeout if none provided to prevent indefinite hangs
if kwargs.get("timeout") is None:
kwargs["timeout"] = self.default_timeout_secs

num_prev_retries = 0
while True:
retry_reason = None
retry_after = None
retry_error = None
response = None
try:
response = super().send(*args, **kwargs)
# Fully-download the content to ensure we catch any errors from
# downloading.
if not response.is_redirect and response.content:
pass
return response
if response.status_code != 429:
return response
retry_reason = "HTTP 429"
retry_after = response.headers.get("Retry-After")
except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e:
if num_prev_retries < self.base_num_retries:
if isinstance(e, requests.exceptions.ReadTimeout):
# Clear all connection pools to discard stale connections. This
# fixes hangs caused by NAT gateways silently dropping idle TCP
# connections (e.g., Azure's ~4 min timeout). close() calls
# PoolManager.clear() which is thread-safe: in-flight requests
# keep their checked-out connections, and new requests create
# fresh pools on demand.
self.close()
# Emulates the sleeping logic in the backoff_factor of urllib3 Retry
sleep_s = self.backoff_factor * (2**num_prev_retries)
print("Retrying request after error:", e, file=sys.stderr)
print("Sleeping for", sleep_s, "seconds", file=sys.stderr)
time.sleep(sleep_s)
num_prev_retries += 1
else:
raise e
if isinstance(e, requests.exceptions.ReadTimeout):
# Clear all connection pools to discard stale connections. This
# fixes hangs caused by NAT gateways silently dropping idle TCP
# connections (e.g., Azure's ~4 min timeout). close() calls
# PoolManager.clear() which is thread-safe: in-flight requests
# keep their checked-out connections, and new requests create
# fresh pools on demand.
self.close()
retry_reason = f"error: {e}"
retry_after = None
retry_error = e

if num_prev_retries >= self.base_num_retries:
if retry_error:
raise retry_error
return response

if response is not None:
response.close()

num_prev_retries = self._sleep_before_retry(
reason=retry_reason, num_prev_retries=num_prev_retries, retry_after=retry_after
)


class HTTPConnection:
Expand All @@ -747,8 +785,9 @@ def ping(self) -> bool:
def make_long_lived(self) -> None:
if not self.adapter:
timeout_secs = BraintrustEnv.HTTP_TIMEOUT.get(60.0)
num_retries = BraintrustEnv.NUM_RETRIES.get(3)
self.adapter = RetryRequestExceptionsAdapter(
base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs
base_num_retries=num_retries, backoff_factor=0.5, default_timeout_secs=timeout_secs
)
self._reset()

Expand Down Expand Up @@ -996,9 +1035,6 @@ def pop(self):
return batch


BACKGROUND_LOGGER_BASE_SLEEP_TIME_S = 1.0


# We should only have one instance of this object in
# 'BraintrustState._global_bg_logger'. Be careful about spawning multiple
# instances of this class, because concurrent _BackgroundLoggers will not log to
Expand All @@ -1016,7 +1052,6 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]):
self.sync_flush = BraintrustEnv.SYNC_FLUSH.get(False)
self._max_request_size_override = BraintrustEnv.MAX_REQUEST_SIZE.get(None)
self.default_batch_size = BraintrustEnv.DEFAULT_BATCH_SIZE.get(100)
self.num_tries = BraintrustEnv.NUM_RETRIES.get(2) + 1
queue_maxsize = BraintrustEnv.QUEUE_SIZE.get(None)
self.queue_maxsize = DEFAULT_QUEUE_SIZE if queue_maxsize is None else queue_maxsize
self.queue_drop_logging_period = BraintrustEnv.QUEUE_DROP_LOGGING_PERIOD.get(60.0)
Expand Down Expand Up @@ -1190,60 +1225,45 @@ def flush(self, batch_size: int | None = None):
def _unwrap_lazy_values(
self, wrapped_items: Sequence[LazyValue[dict[str, Any]]]
) -> tuple[list[dict[str, Any]], list["BaseAttachment"]]:
for i in range(self.num_tries):
try:
unwrapped_items = [item.get() for item in wrapped_items]
merged_items = merge_row_batch(unwrapped_items)

# Apply masking after merging but before sending to backend
if self.masking_function:
for item_idx in range(len(merged_items)):
item = merged_items[item_idx]
masked_item = item.copy()

# Only mask specific fields if they exist
for field in REDACTION_FIELDS:
if field in item:
masked_value = _apply_masking_to_field(self.masking_function, item[field], field)
if isinstance(masked_value, _MaskingError):
# Drop the field and add error message
if field in masked_item:
del masked_item[field]
if "error" in masked_item:
masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
else:
masked_item["error"] = masked_value.error_msg
else:
masked_item[field] = masked_value
try:
unwrapped_items = [item.get() for item in wrapped_items]
merged_items = merge_row_batch(unwrapped_items)

merged_items[item_idx] = masked_item
# Apply masking after merging but before sending to backend
if self.masking_function:
for item_idx in range(len(merged_items)):
item = merged_items[item_idx]
masked_item = item.copy()

attachments: list["BaseAttachment"] = []
for item in merged_items:
_extract_attachments(item, attachments)
# Only mask specific fields if they exist
for field in REDACTION_FIELDS:
if field in item:
masked_value = _apply_masking_to_field(self.masking_function, item[field], field)
if isinstance(masked_value, _MaskingError):
# Drop the field and add error message
if field in masked_item:
del masked_item[field]
if "error" in masked_item:
masked_item["error"] = f"{masked_item['error']}; {masked_value.error_msg}"
else:
masked_item["error"] = masked_value.error_msg
else:
masked_item[field] = masked_value

return merged_items, attachments
except Exception as e:
errmsg = "Encountered error when constructing records to flush"
is_retrying = i + 1 < self.num_tries
if is_retrying:
errmsg += ". Retrying"
merged_items[item_idx] = masked_item

if not is_retrying and self.sync_flush:
raise Exception(errmsg) from e
else:
print(errmsg, file=self.outfile)
traceback.print_exc(file=self.outfile)
if is_retrying:
sleep_time_s = BACKGROUND_LOGGER_BASE_SLEEP_TIME_S * (2**i)
print(f"Sleeping for {sleep_time_s}s", file=self.outfile)
time.sleep(sleep_time_s)

print(
f"Failed to construct log records to flush after {self.num_tries} attempts. Dropping batch",
file=self.outfile,
)
return [], []
attachments: list["BaseAttachment"] = []
for item in merged_items:
_extract_attachments(item, attachments)

return merged_items, attachments
except Exception as e:
errmsg = "Encountered error when constructing records to flush"
if self.sync_flush:
raise Exception(errmsg) from e
print(errmsg, file=self.outfile)
traceback.print_exc(file=self.outfile)
return [], []

def _request_logs3_overflow_upload(
self, conn: HTTPConnection, payload_size_bytes: int, rows: list[dict[str, Any]]
Expand Down Expand Up @@ -1329,54 +1349,38 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz
if use_overflow
else None
)
for i in range(self.num_tries):
start_time = time.time()
resp = None
error = None
try:
if overflow_rows:
if overflow_upload is None:
current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows)
self._upload_logs3_overflow_payload(current_upload, dataStr)
overflow_upload = current_upload
resp = conn.post(
"/logs3",
json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes),
)
else:
resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
except Exception as e:
error = e
if error is None and resp is not None and resp.ok:
if overflow_rows:
self._overflow_upload_count += 1
return
has_response = error is None and resp is not None
is_413 = has_response and resp.status_code == 413
resp_errmsg = f"{resp.status_code}: {resp.text}" if has_response else str(error)

should_retry = i + 1 < self.num_tries and not is_413

if not should_retry and self.failed_publish_payloads_dir:
_HTTPBackgroundLogger._write_payload_to_dir(
payload_dir=self.failed_publish_payloads_dir, payload=dataStr
start_time = time.time()
resp = None
error = None
try:
if overflow_rows:
if overflow_upload is None:
current_upload = self._request_logs3_overflow_upload(conn, payload_bytes, overflow_rows)
self._upload_logs3_overflow_payload(current_upload, dataStr)
overflow_upload = current_upload
resp = conn.post(
"/logs3",
json=construct_logs3_overflow_request(overflow_upload["key"], payload_bytes),
)
self._log_failed_payloads_dir()

retrying_text = " Retrying" if should_retry else ""
errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}.{retrying_text} Error: {resp_errmsg}"
if not should_retry and self.sync_flush:
raise Exception(errmsg)
print(errmsg, file=self.outfile)
else:
resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
except Exception as e:
error = e
if error is None and resp is not None and resp.ok:
if overflow_rows:
self._overflow_upload_count += 1
return
has_response = error is None and resp is not None
resp_errmsg = f"{resp.status_code}: {resp.text}" if has_response else str(error)

if is_413:
return
if should_retry:
sleep_time_s = BACKGROUND_LOGGER_BASE_SLEEP_TIME_S * (2**i)
print(f"Sleeping for {sleep_time_s}s", file=self.outfile)
time.sleep(sleep_time_s)
if self.failed_publish_payloads_dir:
_HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.failed_publish_payloads_dir, payload=dataStr)
self._log_failed_payloads_dir()

print(f"log request failed after {self.num_tries} retries. Dropping batch", file=self.outfile)
errmsg = f"log request failed. Elapsed time: {time.time() - start_time} seconds. Payload size: {payload_bytes}. Error: {resp_errmsg}"
if self.sync_flush:
raise Exception(errmsg)
print(errmsg, file=self.outfile)

def _dump_dropped_events(self, wrapped_items):
publish_payloads_dir = [x for x in [self.all_publish_payloads_dir, self.failed_publish_payloads_dir] if x]
Expand Down
14 changes: 13 additions & 1 deletion py/src/braintrust/test_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .env import BraintrustEnv, EnvParser, EnvVar, parse_bool, parse_float, parse_int, parse_string
from .env import BraintrustEnv, EnvParser, EnvVar, parse_bool, parse_float, parse_int, parse_positive_int, parse_string


class TestEnvParsers:
Expand All @@ -16,6 +16,14 @@ def test_parse_int(self):
assert parse_int("1.2") is None
assert parse_int("not_an_int") is None

def test_parse_positive_int(self):
assert parse_positive_int("123") == 123
assert parse_positive_int("0") == 0
assert parse_positive_int("-5") is None
assert parse_positive_int("") is None
assert parse_positive_int("1.2") is None
assert parse_positive_int("not_an_int") is None

def test_parse_bool(self):
for value in ("true", "True", "1", "yes", "y", "on"):
assert parse_bool(value) is True
Expand Down Expand Up @@ -59,6 +67,10 @@ def test_centralized_env_definitions_are_lazy(self, monkeypatch):
monkeypatch.setenv("BRAINTRUST_HTTP_TIMEOUT", "0.2")
assert BraintrustEnv.HTTP_TIMEOUT.get(60.0) == 0.2

def test_num_retries_uses_default_for_negative_values(self, monkeypatch):
monkeypatch.setenv("BRAINTRUST_NUM_RETRIES", "-1")
assert BraintrustEnv.NUM_RETRIES.get(3) == 3

def test_otel_compat_uses_shared_bool_parser(self, monkeypatch):
for value in ("true", "1", "yes"):
monkeypatch.setenv("BRAINTRUST_OTEL_COMPAT", value)
Expand Down
Loading
Loading