Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies = [
"grpcio>=1.67.1",
"grpcio-health-checking>=1.67.1",
"msgpack>=1.0.0",
"orjson>=3.10.10",
"protobuf>=5.28.3",
"redis>=3.4.1",
"redis-py-cluster>=2.1.0",
Expand Down
13 changes: 0 additions & 13 deletions clients/python/src/taskbroker_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,3 @@ class CompressionType(Enum):

ZSTD = "zstd"
PLAINTEXT = "plaintext"


class ParametersFormat(Enum):
"""
How the producer populates the legacy `parameters` (JSON) and new
`parameters_bytes` (msgpack) fields on TaskActivation.

Set via env var `TASKBROKER_CLIENT_PARAMETERS_FORMAT`. Default BOTH.
"""

BOTH = "both"
JSON = "json"
MSGPACK = "msgpack"
46 changes: 4 additions & 42 deletions clients/python/src/taskbroker_client/task.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

import base64
import datetime
import inspect
import os
import time
from collections.abc import Callable, Collection, Mapping, MutableMapping
from functools import update_wrapper
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, get_origin
from uuid import uuid4

import msgpack
import orjson
import sentry_sdk
import zstandard as zstd
from google.protobuf.timestamp_pb2 import Timestamp
Expand All @@ -25,22 +22,9 @@
DEFAULT_PROCESSING_DEADLINE,
MAX_PARAMETER_BYTES_BEFORE_COMPRESSION,
CompressionType,
ParametersFormat,
)
from taskbroker_client.retry import Retry


def _get_parameters_format() -> ParametersFormat:
raw = os.environ.get("TASKBROKER_CLIENT_PARAMETERS_FORMAT", ParametersFormat.BOTH.value)
try:
return ParametersFormat(raw.lower())
except ValueError:
raise ValueError(
f"Invalid TASKBROKER_CLIENT_PARAMETERS_FORMAT={raw!r}. "
f"Expected one of: {', '.join(f.value for f in ParametersFormat)}"
)


if TYPE_CHECKING:
from taskbroker_client.registry import TaskNamespace

Expand Down Expand Up @@ -270,38 +254,18 @@ def create_activation(
f"The `{key}` header value is of type {type(value)}"
)

parameters_format = _get_parameters_format()
data = {"args": args, "kwargs": kwargs}

msgpack_bytes = (
msgpack.packb(data, use_bin_type=True)
if parameters_format in (ParametersFormat.BOTH, ParametersFormat.MSGPACK)
else b""
)
# JSON can't encode some values msgpack can (e.g. raw bytes). In
# JSON-only mode we surface the TypeError; in BOTH mode we silently
# skip the legacy field so msgpack-aware workers can still run.
json_bytes: bytes | None = None
if parameters_format in (ParametersFormat.BOTH, ParametersFormat.JSON):
try:
json_bytes = orjson.dumps(data)
except TypeError:
if parameters_format == ParametersFormat.JSON:
raise
msgpack_bytes = msgpack.packb(data, use_bin_type=True)

should_compress = (
self.compression_type == CompressionType.ZSTD
or (len(msgpack_bytes) + len(json_bytes or b""))
> MAX_PARAMETER_BYTES_BEFORE_COMPRESSION
or len(msgpack_bytes) > MAX_PARAMETER_BYTES_BEFORE_COMPRESSION
)

if should_compress:
headers["compression-type"] = CompressionType.ZSTD.value
start_time = time.perf_counter()
parameters_bytes_val = zstd.compress(msgpack_bytes) if msgpack_bytes else b""
parameters_str = (
base64.b64encode(zstd.compress(json_bytes)).decode("utf8") if json_bytes else ""
)
parameters_bytes_val = zstd.compress(msgpack_bytes)
elapsed = time.perf_counter() - start_time

metric_tags = {
Expand All @@ -311,7 +275,7 @@ def create_activation(
}
self.namespace.metrics.distribution(
"taskworker.producer.compressed_parameters_size",
len(parameters_bytes_val) or len(parameters_str),
len(parameters_bytes_val),
tags=metric_tags,
)
self.namespace.metrics.distribution(
Expand All @@ -321,15 +285,13 @@ def create_activation(
)
else:
parameters_bytes_val = msgpack_bytes
parameters_str = json_bytes.decode("utf8") if json_bytes else ""

return TaskActivation(
id=uuid4().hex,
application=self._namespace.application,
namespace=self._namespace.name,
taskname=self.name,
headers=headers,
parameters=parameters_str,
parameters_bytes=parameters_bytes_val,
retry_state=self._create_retry_state(),
received_at=received_at,
Expand Down
4 changes: 2 additions & 2 deletions clients/python/src/taskbroker_client/worker/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import hmac
import json
import logging
import random
import threading
Expand All @@ -10,7 +11,6 @@
from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union

import grpc
import orjson
from google.protobuf.message import Message
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
FetchNextTask,
Expand Down Expand Up @@ -130,7 +130,7 @@ def parse_rpc_secret_list(rpc_secret: str | None) -> list[str] | None:
return None

# Try to parse the provided secret
parsed = orjson.loads(rpc_secret)
parsed = json.loads(rpc_secret)

if not isinstance(parsed, list) or len(parsed) == 0:
# If the secret string is not a list with at least one element, it is invalid
Expand Down
24 changes: 4 additions & 20 deletions clients/python/src/taskbroker_client/worker/workerchild.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import base64
import contextlib
import logging
import queue
Expand All @@ -15,7 +14,6 @@

# XXX: Don't import any modules that will import django here, do those within child_process
import msgpack
import orjson
import sentry_sdk
import zstandard as zstd
from arroyo.backends.abstract import ProducerFuture
Expand Down Expand Up @@ -94,24 +92,10 @@ def load_parameters(activation: TaskActivation) -> dict[str, Any]:
headers = dict(activation.headers)
compression_type = headers.get("compression-type", None)

# Prefer new msgpack field
if activation.parameters_bytes:
data = activation.parameters_bytes
if compression_type == CompressionType.ZSTD.value:
data = zstd.decompress(data)
return msgpack.unpackb(data, raw=False)

# Legacy JSON fallback
data_str = activation.parameters
if not compression_type or compression_type == CompressionType.PLAINTEXT.value:
return orjson.loads(data_str)
elif compression_type == CompressionType.ZSTD.value:
return orjson.loads(zstd.decompress(base64.b64decode(data_str)))
else:
logger.error(
"Unsupported compression type: %s. Continuing with plaintext.", compression_type
)
return orjson.loads(data_str)
data = activation.parameters_bytes
if compression_type == CompressionType.ZSTD.value:
data = zstd.decompress(data)
return msgpack.unpackb(data, raw=False)

Comment on lines +96 to 99

This comment was marked as spam.


def status_name(status: TaskActivationStatus.ValueType) -> str:
Expand Down
3 changes: 2 additions & 1 deletion clients/python/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import msgpack
import pytest
from sentry_protos.taskbroker.v1.taskbroker_pb2 import TaskActivation

Expand Down Expand Up @@ -42,7 +43,7 @@ def test_should_attempt_at_most_once() -> None:
id="111",
taskname="examples.simple_task",
namespace="examples",
parameters='{"args": [], "kwargs": {}}',
parameters_bytes=msgpack.packb({"args": [], "kwargs": {}}, use_bin_type=True),
processing_deadline_duration=2,
)
at_most = StubAtMostOnce()
Expand Down
10 changes: 2 additions & 8 deletions clients/python/tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import base64
from concurrent.futures import Future
from unittest.mock import Mock

import msgpack
import orjson
import pytest
import zstandard as zstd
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
Expand Down Expand Up @@ -179,9 +177,7 @@ def simple_task_with_compression(param: str) -> None:
actual_params = msgpack.unpackb(decompressed_data, raw=False)

assert actual_params == expected_params

legacy_decompressed = zstd.decompress(base64.b64decode(activation.parameters.encode("utf-8")))
assert orjson.loads(legacy_decompressed) == expected_params
assert activation.parameters == ""


def test_namespace_send_task_with_auto_compression() -> None:
Expand Down Expand Up @@ -211,9 +207,7 @@ def simple_task_with_compression(param: str) -> None:
actual_params = msgpack.unpackb(decompressed_data, raw=False)

assert actual_params == expected_params

legacy_decompressed = zstd.decompress(base64.b64decode(activation.parameters.encode("utf-8")))
assert orjson.loads(legacy_decompressed) == expected_params
assert activation.parameters == ""


def test_namespace_send_task_with_retry() -> None:
Expand Down
8 changes: 3 additions & 5 deletions clients/python/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from unittest.mock import patch

import msgpack
import orjson
import pytest
import sentry_sdk
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
Expand Down Expand Up @@ -81,7 +80,7 @@ def test_func(*args: Any, **kwargs: Any) -> None:
assert activation.expires == 10
expected_params = {"args": ["arg2"], "kwargs": {"org_id": 2}}
assert msgpack.unpackb(activation.parameters_bytes, raw=False) == expected_params
assert orjson.loads(activation.parameters) == expected_params
assert activation.parameters == ""


def test_apply_async_countdown(task_namespace: TaskNamespace) -> None:
Expand All @@ -102,7 +101,7 @@ def test_func(*args: Any, **kwargs: Any) -> None:
assert activation.delay == 600
expected_params = {"args": ["arg2"], "kwargs": {"org_id": 2}}
assert msgpack.unpackb(activation.parameters_bytes, raw=False) == expected_params
assert orjson.loads(activation.parameters) == expected_params
assert activation.parameters == ""


def test_delay_immediate_mode(task_namespace: TaskNamespace) -> None:
Expand Down Expand Up @@ -274,8 +273,7 @@ def with_parameters(one: str, two: int, org_id: int) -> None:
assert params["args"] == ["one", 22]
assert params["kwargs"] == {"org_id": 99}

json_params = orjson.loads(activation.parameters)
assert json_params == params
assert activation.parameters == ""


def test_create_activation_tracing(task_namespace: TaskNamespace) -> None:
Expand Down
Loading
Loading