diff --git a/clients/python/src/objectstore_client/__init__.py b/clients/python/src/objectstore_client/__init__.py index 41435258..25d81a6b 100644 --- a/clients/python/src/objectstore_client/__init__.py +++ b/clients/python/src/objectstore_client/__init__.py @@ -6,6 +6,7 @@ Session, Usecase, ) +from objectstore_client.many import Delete, Get, ManyResponse, Operation, Put from objectstore_client.metadata import ( Compression, ExpirationPolicy, @@ -22,6 +23,11 @@ "Session", "GetResponse", "RequestError", + "Put", + "Get", + "Delete", + "ManyResponse", + "Operation", "Compression", "ExpirationPolicy", "Metadata", diff --git a/clients/python/src/objectstore_client/client.py b/clients/python/src/objectstore_client/client.py index b1d647dc..c8fa0374 100644 --- a/clients/python/src/objectstore_client/client.py +++ b/clients/python/src/objectstore_client/client.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import asdict, dataclass from io import BytesIO from typing import IO, Any, Literal, NamedTuple, cast @@ -248,6 +248,10 @@ def _make_url(self, key: str | None, full: bool = False) -> str: return f"http://{self._pool.host}:{self._pool.port}{path}" return path + def _make_batch_url(self) -> str: + relative_path = f"/v1/objects:batch/{self._usecase.name}/{self._scope}/" + return self._base_path.rstrip("/") + relative_path + def put( self, contents: bytes | IO[bytes], @@ -386,6 +390,40 @@ def get( return GetResponse(metadata, stream) + def many( + self, + operations: Iterable[Any], + concurrency: int = 3, + ) -> Iterator[Any]: + """Execute multiple get, put, and delete operations as optimized batch requests. + + Operations are automatically batched when possible, with oversized puts (> 1 MB) + or streaming (IO[bytes]) puts routed to the individual endpoint. + + Args: + operations: An iterable of :class:`Put`, :class:`Get`, or :class:`Delete` + instances. Generators are accepted and consumed exactly once. + concurrency: Maximum number of concurrent HTTP requests. Defaults to ``3``. + Set to ``1`` for sequential execution with no thread pool. Must be >= 1. + + Returns: + An iterator of :class:`ManyResponse`. With ``concurrency=1`` results arrive + in input order. With ``concurrency > 1`` results arrive in completion order. + + Raises: + ValueError: If ``concurrency`` is <= 0. + + Example:: + + from objectstore_client import Put, Get, Delete + + for result in session.many([Put(b"hello", key="k1"), Get("k2")]): + print(result.key, result.response) + """ + from objectstore_client.many import execute_many + + return execute_many(self, operations, concurrency=concurrency) + def object_url(self, key: str) -> str: """ Generates a GET url to the object with the given `key`. diff --git a/clients/python/src/objectstore_client/many.py b/clients/python/src/objectstore_client/many.py new file mode 100644 index 00000000..c2e4a8f7 --- /dev/null +++ b/clients/python/src/objectstore_client/many.py @@ -0,0 +1,552 @@ +"""Batch operations API for executing multiple get/put/delete operations.""" + +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Sequence +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from io import BytesIO +from typing import IO, TYPE_CHECKING, Literal, NamedTuple +from urllib.parse import quote, unquote + +import zstandard + +from objectstore_client.metadata import ( + HEADER_EXPIRATION, + HEADER_META_PREFIX, + HEADER_ORIGIN, + Compression, + ExpirationPolicy, + Metadata, + format_expiration, +) +from objectstore_client.multipart import ( + RequestPart, + ResponsePart, + encode_multipart, + iter_multipart_response, +) + +if TYPE_CHECKING: + from objectstore_client.client import GetResponse, RequestError, Session + +# --------------------------------------------------------------------------- +# Constants (matching Rust client) +# --------------------------------------------------------------------------- + +MAX_BATCH_OPS: int = 1000 +"""Maximum number of operations per batch request.""" + +MAX_BATCH_PART_SIZE: int = 1024 * 1024 # 1 MB +"""Maximum body size for a single part in a batch request.""" + +MAX_BATCH_BODY_SIZE: int = 100 * 1024 * 1024 # 100 MB +"""Maximum total body size for a single batch request.""" + +MAX_BATCH_CONCURRENCY: int = 3 +"""Default maximum number of concurrent batch/individual HTTP requests.""" + +# --------------------------------------------------------------------------- +# Batch protocol header constants +# --------------------------------------------------------------------------- + +HEADER_BATCH_OP_KIND = "x-sn-batch-operation-kind" +HEADER_BATCH_OP_KEY = "x-sn-batch-operation-key" +HEADER_BATCH_OP_INDEX = "x-sn-batch-operation-index" +HEADER_BATCH_OP_STATUS = "x-sn-batch-operation-status" + +# --------------------------------------------------------------------------- +# Operation types +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class Put: + """A put operation to enqueue in a batch.""" + + contents: bytes | IO[bytes] + key: str | None = None + compression: Compression | Literal["none"] | None = None + content_type: str | None = None + metadata: dict[str, str] | None = None + expiration_policy: ExpirationPolicy | None = None + origin: str | None = None + + +@dataclass(frozen=True) +class Get: + """A get operation to enqueue in a batch.""" + + key: str + + +@dataclass(frozen=True) +class Delete: + """A delete operation to enqueue in a batch.""" + + key: str + + +Operation = Put | Get | Delete + +# --------------------------------------------------------------------------- +# Result type +# --------------------------------------------------------------------------- + + +class ManyResponse(NamedTuple): + """Result for a single operation in a batch. + + The ``key`` is always the object key (for puts: the server-assigned key). + The ``response`` is: + - ``GetResponse`` for a successful get + - ``RequestError`` for a per-operation failure (not raised) + - ``None`` for a successful put, successful delete, or get-not-found (404) + """ + + key: str + response: GetResponse | RequestError | None + + +# --------------------------------------------------------------------------- +# Internal: prepared put +# --------------------------------------------------------------------------- + + +@dataclass +class _PreparedPut: + """A Put operation with body materialized, compressed, and headers built.""" + + key: str | None + body: bytes + headers: dict[str, str] + + +def _prepare_put( + op: Put, + default_compression: Compression | Literal["none"], + default_expiration: ExpirationPolicy | None, +) -> _PreparedPut: + """Materialize a Put's body (compress if needed) and build metadata headers.""" + if isinstance(op.contents, bytes): + raw = op.contents + else: + raw = op.contents.read() + + compression = op.compression or default_compression + headers: dict[str, str] = {} + + if compression == "zstd": + cctx = zstandard.ZstdCompressor() + body = cctx.compress(raw) + headers["Content-Encoding"] = "zstd" + else: + body = raw + + if op.content_type: + headers["Content-Type"] = op.content_type + + expiration = op.expiration_policy or default_expiration + if expiration: + headers[HEADER_EXPIRATION] = format_expiration(expiration) + + if op.origin: + headers[HEADER_ORIGIN] = op.origin + + if op.metadata: + for k, v in op.metadata.items(): + headers[f"{HEADER_META_PREFIX}{k}"] = v + + return _PreparedPut(key=op.key, body=body, headers=headers) + + +# --------------------------------------------------------------------------- +# Classification +# --------------------------------------------------------------------------- + + +def _classify( + op: Operation, body_size: int +) -> tuple[Literal["batchable", "individual"], int]: + """Classify an operation as batchable or individual. + + Get and Delete are always batchable (size 0). + Put is individual if the compressed body exceeds MAX_BATCH_PART_SIZE. + """ + if isinstance(op, (Get, Delete)): + return "batchable", 0 + # Put + if body_size > MAX_BATCH_PART_SIZE: + return "individual", body_size + return "batchable", body_size + + +# --------------------------------------------------------------------------- +# Batching +# --------------------------------------------------------------------------- + +# A classified op ready for dispatch: (original_index, operation, prepared_put_or_none) +_ClassifiedOp = tuple[int, Operation, _PreparedPut | None] + + +def _iter_batches( + ops: Sequence[tuple[int, Operation, _PreparedPut | None, int]], +) -> Iterator[list[_ClassifiedOp]]: + """Split batchable operations into batches respecting count and size limits. + + Each element is (original_index, operation, prepared, body_size). + Yields lists of (original_index, operation, prepared) with the size dropped. + """ + remaining = iter(ops) + pending: tuple[int, Operation, _PreparedPut | None, int] | None = None + + while True: + batch: list[_ClassifiedOp] = [] + batch_size = 0 + + if pending is not None: + idx, op, prepared, op_size = pending + batch.append((idx, op, prepared)) + batch_size += op_size + pending = None + + for idx, op, prepared, op_size in remaining: + if len(batch) >= MAX_BATCH_OPS: + pending = (idx, op, prepared, op_size) + break + if batch and batch_size + op_size > MAX_BATCH_BODY_SIZE: + pending = (idx, op, prepared, op_size) + break + batch.append((idx, op, prepared)) + batch_size += op_size + + if not batch: + return + + yield batch + + +# --------------------------------------------------------------------------- +# Batch request/response +# --------------------------------------------------------------------------- + + +def _build_batch_parts( + ops: Sequence[tuple[int, Operation, _PreparedPut | None]], +) -> Iterator[RequestPart]: + """Yield multipart request parts from classified operations.""" + for _idx, op, prepared in ops: + headers: dict[str, str] = {} + + if isinstance(op, Get): + headers[HEADER_BATCH_OP_KIND] = "get" + headers[HEADER_BATCH_OP_KEY] = quote(op.key, safe="") + yield RequestPart(headers=headers, body=b"") + + elif isinstance(op, Delete): + headers[HEADER_BATCH_OP_KIND] = "delete" + headers[HEADER_BATCH_OP_KEY] = quote(op.key, safe="") + yield RequestPart(headers=headers, body=b"") + + elif isinstance(op, Put): + assert prepared is not None + headers[HEADER_BATCH_OP_KIND] = "insert" + if prepared.key is not None: + headers[HEADER_BATCH_OP_KEY] = quote(prepared.key, safe="") + headers.update(prepared.headers) + yield RequestPart(headers=headers, body=prepared.body) + + +def _parse_batch_response( + response_parts: Iterable[ResponsePart], + ops: Sequence[tuple[int, Operation, _PreparedPut | None]], +) -> Iterator[tuple[int, ManyResponse]]: + """Stream multipart response parts into indexed ManyResponse tuples.""" + from objectstore_client.client import GetResponse, RequestError + + # Build a map from batch-local index to (original_index, operation, prepared) + index_map = {batch_idx: entry for batch_idx, entry in enumerate(ops)} + + seen_indices: set[int] = set() + + for part in response_parts: + part_headers = part.headers + + # Parse operation index + index_str = part_headers.get(HEADER_BATCH_OP_INDEX) + if index_str is None: + continue + batch_idx = int(index_str) + seen_indices.add(batch_idx) + + entry = index_map.get(batch_idx) + if entry is None: + continue + original_idx, op, prepared = entry + + # Parse status + status_str = part_headers.get(HEADER_BATCH_OP_STATUS, "") + status_code_str = status_str.split(" ", 1)[0] if status_str else "0" + status_code = int(status_code_str) + + # Parse key from response + encoded_key = part_headers.get(HEADER_BATCH_OP_KEY) + response_key = unquote(encoded_key) if encoded_key else None + + # Determine the key to use in the result + if isinstance(op, (Get, Delete)): + result_key = response_key or op.key + else: + result_key = ( + response_key or (prepared.key if prepared else None) or "" + ) + + # Handle errors (status >= 400, except 404 for GET) + is_get_not_found = isinstance(op, Get) and status_code == 404 + if status_code >= 400 and not is_get_not_found: + error_message = part.body.decode("utf-8", "replace") + error = RequestError( + f"Batch operation failed with status {status_code}", + status_code, + error_message, + ) + yield (original_idx, ManyResponse(key=result_key, response=error)) + continue + + # Handle GET not found + if is_get_not_found: + yield (original_idx, ManyResponse(key=result_key, response=None)) + continue + + # Successful operations + if isinstance(op, Get): + metadata = Metadata.from_headers(part_headers) + payload = BytesIO(part.body) + + # Decompress if needed + if metadata.compression == "zstd": + dctx = zstandard.ZstdDecompressor() + decompressed = dctx.decompress(part.body) + payload = BytesIO(decompressed) + metadata.compression = None + + response = GetResponse(metadata=metadata, payload=payload) + yield (original_idx, ManyResponse(key=result_key, response=response)) + + elif isinstance(op, Put): + yield (original_idx, ManyResponse(key=result_key, response=None)) + + elif isinstance(op, Delete): + yield (original_idx, ManyResponse(key=result_key, response=None)) + + # After all parts arrive, report any operations the server didn't respond to. + for batch_idx, entry in index_map.items(): + if batch_idx not in seen_indices: + original_idx, op, prepared = entry + if isinstance(op, (Get, Delete)): + key = op.key + else: + key = (prepared.key if prepared else None) or "" + + error = RequestError( + f"Server did not return a response for operation at index {batch_idx}", + 0, + "", + ) + yield (original_idx, ManyResponse(key=key, response=error)) + + +def _send_batch( + session: Session, + ops: Sequence[tuple[int, Operation, _PreparedPut | None]], +) -> Iterator[tuple[int, ManyResponse]]: + """Send a batch of operations as a single multipart request.""" + from objectstore_client.client import RequestError + + parts = _build_batch_parts(ops) + content_type, body_iter = encode_multipart(parts) + + batch_url = session._make_batch_url() + headers = session._make_headers() + headers["Content-Type"] = content_type + + try: + response = session._pool.request( + "POST", + batch_url, + body=body_iter, + headers=headers, + preload_content=False, + ) + + if response.status >= 400: + error_body = response.read().decode("utf-8", "replace") + error = RequestError( + f"Batch request failed with status {response.status}", + response.status, + error_body, + ) + yield from _batch_level_error(ops, error) + return + + response_content_type = response.headers.get("Content-Type", "") + yield from _parse_batch_response( + iter_multipart_response(response_content_type, response.stream(65536)), + ops, + ) + except RequestError: + raise + except Exception as exc: + error = RequestError(f"Batch request failed: {exc}", 0, str(exc)) + yield from _batch_level_error(ops, error) + + +def _batch_level_error( + ops: Sequence[tuple[int, Operation, _PreparedPut | None]], + error: RequestError, +) -> list[tuple[int, ManyResponse]]: + """Produce error results for all operations when the entire batch fails.""" + results: list[tuple[int, ManyResponse]] = [] + for original_idx, op, prepared in ops: + if isinstance(op, (Get, Delete)): + key = op.key + else: + key = (prepared.key if prepared else None) or "" + results.append((original_idx, ManyResponse(key=key, response=error))) + return results + + +# --------------------------------------------------------------------------- +# Individual execution (for oversized ops) +# --------------------------------------------------------------------------- + + +def _execute_individual( + session: Session, original_idx: int, op: Operation, prepared: _PreparedPut | None +) -> tuple[int, ManyResponse]: + """Execute a single operation via the non-batch endpoint.""" + from objectstore_client.client import RequestError + + try: + if isinstance(op, Get): + response = session.get(op.key) + return (original_idx, ManyResponse(key=op.key, response=response)) + + elif isinstance(op, Delete): + session.delete(op.key) + return (original_idx, ManyResponse(key=op.key, response=None)) + + elif isinstance(op, Put): + if prepared is not None: + # Already compressed: pass compression="none" to avoid re-compressing. + key = session.put( + prepared.body, + key=prepared.key, + compression="none", + content_type=op.content_type, + metadata=op.metadata, + expiration_policy=op.expiration_policy, + origin=op.origin, + ) + else: + # IO[bytes] body: let session.put() handle compression normally. + key = session.put( + op.contents, + key=op.key, + compression=op.compression, + content_type=op.content_type, + metadata=op.metadata, + expiration_policy=op.expiration_policy, + origin=op.origin, + ) + return (original_idx, ManyResponse(key=key, response=None)) + + except RequestError as exc: + if isinstance(op, (Get, Delete)): + key = op.key + else: + # op is Put: use prepared.key if available, else op.key (for IO[bytes] path) + key = (prepared.key if prepared is not None else op.key) or "" + return (original_idx, ManyResponse(key=key, response=exc)) + + +# --------------------------------------------------------------------------- +# Orchestration +# --------------------------------------------------------------------------- + + +def execute_many( + session: Session, + operations: Iterable[Operation], + concurrency: int = MAX_BATCH_CONCURRENCY, +) -> Iterator[ManyResponse]: + """Execute multiple operations, batching where possible. + + Args: + session: The session to execute operations against. + operations: The operations to execute. Any iterable is accepted, + including generators; it is consumed exactly once. + concurrency: Max parallel HTTP requests. Default is 3. + Set to 1 for sequential execution (no thread pool). + Must be >= 1. + + Returns: + An iterator of ManyResponse. With concurrency=1 results arrive in + input order. With concurrency > 1 results arrive in completion order. + """ + if concurrency <= 0: + raise ValueError(f"concurrency must be >= 1, got {concurrency}") + return _execute_many_gen(session, operations, concurrency) + + +def _execute_many_gen( + session: Session, + operations: Iterable[Operation], + concurrency: int, +) -> Iterator[ManyResponse]: + default_compression = session._usecase._compression + default_expiration = session._usecase._expiration_policy + + # Step 1: Consume the iterable once, preparing and classifying as we go. + batchable: list[tuple[int, Operation, _PreparedPut | None, int]] = [] + individual: list[_ClassifiedOp] = [] + + for idx, op in enumerate(operations): + if isinstance(op, Put): + if isinstance(op.contents, bytes): + prepared = _prepare_put(op, default_compression, default_expiration) + kind, size = _classify(op, body_size=len(prepared.body)) + if kind == "individual": + individual.append((idx, op, prepared)) + else: + batchable.append((idx, op, prepared, size)) + else: + # IO[bytes] bodies are always sent individually to avoid eager reading. + individual.append((idx, op, None)) + else: + batchable.append((idx, op, None, 0)) + + # Step 2: Partition batchable ops into batch chunks. + batch_chunks = list(_iter_batches(batchable)) + + def run_individual(entry: _ClassifiedOp) -> list[tuple[int, ManyResponse]]: + idx, op, prepared = entry + return [_execute_individual(session, idx, op, prepared)] + + # Step 3: Execute and yield results as they arrive. + if concurrency == 1: + for chunk in batch_chunks: + for _, result in _send_batch(session, chunk): + yield result + for entry in individual: + for _, result in run_individual(entry): + yield result + else: + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [ + executor.submit(_send_batch, session, chunk) for chunk in batch_chunks + ] + futures += [executor.submit(run_individual, entry) for entry in individual] + for future in as_completed(futures): + for _, result in future.result(): + yield result diff --git a/clients/python/src/objectstore_client/multipart.py b/clients/python/src/objectstore_client/multipart.py new file mode 100644 index 00000000..eeaaeb5f --- /dev/null +++ b/clients/python/src/objectstore_client/multipart.py @@ -0,0 +1,150 @@ +"""Custom multipart encoder/decoder for batch requests. + +urllib3's encode_multipart_formdata doesn't support per-part custom headers, +which we need for the batch protocol's operation-kind/key/index headers. +""" + +from __future__ import annotations + +import re +import secrets +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from typing import IO + +# 64kB matches urllib3's default chunk size +_CHUNK_SIZE = 64 * 1024 + + +@dataclass +class RequestPart: + """A single part in a multipart request.""" + + headers: dict[str, str] + body: bytes | IO[bytes] + + +@dataclass +class ResponsePart: + """A single part parsed from a multipart response.""" + + headers: dict[str, str] + body: bytes + + +def encode_multipart(parts: Iterable[RequestPart]) -> tuple[str, Iterator[bytes]]: + """Encode parts into a multipart/form-data body. + + Returns (content_type_header, body_iterator). The iterator yields chunks + lazily; urllib3 will send them via chunked transfer encoding. + """ + boundary = f"os-boundary-{secrets.token_hex(16)}" + + def _generate() -> Iterator[bytes]: + for part in parts: + yield f"--{boundary}\r\n".encode() + yield b"content-disposition: form-data; name=part\r\n" + for name, value in part.headers.items(): + yield f"{name}: {value}\r\n".encode() + yield b"\r\n" + if isinstance(part.body, bytes): + yield part.body + else: + while True: + chunk = part.body.read(_CHUNK_SIZE) + if not chunk: + break + yield chunk + yield b"\r\n" + yield f"--{boundary}--".encode() + + content_type = f'multipart/form-data; boundary="{boundary}"' + return content_type, _generate() + + +def _extract_boundary(content_type: str) -> str: + """Extract the boundary string from a Content-Type header.""" + match = re.search(r'boundary="?([^";]+)"?', content_type) + if not match: + raise ValueError(f"No boundary found in Content-Type: {content_type}") + return match.group(1) + + +def _parse_part(data: bytes) -> ResponsePart: + """Parse a single multipart part from its raw bytes (headers + body).""" + header_end = data.find(b"\r\n\r\n") + if header_end == -1: + header_bytes = data + part_body = b"" + else: + header_bytes = data[:header_end] + part_body = data[header_end + 4 :] + + headers: dict[str, str] = {} + for line in header_bytes.split(b"\r\n"): + if not line: + continue + colon_idx = line.find(b": ") + if colon_idx == -1: + continue + name = line[:colon_idx].decode("ascii").lower() + value = line[colon_idx + 2 :].decode("utf-8") + headers[name] = value + + return ResponsePart(headers=headers, body=part_body) + + +def iter_multipart_response( + content_type: str, chunks: Iterable[bytes] +) -> Iterator[ResponsePart]: + """Parse a streaming multipart response, yielding one ResponsePart at a time. + + Accepts an iterable of byte chunks (e.g. from urllib3's response.stream()). + Parts are yielded as soon as their full content has been received; the caller + does not need to wait for the entire response body. + """ + boundary = _extract_boundary(content_type) + opening = f"--{boundary}\r\n".encode() + # Between consecutive parts the delimiter is \r\n--boundary; the \r\n + # belongs to the trailer of the preceding part, not the next part's header. + delimiter = f"\r\n--{boundary}".encode() + + buf = bytearray() + started = False + + for chunk in chunks: + buf.extend(chunk) + + if not started: + pos = buf.find(opening) + if pos == -1: + # Discard everything except the last len(opening)-1 bytes — + # only that suffix could form a partial match across the next chunk. + del buf[: -(len(opening) - 1)] + continue + del buf[: pos + len(opening)] + started = True + + while True: + pos = buf.find(delimiter) + if pos == -1: + break + after_delim = pos + len(delimiter) + # Need at least 2 bytes after the delimiter to distinguish + # \r\n (next part follows) from -- (closing boundary). + if len(buf) < after_delim + 2: + break + yield _parse_part(bytes(buf[:pos])) + suffix = bytes(buf[after_delim : after_delim + 2]) + if suffix == b"--": + return + if suffix != b"\r\n": + raise ValueError( + f"Malformed multipart body: unexpected bytes {suffix!r}" + ) + del buf[: after_delim + 2] # consume delimiter + \r\n + + +def parse_multipart_response(content_type: str, body: bytes) -> list[ResponsePart]: + """Parse a multipart/form-data response body into parts.""" + return list(iter_multipart_response(content_type, [body])) diff --git a/clients/python/tests/test_e2e.py b/clients/python/tests/test_e2e.py index 9632d489..8cdb7129 100644 --- a/clients/python/tests/test_e2e.py +++ b/clients/python/tests/test_e2e.py @@ -12,9 +12,9 @@ import pytest import urllib3 import zstandard -from objectstore_client import Client, Usecase +from objectstore_client import Client, Delete, Get, ManyResponse, Put, Usecase from objectstore_client.auth import Permission, TokenGenerator -from objectstore_client.client import RequestError +from objectstore_client.client import GetResponse, RequestError from objectstore_client.metadata import TimeToLive from objectstore_client.scope import Scope @@ -342,3 +342,89 @@ def test_connect_timeout() -> None: with pytest.raises(urllib3.exceptions.MaxRetryError): session.put(b"test data", compression="zstd") + + +def test_many_full_cycle(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase("test-usecase", expiration_policy=TimeToLive(timedelta(days=1))) + session = client.session(usecase, org=42, project=2001) + + # Put and get in the same batch + results = list( + session.many( + [ + Put(b"many-data", key="many-k1", compression="none"), + Get("many-k1"), + ], + concurrency=1, + ) + ) + assert len(results) == 2 + put_result = results[0] + get_result = results[1] + + assert isinstance(put_result, ManyResponse) + assert put_result.key == "many-k1" + assert put_result.response is None + + assert isinstance(get_result, ManyResponse) + assert get_result.key == "many-k1" + assert isinstance(get_result.response, GetResponse) + assert get_result.response.payload.read() == b"many-data" + + # Delete in a subsequent batch + delete_results = list(session.many([Delete("many-k1")], concurrency=1)) + assert delete_results[0].response is None + + # Get after delete should return None (not found) + not_found_results = list(session.many([Get("many-k1")], concurrency=1)) + assert not_found_results[0].response is None + + +def test_many_get_not_found(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase("test-usecase", expiration_policy=TimeToLive(timedelta(days=1))) + session = client.session(usecase, org=42, project=2002) + + results = list(session.many([Get("nonexistent-key")], concurrency=1)) + assert len(results) == 1 + assert results[0].response is None + + +def test_many_mixed_sizes(server_url: str) -> None: + """Small puts go through batch, large puts go through the individual endpoint.""" + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=2003) + + small_data = b"small" + large_data = b"x" * (1024 * 1024 + 1) # just over 1 MB + + results = list( + session.many( + [ + Put(small_data, key="mixed-small"), + Put(large_data, key="mixed-large"), + ], + concurrency=1, + ) + ) + assert len(results) == 2 + assert results[0].key == "mixed-small" + assert results[1].key == "mixed-large" + + # Verify contents + get_results = list( + session.many( + [Get("mixed-small"), Get("mixed-large")], + concurrency=1, + ) + ) + assert isinstance(get_results[0].response, GetResponse) + assert isinstance(get_results[1].response, GetResponse) + assert get_results[0].response.payload.read() == small_data + assert get_results[1].response.payload.read() == large_data diff --git a/clients/python/tests/test_many.py b/clients/python/tests/test_many.py new file mode 100644 index 00000000..ee432e62 --- /dev/null +++ b/clients/python/tests/test_many.py @@ -0,0 +1,502 @@ +import io +from collections.abc import Iterator +from unittest.mock import MagicMock + +import pytest +from objectstore_client.client import RequestError +from objectstore_client.many import ( + HEADER_BATCH_OP_INDEX, + HEADER_BATCH_OP_KEY, + HEADER_BATCH_OP_KIND, + HEADER_BATCH_OP_STATUS, + MAX_BATCH_BODY_SIZE, + MAX_BATCH_OPS, + MAX_BATCH_PART_SIZE, + Delete, + Get, + Put, + _classify, + _iter_batches, + _parse_batch_response, + _prepare_put, + _PreparedPut, + execute_many, +) +from objectstore_client.metadata import TimeToLive +from objectstore_client.multipart import ResponsePart + +# Shorthand for test ops lists: (batch_index, operation, prepared_put_or_none) +_Ops = list[tuple[int, Put | Get | Delete, _PreparedPut | None]] + +# --------------------------------------------------------------------------- +# _prepare_put +# --------------------------------------------------------------------------- + + +def test_prepare_put_bytes_no_compression() -> None: + op = Put(contents=b"hello world") + prepared = _prepare_put(op, default_compression="none", default_expiration=None) + assert prepared.body == b"hello world" + assert "Content-Encoding" not in prepared.headers + + +def test_prepare_put_bytes_with_zstd() -> None: + op = Put(contents=b"hello world") + prepared = _prepare_put(op, default_compression="zstd", default_expiration=None) + assert prepared.body != b"hello world" # compressed + assert prepared.headers["Content-Encoding"] == "zstd" + + +def test_prepare_put_io_materialized() -> None: + data = b"stream data" + op = Put(contents=io.BytesIO(data)) + prepared = _prepare_put(op, default_compression="none", default_expiration=None) + assert prepared.body == data + + +def test_prepare_put_explicit_compression_override() -> None: + op = Put(contents=b"data", compression="none") + prepared = _prepare_put(op, default_compression="zstd", default_expiration=None) + assert prepared.body == b"data" + assert "Content-Encoding" not in prepared.headers + + +def test_prepare_put_metadata_headers() -> None: + from datetime import timedelta + + op = Put( + contents=b"x", + key="k", + content_type="text/plain", + metadata={"custom-key": "custom-val"}, + expiration_policy=TimeToLive(timedelta(days=1)), + origin="127.0.0.1", + ) + prepared = _prepare_put(op, default_compression="none", default_expiration=None) + assert prepared.headers["Content-Type"] == "text/plain" + assert prepared.headers["x-snme-custom-key"] == "custom-val" + assert "x-sn-expiration" in prepared.headers + assert prepared.headers["x-sn-origin"] == "127.0.0.1" + + +def test_prepare_put_default_expiration() -> None: + from datetime import timedelta + + op = Put(contents=b"x") + default_exp = TimeToLive(timedelta(hours=1)) + prepared = _prepare_put( + op, default_compression="none", default_expiration=default_exp + ) + assert "x-sn-expiration" in prepared.headers + + +def test_prepare_put_explicit_expiration_overrides_default() -> None: + from datetime import timedelta + + explicit = TimeToLive(timedelta(days=7)) + default = TimeToLive(timedelta(hours=1)) + op = Put(contents=b"x", expiration_policy=explicit) + prepared = _prepare_put(op, default_compression="none", default_expiration=default) + assert "7 days" in prepared.headers["x-sn-expiration"] + + +# --------------------------------------------------------------------------- +# _classify +# --------------------------------------------------------------------------- + + +def test_classify_get_is_batchable() -> None: + kind, size = _classify(Get("key"), body_size=0) + assert kind == "batchable" + assert size == 0 + + +def test_classify_delete_is_batchable() -> None: + kind, size = _classify(Delete("key"), body_size=0) + assert kind == "batchable" + assert size == 0 + + +def test_classify_small_put_is_batchable() -> None: + kind, size = _classify(Put(b"x"), body_size=100) + assert kind == "batchable" + assert size == 100 + + +def test_classify_large_put_is_individual() -> None: + kind, size = _classify(Put(b"x"), body_size=MAX_BATCH_PART_SIZE + 1) + assert kind == "individual" + assert size == MAX_BATCH_PART_SIZE + 1 + + +def test_classify_put_at_exact_limit_is_batchable() -> None: + kind, _ = _classify(Put(b"x"), body_size=MAX_BATCH_PART_SIZE) + assert kind == "batchable" + + +# --------------------------------------------------------------------------- +# _iter_batches +# --------------------------------------------------------------------------- + + +def _batchable_op(i: int, size: int) -> tuple[int, Delete, None, int]: + """Create a dummy batchable op 4-tuple for _iter_batches tests.""" + return (i, Delete("k"), None, size) + + +def test_iter_batches_empty() -> None: + assert list(_iter_batches([])) == [] + + +def test_iter_batches_single_batch_at_count_limit() -> None: + ops = [_batchable_op(i, 1) for i in range(MAX_BATCH_OPS)] + batches = list(_iter_batches(ops)) + assert len(batches) == 1 + assert len(batches[0]) == MAX_BATCH_OPS + + +def test_iter_batches_splits_on_count() -> None: + ops = [_batchable_op(i, 1) for i in range(MAX_BATCH_OPS + 1)] + batches = list(_iter_batches(ops)) + assert len(batches) == 2 + assert len(batches[0]) == MAX_BATCH_OPS + assert len(batches[1]) == 1 + + +def test_iter_batches_exactly_at_size_limit() -> None: + op_size = 1024 * 1024 # 1 MB + count = MAX_BATCH_BODY_SIZE // op_size # 100 + ops = [_batchable_op(i, op_size) for i in range(count)] + batches = list(_iter_batches(ops)) + assert len(batches) == 1 + assert len(batches[0]) == count + + +def test_iter_batches_splits_on_size() -> None: + op_size = 1024 * 1024 # 1 MB + count = MAX_BATCH_BODY_SIZE // op_size + 1 # 101 + ops = [_batchable_op(i, op_size) for i in range(count)] + batches = list(_iter_batches(ops)) + assert len(batches) == 2 + assert len(batches[0]) == MAX_BATCH_BODY_SIZE // op_size + assert len(batches[1]) == 1 + + +def test_iter_batches_size_hits_before_count() -> None: + op_size = 600 * 1024 # 600 KB + ops = [_batchable_op(i, op_size) for i in range(200)] + batches = list(_iter_batches(ops)) + per_batch = MAX_BATCH_BODY_SIZE // op_size + assert len(batches) > 1 + for batch in batches[:-1]: + assert len(batch) == per_batch + + +# --------------------------------------------------------------------------- +# _parse_batch_response +# --------------------------------------------------------------------------- + + +def _make_response_part( + index: int, + status: str, + key: str, + kind: str = "get", + body: bytes = b"", + extra_headers: dict[str, str] | None = None, +) -> ResponsePart: + headers = { + HEADER_BATCH_OP_INDEX: str(index), + HEADER_BATCH_OP_STATUS: status, + HEADER_BATCH_OP_KEY: key, + HEADER_BATCH_OP_KIND: kind, + } + if extra_headers: + headers.update(extra_headers) + return ResponsePart(headers=headers, body=body) + + +def test_parse_successful_get() -> None: + parts = [_make_response_part(0, "200 OK", "k1", body=b"payload")] + ops: _Ops = [(0, Get("k1"), None)] + results = list(_parse_batch_response(parts, ops)) + assert len(results) == 1 + idx, result = results[0] + assert idx == 0 + assert result.key == "k1" + # response should be a GetResponse + from objectstore_client.client import GetResponse + + assert isinstance(result.response, GetResponse) + assert result.response.payload.read() == b"payload" + + +def test_parse_get_not_found() -> None: + parts = [_make_response_part(0, "404 Not Found", "k1")] + ops: _Ops = [(0, Get("k1"), None)] + results = list(_parse_batch_response(parts, ops)) + assert len(results) == 1 + _, result = results[0] + assert result.response is None + + +def test_parse_successful_put() -> None: + from objectstore_client.many import _PreparedPut + + prepared = _PreparedPut(key="k1", body=b"", headers={}) + parts = [_make_response_part(0, "200 OK", "k1", kind="insert")] + ops = [(0, Put(b"data", key="k1"), prepared)] + results = list(_parse_batch_response(parts, ops)) + assert len(results) == 1 + _, result = results[0] + assert result.key == "k1" + assert result.response is None + + +def test_parse_successful_delete() -> None: + parts = [_make_response_part(0, "200 OK", "k1", kind="delete")] + ops: _Ops = [(0, Delete("k1"), None)] + results = list(_parse_batch_response(parts, ops)) + assert len(results) == 1 + _, result = results[0] + assert result.key == "k1" + assert result.response is None + + +def test_parse_per_operation_error() -> None: + parts = [ + _make_response_part( + 0, "500 Internal Server Error", "k1", body=b"something broke" + ) + ] + ops: _Ops = [(0, Get("k1"), None)] + results = list(_parse_batch_response(parts, ops)) + assert len(results) == 1 + _, result = results[0] + assert isinstance(result.response, RequestError) + assert result.response.status == 500 + + +def test_parse_mixed_operations() -> None: + from objectstore_client.many import _PreparedPut + + prepared = _PreparedPut(key=None, body=b"", headers={}) + parts = [ + _make_response_part(0, "200 OK", "assigned-key", kind="insert"), + _make_response_part(1, "200 OK", "k2", kind="get", body=b"data"), + _make_response_part(2, "200 OK", "k3", kind="delete"), + ] + ops: _Ops = [ + (5, Put(b"data"), prepared), + (10, Get("k2"), None), + (15, Delete("k3"), None), + ] + results = list(_parse_batch_response(parts, ops)) + assert len(results) == 3 + + result_map = {idx: result for idx, result in results} + assert result_map[5].key == "assigned-key" + assert result_map[5].response is None + assert result_map[10].key == "k2" + assert result_map[15].key == "k3" + + +def test_parse_missing_response_part() -> None: + # Server returns response for index 0 but not index 1 + parts = [_make_response_part(0, "200 OK", "k1")] + ops: _Ops = [ + (0, Get("k1"), None), + (1, Get("k2"), None), + ] + results = list(_parse_batch_response(parts, ops)) + assert len(results) == 2 + result_map = {idx: result for idx, result in results} + assert result_map[0].response is not None # success + assert isinstance(result_map[1].response, RequestError) # missing + + +# --------------------------------------------------------------------------- +# execute_many +# --------------------------------------------------------------------------- + + +def _make_mock_session( + batch_response_parts: list[ResponsePart] | None = None, +) -> MagicMock: + """Create a mock Session for testing execute_many.""" + session = MagicMock() + session._usecase.name = "test-usecase" + session._usecase._compression = "none" + session._usecase._expiration_policy = None + session._make_batch_url.return_value = "/v1/objects:batch/test-usecase/org=1/" + session._make_headers.return_value = {} + + if batch_response_parts is not None: + from objectstore_client.multipart import RequestPart, encode_multipart + + # Build a fake multipart response body + fake_parts = [ + RequestPart(headers=p.headers, body=p.body) for p in batch_response_parts + ] + content_type, body_iter = encode_multipart(fake_parts) + body = b"".join(body_iter) + + mock_response = MagicMock() + mock_response.status = 200 + mock_response.headers = {"Content-Type": content_type} + # _send_batch uses preload_content=False and response.stream(chunk_size) + mock_response.stream.return_value = iter([body]) + session._pool.request.return_value = mock_response + + return session + + +def test_execute_many_empty() -> None: + session = _make_mock_session() + results = list(execute_many(session, [])) + assert results == [] + + +def test_execute_many_accepts_generator() -> None: + """execute_many should accept a generator (consumed only once).""" + response_parts = [_make_response_part(0, "200 OK", "k1", kind="delete")] + session = _make_mock_session(response_parts) + + def ops_gen() -> Iterator[Delete]: + yield Delete("k1") + + results = list(execute_many(session, ops_gen(), concurrency=1)) + assert len(results) == 1 + assert results[0].key == "k1" + + +def test_execute_many_returns_iterator() -> None: + """execute_many should return an iterator, not a list.""" + response_parts = [_make_response_part(0, "200 OK", "k1", kind="delete")] + session = _make_mock_session(response_parts) + result = execute_many(session, [Delete("k1")], concurrency=1) + assert hasattr(result, "__iter__") and hasattr(result, "__next__") + + +def test_execute_many_concurrency_zero_raises() -> None: + session = _make_mock_session() + with pytest.raises(ValueError, match="concurrency must be >= 1"): + execute_many(session, [Get("k")], concurrency=0) + + +def test_execute_many_concurrency_negative_raises() -> None: + session = _make_mock_session() + with pytest.raises(ValueError, match="concurrency must be >= 1"): + execute_many(session, [Get("k")], concurrency=-1) + + +def test_execute_many_single_batch_sequential() -> None: + """Test a simple batch with concurrency=1.""" + response_parts = [ + _make_response_part(0, "200 OK", "k1", kind="get", body=b"hello"), + _make_response_part(1, "200 OK", "k2", kind="delete"), + ] + session = _make_mock_session(response_parts) + + results = list(execute_many(session, [Get("k1"), Delete("k2")], concurrency=1)) + assert len(results) == 2 + assert results[0].key == "k1" + assert results[1].key == "k2" + assert results[1].response is None + + +def test_execute_many_preserves_order_sequential() -> None: + """With concurrency=1, results arrive in input order.""" + response_parts = [ + _make_response_part(0, "200 OK", "k3", kind="delete"), + _make_response_part(1, "200 OK", "k1", kind="get", body=b"data"), + _make_response_part(2, "200 OK", "k2", kind="delete"), + ] + session = _make_mock_session(response_parts) + + ops: list[Put | Get | Delete] = [Delete("k3"), Get("k1"), Delete("k2")] + results = list(execute_many(session, ops, concurrency=1)) + assert results[0].key == "k3" + assert results[1].key == "k1" + assert results[2].key == "k2" + + +def test_execute_many_individual_put() -> None: + """Large puts should be routed to individual endpoints.""" + large_body = b"x" * (MAX_BATCH_PART_SIZE + 1) + session = _make_mock_session() + session.put.return_value = "assigned-key" + + results = list( + execute_many(session, [Put(large_body, compression="none")], concurrency=1) + ) + assert len(results) == 1 + assert results[0].key == "assigned-key" + assert results[0].response is None + # Should have called session.put, not session._pool.request for batch + session.put.assert_called_once() + + +def test_execute_many_mixed_batch_and_individual() -> None: + """Mix of small (batchable) and large (individual) operations.""" + large_body = b"x" * (MAX_BATCH_PART_SIZE + 1) + + # Response for the batch (index 0 = Get) + response_parts = [_make_response_part(0, "200 OK", "k1", kind="get", body=b"hi")] + session = _make_mock_session(response_parts) + session.put.return_value = "big-key" + + ops: list[Put | Get | Delete] = [ + Get("k1"), + Put(large_body, key="big-key", compression="none"), + ] + results = list(execute_many(session, ops, concurrency=1)) + assert len(results) == 2 + assert results[0].key == "k1" + assert results[1].key == "big-key" + + +def test_execute_many_io_put_goes_individual() -> None: + """IO[bytes] puts should be routed individually without eager reading.""" + session = _make_mock_session() + session.put.return_value = "assigned-key" + + results = list( + execute_many(session, [Put(io.BytesIO(b"data"), key="my-key")], concurrency=1) + ) + assert len(results) == 1 + assert results[0].key == "assigned-key" + assert results[0].response is None + session.put.assert_called_once() + session._pool.request.assert_not_called() + + +def test_execute_many_io_put_passes_compression_to_session() -> None: + """IO[bytes] puts should forward their compression setting to session.put().""" + session = _make_mock_session() + session.put.return_value = "assigned-key" + + list( + execute_many( + session, + [Put(io.BytesIO(b"data"), key="k", compression="zstd")], + concurrency=1, + ) + ) + _, kwargs = session.put.call_args + assert kwargs.get("compression") == "zstd" + + +def test_execute_many_io_put_mixed_with_batch() -> None: + """IO[bytes] put goes individual while a concurrent Get is batched.""" + response_parts = [_make_response_part(0, "200 OK", "k1", kind="get", body=b"hi")] + session = _make_mock_session(response_parts) + session.put.return_value = "io-key" + + ops: list[Put | Get | Delete] = [Get("k1"), Put(io.BytesIO(b"data"), key="io-key")] + results = list(execute_many(session, ops, concurrency=1)) + assert len(results) == 2 + assert results[0].key == "k1" + assert results[1].key == "io-key" + session.put.assert_called_once() + session._pool.request.assert_called_once() diff --git a/clients/python/tests/test_multipart.py b/clients/python/tests/test_multipart.py new file mode 100644 index 00000000..04ac8861 --- /dev/null +++ b/clients/python/tests/test_multipart.py @@ -0,0 +1,278 @@ +import io + +from objectstore_client.multipart import ( + RequestPart, + encode_multipart, + iter_multipart_response, + parse_multipart_response, +) + +# --------------------------------------------------------------------------- +# encode_multipart +# --------------------------------------------------------------------------- + + +def test_encode_single_part() -> None: + parts = [RequestPart(headers={"x-custom": "value"}, body=b"hello")] + content_type, body_iter = encode_multipart(parts) + body = b"".join(body_iter) + + assert content_type.startswith("multipart/form-data; boundary=") + assert b"hello" in body + assert b"x-custom: value" in body + + +def test_encode_multiple_parts() -> None: + parts = [ + RequestPart(headers={"x-kind": "get"}, body=b""), + RequestPart(headers={"x-kind": "insert"}, body=b"payload"), + ] + content_type, body_iter = encode_multipart(parts) + body = b"".join(body_iter) + + assert body.count(b"x-kind") == 2 + assert b"payload" in body + + +def test_encode_empty_body() -> None: + parts = [RequestPart(headers={"x-op": "delete"}, body=b"")] + _, body_iter = encode_multipart(parts) + body = b"".join(body_iter) + assert b"x-op: delete" in body + + +def test_encode_binary_body() -> None: + binary_data = bytes(range(256)) + parts = [RequestPart(headers={}, body=binary_data)] + _, body_iter = encode_multipart(parts) + body = b"".join(body_iter) + assert binary_data in body + + +def test_encode_io_body() -> None: + data = b"streamed content" + parts = [RequestPart(headers={"x-kind": "insert"}, body=io.BytesIO(data))] + _, body_iter = encode_multipart(parts) + body = b"".join(body_iter) + assert data in body + assert b"x-kind: insert" in body + + +def test_encode_accepts_generator_of_parts() -> None: + """encode_multipart should accept any Iterable, not just a list.""" + + def parts_gen() -> object: + yield RequestPart(headers={"x-i": "0"}, body=b"a") + yield RequestPart(headers={"x-i": "1"}, body=b"b") + + _, body_iter = encode_multipart(parts_gen()) # type: ignore[arg-type] + body = b"".join(body_iter) + assert b"x-i: 0" in body + assert b"x-i: 1" in body + + +def test_content_disposition_header_included() -> None: + """Each encoded part should include content-disposition: form-data; name=part.""" + parts = [RequestPart(headers={"x-op": "get"}, body=b"")] + _, body_iter = encode_multipart(parts) + body = b"".join(body_iter) + assert b"content-disposition: form-data; name=part" in body + + +# --------------------------------------------------------------------------- +# round-trip: encode then parse +# --------------------------------------------------------------------------- + + +def test_round_trip() -> None: + original_parts = [ + RequestPart(headers={"x-index": "0", "x-kind": "get"}, body=b""), + RequestPart(headers={"x-index": "1", "x-kind": "insert"}, body=b"data here"), + RequestPart(headers={"x-index": "2", "x-kind": "delete"}, body=b""), + ] + content_type, body_iter = encode_multipart(original_parts) + body = b"".join(body_iter) + + parsed = parse_multipart_response(content_type, body) + assert len(parsed) == 3 + + for i, (original, parsed_part) in enumerate(zip(original_parts, parsed)): + assert parsed_part.body == original.body, f"body mismatch at part {i}" + for key, value in original.headers.items(): + assert parsed_part.headers[key] == value, ( + f"header {key} mismatch at part {i}" + ) + + +def test_round_trip_binary() -> None: + binary_data = bytes(range(256)) + b"\r\n--boundary\r\n" + original = [RequestPart(headers={"x-test": "bin"}, body=binary_data)] + content_type, body_iter = encode_multipart(original) + body = b"".join(body_iter) + + parsed = parse_multipart_response(content_type, body) + assert len(parsed) == 1 + assert parsed[0].body == binary_data + + +def test_parse_extracts_boundary_from_content_type() -> None: + parts = [RequestPart(headers={"x-a": "1"}, body=b"test")] + content_type, body_iter = encode_multipart(parts) + body = b"".join(body_iter) + + parsed = parse_multipart_response(content_type, body) + assert len(parsed) == 1 + assert parsed[0].body == b"test" + + +def test_parse_empty_response() -> None: + """A response with just the closing boundary and no parts.""" + boundary = "test-boundary" + content_type = f'multipart/form-data; boundary="{boundary}"' + body = f"--{boundary}--\r\n".encode() + + parsed = parse_multipart_response(content_type, body) + assert len(parsed) == 0 + + +# --------------------------------------------------------------------------- +# iter_multipart_response +# --------------------------------------------------------------------------- + + +def _make_multipart_body( + boundary: str, parts: list[tuple[dict[str, str], bytes]] +) -> bytes: + """Build a multipart body from boundary and (headers, body) pairs.""" + chunks: list[bytes] = [] + for headers, body in parts: + chunks.append(f"--{boundary}\r\n".encode()) + chunks.append(b"content-disposition: form-data; name=part\r\n") + for name, value in headers.items(): + chunks.append(f"{name}: {value}\r\n".encode()) + chunks.append(b"\r\n") + chunks.append(body) + chunks.append(b"\r\n") + chunks.append(f"--{boundary}--".encode()) + return b"".join(chunks) + + +def test_iter_multipart_yields_parts_from_single_chunk() -> None: + boundary = "bnd" + body = _make_multipart_body( + boundary, + [ + ({"x-i": "0"}, b"hello"), + ({"x-i": "1"}, b"world"), + ], + ) + content_type = f'multipart/form-data; boundary="{boundary}"' + parts = list(iter_multipart_response(content_type, [body])) + + assert len(parts) == 2 + assert parts[0].body == b"hello" + assert parts[1].body == b"world" + + +def test_iter_multipart_yields_parts_split_across_chunks() -> None: + """Parts should be yielded correctly even when the boundary spans chunks.""" + boundary = "bnd" + body = _make_multipart_body( + boundary, + [ + ({"x-i": "0"}, b"part-zero-data"), + ({"x-i": "1"}, b"part-one-data"), + ], + ) + content_type = f'multipart/form-data; boundary="{boundary}"' + + # Split the body at every byte to maximally stress boundary detection. + chunks = [bytes([b]) for b in body] + parts = list(iter_multipart_response(content_type, chunks)) + + assert len(parts) == 2 + assert parts[0].body == b"part-zero-data" + assert parts[1].body == b"part-one-data" + + +def test_iter_multipart_yields_parts_split_at_boundary() -> None: + """Split exactly at the boundary delimiter.""" + boundary = "myboundary" + body = _make_multipart_body( + boundary, + [ + ({"x-n": "a"}, b"AAA"), + ({"x-n": "b"}, b"BBB"), + ], + ) + content_type = f'multipart/form-data; boundary="{boundary}"' + + # Find the delimiter position and split there + delim = f"\r\n--{boundary}".encode() + split_pos = body.find(delim) + 3 # split mid-delimiter + chunks = [body[:split_pos], body[split_pos:]] + parts = list(iter_multipart_response(content_type, chunks)) + + assert len(parts) == 2 + assert parts[0].body == b"AAA" + assert parts[1].body == b"BBB" + + +def test_iter_multipart_empty_response() -> None: + boundary = "bnd" + body = f"--{boundary}--".encode() + content_type = f'multipart/form-data; boundary="{boundary}"' + parts = list(iter_multipart_response(content_type, [body])) + assert parts == [] + + +def test_iter_multipart_headers_parsed_correctly() -> None: + boundary = "bnd" + body = _make_multipart_body( + boundary, + [ + ({"x-status": "200 OK", "x-key": "my-key"}, b"payload"), + ], + ) + content_type = f'multipart/form-data; boundary="{boundary}"' + parts = list(iter_multipart_response(content_type, [body])) + + assert len(parts) == 1 + assert parts[0].headers["x-status"] == "200 OK" + assert parts[0].headers["x-key"] == "my-key" + assert parts[0].body == b"payload" + + +def test_iter_multipart_raises_on_malformed_post_delimiter_bytes() -> None: + """Bytes after the boundary delimiter that are neither \\r\\n nor -- raise.""" + import pytest + + boundary = "bnd" + body = ( + f"--{boundary}\r\n".encode() + + b"x-i: 0\r\n\r\nbody" + + f"\r\n--{boundary}".encode() + + b"XX" # garbage instead of \r\n or -- + ) + content_type = f'multipart/form-data; boundary="{boundary}"' + with pytest.raises(ValueError, match="Malformed multipart"): + list(iter_multipart_response(content_type, [body])) + + +def test_iter_multipart_round_trip_matches_parse() -> None: + """iter_multipart_response and parse_multipart_response must agree.""" + original_parts = [ + RequestPart(headers={"x-index": "0"}, body=b"alpha"), + RequestPart(headers={"x-index": "1"}, body=b"beta"), + RequestPart(headers={"x-index": "2"}, body=b""), + ] + content_type, body_iter = encode_multipart(original_parts) + body = b"".join(body_iter) + + from_list = parse_multipart_response(content_type, body) + from_iter = list(iter_multipart_response(content_type, [body])) + + assert len(from_list) == len(from_iter) + for a, b_part in zip(from_list, from_iter): + assert a.headers == b_part.headers + assert a.body == b_part.body