From cc2991d1bb8d11a881b18915806616fe844eba18 Mon Sep 17 00:00:00 2001 From: steven-passynkov Date: Mon, 6 Apr 2026 16:45:35 -0400 Subject: [PATCH 1/4] Refactor --- examples/async_quickstart.py | 3 +- examples/quickstart.py | 3 +- leap0/_async/desktop.py | 67 +++++----- leap0/_async/filesystem.py | 39 +++--- leap0/_async/process.py | 1 + leap0/_async/sandbox.py | 19 ++- leap0/_schemas/desktop.py | 20 +-- leap0/_schemas/process.py | 3 +- leap0/_sync/desktop.py | 62 ++++----- leap0/_sync/filesystem.py | 37 +++--- leap0/_sync/process.py | 1 + leap0/_sync/sandbox.py | 7 +- leap0/_utils/stream.py | 40 ++++-- leap0/models/config.py | 27 +++- leap0/models/desktop.py | 209 +++++++++++++++++++++++++----- leap0/models/filesystem.py | 53 ++++++++ leap0/models/process.py | 9 +- leap0/models/sandbox.py | 62 +++++++++ leap0/models/snapshot.py | 17 ++- tests/_async/test_desktop.py | 121 +++++++++++++++++ tests/_async/test_filesystem.py | 23 ++++ tests/_async/test_process.py | 5 +- tests/_async/test_sandboxes.py | 40 ++++++ tests/_sync/test_client_config.py | 20 +++ tests/_sync/test_desktop.py | 71 +++++++++- tests/_sync/test_filesystem.py | 26 ++++ tests/_sync/test_process.py | 5 +- tests/_sync/test_sandboxes.py | 21 +++ tests/_utils/test_stream.py | 2 +- tests/models/test_desktop.py | 53 ++++++-- tests/models/test_process.py | 5 +- tests/models/test_sandbox.py | 17 ++- tests/models/test_snapshot.py | 18 ++- 33 files changed, 907 insertions(+), 199 deletions(-) create mode 100644 tests/_async/test_desktop.py diff --git a/examples/async_quickstart.py b/examples/async_quickstart.py index 00f7ebf..5ff419c 100644 --- a/examples/async_quickstart.py +++ b/examples/async_quickstart.py @@ -17,7 +17,8 @@ async def main() -> None: result: ProcessResult = await sandbox.process.execute(command="echo hello from async leap0") print("sandbox:", sandbox.id) print("exit code:", result.exit_code) - print("result:", result.result.strip()) + print("stdout:", result.stdout.strip()) + print("stderr:", result.stderr.strip()) finally: await sandbox.delete() diff --git a/examples/quickstart.py b/examples/quickstart.py index c9e2333..1c73545 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -15,7 +15,8 @@ def main() -> None: result: ProcessResult = sandbox.process.execute(command="echo hello from leap0") print("sandbox:", sandbox.id) print("exit code:", result.exit_code) - print("result:", result.result.strip()) + print("stdout:", result.stdout.strip()) + print("stderr:", result.stderr.strip()) finally: sandbox.delete() client.close() diff --git a/leap0/_async/desktop.py b/leap0/_async/desktop.py index 3d019d0..028fce9 100644 --- a/leap0/_async/desktop.py +++ b/leap0/_async/desktop.py @@ -9,10 +9,12 @@ from .._internal.types import JsonObject from ..models.desktop import ( + DesktopClickParams, DesktopDisplayInfo, DesktopDisplayInfoDict, DesktopHealth, DesktopHealthDict, + DesktopOkResponse, DesktopPointerPosition, DesktopPointerPositionDict, DesktopProcessErrors, @@ -29,6 +31,10 @@ DesktopRecordingStatusDict, DesktopRecordingSummary, DesktopRecordingSummaryDict, + DesktopResizeScreenParams, + DesktopScreenshotParams, + DesktopScreenshotRegionParams, + DesktopStatusStreamErrorEvent, DesktopWindow, DesktopWindowsDict, ) @@ -138,7 +144,8 @@ async def resize_screen(self, sandbox: SandboxRef, *, width: int, height: int, h Returns: object: Result returned by this operation. """ - data = cast(DesktopDisplayInfoDict, await self._request_json("POST", sandbox, "/api/display/screen", json={"width": width, "height": height}, http_timeout=http_timeout)) + payload = DesktopResizeScreenParams(width=width, height=height).model_dump() + data = cast(DesktopDisplayInfoDict, await self._request_json("POST", sandbox, "/api/display/screen", json=payload, http_timeout=http_timeout)) return DesktopDisplayInfo.from_dict(data) @intercept_errors("Failed to list windows: ") @@ -182,19 +189,14 @@ async def screenshot( Returns: object: Result returned by this operation. """ - params: JsonObject = {} - if image_format is not None: - params["format"] = image_format - if quality is not None: - params["quality"] = quality - if x is not None: - params["x"] = x - if y is not None: - params["y"] = y - if width is not None: - params["width"] = width - if height is not None: - params["height"] = height + params = DesktopScreenshotParams( + format=image_format, + quality=quality, + x=x, + y=y, + width=width, + height=height, + ).model_dump(exclude_none=True) response = await self._request("GET", sandbox, "/api/screenshot", params=params or None, http_timeout=http_timeout) return response.content @@ -220,11 +222,14 @@ async def screenshot_region( Returns: object: Result returned by this operation. """ - payload: JsonObject = {"x": x, "y": y, "width": width, "height": height} - if image_format is not None: - payload["format"] = image_format - if quality is not None: - payload["quality"] = quality + payload = DesktopScreenshotRegionParams( + x=x, + y=y, + width=width, + height=height, + format=image_format, + quality=quality, + ).model_dump(exclude_none=True) response = await self._request("POST", sandbox, "/api/screenshot/region", json=payload, http_timeout=http_timeout) return response.content @@ -269,13 +274,7 @@ async def click(self, sandbox: SandboxRef, *, x: int | None = None, y: int | Non Returns: object: Result returned by this operation. """ - payload: JsonObject = {} - if x is not None: - payload["x"] = x - if y is not None: - payload["y"] = y - if button is not None: - payload["button"] = button + payload = DesktopClickParams(x=x, y=y, button=button).model_dump(exclude_none=True) data = cast(DesktopPointerPositionDict, await self._request_json("POST", sandbox, "/api/input/click", json=payload, http_timeout=http_timeout)) return DesktopPointerPosition.from_dict(data) @@ -328,7 +327,7 @@ async def type_text(self, sandbox: SandboxRef, *, text: str, http_timeout: float object: Result returned by this operation. """ data = await self._request_json("POST", sandbox, "/api/input/type", json={"text": text}, http_timeout=http_timeout) - return bool(data.get("ok", False)) + return DesktopOkResponse.model_validate(data).ok @intercept_errors("Failed to press key: ") async def press_key(self, sandbox: SandboxRef, *, key: str, http_timeout: float | None = None) -> bool: @@ -342,7 +341,7 @@ async def press_key(self, sandbox: SandboxRef, *, key: str, http_timeout: float object: Result returned by this operation. """ data = await self._request_json("POST", sandbox, "/api/input/press", json={"key": key}, http_timeout=http_timeout) - return bool(data.get("ok", False)) + return DesktopOkResponse.model_validate(data).ok @intercept_errors("Failed to press hotkey: ") async def hotkey(self, sandbox: SandboxRef, *, keys: list[str]) -> bool: @@ -356,7 +355,7 @@ async def hotkey(self, sandbox: SandboxRef, *, keys: list[str]) -> bool: object: Result returned by this operation. """ data = await self._request_json("POST", sandbox, "/api/input/hotkey", json={"keys": keys}) - return bool(data.get("ok", False)) + return DesktopOkResponse.model_validate(data).ok @intercept_errors("Failed to get recording status: ") async def recording_status(self, sandbox: SandboxRef, http_timeout: float | None = None) -> DesktopRecordingStatus: @@ -567,9 +566,15 @@ async def status_stream(self, sandbox: SandboxRef, *, deadline: float | None = N except StopAsyncIteration: break if not isinstance(event, dict): - continue + raise ValueError( + "Malformed desktop status stream event " + f"for sandbox={sandbox_id_of(sandbox)!r}, source='status_stream': {event!r}" + ) if "error" in event: raise Leap0Error("Desktop status stream error", body=str(event["error"])) + if {"error", "message"}.intersection(event) and not {"status", "items", "running", "total"}.intersection(event): + error_event = DesktopStatusStreamErrorEvent.model_validate(event) + raise Leap0Error("Desktop status stream error", body=error_event.detail) yield DesktopProcessStatusList.from_dict(cast(DesktopProcessStatusListDict, event)) finally: await response.aclose() @@ -598,7 +603,7 @@ async def wait_until_ready(self, sandbox: SandboxRef, *, timeout: float = 60.0, while time.monotonic() < deadline: try: async for status in self.status_stream(sandbox, deadline=deadline, http_timeout=http_timeout): - if status.status == "running": + if status.status == "running" or (status.total > 0 and status.running >= status.total): return raise Leap0Error("Desktop status stream ended without reaching 'running' state") except Leap0TimeoutError as exc: diff --git a/leap0/_async/filesystem.py b/leap0/_async/filesystem.py index 8a22ddc..34233d4 100644 --- a/leap0/_async/filesystem.py +++ b/leap0/_async/filesystem.py @@ -3,7 +3,7 @@ from typing import Any, cast from .._internal.types import JsonObject -from ..models.filesystem import EditFileResult, EditResult, FileEdit, FileInfo, LsResult, SearchMatch, TreeResult +from ..models.filesystem import EditFileResult, EditResult, FileEdit, FileInfo, LsResult, ReadFileParams, SearchMatch, SetPermissionsParams, TreeResult from ..models.sandbox import SandboxRef, sandbox_id_of from .._schemas.filesystem import EditFileResponseDict, EditFilesResponseDict, ExistsResponseDict, FileInfoDict, GlobResponseDict, GrepResponseDict, LsResponseDict, TreeResponseDict from .._utils.errors import intercept_errors @@ -213,17 +213,7 @@ async def read_bytes( print(content) ``` """ - if head is not None and tail is not None: - raise ValueError("`head` and `tail` are mutually exclusive") - payload: JsonObject = {"path": path} - if offset is not None: - payload["offset"] = offset - if limit is not None: - payload["limit"] = limit - if head is not None: - payload["head"] = head - if tail is not None: - payload["tail"] = tail + payload = ReadFileParams(path=path, offset=offset, limit=limit, head=head, tail=tail).model_dump(exclude_none=True) response = await self._transport.request("POST", f"/v1/sandbox/{sandbox_id_of(sandbox)}/filesystem/read-file", json=payload, timeout=http_timeout) return response.content @@ -347,13 +337,7 @@ async def set_permissions( Args: http_timeout: Optional HTTP request timeout in seconds for this SDK call. """ - payload: JsonObject = {"path": path} - if mode is not None: - payload["mode"] = mode - if owner is not None: - payload["owner"] = owner - if group is not None: - payload["group"] = group + payload = SetPermissionsParams(path=path, mode=mode, owner=owner, group=group).model_dump(exclude_none=True) await self._transport.request("POST", f"/v1/sandbox/{sandbox_id_of(sandbox)}/filesystem/set-permissions", json=payload, expected_status=204, timeout=http_timeout) @intercept_errors("Failed to glob: ") @@ -528,8 +512,17 @@ def _parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes ) for part in msg.get_payload(): # type: ignore[union-attr] name = part.get_param("name", header="content-disposition") - if name: - payload = part.get_payload(decode=True) - if payload is not None: - result[str(name)] = payload + if not name: + continue + content_type = part.get_content_type() + if content_type != "application/octet-stream": + raise ValueError( + f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" + ) + payload = part.get_payload(decode=True) + if payload is None: + raise ValueError( + f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" + ) + result[str(name)] = payload return result diff --git a/leap0/_async/process.py b/leap0/_async/process.py index 04858a7..064f5c5 100644 --- a/leap0/_async/process.py +++ b/leap0/_async/process.py @@ -44,6 +44,7 @@ async def execute(self, sandbox: SandboxRef, *, command: str, cwd: str | None = command="ls -la /workspace", ) print(result.stdout) + print(result.stderr) ``` """ payload: JsonObject = {"command": command} diff --git a/leap0/_async/sandbox.py b/leap0/_async/sandbox.py index 15a804a..1c75983 100644 --- a/leap0/_async/sandbox.py +++ b/leap0/_async/sandbox.py @@ -102,13 +102,16 @@ async def refresh(self) -> AsyncSandbox: self._data = latest._data return self - async def pause(self) -> AsyncSandbox: + async def pause(self, http_timeout: float | None = None) -> AsyncSandbox: """Pause the sandbox and return updated metadata. + Args: + http_timeout: Optional HTTP request timeout in seconds for this SDK call. + Returns: AsyncSandbox: This sandbox object with updated metadata. """ - latest = await self._client.sandboxes.pause(self) + latest = await self._client.sandboxes.pause(self, http_timeout=http_timeout) self._data = latest._data return self @@ -237,17 +240,25 @@ async def create( return self._wrap_sandbox(SandboxData.from_dict(data)) @intercept_errors("Failed to pause sandbox: ") - async def pause(self, sandbox: SandboxRef) -> AsyncSandboxT | SandboxData | SandboxStatus: + async def pause( + self, + sandbox: SandboxRef, + http_timeout: float | None = None, + ) -> AsyncSandboxT | SandboxData | SandboxStatus: """Pause the sandbox and return updated metadata. Args: sandbox: Sandbox ID or object. + http_timeout: Optional HTTP request timeout in seconds for this SDK call. Returns: AsyncSandboxT | SandboxData | SandboxStatus: Updated sandbox object. """ data: SandboxCreateResponseDict = await self._transport.request_json( - "POST", f"/v1/sandbox/{sandbox_id_of(sandbox)}/pause", expected_status=201 + "POST", + f"/v1/sandbox/{sandbox_id_of(sandbox)}/pause", + expected_status=201, + timeout=http_timeout, ) return self._wrap_sandbox(SandboxData.from_dict(data)) diff --git a/leap0/_schemas/desktop.py b/leap0/_schemas/desktop.py index a238d52..fed2949 100644 --- a/leap0/_schemas/desktop.py +++ b/leap0/_schemas/desktop.py @@ -2,7 +2,9 @@ from typing import Any, Literal, TypedDict, cast -class DesktopDisplayInfoDict(TypedDict, total=False): +from typing_extensions import NotRequired, Required + +class DesktopDisplayInfoDict(TypedDict): """Wire schema for desktop display information.""" display: str width: int @@ -26,7 +28,7 @@ class DesktopWindowsDict(TypedDict): """Wire schema for desktop window listings.""" items: list[DesktopWindowDict] -class DesktopPointerPositionDict(TypedDict, total=False): +class DesktopPointerPositionDict(TypedDict): """Wire schema for pointer position.""" x: int y: int @@ -53,19 +55,19 @@ class DesktopRecordingSummaryDict(TypedDict, total=False): created_at: str active: bool -class DesktopHealthDict(TypedDict, total=False): +class DesktopHealthDict(TypedDict): """Wire schema for desktop health state.""" ok: bool class DesktopProcessStatusDict(TypedDict, total=False): """Wire schema for one desktop process status.""" - name: str - running: bool - pid: int - stdout_log: str - stderr_log: str + name: Required[str] + running: Required[bool] + pid: NotRequired[int] + stdout_log: Required[str] + stderr_log: Required[str] -class DesktopProcessStatusListDict(TypedDict, total=False): +class DesktopProcessStatusListDict(TypedDict): """Wire schema for desktop process status listings.""" status: str items: list[DesktopProcessStatusDict] diff --git a/leap0/_schemas/process.py b/leap0/_schemas/process.py index 554823c..55af855 100644 --- a/leap0/_schemas/process.py +++ b/leap0/_schemas/process.py @@ -5,4 +5,5 @@ class ProcessResultDict(TypedDict, total=False): """Wire schema for process execution results.""" exit_code: int - result: str + stdout: str + stderr: str diff --git a/leap0/_sync/desktop.py b/leap0/_sync/desktop.py index c9ff346..ff20cca 100644 --- a/leap0/_sync/desktop.py +++ b/leap0/_sync/desktop.py @@ -11,10 +11,12 @@ from .._internal.types import JsonObject from ..models.desktop import ( + DesktopClickParams, DesktopDisplayInfo, DesktopDisplayInfoDict, DesktopHealth, DesktopHealthDict, + DesktopOkResponse, DesktopPointerPosition, DesktopPointerPositionDict, DesktopProcessErrors, @@ -31,6 +33,10 @@ DesktopRecordingStatusDict, DesktopRecordingSummary, DesktopRecordingSummaryDict, + DesktopResizeScreenParams, + DesktopScreenshotParams, + DesktopScreenshotRegionParams, + DesktopStatusStreamErrorEvent, DesktopWindow, DesktopWindowsDict, ) @@ -146,11 +152,12 @@ def resize_screen(self, sandbox: SandboxRef, *, width: int, height: int, http_ti Returns: object: Result returned by this operation. """ + payload = DesktopResizeScreenParams(width=width, height=height).model_dump() data = cast(DesktopDisplayInfoDict, self._request_json( "POST", sandbox, "/api/display/screen", - json={"width": width, "height": height}, + json=payload, http_timeout=http_timeout, )) return DesktopDisplayInfo.from_dict(data) @@ -196,19 +203,14 @@ def screenshot( Returns: object: Result returned by this operation. """ - params: JsonObject = {} - if image_format is not None: - params["format"] = image_format - if quality is not None: - params["quality"] = quality - if x is not None: - params["x"] = x - if y is not None: - params["y"] = y - if width is not None: - params["width"] = width - if height is not None: - params["height"] = height + params = DesktopScreenshotParams( + format=image_format, + quality=quality, + x=x, + y=y, + width=width, + height=height, + ).model_dump(exclude_none=True) response = self._request("GET", sandbox, "/api/screenshot", params=params or None, http_timeout=http_timeout) return response.content @@ -234,11 +236,14 @@ def screenshot_region( Returns: object: Result returned by this operation. """ - payload: JsonObject = {"x": x, "y": y, "width": width, "height": height} - if image_format is not None: - payload["format"] = image_format - if quality is not None: - payload["quality"] = quality + payload = DesktopScreenshotRegionParams( + x=x, + y=y, + width=width, + height=height, + format=image_format, + quality=quality, + ).model_dump(exclude_none=True) response = self._request("POST", sandbox, "/api/screenshot/region", json=payload, http_timeout=http_timeout) return response.content @@ -291,13 +296,7 @@ def click( Returns: object: Result returned by this operation. """ - payload: JsonObject = {} - if x is not None: - payload["x"] = x - if y is not None: - payload["y"] = y - if button is not None: - payload["button"] = button + payload = DesktopClickParams(x=x, y=y, button=button).model_dump(exclude_none=True) data = cast(DesktopPointerPositionDict, self._request_json("POST", sandbox, "/api/input/click", json=payload, http_timeout=http_timeout)) return DesktopPointerPosition.from_dict(data) @@ -364,7 +363,7 @@ def type_text(self, sandbox: SandboxRef, *, text: str) -> bool: object: Result returned by this operation. """ data = self._request_json("POST", sandbox, "/api/input/type", json={"text": text}) - return bool(data.get("ok", False)) + return DesktopOkResponse.model_validate(data).ok @intercept_errors("Failed to press key: ") def press_key(self, sandbox: SandboxRef, *, key: str, http_timeout: float | None = None) -> bool: @@ -378,7 +377,7 @@ def press_key(self, sandbox: SandboxRef, *, key: str, http_timeout: float | None object: Result returned by this operation. """ data = self._request_json("POST", sandbox, "/api/input/press", json={"key": key}, http_timeout=http_timeout) - return bool(data.get("ok", False)) + return DesktopOkResponse.model_validate(data).ok @intercept_errors("Failed to press hotkey: ") def hotkey(self, sandbox: SandboxRef, *, keys: list[str]) -> bool: @@ -392,7 +391,7 @@ def hotkey(self, sandbox: SandboxRef, *, keys: list[str]) -> bool: object: Result returned by this operation. """ data = self._request_json("POST", sandbox, "/api/input/hotkey", json={"keys": keys}) - return bool(data.get("ok", False)) + return DesktopOkResponse.model_validate(data).ok @intercept_errors("Failed to get recording status: ") def recording_status(self, sandbox: SandboxRef, http_timeout: float | None = None) -> DesktopRecordingStatus: @@ -640,6 +639,9 @@ def _read_event() -> None: "Desktop status stream error", body=str(event["error"]), ) + if {"error", "message"}.intersection(event) and not {"status", "items", "running", "total"}.intersection(event): + error_event = DesktopStatusStreamErrorEvent.model_validate(event) + raise Leap0Error("Desktop status stream error", body=error_event.detail) yield DesktopProcessStatusList.from_dict(cast(DesktopProcessStatusListDict, event)) finally: response.close() @@ -676,7 +678,7 @@ def _is_transient_leap0(exc: BaseException) -> bool: ) def _poll() -> None: for status in self.status_stream(sandbox, deadline=deadline, http_timeout=http_timeout): - if status.status == "running": + if status.status == "running" or (status.total > 0 and status.running >= status.total): return raise Leap0Error("Desktop status stream ended without reaching 'running' state", retryable=True) diff --git a/leap0/_sync/filesystem.py b/leap0/_sync/filesystem.py index f9e8692..83c3209 100644 --- a/leap0/_sync/filesystem.py +++ b/leap0/_sync/filesystem.py @@ -3,7 +3,7 @@ from typing import Any, cast from .._internal.types import JsonObject -from ..models.filesystem import EditFileResult, EditResult, FileEdit, FileInfo, LsResult, SearchMatch, TreeResult +from ..models.filesystem import EditFileResult, EditResult, FileEdit, FileInfo, LsResult, ReadFileParams, SearchMatch, SetPermissionsParams, TreeResult from ..models.sandbox import SandboxRef, sandbox_id_of from .._schemas.filesystem import EditFileResponseDict, EditFilesResponseDict, ExistsResponseDict, FileInfoDict, GlobResponseDict, GrepResponseDict, LsResponseDict, TreeResponseDict from .._utils.errors import intercept_errors @@ -211,15 +211,7 @@ def read_bytes( print(content) ``` """ - payload: JsonObject = {"path": path} - if offset is not None: - payload["offset"] = offset - if limit is not None: - payload["limit"] = limit - if head is not None: - payload["head"] = head - if tail is not None: - payload["tail"] = tail + payload = ReadFileParams(path=path, offset=offset, limit=limit, head=head, tail=tail).model_dump(exclude_none=True) response = self._transport.request("POST", f"/v1/sandbox/{sandbox_id_of(sandbox)}/filesystem/read-file", json=payload, timeout=http_timeout) return response.content @@ -341,13 +333,7 @@ def set_permissions( Args: http_timeout: Optional HTTP request timeout in seconds for this SDK call. """ - payload: JsonObject = {"path": path} - if mode is not None: - payload["mode"] = mode - if owner is not None: - payload["owner"] = owner - if group is not None: - payload["group"] = group + payload = SetPermissionsParams(path=path, mode=mode, owner=owner, group=group).model_dump(exclude_none=True) self._transport.request("POST", f"/v1/sandbox/{sandbox_id_of(sandbox)}/filesystem/set-permissions", json=payload, expected_status=204, timeout=http_timeout) @intercept_errors("Failed to glob: ") @@ -527,8 +513,17 @@ def _parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes ) for part in msg.get_payload(): # type: ignore[union-attr] name = part.get_param("name", header="content-disposition") - if name: - payload = part.get_payload(decode=True) - if payload is not None: - result[str(name)] = payload + if not name: + continue + content_type = part.get_content_type() + if content_type != "application/octet-stream": + raise ValueError( + f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" + ) + payload = part.get_payload(decode=True) + if payload is None: + raise ValueError( + f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" + ) + result[str(name)] = payload return result diff --git a/leap0/_sync/process.py b/leap0/_sync/process.py index 861f1e7..0589453 100644 --- a/leap0/_sync/process.py +++ b/leap0/_sync/process.py @@ -45,6 +45,7 @@ def execute(self, sandbox: SandboxRef, *, command: str, cwd: str | None = None, command="ls -la /workspace", ) print(result.stdout) + print(result.stderr) ``` """ payload: JsonObject = {"command": command} diff --git a/leap0/_sync/sandbox.py b/leap0/_sync/sandbox.py index dd31cbc..51a8af3 100644 --- a/leap0/_sync/sandbox.py +++ b/leap0/_sync/sandbox.py @@ -99,13 +99,16 @@ def refresh(self) -> Sandbox: self._data = latest._data return self - def pause(self) -> Sandbox: + def pause(self, http_timeout: float | None = None) -> Sandbox: """Pause the sandbox and update this handle with the latest metadata. + Args: + http_timeout: Optional HTTP request timeout in seconds for this SDK call. + Returns: Sandbox: This sandbox object with updated metadata. """ - latest = self._client.sandboxes.pause(self) + latest = self._client.sandboxes.pause(self, http_timeout=http_timeout) self._data = latest._data return self diff --git a/leap0/_utils/stream.py b/leap0/_utils/stream.py index 5a8d841..eab3ec9 100644 --- a/leap0/_utils/stream.py +++ b/leap0/_utils/stream.py @@ -29,6 +29,22 @@ def _parse_sse_data(data: str) -> dict[str, Any] | str: return parsed if isinstance(parsed, dict) else data +def _emit_sse_event(buffer: list[str]) -> dict[str, Any] | str | None: + data_lines = [_sse_data_value(item) for item in buffer if item.startswith("data:")] + if not data_lines: + return None + + event_name: str | None = None + for item in buffer: + if item.startswith("event:"): + event_name = item[6:].lstrip(" ") + + data = "\n".join(data_lines) + if event_name == "error": + return {"error": data} + return _parse_sse_data(data) + + def iter_sse_events(lines: Iterable[str]) -> Iterator[dict[str, Any] | str]: """Yield parsed events from an SSE line iterator.""" buffer: list[str] = [] @@ -36,18 +52,18 @@ def iter_sse_events(lines: Iterable[str]) -> Iterator[dict[str, Any] | str]: stripped = line.rstrip("\r\n") if stripped == "": if buffer: - data_lines = [_sse_data_value(item) for item in buffer if item.startswith("data:")] - if data_lines: - yield _parse_sse_data("\n".join(data_lines)) + event = _emit_sse_event(buffer) + if event is not None: + yield event buffer.clear() continue if stripped.startswith(":"): continue buffer.append(stripped) if buffer: - data_lines = [_sse_data_value(item) for item in buffer if item.startswith("data:")] - if data_lines: - yield _parse_sse_data("\n".join(data_lines)) + event = _emit_sse_event(buffer) + if event is not None: + yield event async def aiter_sse_events(lines: AsyncIterable[str]) -> AsyncIterator[dict[str, Any] | str]: @@ -57,15 +73,15 @@ async def aiter_sse_events(lines: AsyncIterable[str]) -> AsyncIterator[dict[str, stripped = line.rstrip("\r\n") if stripped == "": if buffer: - data_lines = [_sse_data_value(item) for item in buffer if item.startswith("data:")] - if data_lines: - yield _parse_sse_data("\n".join(data_lines)) + event = _emit_sse_event(buffer) + if event is not None: + yield event buffer.clear() continue if stripped.startswith(":"): continue buffer.append(stripped) if buffer: - data_lines = [_sse_data_value(item) for item in buffer if item.startswith("data:")] - if data_lines: - yield _parse_sse_data("\n".join(data_lines)) + event = _emit_sse_event(buffer) + if event is not None: + yield event diff --git a/leap0/models/config.py b/leap0/models/config.py index fe959a6..e79c605 100644 --- a/leap0/models/config.py +++ b/leap0/models/config.py @@ -30,6 +30,25 @@ DEFAULT_CLIENT_TIMEOUT = 300.0 + +def _resolve_sdk_otel_enabled(value: bool | None) -> bool: + if value is not None: + return value + + sdk_otel_env = os.environ.get(LEAP0_SDK_OTEL_ENABLED_ENV) + sdk_otel_env = sdk_otel_env.strip() if sdk_otel_env is not None else None + if not sdk_otel_env: + return bool(os.environ.get(OTEL_EXPORTER_OTLP_ENDPOINT_ENV)) + + lowered = sdk_otel_env.lower() + if lowered == "true": + return True + if lowered == "false": + return False + raise ValueError( + f"invalid {LEAP0_SDK_OTEL_ENABLED_ENV} value: {sdk_otel_env}" + ) + def _resolve_env_str(value: str | None, env_var: str, default: str) -> str: resolved = value.strip() if value else None if not resolved: @@ -73,13 +92,7 @@ def _resolve_and_validate(self) -> Leap0Config: self.timeout = timeout self.api_key = api_key self.auth_header = auth_header - if self.sdk_otel_enabled is None: - sdk_otel_env = os.environ.get(LEAP0_SDK_OTEL_ENABLED_ENV) - sdk_otel_env = sdk_otel_env.strip() if sdk_otel_env is not None else None - if sdk_otel_env: - self.sdk_otel_enabled = sdk_otel_env.lower() == "true" - else: - self.sdk_otel_enabled = bool(os.environ.get(OTEL_EXPORTER_OTLP_ENDPOINT_ENV)) + self.sdk_otel_enabled = _resolve_sdk_otel_enabled(self.sdk_otel_enabled) self.base_url = _resolve_env_str(self.base_url, LEAP0_BASE_URL_ENV, DEFAULT_BASE_URL) self.sandbox_domain = _resolve_env_str( self.sandbox_domain, diff --git a/leap0/models/desktop.py b/leap0/models/desktop.py index 62b6981..09a1258 100644 --- a/leap0/models/desktop.py +++ b/leap0/models/desktop.py @@ -3,16 +3,167 @@ from dataclasses import dataclass, field from typing import Any +from pydantic import BaseModel, ConfigDict, StrictBool, field_validator, model_validator + from .._schemas.desktop import DesktopDisplayInfoDict, DesktopHealthDict, DesktopPointerPositionDict, DesktopProcessErrorsDict, DesktopProcessLogsDict, DesktopProcessRestartDict, DesktopProcessStatusDict, DesktopProcessStatusListDict, DesktopRecordingStatusDict, DesktopRecordingSummaryDict, DesktopWindowDict, DesktopWindowsDict -def _safe_int(value: Any, default: int = 0) -> int: - """Parse *value* as an integer, returning *default* on ``None`` or invalid input.""" + +class DesktopResizeScreenParams(BaseModel): + """Validated request payload for resizing the desktop screen.""" + + model_config = ConfigDict(extra="forbid") + + width: int + height: int + + @model_validator(mode="after") + def _validate_bounds(self) -> "DesktopResizeScreenParams": + if not 320 <= self.width <= 7680: + raise ValueError("width must be between 320 and 7680") + if not 320 <= self.height <= 4320: + raise ValueError("height must be between 320 and 4320") + return self + + +class DesktopScreenshotParams(BaseModel): + """Validated query parameters for desktop screenshots.""" + + model_config = ConfigDict(extra="forbid") + + format: str | None = None + quality: int | None = None + x: int | None = None + y: int | None = None + width: int | None = None + height: int | None = None + + @model_validator(mode="after") + def _validate_region(self) -> "DesktopScreenshotParams": + if self.format is not None and self.format not in {"png", "jpg", "jpeg"}: + raise ValueError("format must be one of: png, jpg, jpeg") + if self.quality is not None and not 1 <= self.quality <= 100: + raise ValueError("quality must be between 1 and 100") + if (self.width is None) != (self.height is None): + raise ValueError("width and height must be provided together") + for name, value in (("x", self.x), ("y", self.y)): + if value is not None and value < 0: + raise ValueError(f"{name} must be >= 0") + for name, value in (("width", self.width), ("height", self.height)): + if value is not None and value < 0: + raise ValueError(f"{name} must be >= 0") + return self + + +class DesktopScreenshotRegionParams(BaseModel): + """Validated request payload for region screenshot capture.""" + + model_config = ConfigDict(extra="forbid") + + x: int + y: int + width: int + height: int + format: str | None = None + quality: int | None = None + + @model_validator(mode="after") + def _validate_region(self) -> "DesktopScreenshotRegionParams": + if self.format is not None and self.format not in {"png", "jpg", "jpeg"}: + raise ValueError("format must be one of: png, jpg, jpeg") + if self.quality is not None and not 1 <= self.quality <= 100: + raise ValueError("quality must be between 1 and 100") + if self.x < 0: + raise ValueError("x must be >= 0") + if self.y < 0: + raise ValueError("y must be >= 0") + if self.width < 1: + raise ValueError("width must be >= 1") + if self.height < 1: + raise ValueError("height must be >= 1") + return self + + +class DesktopClickParams(BaseModel): + """Validated request payload for desktop click operations.""" + + model_config = ConfigDict(extra="forbid") + + x: int | None = None + y: int | None = None + button: int | None = None + + @model_validator(mode="after") + def _validate_click(self) -> "DesktopClickParams": + if (self.x is None) != (self.y is None): + raise ValueError("x and y must be provided together or both omitted") + for name, value in (("x", self.x), ("y", self.y)): + if value is not None and value < 0: + raise ValueError(f"{name} must be >= 0") + return self + + +class DesktopOkResponse(BaseModel): + """Validated response shape for desktop endpoints that return ``ok``.""" + + model_config = ConfigDict(extra="allow") + + ok: StrictBool + + @field_validator("ok", mode="before") + @classmethod + def _validate_ok(cls, value: object) -> object: + if not isinstance(value, bool): + raise ValueError(f"Desktop response missing boolean 'ok', got: {value!r}") + return value + + +class DesktopStatusStreamErrorEvent(BaseModel): + """Validated error envelope emitted by the desktop status SSE stream.""" + + model_config = ConfigDict(extra="allow") + + error: str | None = None + message: str | None = None + + @model_validator(mode="after") + def _validate_error(self) -> "DesktopStatusStreamErrorEvent": + if self.error is None and self.message is None: + raise ValueError("Desktop status stream error event must include error or message") + return self + + @property + def detail(self) -> str: + """Return the normalized human-readable error detail.""" + return self.error or self.message or "unknown desktop status stream error" + +def _require_str(data: dict[str, Any], field: str) -> str: + value = data.get(field) + if not isinstance(value, str): + raise ValueError(f"Desktop response missing string '{field}', got: {value!r}") + return value + + +def _require_bool(data: dict[str, Any], field: str) -> bool: + value = data.get(field) + if not isinstance(value, bool): + raise ValueError(f"Desktop response missing boolean '{field}', got: {value!r}") + return value + + +def _require_int(data: dict[str, Any], field: str) -> int: + value = data.get(field) + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError(f"Desktop response missing integer '{field}', got: {value!r}") + return value + + +def _optional_int(data: dict[str, Any], field: str) -> int | None: + value = data.get(field) if value is None: - return default - try: - return int(value) - except (TypeError, ValueError): - return default + return None + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError(f"Desktop response has invalid integer '{field}', got: {value!r}") + return value @dataclass(slots=True) class DesktopDisplayInfo: @@ -25,9 +176,9 @@ class DesktopDisplayInfo: def from_dict(cls, data: DesktopDisplayInfoDict) -> DesktopDisplayInfo: """Build an instance from a wire-format dictionary.""" return cls( - display=data.get("display", ""), - width=_safe_int(data.get("width"), 0), - height=_safe_int(data.get("height"), 0), + display=_require_str(data, "display"), + width=_require_int(data, "width"), + height=_require_int(data, "height"), ) @dataclass(slots=True) @@ -50,16 +201,16 @@ def from_dict(cls, data: DesktopWindowDict) -> DesktopWindow: """Build an instance from a wire-format dictionary.""" return cls( id=data.get("id", ""), - desktop=_safe_int(data.get("desktop"), 0), - pid=_safe_int(data.get("pid"), 0), - x=_safe_int(data.get("x"), 0), - y=_safe_int(data.get("y"), 0), - width=_safe_int(data.get("width"), 0), - height=_safe_int(data.get("height"), 0), + desktop=_optional_int(data, "desktop") or 0, + pid=_optional_int(data, "pid") or 0, + x=_optional_int(data, "x") or 0, + y=_optional_int(data, "y") or 0, + width=_optional_int(data, "width") or 0, + height=_optional_int(data, "height") or 0, window_class=data.get("class", data.get("class_", "")), host=data.get("host", ""), title=data.get("title", ""), - focused=bool(data.get("focused", False)), + focused=_require_bool(data, "focused") if "focused" in data else False, ) @dataclass(slots=True) @@ -71,7 +222,7 @@ class DesktopPointerPosition: @classmethod def from_dict(cls, data: DesktopPointerPositionDict) -> DesktopPointerPosition: """Build an instance from a wire-format dictionary.""" - return cls(x=_safe_int(data.get("x"), 0), y=_safe_int(data.get("y"), 0)) + return cls(x=_require_int(data, "x"), y=_require_int(data, "y")) @dataclass(slots=True) class DesktopRecordingStatus: @@ -120,7 +271,7 @@ def from_dict(cls, data: DesktopRecordingSummaryDict) -> DesktopRecordingSummary file_name=data.get("file_name", ""), download=data.get("download", ""), mime_type=data.get("mime_type", ""), - size_bytes=_safe_int(data.get("size_bytes"), 0), + size_bytes=_require_int(data, "size_bytes"), created_at=data.get("created_at", ""), active=bool(data.get("active", False)), ) @@ -133,7 +284,7 @@ class DesktopHealth: @classmethod def from_dict(cls, data: DesktopHealthDict) -> DesktopHealth: """Build an instance from a wire-format dictionary.""" - return cls(ok=bool(data.get("ok", False))) + return cls(ok=_require_bool(data, "ok")) @dataclass(slots=True) class DesktopProcessStatus: @@ -148,11 +299,11 @@ class DesktopProcessStatus: def from_dict(cls, data: DesktopProcessStatusDict) -> DesktopProcessStatus: """Build an instance from a wire-format dictionary.""" return cls( - name=data.get("name", ""), - running=bool(data.get("running", False)), - pid=_safe_int(data.get("pid"), 0), - stdout_log=data.get("stdout_log", ""), - stderr_log=data.get("stderr_log", ""), + name=_require_str(data, "name"), + running=_require_bool(data, "running"), + pid=_optional_int(data, "pid") or 0, + stdout_log=_require_str(data, "stdout_log"), + stderr_log=_require_str(data, "stderr_log"), ) @dataclass(slots=True) @@ -168,16 +319,16 @@ def from_dict(cls, data: DesktopProcessStatusListDict) -> DesktopProcessStatusLi """Build an instance from a wire-format dictionary.""" raw_items = data.get("items") if not isinstance(raw_items, (list, tuple)): - raw_items = [] + raise ValueError(f"Desktop response missing array 'items', got: {raw_items!r}") return cls( - status=data.get("status", ""), + status=_require_str(data, "status"), items=[ DesktopProcessStatus.from_dict(item) # type: ignore[arg-type] for item in raw_items if isinstance(item, dict) ], - running=_safe_int(data.get("running"), 0), - total=_safe_int(data.get("total"), 0), + running=_require_int(data, "running"), + total=_require_int(data, "total"), ) @dataclass(slots=True) diff --git a/leap0/models/filesystem.py b/leap0/models/filesystem.py index 9649912..a805802 100644 --- a/leap0/models/filesystem.py +++ b/leap0/models/filesystem.py @@ -1,8 +1,61 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Annotated + +from pydantic import BaseModel, ConfigDict, StringConstraints, field_validator, model_validator + from .._schemas.filesystem import EditFileResponseDict, EditFilesResponseDict, EditResultDict, ExistsResponseDict, FileInfoDict, GlobResponseDict, GrepResponseDict, LsResponseDict, SearchMatchDict, TreeEntryDict, TreeResponseDict + +class ReadFileParams(BaseModel): + """Validated request parameters for reading a single file.""" + + model_config = ConfigDict(extra="forbid") + + path: str + offset: int | None = None + limit: int | None = None + head: int | None = None + tail: int | None = None + + @model_validator(mode="after") + def _validate_head_tail(self) -> "ReadFileParams": + if self.head is not None and self.tail is not None: + raise ValueError("`head` and `tail` are mutually exclusive") + return self + + +NonEmptyOptionalString = Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)] + + +class SetPermissionsParams(BaseModel): + """Validated request parameters for the set-permissions endpoint.""" + + model_config = ConfigDict(extra="forbid") + + path: str + mode: NonEmptyOptionalString | None = None + owner: NonEmptyOptionalString | None = None + group: NonEmptyOptionalString | None = None + + @field_validator("mode", "owner", "group", mode="before") + @classmethod + def _validate_non_empty_string(cls, value: str | None, info: object) -> str | None: + if value is None: + return None + trimmed = value.strip() + if not trimmed: + field_name = getattr(info, "field_name", "value") + raise ValueError(f"set_permissions {field_name} must be a non-empty string") + return trimmed + + @model_validator(mode="after") + def _validate_updates(self) -> "SetPermissionsParams": + if self.mode is None and self.owner is None and self.group is None: + raise ValueError("set_permissions requires at least one of mode, owner, or group") + return self + @dataclass(slots=True) class FileInfo: """Filesystem metadata for a path inside a sandbox. diff --git a/leap0/models/process.py b/leap0/models/process.py index db5f950..66024c3 100644 --- a/leap0/models/process.py +++ b/leap0/models/process.py @@ -7,9 +7,14 @@ class ProcessResult: """Result of a one-shot process execution.""" exit_code: int - result: str + stdout: str + stderr: str @classmethod def from_dict(cls, data: ProcessResultDict) -> ProcessResult: """Build an instance from a wire-format dictionary.""" - return cls(exit_code=int(data.get("exit_code", 0)), result=data.get("result", "")) + return cls( + exit_code=int(data.get("exit_code", 0)), + stdout=data.get("stdout", ""), + stderr=data.get("stderr", ""), + ) diff --git a/leap0/models/sandbox.py b/leap0/models/sandbox.py index 9ffdb0d..b7e2733 100644 --- a/leap0/models/sandbox.py +++ b/leap0/models/sandbox.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from enum import Enum +import ipaddress +import re from typing import TypeAlias from pydantic import BaseModel, ConfigDict, model_validator @@ -27,6 +29,65 @@ class NetworkPolicyMode(str, Enum): DENY_ALL = "deny-all" CUSTOM = "custom" + +_DOMAIN_LABEL_RE = re.compile(r"^[A-Za-z0-9-]+$") + + +def _validate_domain_pattern(value: str) -> str: + domain = value.strip() + if not domain: + raise ValueError("network policy domains must be non-empty") + + host = domain[2:] if domain.startswith("*.") else domain + if not host or host.startswith(".") or host.endswith("."): + raise ValueError(f"invalid network policy domain pattern: {value!r}") + if ":" in host: + raise ValueError(f"invalid network policy domain pattern: {value!r}") + + labels = host.split(".") + if len(labels) < 2: + raise ValueError(f"invalid network policy domain pattern: {value!r}") + for label in labels: + if not label or label.startswith("-") or label.endswith("-") or not _DOMAIN_LABEL_RE.fullmatch(label): + raise ValueError(f"invalid network policy domain pattern: {value!r}") + return domain + + +def _validate_network_policy(policy: NetworkPolicyDict | None) -> NetworkPolicyDict | None: + if policy is None: + return None + + mode = policy.get("mode") + valid_modes = {item.value for item in NetworkPolicyMode} + if mode not in valid_modes: + raise ValueError(f"network_policy.mode must be one of {sorted(valid_modes)}") + + allow_domains = policy.get("allow_domains") + if allow_domains is not None: + if len(allow_domains) > 50: + raise ValueError("network_policy.allow_domains must contain at most 50 entries") + for domain in allow_domains: + _validate_domain_pattern(domain) + + allow_cidrs = policy.get("allow_cidrs") + if allow_cidrs is not None: + if len(allow_cidrs) > 10: + raise ValueError("network_policy.allow_cidrs must contain at most 10 entries") + for cidr in allow_cidrs: + try: + ipaddress.IPv4Network(cidr, strict=False) + except ValueError as err: + raise ValueError(f"invalid network policy CIDR: {cidr!r}") from err + + transforms = policy.get("transforms") + if transforms is not None: + if len(transforms) > 20: + raise ValueError("network_policy.transforms must contain at most 20 entries") + for transform in transforms: + _validate_domain_pattern(transform["domain"]) + + return policy + class CreateSandboxParams(BaseModel): """Validated sandbox creation parameters.""" model_config = ConfigDict(extra="forbid") @@ -53,6 +114,7 @@ def _validate_values(self) -> CreateSandboxParams: raise ValueError("memory_mib must be an even number between 512 and 8192") if not 1 <= self.timeout_min <= 480: raise ValueError("timeout_min must be between 1 and 480") + self.network_policy = _validate_network_policy(self.network_policy) self.template_name = template_name return self diff --git a/leap0/models/snapshot.py b/leap0/models/snapshot.py index 8d943ad..04dd87f 100644 --- a/leap0/models/snapshot.py +++ b/leap0/models/snapshot.py @@ -71,21 +71,30 @@ class Snapshot: vcpu: int = 0 memory_mib: int = 0 disk_mib: int = 0 - state: SandboxState | str = SandboxState.STARTING + state: SandboxState | str | None = None network_policy: NetworkPolicyDict | None = None created_at: str = "" @classmethod def from_dict(cls, data: SnapshotCreateResponseDict) -> Snapshot: """Build an instance from a wire-format dictionary.""" + snapshot_id = data.get("id") + if not isinstance(snapshot_id, str) or not snapshot_id.strip(): + raise ValueError(f"Snapshot response missing required non-empty string 'id', got: {snapshot_id!r}") + snapshot_name = data.get("name") + if not isinstance(snapshot_name, str) or not snapshot_name.strip(): + raise ValueError( + f"Snapshot response missing required non-empty string 'name', got: {snapshot_name!r}" + ) + state = data.get("state") return cls( - id=data.get("id", ""), - name=data.get("name", ""), + id=snapshot_id, + name=snapshot_name, template_id=data.get("template_id", ""), vcpu=int(data.get("vcpu", 0)), memory_mib=int(data.get("memory_mib", 0)), disk_mib=int(data.get("disk_mib", 0)), - state=_parse_sandbox_state(data.get("state")), + state=_parse_sandbox_state(state) if state is not None else None, network_policy=data.get("network_policy"), created_at=data.get("created_at", ""), ) diff --git a/tests/_async/test_desktop.py b/tests/_async/test_desktop.py new file mode 100644 index 0000000..2a03ce6 --- /dev/null +++ b/tests/_async/test_desktop.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from leap0._async.desktop import AsyncDesktopClient +from leap0.models.errors import Leap0Error + + +class TestAsyncDesktopClient: + def test_validates_request_payloads(self, async_mock_transport): + async def run() -> None: + client = AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com") + + with pytest.raises(Leap0Error, match="width must be between 320 and 7680"): + await client.resize_screen("sbx-1", width=100, height=720) + with pytest.raises(Leap0Error, match="width and height must be provided together"): + await client.screenshot("sbx-1", width=100) + with pytest.raises(Leap0Error, match="format must be one of: png, jpg, jpeg"): + await client.screenshot("sbx-1", image_format="webp") + with pytest.raises(Leap0Error, match="quality must be between 1 and 100"): + await client.screenshot("sbx-1", quality=101) + with pytest.raises(Leap0Error, match="height must be >= 1"): + await client.screenshot_region("sbx-1", x=0, y=0, width=10, height=0) + with pytest.raises(Leap0Error, match="x and y must be provided together"): + await client.click("sbx-1", x=10) + + assert async_mock_transport.request.call_count == 0 + assert async_mock_transport.request_target_json.call_count == 0 + + asyncio.run(run()) + + def test_screenshot_allows_zero_sized_paired_region_query(self, async_mock_transport): + async def run() -> None: + response = MagicMock() + response.content = b"image" + async_mock_transport.request_target.return_value = response + + result = await AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").screenshot( + "sbx-1", + width=0, + height=0, + ) + + assert result == b"image" + assert async_mock_transport.request_target.call_args.kwargs["params"] == {"width": 0, "height": 0} + + asyncio.run(run()) + + def test_requires_boolean_ok_response(self, async_mock_transport): + async def run() -> None: + async_mock_transport.request_target_json.return_value = {"ok": "false"} + + with pytest.raises(Leap0Error, match="missing boolean 'ok'"): + await AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").type_text("sbx-1", text="hello") + + asyncio.run(run()) + + def test_status_stream_raises_on_non_dict_event(self, async_mock_transport): + async def run() -> None: + response = MagicMock() + + async def aiter_lines(): + yield "data: malformed" + yield "" + + response.aiter_lines = aiter_lines + response.aclose = AsyncMock() + async_mock_transport.stream.return_value = response + + with pytest.raises(Leap0Error, match="Malformed desktop status stream event"): + async for _ in AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1"): + pass + + asyncio.run(run()) + + def test_wait_until_ready_accepts_count_only_running_updates(self, async_mock_transport): + async def run() -> None: + response = MagicMock() + + async def aiter_lines(): + yield 'data: {"status": "degraded", "items": [{"name": "xvfb", "running": true, "stdout_log": "/tmp/xvfb.stdout.log", "stderr_log": "/tmp/xvfb.stderr.log"}], "running": 4, "total": 4}' + yield "" + + response.aiter_lines = aiter_lines + response.aclose = AsyncMock() + async_mock_transport.stream.return_value = response + + await AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").wait_until_ready("sbx-1", timeout=1) + + asyncio.run(run()) + + def test_status_stream_raises_on_plain_text_error_event(self, async_mock_transport): + async def run() -> None: + response = MagicMock() + + async def aiter_lines(): + yield "event: error" + yield "data: Desktop request failed" + yield "" + + response.aiter_lines = aiter_lines + response.aclose = AsyncMock() + async_mock_transport.stream.return_value = response + + with pytest.raises(Leap0Error, match="Desktop status stream error"): + async for _ in AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1"): + pass + + asyncio.run(run()) + + def test_process_status_requires_documented_fields(self, async_mock_transport): + async def run() -> None: + async_mock_transport.request_target_json.return_value = {"items": [], "running": 0, "total": 0} + + with pytest.raises(Leap0Error, match="missing string 'status'"): + await AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").process_status("sbx-1") + + asyncio.run(run()) diff --git a/tests/_async/test_filesystem.py b/tests/_async/test_filesystem.py index 37a8ce7..d327392 100644 --- a/tests/_async/test_filesystem.py +++ b/tests/_async/test_filesystem.py @@ -56,8 +56,31 @@ async def run() -> None: asyncio.run(run()) + def test_set_permissions_rejects_missing_or_blank_updates(self, async_mock_transport): + async def run() -> None: + client = AsyncFilesystemClient(async_mock_transport) + + with pytest.raises(Leap0Error, match="at least one of mode, owner, or group"): + await client.set_permissions("sbx-1", path="/workspace/a.txt") + with pytest.raises(Leap0Error, match="group must be a non-empty string"): + await client.set_permissions("sbx-1", path="/workspace/a.txt", group=" ") + + assert async_mock_transport.request.call_count == 0 + + asyncio.run(run()) + class TestParseMultipartResponse: def test_non_multipart_redacts_preview(self): with pytest.raises(ValueError, match=""): _parse_multipart_response("application/json", b'{"secret": "value"}') + + def test_text_part_raises(self): + boundary = "boundary123" + body = ( + f"--{boundary}\r\nContent-Disposition: form-data; name=\"/a.txt\"\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n\r\ncontent a\r\n" + f"--{boundary}--\r\n" + ).encode() + with pytest.raises(ValueError, match="Failed to parse /read-files response"): + _parse_multipart_response(f"multipart/form-data; boundary={boundary}", body) diff --git a/tests/_async/test_process.py b/tests/_async/test_process.py index c2d66a7..e902477 100644 --- a/tests/_async/test_process.py +++ b/tests/_async/test_process.py @@ -8,10 +8,11 @@ class TestAsyncProcessClient: def test_execute(self, async_mock_transport): async def run() -> None: - async_mock_transport.request_json.return_value = {"exit_code": 0, "result": "hello"} + async_mock_transport.request_json.return_value = {"exit_code": 0, "stdout": "hello", "stderr": "warn"} result = await AsyncProcessClient(async_mock_transport).execute("sbx-1", command="echo hello") assert result.exit_code == 0 - assert result.result == "hello" + assert result.stdout == "hello" + assert result.stderr == "warn" assert async_mock_transport.request_json.call_args[0][:2] == ("POST", "/v1/sandbox/sbx-1/process/execute") asyncio.run(run()) diff --git a/tests/_async/test_sandboxes.py b/tests/_async/test_sandboxes.py index 684e2d9..7cbd0e3 100644 --- a/tests/_async/test_sandboxes.py +++ b/tests/_async/test_sandboxes.py @@ -72,3 +72,43 @@ async def run() -> None: await AsyncSandboxesClient(async_mock_transport, sandbox_domain="s.dev").create(otel_export=True) asyncio.run(run()) + + def test_pause_forwards_http_timeout(self, async_mock_transport): + async def run() -> None: + async_mock_transport.request_json.return_value = { + "id": "sbx-1", "template_id": "tpl-1", "vcpu": 2, "memory_mib": 2048, + "disk_mib": 10240, "state": "paused", "auto_pause": False, "created_at": "", + } + + await AsyncSandboxesClient(async_mock_transport, sandbox_domain="s.dev").pause( + "sbx-1", + http_timeout=4.0, + ) + + assert async_mock_transport.request_json.call_args.kwargs["timeout"] == 4.0 + + asyncio.run(run()) + + +class TestAsyncSandbox: + def test_pause_forwards_http_timeout(self): + async def run() -> None: + sandboxes = SimpleNamespace() + fake_client = SimpleNamespace( + _filesystem=SimpleNamespace(), _git=SimpleNamespace(), _process=SimpleNamespace(), _pty=SimpleNamespace(), + _lsp=SimpleNamespace(), _ssh=SimpleNamespace(), _code_interpreter=SimpleNamespace(), _desktop=SimpleNamespace(), + sandboxes=sandboxes, + ) + + async def pause(sandbox: object, http_timeout: float | None = None): + assert http_timeout == 2.5 + return AsyncSandbox(fake_client, Sandbox(id="sbx-1", state="paused")) + + sandboxes.pause = pause + sandbox = AsyncSandbox(fake_client, Sandbox(id="sbx-1", state="running")) + + await sandbox.pause(http_timeout=2.5) + + assert sandbox.state == "paused" + + asyncio.run(run()) diff --git a/tests/_sync/test_client_config.py b/tests/_sync/test_client_config.py index 01dab00..4582f67 100644 --- a/tests/_sync/test_client_config.py +++ b/tests/_sync/test_client_config.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from leap0.models.config import Leap0Config @@ -31,6 +33,24 @@ def test_explicit_sdk_flag_precedence(monkeypatch): assert config.sdk_otel_enabled is False +def test_sdk_otel_enabled_accepts_case_insensitive_values(monkeypatch): + monkeypatch.setenv("LEAP0_API_KEY", "test-key") + monkeypatch.setenv("LEAP0_SDK_OTEL_ENABLED", "TrUe") + + assert Leap0Config().sdk_otel_enabled is True + + monkeypatch.setenv("LEAP0_SDK_OTEL_ENABLED", "FaLsE") + assert Leap0Config().sdk_otel_enabled is False + + +def test_sdk_otel_enabled_rejects_invalid_string(monkeypatch): + monkeypatch.setenv("LEAP0_API_KEY", "test-key") + monkeypatch.setenv("LEAP0_SDK_OTEL_ENABLED", "maybe") + + with pytest.raises(ValueError, match="invalid LEAP0_SDK_OTEL_ENABLED value: maybe"): + Leap0Config() + + def test_legacy_otel_env_no_longer_enables_sdk(monkeypatch): monkeypatch.setenv("LEAP0_API_KEY", "test-key") monkeypatch.setenv("LEAP0_OTEL_ENABLED", "true") diff --git a/tests/_sync/test_desktop.py b/tests/_sync/test_desktop.py index b0e064f..c51117f 100644 --- a/tests/_sync/test_desktop.py +++ b/tests/_sync/test_desktop.py @@ -9,6 +9,45 @@ class TestDesktopClient: + def test_validates_request_payloads(self, mock_transport): + client = DesktopClient(mock_transport, sandbox_domain="sandbox.example.com") + + with pytest.raises(Leap0Error, match="width must be between 320 and 7680"): + client.resize_screen("sbx-1", width=100, height=720) + with pytest.raises(Leap0Error, match="width and height must be provided together"): + client.screenshot("sbx-1", width=100) + with pytest.raises(Leap0Error, match="format must be one of: png, jpg, jpeg"): + client.screenshot("sbx-1", image_format="webp") + with pytest.raises(Leap0Error, match="quality must be between 1 and 100"): + client.screenshot("sbx-1", quality=101) + with pytest.raises(Leap0Error, match="height must be >= 1"): + client.screenshot_region("sbx-1", x=0, y=0, width=10, height=0) + with pytest.raises(Leap0Error, match="x and y must be provided together"): + client.click("sbx-1", x=10) + + assert mock_transport.request.call_count == 0 + assert mock_transport.request_target_json.call_count == 0 + + def test_screenshot_allows_zero_sized_paired_region_query(self, mock_transport): + response = MagicMock() + response.content = b"image" + mock_transport.request_target.return_value = response + + result = DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").screenshot( + "sbx-1", + width=0, + height=0, + ) + + assert result == b"image" + assert mock_transport.request_target.call_args.kwargs["params"] == {"width": 0, "height": 0} + + def test_requires_boolean_ok_response(self, mock_transport): + mock_transport.request_target_json.return_value = {"ok": "false"} + + with pytest.raises(Leap0Error, match="missing boolean 'ok'"): + DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").type_text("sbx-1", text="hello") + def test_status_stream_raises_on_non_dict_event(self, mock_transport): response = MagicMock() response.iter_lines.return_value = iter(["data: malformed", ""]) @@ -22,7 +61,7 @@ def test_wait_until_ready_retries_only_retryable_errors(self, mock_transport): first.iter_lines.return_value = iter([]) second = MagicMock() second.iter_lines.return_value = iter([ - 'data: {"status": "running", "items": []}', + 'data: {"status": "running", "items": [{"name": "xvfb", "running": true, "stdout_log": "/tmp/xvfb.stdout.log", "stderr_log": "/tmp/xvfb.stderr.log"}], "running": 1, "total": 1}', "", ]) mock_transport.stream.side_effect = [first, second] @@ -36,7 +75,7 @@ def test_wait_until_ready_stops_on_malformed_stream(self, mock_transport): bad.iter_lines.return_value = iter(["data: malformed", ""]) good = MagicMock() good.iter_lines.return_value = iter([ - 'data: {"status": "running", "items": []}', + 'data: {"status": "running", "items": [{"name": "xvfb", "running": true, "stdout_log": "/tmp/xvfb.stdout.log", "stderr_log": "/tmp/xvfb.stderr.log"}], "running": 1, "total": 1}', "", ]) mock_transport.stream.side_effect = [bad, good] @@ -45,3 +84,31 @@ def test_wait_until_ready_stops_on_malformed_stream(self, mock_transport): DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").wait_until_ready("sbx-1", timeout=1) assert mock_transport.stream.call_count == 1 + + def test_wait_until_ready_accepts_count_only_running_updates(self, mock_transport): + response = MagicMock() + response.iter_lines.return_value = iter([ + 'data: {"status": "degraded", "items": [{"name": "xvfb", "running": true, "stdout_log": "/tmp/xvfb.stdout.log", "stderr_log": "/tmp/xvfb.stderr.log"}], "running": 4, "total": 4}', + "", + ]) + mock_transport.stream.return_value = response + + DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").wait_until_ready("sbx-1", timeout=1) + + def test_status_stream_raises_on_plain_text_error_event(self, mock_transport): + response = MagicMock() + response.iter_lines.return_value = iter([ + "event: error", + "data: Desktop request failed", + "", + ]) + mock_transport.stream.return_value = response + + with pytest.raises(Leap0Error, match="Desktop status stream error"): + list(DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1")) + + def test_process_status_requires_documented_fields(self, mock_transport): + mock_transport.request_target_json.return_value = {"items": [], "running": 0, "total": 0} + + with pytest.raises(Leap0Error, match="missing string 'status'"): + DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").process_status("sbx-1") diff --git a/tests/_sync/test_filesystem.py b/tests/_sync/test_filesystem.py index d6bb879..fa94eaa 100644 --- a/tests/_sync/test_filesystem.py +++ b/tests/_sync/test_filesystem.py @@ -50,6 +50,22 @@ def test_read_bytes(self, mock_transport): mock_transport.request.return_value = MagicMock(content=b"Hello") assert FilesystemClient(mock_transport).read_bytes("sbx-1", path="/workspace/hello.bin") == b"Hello" + def test_read_bytes_rejects_head_and_tail(self, mock_transport): + with pytest.raises(Exception, match="mutually exclusive"): + FilesystemClient(mock_transport).read_bytes("sbx-1", path="/workspace/hello.bin", head=1, tail=1) + + def test_set_permissions_rejects_missing_or_blank_updates(self, mock_transport): + client = FilesystemClient(mock_transport) + + with pytest.raises(Exception, match="at least one of mode, owner, or group"): + client.set_permissions("sbx-1", path="/workspace/a.txt") + with pytest.raises(Exception, match="mode must be a non-empty string"): + client.set_permissions("sbx-1", path="/workspace/a.txt", mode=" ") + with pytest.raises(Exception, match="owner must be a non-empty string"): + client.set_permissions("sbx-1", path="/workspace/a.txt", owner="") + + assert mock_transport.request.call_count == 0 + def test_write_files(self, mock_transport): mock_transport.request.return_value = MagicMock(status_code=204) FilesystemClient(mock_transport).write_files("sbx-1", files={"/workspace/hello.txt": "Hello"}) @@ -83,3 +99,13 @@ def test_valid(self): def test_non_multipart_raises(self): with pytest.raises(ValueError, match="Expected multipart"): _parse_multipart_response("application/json", b'{"error": "bad"}') + + def test_text_part_raises(self): + boundary = "boundary123" + body = ( + f"--{boundary}\r\nContent-Disposition: form-data; name=\"/a.txt\"\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n\r\ncontent a\r\n" + f"--{boundary}--\r\n" + ).encode() + with pytest.raises(ValueError, match="Failed to parse /read-files response"): + _parse_multipart_response(f"multipart/form-data; boundary={boundary}", body) diff --git a/tests/_sync/test_process.py b/tests/_sync/test_process.py index 3e8d26b..678a107 100644 --- a/tests/_sync/test_process.py +++ b/tests/_sync/test_process.py @@ -5,8 +5,9 @@ class TestProcessClient: def test_execute(self, mock_transport): - mock_transport.request_json.return_value = {"exit_code": 0, "result": "hello"} + mock_transport.request_json.return_value = {"exit_code": 0, "stdout": "hello", "stderr": "warn"} result = ProcessClient(mock_transport).execute("sbx-1", command="echo hello") assert result.exit_code == 0 - assert result.result == "hello" + assert result.stdout == "hello" + assert result.stderr == "warn" assert mock_transport.request_json.call_args[0][:2] == ("POST", "/v1/sandbox/sbx-1/process/execute") diff --git a/tests/_sync/test_sandboxes.py b/tests/_sync/test_sandboxes.py index 2926246..eab5319 100644 --- a/tests/_sync/test_sandboxes.py +++ b/tests/_sync/test_sandboxes.py @@ -178,3 +178,24 @@ def test_refresh_updates_metadata(self): sandbox.refresh() assert sandbox.state == "running" + + def test_pause_forwards_http_timeout(self): + sandboxes = MagicMock() + client = SimpleNamespace( + _filesystem=MagicMock(), + _git=MagicMock(), + _process=MagicMock(), + _pty=MagicMock(), + _lsp=MagicMock(), + _ssh=MagicMock(), + _code_interpreter=MagicMock(), + _desktop=MagicMock(), + sandboxes=sandboxes, + ) + sandbox = RichSandbox(client, Sandbox(id="sbx-1", state="running")) + sandboxes.pause.return_value = RichSandbox(client, Sandbox(id="sbx-1", state="paused")) + + sandbox.pause(http_timeout=7.5) + + sandboxes.pause.assert_called_once_with(sandbox, http_timeout=7.5) + assert sandbox.state == "paused" diff --git a/tests/_utils/test_stream.py b/tests/_utils/test_stream.py index b64eafc..5f30450 100644 --- a/tests/_utils/test_stream.py +++ b/tests/_utils/test_stream.py @@ -35,7 +35,7 @@ def test_non_data_fields_ignored(self): assert list(iter_sse_events(["event: update", "id: 42", "data: {\"ok\": true}", ""])) == [{"ok": True}] def test_plain_text_data_preserved(self): - assert list(iter_sse_events(["event: error", "data: desktop stream failed", ""])) == ["desktop stream failed"] + assert list(iter_sse_events(["event: error", "data: desktop stream failed", ""])) == [{"error": "desktop stream failed"}] class TestIterNdjson: diff --git a/tests/models/test_desktop.py b/tests/models/test_desktop.py index 2a1b475..6e6a4c3 100644 --- a/tests/models/test_desktop.py +++ b/tests/models/test_desktop.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from leap0.models.desktop import ( DesktopDisplayInfo, DesktopHealth, DesktopPointerPosition, DesktopProcessErrors, DesktopProcessLogs, DesktopProcessRestart, DesktopProcessStatus, DesktopProcessStatusList, @@ -11,8 +13,9 @@ class TestDesktopHealth: def test_ok_true(self): assert DesktopHealth.from_dict({"ok": True}).ok is True - def test_empty_dict(self): - assert DesktopHealth.from_dict({}).ok is False + def test_requires_ok(self): + with pytest.raises(ValueError, match="missing boolean 'ok'"): + DesktopHealth.from_dict({}) class TestDesktopDisplayInfo: @@ -55,24 +58,58 @@ def test_from_dict(self): class TestDesktopProcessStatus: def test_from_dict(self): - p = DesktopProcessStatus.from_dict({"name": "xvfb", "running": True, "pid": 123}) + p = DesktopProcessStatus.from_dict( + { + "name": "xvfb", + "running": True, + "pid": 123, + "stdout_log": "/tmp/xvfb.stdout.log", + "stderr_log": "/tmp/xvfb.stderr.log", + } + ) assert p.running is True class TestDesktopProcessStatusList: def test_from_dict(self): - sl = DesktopProcessStatusList.from_dict({"status": "running", "items": [{"name": "xvfb", "running": True, "pid": 1}], - "running": 1, "total": 4}) + sl = DesktopProcessStatusList.from_dict( + { + "status": "running", + "items": [ + { + "name": "xvfb", + "running": True, + "pid": 1, + "stdout_log": "/tmp/xvfb.stdout.log", + "stderr_log": "/tmp/xvfb.stderr.log", + } + ], + "running": 1, + "total": 4, + } + ) assert len(sl.items) == 1 assert sl.total == 4 - def test_empty_dict(self): - assert DesktopProcessStatusList.from_dict({}).items == [] + def test_requires_status_fields(self): + with pytest.raises(ValueError, match="missing array 'items'"): + DesktopProcessStatusList.from_dict({}) class TestDesktopProcessRestart: def test_with_status(self): - r = DesktopProcessRestart.from_dict({"message": "restarted", "status": {"name": "xvfb", "running": True, "pid": 42}}) + r = DesktopProcessRestart.from_dict( + { + "message": "restarted", + "status": { + "name": "xvfb", + "running": True, + "pid": 42, + "stdout_log": "/tmp/xvfb.stdout.log", + "stderr_log": "/tmp/xvfb.stderr.log", + }, + } + ) assert r.status.pid == 42 def test_without_status(self): diff --git a/tests/models/test_process.py b/tests/models/test_process.py index 667a3f0..421dc5e 100644 --- a/tests/models/test_process.py +++ b/tests/models/test_process.py @@ -5,6 +5,7 @@ class TestProcessResult: def test_from_dict(self): - r = ProcessResult.from_dict({"exit_code": 1, "result": "error output"}) + r = ProcessResult.from_dict({"exit_code": 1, "stdout": "hello", "stderr": "error output"}) assert r.exit_code == 1 - assert r.result == "error output" + assert r.stdout == "hello" + assert r.stderr == "error output" diff --git a/tests/models/test_sandbox.py b/tests/models/test_sandbox.py index d2b2ba8..33f405e 100644 --- a/tests/models/test_sandbox.py +++ b/tests/models/test_sandbox.py @@ -2,7 +2,7 @@ import pytest -from leap0.models.sandbox import Sandbox, SandboxStatus, sandbox_id_of +from leap0.models.sandbox import CreateSandboxParams, Sandbox, SandboxStatus, sandbox_id_of class TestSandboxIdOf: @@ -52,3 +52,18 @@ def test_full_dict(self): def test_empty_dict_raises(self): with pytest.raises(ValueError, match="missing required non-empty string 'id'"): SandboxStatus.from_dict({}) + + +class TestCreateSandboxParams: + def test_rejects_invalid_network_policy(self): + with pytest.raises(ValueError, match="network_policy.mode"): + CreateSandboxParams(network_policy={"mode": "nope"}) + + with pytest.raises(ValueError, match="invalid network policy domain pattern"): + CreateSandboxParams(network_policy={"mode": "custom", "allow_domains": ["localhost"]}) + + with pytest.raises(ValueError, match="invalid network policy CIDR"): + CreateSandboxParams(network_policy={"mode": "custom", "allow_cidrs": ["bad"]}) + + with pytest.raises(ValueError, match="allow_domains must contain at most 50"): + CreateSandboxParams(network_policy={"mode": "custom", "allow_domains": ["a.example.com"] * 51}) diff --git a/tests/models/test_snapshot.py b/tests/models/test_snapshot.py index 410e782..3a012ae 100644 --- a/tests/models/test_snapshot.py +++ b/tests/models/test_snapshot.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from leap0.models.snapshot import Snapshot, snapshot_id_of @@ -20,9 +22,17 @@ def test_from_dict_full(self): "vcpu": 2, "memory_mib": 1024, "disk_mib": 10240, "network_policy": {"mode": "deny-all"}, "created_at": "2025-01-01"}) assert s.id == "snap-1" + assert s.state is None assert s.network_policy == {"mode": "deny-all"} - def test_from_dict_minimal(self): - s = Snapshot.from_dict({}) - assert s.id == "" - assert s.name == "" + def test_from_dict_with_state(self): + s = Snapshot.from_dict({"id": "snap-1", "name": "my-snap", "state": "paused"}) + assert s.state == "paused" + + def test_from_dict_requires_id(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'id'"): + Snapshot.from_dict({"name": "my-snap"}) + + def test_from_dict_requires_name(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'name'"): + Snapshot.from_dict({"id": "snap-1"}) From c9cdc57ed06682b9a802e47db6d4307f85c8c9c6 Mon Sep 17 00:00:00 2001 From: steven-passynkov Date: Mon, 6 Apr 2026 18:05:33 -0400 Subject: [PATCH 2/4] Fixes --- leap0/_async/desktop.py | 18 ++++++++++------ leap0/_async/filesystem.py | 29 ++----------------------- leap0/_schemas/process.py | 1 + leap0/_sync/desktop.py | 25 +++++++++++++--------- leap0/_sync/filesystem.py | 30 ++------------------------ leap0/_utils/multipart.py | 31 +++++++++++++++++++++++++++ leap0/_utils/stream.py | 15 ++++++++----- leap0/models/config.py | 3 ++- leap0/models/desktop.py | 16 +++++++++++--- leap0/models/process.py | 35 ++++++++++++++++++++++++++++--- leap0/models/sandbox.py | 12 +++++++++-- leap0/models/snapshot.py | 2 ++ tests/_async/test_desktop.py | 23 +++++++++++++++++++- tests/_sync/test_client_config.py | 8 +++++++ tests/_sync/test_desktop.py | 16 +++++++++++++- tests/_sync/test_filesystem.py | 9 ++++---- tests/_utils/test_stream.py | 3 +++ tests/models/test_desktop.py | 24 ++++++++++++++------- tests/models/test_process.py | 12 +++++++++++ tests/models/test_sandbox.py | 8 ++++++- tests/models/test_snapshot.py | 29 +++++++++++++++++++++++++ 21 files changed, 250 insertions(+), 99 deletions(-) create mode 100644 leap0/_utils/multipart.py diff --git a/leap0/_async/desktop.py b/leap0/_async/desktop.py index 028fce9..71f0c26 100644 --- a/leap0/_async/desktop.py +++ b/leap0/_async/desktop.py @@ -6,6 +6,7 @@ from typing import cast import httpx +from pydantic import ValidationError from .._internal.types import JsonObject from ..models.desktop import ( @@ -570,12 +571,17 @@ async def status_stream(self, sandbox: SandboxRef, *, deadline: float | None = N "Malformed desktop status stream event " f"for sandbox={sandbox_id_of(sandbox)!r}, source='status_stream': {event!r}" ) - if "error" in event: - raise Leap0Error("Desktop status stream error", body=str(event["error"])) - if {"error", "message"}.intersection(event) and not {"status", "items", "running", "total"}.intersection(event): - error_event = DesktopStatusStreamErrorEvent.model_validate(event) - raise Leap0Error("Desktop status stream error", body=error_event.detail) - yield DesktopProcessStatusList.from_dict(cast(DesktopProcessStatusListDict, event)) + try: + yield DesktopProcessStatusList.from_dict(cast(DesktopProcessStatusListDict, event)) + continue + except (TypeError, ValueError) as status_error: + try: + error_event = DesktopStatusStreamErrorEvent.model_validate(event) + except ValidationError: + if "error" in event: + raise Leap0Error("Desktop status stream error", body=str(event["error"])) from status_error + raise status_error + raise Leap0Error("Desktop status stream error", body=error_event.detail) from status_error finally: await response.aclose() diff --git a/leap0/_async/filesystem.py b/leap0/_async/filesystem.py index 34233d4..2953ef5 100644 --- a/leap0/_async/filesystem.py +++ b/leap0/_async/filesystem.py @@ -7,6 +7,7 @@ from ..models.sandbox import SandboxRef, sandbox_id_of from .._schemas.filesystem import EditFileResponseDict, EditFilesResponseDict, ExistsResponseDict, FileInfoDict, GlobResponseDict, GrepResponseDict, LsResponseDict, TreeResponseDict from .._utils.errors import intercept_errors +from .._utils.multipart import parse_multipart_response from ._transport import AsyncTransport @@ -499,30 +500,4 @@ async def tree(self, sandbox: SandboxRef, *, path: str, max_depth: int | None = def _parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes]: - from email.parser import BytesParser - - raw = f"Content-Type: {content_type}\r\n\r\n".encode() + body - msg = BytesParser().parsebytes(raw) - - result: dict[str, bytes] = {} - if not msg.is_multipart(): - raise ValueError( - f"Expected multipart response but got content_type={content_type!r} " - f"(body length={len(body)}, preview='')" - ) - for part in msg.get_payload(): # type: ignore[union-attr] - name = part.get_param("name", header="content-disposition") - if not name: - continue - content_type = part.get_content_type() - if content_type != "application/octet-stream": - raise ValueError( - f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" - ) - payload = part.get_payload(decode=True) - if payload is None: - raise ValueError( - f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" - ) - result[str(name)] = payload - return result + return parse_multipart_response(content_type, body) diff --git a/leap0/_schemas/process.py b/leap0/_schemas/process.py index 55af855..d8f7a44 100644 --- a/leap0/_schemas/process.py +++ b/leap0/_schemas/process.py @@ -7,3 +7,4 @@ class ProcessResultDict(TypedDict, total=False): exit_code: int stdout: str stderr: str + result: object diff --git a/leap0/_sync/desktop.py b/leap0/_sync/desktop.py index ff20cca..41c2562 100644 --- a/leap0/_sync/desktop.py +++ b/leap0/_sync/desktop.py @@ -7,6 +7,7 @@ from typing import cast import httpx +from pydantic import ValidationError from tenacity import retry, retry_if_exception, stop_after_delay, wait_exponential from .._internal.types import JsonObject @@ -633,16 +634,20 @@ def _read_event() -> None: "Malformed desktop status stream event " f"for sandbox={sandbox_id_of(sandbox)!r}, source='status_stream': {event!r}" ) - # Explicit error envelope from the server. - if "error" in event: - raise Leap0Error( - "Desktop status stream error", - body=str(event["error"]), - ) - if {"error", "message"}.intersection(event) and not {"status", "items", "running", "total"}.intersection(event): - error_event = DesktopStatusStreamErrorEvent.model_validate(event) - raise Leap0Error("Desktop status stream error", body=error_event.detail) - yield DesktopProcessStatusList.from_dict(cast(DesktopProcessStatusListDict, event)) + try: + yield DesktopProcessStatusList.from_dict(cast(DesktopProcessStatusListDict, event)) + continue + except (TypeError, ValueError) as status_error: + try: + error_event = DesktopStatusStreamErrorEvent.model_validate(event) + except ValidationError: + if "error" in event: + raise Leap0Error( + "Desktop status stream error", + body=str(event["error"]), + ) from status_error + raise status_error + raise Leap0Error("Desktop status stream error", body=error_event.detail) from status_error finally: response.close() diff --git a/leap0/_sync/filesystem.py b/leap0/_sync/filesystem.py index 83c3209..d8c9af1 100644 --- a/leap0/_sync/filesystem.py +++ b/leap0/_sync/filesystem.py @@ -7,6 +7,7 @@ from ..models.sandbox import SandboxRef, sandbox_id_of from .._schemas.filesystem import EditFileResponseDict, EditFilesResponseDict, ExistsResponseDict, FileInfoDict, GlobResponseDict, GrepResponseDict, LsResponseDict, TreeResponseDict from .._utils.errors import intercept_errors +from .._utils.multipart import parse_multipart_response from ._transport import Transport @@ -499,31 +500,4 @@ def tree(self, sandbox: SandboxRef, *, path: str, max_depth: int | None = None, def _parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes]: - from email.parser import BytesParser - - raw = f"Content-Type: {content_type}\r\n\r\n".encode() + body - msg = BytesParser().parsebytes(raw) - - result: dict[str, bytes] = {} - if not msg.is_multipart(): - body_preview = body[:200] if len(body) > 200 else body - raise ValueError( - f"Expected multipart response but got content_type={content_type!r} " - f"(body length={len(body)}, preview={body_preview!r})" - ) - for part in msg.get_payload(): # type: ignore[union-attr] - name = part.get_param("name", header="content-disposition") - if not name: - continue - content_type = part.get_content_type() - if content_type != "application/octet-stream": - raise ValueError( - f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" - ) - payload = part.get_payload(decode=True) - if payload is None: - raise ValueError( - f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {content_type}" - ) - result[str(name)] = payload - return result + return parse_multipart_response(content_type, body) diff --git a/leap0/_utils/multipart.py b/leap0/_utils/multipart.py new file mode 100644 index 0000000..9fc049f --- /dev/null +++ b/leap0/_utils/multipart.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from email.parser import BytesParser + + +def parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes]: + raw = f"Content-Type: {content_type}\r\n\r\n".encode() + body + msg = BytesParser().parsebytes(raw) + + result: dict[str, bytes] = {} + if not msg.is_multipart(): + raise ValueError( + f"Expected multipart response but got content_type={content_type!r} " + f"(body length={len(body)}, preview='')" + ) + for part in msg.get_payload(): # type: ignore[union-attr] + name = part.get_param("name", header="content-disposition") + if not name: + continue + part_content_type = part.get_content_type() + if part_content_type != "application/octet-stream": + raise ValueError( + f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {part_content_type}" + ) + payload = part.get_payload(decode=True) + if payload is None: + raise ValueError( + f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {part_content_type}" + ) + result[str(name)] = payload + return result diff --git a/leap0/_utils/stream.py b/leap0/_utils/stream.py index eab3ec9..2d078ff 100644 --- a/leap0/_utils/stream.py +++ b/leap0/_utils/stream.py @@ -21,15 +21,15 @@ def _sse_data_value(raw: str) -> str: return value -def _parse_sse_data(data: str) -> dict[str, Any] | str: +def _parse_sse_data(data: str) -> Any: try: parsed = json.loads(data) except json.JSONDecodeError: return data - return parsed if isinstance(parsed, dict) else data + return parsed if isinstance(parsed, (dict, list)) else data -def _emit_sse_event(buffer: list[str]) -> dict[str, Any] | str | None: +def _emit_sse_event(buffer: list[str]) -> dict[str, Any] | list[Any] | str | None: data_lines = [_sse_data_value(item) for item in buffer if item.startswith("data:")] if not data_lines: return None @@ -41,11 +41,16 @@ def _emit_sse_event(buffer: list[str]) -> dict[str, Any] | str | None: data = "\n".join(data_lines) if event_name == "error": + parsed = _parse_sse_data(data) + if isinstance(parsed, dict): + return parsed + if isinstance(parsed, list): + return {"error": parsed} return {"error": data} return _parse_sse_data(data) -def iter_sse_events(lines: Iterable[str]) -> Iterator[dict[str, Any] | str]: +def iter_sse_events(lines: Iterable[str]) -> Iterator[dict[str, Any] | list[Any] | str]: """Yield parsed events from an SSE line iterator.""" buffer: list[str] = [] for line in lines: @@ -66,7 +71,7 @@ def iter_sse_events(lines: Iterable[str]) -> Iterator[dict[str, Any] | str]: yield event -async def aiter_sse_events(lines: AsyncIterable[str]) -> AsyncIterator[dict[str, Any] | str]: +async def aiter_sse_events(lines: AsyncIterable[str]) -> AsyncIterator[dict[str, Any] | list[Any] | str]: """Yield parsed events from an asynchronous SSE line iterator.""" buffer: list[str] = [] async for line in lines: diff --git a/leap0/models/config.py b/leap0/models/config.py index e79c605..5f0f428 100644 --- a/leap0/models/config.py +++ b/leap0/models/config.py @@ -38,7 +38,8 @@ def _resolve_sdk_otel_enabled(value: bool | None) -> bool: sdk_otel_env = os.environ.get(LEAP0_SDK_OTEL_ENABLED_ENV) sdk_otel_env = sdk_otel_env.strip() if sdk_otel_env is not None else None if not sdk_otel_env: - return bool(os.environ.get(OTEL_EXPORTER_OTLP_ENDPOINT_ENV)) + endpoint = os.environ.get(OTEL_EXPORTER_OTLP_ENDPOINT_ENV) + return bool(endpoint and endpoint.strip()) lowered = sdk_otel_env.lower() if lowered == "true": diff --git a/leap0/models/desktop.py b/leap0/models/desktop.py index 09a1258..d2ab1a2 100644 --- a/leap0/models/desktop.py +++ b/leap0/models/desktop.py @@ -99,6 +99,8 @@ def _validate_click(self) -> "DesktopClickParams": for name, value in (("x", self.x), ("y", self.y)): if value is not None and value < 0: raise ValueError(f"{name} must be >= 0") + if self.button is not None and self.button not in {1, 2, 3}: + raise ValueError("button must be one of: 1, 2, 3") return self @@ -323,14 +325,22 @@ def from_dict(cls, data: DesktopProcessStatusListDict) -> DesktopProcessStatusLi return cls( status=_require_str(data, "status"), items=[ - DesktopProcessStatus.from_dict(item) # type: ignore[arg-type] - for item in raw_items - if isinstance(item, dict) + DesktopProcessStatus.from_dict(item) + for item in _validated_status_items(raw_items) ], running=_require_int(data, "running"), total=_require_int(data, "total"), ) + +def _validated_status_items(raw_items: list[Any] | tuple[Any, ...]) -> list[DesktopProcessStatusDict]: + validated_items: list[DesktopProcessStatusDict] = [] + for index, item in enumerate(raw_items): + if not isinstance(item, dict): + raise TypeError(f"Desktop response item at index {index} must be a mapping, got: {item!r}") + validated_items.append(item) + return validated_items + @dataclass(slots=True) class DesktopProcessRestart: """Result of restarting a desktop-side process.""" diff --git a/leap0/models/process.py b/leap0/models/process.py index 66024c3..47cfeb6 100644 --- a/leap0/models/process.py +++ b/leap0/models/process.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from .._schemas.process import ProcessResultDict @dataclass(slots=True) @@ -9,12 +9,41 @@ class ProcessResult: exit_code: int stdout: str stderr: str + _legacy_result: str | dict[str, str] | None = field(default=None, repr=False) + + @property + def result(self) -> str | dict[str, str]: + """Backward-compatible alias for legacy result payloads.""" + if self._legacy_result is not None: + return self._legacy_result + return self.stdout @classmethod def from_dict(cls, data: ProcessResultDict) -> ProcessResult: """Build an instance from a wire-format dictionary.""" + legacy_result = data.get("result") + stdout = data.get("stdout") + stderr = data.get("stderr") + + normalized_legacy_result: str | dict[str, str] | None = None + if isinstance(legacy_result, dict): + normalized_legacy_result = { + key: value + for key in ("stdout", "stderr") + if isinstance((value := legacy_result.get(key)), str) + } + if stdout is None: + stdout = legacy_result.get("stdout", "") + if stderr is None: + stderr = legacy_result.get("stderr", "") + elif legacy_result is not None and stdout is None and stderr is None: + normalized_legacy_result = legacy_result if isinstance(legacy_result, str) else None + stdout = legacy_result if isinstance(legacy_result, str) else "" + stderr = "" + return cls( exit_code=int(data.get("exit_code", 0)), - stdout=data.get("stdout", ""), - stderr=data.get("stderr", ""), + stdout=stdout if isinstance(stdout, str) else "", + stderr=stderr if isinstance(stderr, str) else "", + _legacy_result=normalized_legacy_result, ) diff --git a/leap0/models/sandbox.py b/leap0/models/sandbox.py index b7e2733..084a061 100644 --- a/leap0/models/sandbox.py +++ b/leap0/models/sandbox.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass from enum import Enum import ipaddress @@ -83,8 +84,15 @@ def _validate_network_policy(policy: NetworkPolicyDict | None) -> NetworkPolicyD if transforms is not None: if len(transforms) > 20: raise ValueError("network_policy.transforms must contain at most 20 entries") - for transform in transforms: - _validate_domain_pattern(transform["domain"]) + for index, transform in enumerate(transforms): + if not isinstance(transform, Mapping): + raise ValueError(f"network_policy.transforms[{index}] must be a mapping, got: {transform!r}") + domain = transform.get("domain") + if domain is None: + raise ValueError(f"network_policy.transforms[{index}] missing required 'domain': {transform!r}") + if not isinstance(domain, str): + raise ValueError(f"network_policy.transforms[{index}].domain must be a string, got: {domain!r}") + _validate_domain_pattern(domain) return policy diff --git a/leap0/models/snapshot.py b/leap0/models/snapshot.py index 04dd87f..1a38c4f 100644 --- a/leap0/models/snapshot.py +++ b/leap0/models/snapshot.py @@ -86,6 +86,8 @@ def from_dict(cls, data: SnapshotCreateResponseDict) -> Snapshot: raise ValueError( f"Snapshot response missing required non-empty string 'name', got: {snapshot_name!r}" ) + snapshot_id = snapshot_id.strip() + snapshot_name = snapshot_name.strip() state = data.get("state") return cls( id=snapshot_id, diff --git a/tests/_async/test_desktop.py b/tests/_async/test_desktop.py index 2a03ce6..3174264 100644 --- a/tests/_async/test_desktop.py +++ b/tests/_async/test_desktop.py @@ -27,7 +27,7 @@ async def run() -> None: with pytest.raises(Leap0Error, match="x and y must be provided together"): await client.click("sbx-1", x=10) - assert async_mock_transport.request.call_count == 0 + assert async_mock_transport.request_target.call_count == 0 assert async_mock_transport.request_target_json.call_count == 0 asyncio.run(run()) @@ -111,6 +111,27 @@ async def aiter_lines(): asyncio.run(run()) + def test_status_stream_raises_structured_error_detail(self, async_mock_transport): + async def run() -> None: + response = MagicMock() + + async def aiter_lines(): + yield 'event: error' + yield 'data: {"message": "Desktop request failed"}' + yield "" + + response.aiter_lines = aiter_lines + response.aclose = AsyncMock() + async_mock_transport.stream.return_value = response + + with pytest.raises(Leap0Error, match="Desktop status stream error") as exc_info: + async for _ in AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1"): + pass + + assert exc_info.value.body == "Desktop request failed" + + asyncio.run(run()) + def test_process_status_requires_documented_fields(self, async_mock_transport): async def run() -> None: async_mock_transport.request_target_json.return_value = {"items": [], "running": 0, "total": 0} diff --git a/tests/_sync/test_client_config.py b/tests/_sync/test_client_config.py index 4582f67..004f486 100644 --- a/tests/_sync/test_client_config.py +++ b/tests/_sync/test_client_config.py @@ -59,3 +59,11 @@ def test_legacy_otel_env_no_longer_enables_sdk(monkeypatch): config = Leap0Config() assert config.sdk_otel_enabled is False + + +def test_blank_otel_endpoint_does_not_enable_sdk(monkeypatch): + monkeypatch.setenv("LEAP0_API_KEY", "test-key") + monkeypatch.delenv("LEAP0_SDK_OTEL_ENABLED", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", " ") + + assert Leap0Config().sdk_otel_enabled is False diff --git a/tests/_sync/test_desktop.py b/tests/_sync/test_desktop.py index c51117f..9a78917 100644 --- a/tests/_sync/test_desktop.py +++ b/tests/_sync/test_desktop.py @@ -25,7 +25,7 @@ def test_validates_request_payloads(self, mock_transport): with pytest.raises(Leap0Error, match="x and y must be provided together"): client.click("sbx-1", x=10) - assert mock_transport.request.call_count == 0 + assert mock_transport.request_target.call_count == 0 assert mock_transport.request_target_json.call_count == 0 def test_screenshot_allows_zero_sized_paired_region_query(self, mock_transport): @@ -107,6 +107,20 @@ def test_status_stream_raises_on_plain_text_error_event(self, mock_transport): with pytest.raises(Leap0Error, match="Desktop status stream error"): list(DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1")) + def test_status_stream_raises_structured_error_detail(self, mock_transport): + response = MagicMock() + response.iter_lines.return_value = iter([ + "event: error", + 'data: {"message": "Desktop request failed"}', + "", + ]) + mock_transport.stream.return_value = response + + with pytest.raises(Leap0Error, match="Desktop status stream error") as exc_info: + list(DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1")) + + assert exc_info.value.body == "Desktop request failed" + def test_process_status_requires_documented_fields(self, mock_transport): mock_transport.request_target_json.return_value = {"items": [], "running": 0, "total": 0} diff --git a/tests/_sync/test_filesystem.py b/tests/_sync/test_filesystem.py index fa94eaa..37c95e9 100644 --- a/tests/_sync/test_filesystem.py +++ b/tests/_sync/test_filesystem.py @@ -5,6 +5,7 @@ import pytest from leap0._sync.filesystem import FilesystemClient, _parse_multipart_response +from leap0.models.errors import Leap0Error from leap0.models.filesystem import FileEdit @@ -51,17 +52,17 @@ def test_read_bytes(self, mock_transport): assert FilesystemClient(mock_transport).read_bytes("sbx-1", path="/workspace/hello.bin") == b"Hello" def test_read_bytes_rejects_head_and_tail(self, mock_transport): - with pytest.raises(Exception, match="mutually exclusive"): + with pytest.raises(Leap0Error, match="mutually exclusive"): FilesystemClient(mock_transport).read_bytes("sbx-1", path="/workspace/hello.bin", head=1, tail=1) def test_set_permissions_rejects_missing_or_blank_updates(self, mock_transport): client = FilesystemClient(mock_transport) - with pytest.raises(Exception, match="at least one of mode, owner, or group"): + with pytest.raises(Leap0Error, match="at least one of mode, owner, or group"): client.set_permissions("sbx-1", path="/workspace/a.txt") - with pytest.raises(Exception, match="mode must be a non-empty string"): + with pytest.raises(Leap0Error, match="mode must be a non-empty string"): client.set_permissions("sbx-1", path="/workspace/a.txt", mode=" ") - with pytest.raises(Exception, match="owner must be a non-empty string"): + with pytest.raises(Leap0Error, match="owner must be a non-empty string"): client.set_permissions("sbx-1", path="/workspace/a.txt", owner="") assert mock_transport.request.call_count == 0 diff --git a/tests/_utils/test_stream.py b/tests/_utils/test_stream.py index 5f30450..b31e2be 100644 --- a/tests/_utils/test_stream.py +++ b/tests/_utils/test_stream.py @@ -37,6 +37,9 @@ def test_non_data_fields_ignored(self): def test_plain_text_data_preserved(self): assert list(iter_sse_events(["event: error", "data: desktop stream failed", ""])) == [{"error": "desktop stream failed"}] + def test_error_json_data_parsed(self): + assert list(iter_sse_events(["event: error", 'data: {"error":"boom"}', ""])) == [{"error": "boom"}] + class TestIterNdjson: def test_standard(self): diff --git a/tests/models/test_desktop.py b/tests/models/test_desktop.py index 6e6a4c3..658fb98 100644 --- a/tests/models/test_desktop.py +++ b/tests/models/test_desktop.py @@ -3,7 +3,7 @@ import pytest from leap0.models.desktop import ( - DesktopDisplayInfo, DesktopHealth, DesktopPointerPosition, DesktopProcessErrors, + DesktopClickParams, DesktopDisplayInfo, DesktopHealth, DesktopPointerPosition, DesktopProcessErrors, DesktopProcessLogs, DesktopProcessRestart, DesktopProcessStatus, DesktopProcessStatusList, DesktopRecordingStatus, DesktopRecordingSummary, DesktopWindow, ) @@ -63,13 +63,19 @@ def test_from_dict(self): "name": "xvfb", "running": True, "pid": 123, - "stdout_log": "/tmp/xvfb.stdout.log", - "stderr_log": "/tmp/xvfb.stderr.log", + "stdout_log": "placeholder/xvfb.stdout.log", + "stderr_log": "placeholder/xvfb.stderr.log", } ) assert p.running is True +class TestDesktopClickParams: + def test_rejects_invalid_button(self): + with pytest.raises(ValueError, match="button must be one of: 1, 2, 3"): + DesktopClickParams(button=4) + + class TestDesktopProcessStatusList: def test_from_dict(self): sl = DesktopProcessStatusList.from_dict( @@ -80,8 +86,8 @@ def test_from_dict(self): "name": "xvfb", "running": True, "pid": 1, - "stdout_log": "/tmp/xvfb.stdout.log", - "stderr_log": "/tmp/xvfb.stderr.log", + "stdout_log": "placeholder/xvfb.stdout.log", + "stderr_log": "placeholder/xvfb.stderr.log", } ], "running": 1, @@ -95,6 +101,10 @@ def test_requires_status_fields(self): with pytest.raises(ValueError, match="missing array 'items'"): DesktopProcessStatusList.from_dict({}) + def test_rejects_non_mapping_item(self): + with pytest.raises(TypeError, match="item at index 0 must be a mapping"): + DesktopProcessStatusList.from_dict({"status": "running", "items": ["bad"], "running": 0, "total": 1}) + class TestDesktopProcessRestart: def test_with_status(self): @@ -105,8 +115,8 @@ def test_with_status(self): "name": "xvfb", "running": True, "pid": 42, - "stdout_log": "/tmp/xvfb.stdout.log", - "stderr_log": "/tmp/xvfb.stderr.log", + "stdout_log": "placeholder/xvfb.stdout.log", + "stderr_log": "placeholder/xvfb.stderr.log", }, } ) diff --git a/tests/models/test_process.py b/tests/models/test_process.py index 421dc5e..f7e5aa2 100644 --- a/tests/models/test_process.py +++ b/tests/models/test_process.py @@ -9,3 +9,15 @@ def test_from_dict(self): assert r.exit_code == 1 assert r.stdout == "hello" assert r.stderr == "error output" + + def test_from_dict_accepts_legacy_result_string(self): + r = ProcessResult.from_dict({"exit_code": 0, "result": "hello"}) + assert r.stdout == "hello" + assert r.stderr == "" + assert r.result == "hello" + + def test_from_dict_accepts_legacy_result_mapping(self): + r = ProcessResult.from_dict({"exit_code": 0, "result": {"stdout": "hello", "stderr": "warn"}}) + assert r.stdout == "hello" + assert r.stderr == "warn" + assert r.result == {"stdout": "hello", "stderr": "warn"} diff --git a/tests/models/test_sandbox.py b/tests/models/test_sandbox.py index 33f405e..9229822 100644 --- a/tests/models/test_sandbox.py +++ b/tests/models/test_sandbox.py @@ -2,7 +2,7 @@ import pytest -from leap0.models.sandbox import CreateSandboxParams, Sandbox, SandboxStatus, sandbox_id_of +from leap0.models.sandbox import CreateSandboxParams, Sandbox, SandboxStatus, _validate_network_policy, sandbox_id_of class TestSandboxIdOf: @@ -67,3 +67,9 @@ def test_rejects_invalid_network_policy(self): with pytest.raises(ValueError, match="allow_domains must contain at most 50"): CreateSandboxParams(network_policy={"mode": "custom", "allow_domains": ["a.example.com"] * 51}) + + with pytest.raises(ValueError, match=r"transforms\[0\] missing required 'domain'"): + _validate_network_policy({"mode": "custom", "transforms": [{"rewrite": "x"}]}) + + with pytest.raises(ValueError, match=r"transforms\[0\] must be a mapping"): + _validate_network_policy({"mode": "custom", "transforms": ["bad"]}) diff --git a/tests/models/test_snapshot.py b/tests/models/test_snapshot.py index 3a012ae..e7f3455 100644 --- a/tests/models/test_snapshot.py +++ b/tests/models/test_snapshot.py @@ -29,10 +29,39 @@ def test_from_dict_with_state(self): s = Snapshot.from_dict({"id": "snap-1", "name": "my-snap", "state": "paused"}) assert s.state == "paused" + def test_from_dict_strips_required_fields(self): + s = Snapshot.from_dict({"id": " snap-1 ", "name": " my-snap "}) + assert s.id == "snap-1" + assert s.name == "my-snap" + def test_from_dict_requires_id(self): with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'id'"): Snapshot.from_dict({"name": "my-snap"}) + def test_from_dict_rejects_empty_id(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'id'"): + Snapshot.from_dict({"id": "", "name": "my-snap"}) + + def test_from_dict_rejects_whitespace_only_id(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'id'"): + Snapshot.from_dict({"id": " ", "name": "my-snap"}) + + def test_from_dict_rejects_non_string_id(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'id'"): + Snapshot.from_dict({"id": 123, "name": "my-snap"}) + def test_from_dict_requires_name(self): with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'name'"): Snapshot.from_dict({"id": "snap-1"}) + + def test_from_dict_rejects_empty_name(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'name'"): + Snapshot.from_dict({"id": "snap-1", "name": ""}) + + def test_from_dict_rejects_whitespace_only_name(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'name'"): + Snapshot.from_dict({"id": "snap-1", "name": " "}) + + def test_from_dict_rejects_non_string_name(self): + with pytest.raises(ValueError, match="Snapshot response missing required non-empty string 'name'"): + Snapshot.from_dict({"id": "snap-1", "name": 123}) From 3bd158867f90fa696a0e7f2391c11415fd528b52 Mon Sep 17 00:00:00 2001 From: steven-passynkov Date: Mon, 6 Apr 2026 18:20:58 -0400 Subject: [PATCH 3/4] Fixes --- leap0/_async/filesystem.py | 2 +- leap0/_sync/desktop.py | 2 +- leap0/_sync/filesystem.py | 10 +++++----- leap0/_utils/multipart.py | 13 ++++++++++--- leap0/models/sandbox.py | 5 ++--- tests/_async/test_filesystem.py | 2 +- tests/_sync/test_filesystem.py | 11 ++++++----- tests/models/test_sandbox.py | 2 +- 8 files changed, 27 insertions(+), 20 deletions(-) diff --git a/leap0/_async/filesystem.py b/leap0/_async/filesystem.py index 2953ef5..00b366d 100644 --- a/leap0/_async/filesystem.py +++ b/leap0/_async/filesystem.py @@ -500,4 +500,4 @@ async def tree(self, sandbox: SandboxRef, *, path: str, max_depth: int | None = def _parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes]: - return parse_multipart_response(content_type, body) + return parse_multipart_response(content_type, body, operation="read_files") diff --git a/leap0/_sync/desktop.py b/leap0/_sync/desktop.py index 41c2562..9f9a44e 100644 --- a/leap0/_sync/desktop.py +++ b/leap0/_sync/desktop.py @@ -646,7 +646,7 @@ def _read_event() -> None: "Desktop status stream error", body=str(event["error"]), ) from status_error - raise status_error + raise raise Leap0Error("Desktop status stream error", body=error_event.detail) from status_error finally: response.close() diff --git a/leap0/_sync/filesystem.py b/leap0/_sync/filesystem.py index d8c9af1..ff8643d 100644 --- a/leap0/_sync/filesystem.py +++ b/leap0/_sync/filesystem.py @@ -277,7 +277,11 @@ def read_files_bytes( json={"paths": paths}, timeout=http_timeout, ) - return _parse_multipart_response(response.headers.get("content-type", ""), response.content) + return parse_multipart_response( + response.headers.get("content-type", ""), + response.content, + operation="read_files", + ) @intercept_errors("Failed to read files: ") def read_files(self, sandbox: SandboxRef, *, paths: list[str], encoding: str = "utf-8", http_timeout: float | None = None) -> dict[str, str]: @@ -497,7 +501,3 @@ def tree(self, sandbox: SandboxRef, *, path: str, max_depth: int | None = None, payload["exclude"] = exclude data = cast(TreeResponseDict, self._transport.request_json("POST", f"/v1/sandbox/{sandbox_id_of(sandbox)}/filesystem/tree", json=payload, timeout=http_timeout)) return TreeResult.from_dict(data) - - -def _parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes]: - return parse_multipart_response(content_type, body) diff --git a/leap0/_utils/multipart.py b/leap0/_utils/multipart.py index 9fc049f..02b070c 100644 --- a/leap0/_utils/multipart.py +++ b/leap0/_utils/multipart.py @@ -3,9 +3,16 @@ from email.parser import BytesParser -def parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes]: +def parse_multipart_response( + content_type: str, + body: bytes, + *, + subject: str = "multipart body", + operation: str | None = None, +) -> dict[str, bytes]: raw = f"Content-Type: {content_type}\r\n\r\n".encode() + body msg = BytesParser().parsebytes(raw) + target = f"{operation} {subject}" if operation else subject result: dict[str, bytes] = {} if not msg.is_multipart(): @@ -20,12 +27,12 @@ def parse_multipart_response(content_type: str, body: bytes) -> dict[str, bytes] part_content_type = part.get_content_type() if part_content_type != "application/octet-stream": raise ValueError( - f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {part_content_type}" + f"Failed to parse {target}: expected file bytes for entry {name!r}, got {part_content_type}" ) payload = part.get_payload(decode=True) if payload is None: raise ValueError( - f"Failed to parse /read-files response: expected file bytes for entry {name!r}, got {part_content_type}" + f"Failed to parse {target}: expected file bytes for entry {name!r}, got {part_content_type}" ) result[str(name)] = payload return result diff --git a/leap0/models/sandbox.py b/leap0/models/sandbox.py index 084a061..15130b4 100644 --- a/leap0/models/sandbox.py +++ b/leap0/models/sandbox.py @@ -67,8 +67,7 @@ def _validate_network_policy(policy: NetworkPolicyDict | None) -> NetworkPolicyD if allow_domains is not None: if len(allow_domains) > 50: raise ValueError("network_policy.allow_domains must contain at most 50 entries") - for domain in allow_domains: - _validate_domain_pattern(domain) + policy["allow_domains"] = [_validate_domain_pattern(domain) for domain in allow_domains] allow_cidrs = policy.get("allow_cidrs") if allow_cidrs is not None: @@ -92,7 +91,7 @@ def _validate_network_policy(policy: NetworkPolicyDict | None) -> NetworkPolicyD raise ValueError(f"network_policy.transforms[{index}] missing required 'domain': {transform!r}") if not isinstance(domain, str): raise ValueError(f"network_policy.transforms[{index}].domain must be a string, got: {domain!r}") - _validate_domain_pattern(domain) + transform["domain"] = _validate_domain_pattern(domain) return policy diff --git a/tests/_async/test_filesystem.py b/tests/_async/test_filesystem.py index d327392..0cb1b6f 100644 --- a/tests/_async/test_filesystem.py +++ b/tests/_async/test_filesystem.py @@ -82,5 +82,5 @@ def test_text_part_raises(self): f"Content-Type: text/plain; charset=utf-8\r\n\r\ncontent a\r\n" f"--{boundary}--\r\n" ).encode() - with pytest.raises(ValueError, match="Failed to parse /read-files response"): + with pytest.raises(ValueError, match="Failed to parse read_files multipart body"): _parse_multipart_response(f"multipart/form-data; boundary={boundary}", body) diff --git a/tests/_sync/test_filesystem.py b/tests/_sync/test_filesystem.py index 37c95e9..3cf81f4 100644 --- a/tests/_sync/test_filesystem.py +++ b/tests/_sync/test_filesystem.py @@ -4,7 +4,8 @@ import pytest -from leap0._sync.filesystem import FilesystemClient, _parse_multipart_response +from leap0._sync.filesystem import FilesystemClient +from leap0._utils.multipart import parse_multipart_response from leap0.models.errors import Leap0Error from leap0.models.filesystem import FileEdit @@ -94,12 +95,12 @@ def test_valid(self): f"Content-Type: application/octet-stream\r\n\r\ncontent a\r\n" f"--{boundary}--\r\n" ).encode() - result = _parse_multipart_response(f"multipart/form-data; boundary={boundary}", body) + result = parse_multipart_response(f"multipart/form-data; boundary={boundary}", body) assert result["/a.txt"] == b"content a" def test_non_multipart_raises(self): with pytest.raises(ValueError, match="Expected multipart"): - _parse_multipart_response("application/json", b'{"error": "bad"}') + parse_multipart_response("application/json", b'{"error": "bad"}') def test_text_part_raises(self): boundary = "boundary123" @@ -108,5 +109,5 @@ def test_text_part_raises(self): f"Content-Type: text/plain; charset=utf-8\r\n\r\ncontent a\r\n" f"--{boundary}--\r\n" ).encode() - with pytest.raises(ValueError, match="Failed to parse /read-files response"): - _parse_multipart_response(f"multipart/form-data; boundary={boundary}", body) + with pytest.raises(ValueError, match="Failed to parse read_files multipart body"): + parse_multipart_response(f"multipart/form-data; boundary={boundary}", body, operation="read_files") diff --git a/tests/models/test_sandbox.py b/tests/models/test_sandbox.py index 9229822..ef61435 100644 --- a/tests/models/test_sandbox.py +++ b/tests/models/test_sandbox.py @@ -56,7 +56,7 @@ def test_empty_dict_raises(self): class TestCreateSandboxParams: def test_rejects_invalid_network_policy(self): - with pytest.raises(ValueError, match="network_policy.mode"): + with pytest.raises(ValueError, match=r"network_policy\.mode"): CreateSandboxParams(network_policy={"mode": "nope"}) with pytest.raises(ValueError, match="invalid network policy domain pattern"): From 05ab76caaf564b4865de911c61b092169369ec0c Mon Sep 17 00:00:00 2001 From: steven-passynkov Date: Mon, 6 Apr 2026 18:35:57 -0400 Subject: [PATCH 4/4] Fixes --- leap0/_async/desktop.py | 2 -- leap0/_sync/desktop.py | 5 ----- leap0/_utils/stream.py | 12 ------------ leap0/models/desktop.py | 9 ++++----- tests/_async/test_desktop.py | 19 ------------------- tests/_sync/test_desktop.py | 12 ------------ tests/_utils/test_stream.py | 3 --- 7 files changed, 4 insertions(+), 58 deletions(-) diff --git a/leap0/_async/desktop.py b/leap0/_async/desktop.py index 71f0c26..c5d42cd 100644 --- a/leap0/_async/desktop.py +++ b/leap0/_async/desktop.py @@ -578,8 +578,6 @@ async def status_stream(self, sandbox: SandboxRef, *, deadline: float | None = N try: error_event = DesktopStatusStreamErrorEvent.model_validate(event) except ValidationError: - if "error" in event: - raise Leap0Error("Desktop status stream error", body=str(event["error"])) from status_error raise status_error raise Leap0Error("Desktop status stream error", body=error_event.detail) from status_error finally: diff --git a/leap0/_sync/desktop.py b/leap0/_sync/desktop.py index 9f9a44e..9bd8f91 100644 --- a/leap0/_sync/desktop.py +++ b/leap0/_sync/desktop.py @@ -641,11 +641,6 @@ def _read_event() -> None: try: error_event = DesktopStatusStreamErrorEvent.model_validate(event) except ValidationError: - if "error" in event: - raise Leap0Error( - "Desktop status stream error", - body=str(event["error"]), - ) from status_error raise raise Leap0Error("Desktop status stream error", body=error_event.detail) from status_error finally: diff --git a/leap0/_utils/stream.py b/leap0/_utils/stream.py index 2d078ff..15e7341 100644 --- a/leap0/_utils/stream.py +++ b/leap0/_utils/stream.py @@ -34,19 +34,7 @@ def _emit_sse_event(buffer: list[str]) -> dict[str, Any] | list[Any] | str | Non if not data_lines: return None - event_name: str | None = None - for item in buffer: - if item.startswith("event:"): - event_name = item[6:].lstrip(" ") - data = "\n".join(data_lines) - if event_name == "error": - parsed = _parse_sse_data(data) - if isinstance(parsed, dict): - return parsed - if isinstance(parsed, list): - return {"error": parsed} - return {"error": data} return _parse_sse_data(data) diff --git a/leap0/models/desktop.py b/leap0/models/desktop.py index d2ab1a2..90ab30b 100644 --- a/leap0/models/desktop.py +++ b/leap0/models/desktop.py @@ -124,19 +124,18 @@ class DesktopStatusStreamErrorEvent(BaseModel): model_config = ConfigDict(extra="allow") - error: str | None = None - message: str | None = None + message: str @model_validator(mode="after") def _validate_error(self) -> "DesktopStatusStreamErrorEvent": - if self.error is None and self.message is None: - raise ValueError("Desktop status stream error event must include error or message") + if not self.message.strip(): + raise ValueError("Desktop status stream error event must include a non-empty message") return self @property def detail(self) -> str: """Return the normalized human-readable error detail.""" - return self.error or self.message or "unknown desktop status stream error" + return self.message def _require_str(data: dict[str, Any], field: str) -> str: value = data.get(field) diff --git a/tests/_async/test_desktop.py b/tests/_async/test_desktop.py index 3174264..e8d4d63 100644 --- a/tests/_async/test_desktop.py +++ b/tests/_async/test_desktop.py @@ -92,25 +92,6 @@ async def aiter_lines(): asyncio.run(run()) - def test_status_stream_raises_on_plain_text_error_event(self, async_mock_transport): - async def run() -> None: - response = MagicMock() - - async def aiter_lines(): - yield "event: error" - yield "data: Desktop request failed" - yield "" - - response.aiter_lines = aiter_lines - response.aclose = AsyncMock() - async_mock_transport.stream.return_value = response - - with pytest.raises(Leap0Error, match="Desktop status stream error"): - async for _ in AsyncDesktopClient(async_mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1"): - pass - - asyncio.run(run()) - def test_status_stream_raises_structured_error_detail(self, async_mock_transport): async def run() -> None: response = MagicMock() diff --git a/tests/_sync/test_desktop.py b/tests/_sync/test_desktop.py index 9a78917..6d9c63e 100644 --- a/tests/_sync/test_desktop.py +++ b/tests/_sync/test_desktop.py @@ -95,18 +95,6 @@ def test_wait_until_ready_accepts_count_only_running_updates(self, mock_transpor DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").wait_until_ready("sbx-1", timeout=1) - def test_status_stream_raises_on_plain_text_error_event(self, mock_transport): - response = MagicMock() - response.iter_lines.return_value = iter([ - "event: error", - "data: Desktop request failed", - "", - ]) - mock_transport.stream.return_value = response - - with pytest.raises(Leap0Error, match="Desktop status stream error"): - list(DesktopClient(mock_transport, sandbox_domain="sandbox.example.com").status_stream("sbx-1")) - def test_status_stream_raises_structured_error_detail(self, mock_transport): response = MagicMock() response.iter_lines.return_value = iter([ diff --git a/tests/_utils/test_stream.py b/tests/_utils/test_stream.py index b31e2be..4a10ee7 100644 --- a/tests/_utils/test_stream.py +++ b/tests/_utils/test_stream.py @@ -34,9 +34,6 @@ def test_leading_space_stripped(self): def test_non_data_fields_ignored(self): assert list(iter_sse_events(["event: update", "id: 42", "data: {\"ok\": true}", ""])) == [{"ok": True}] - def test_plain_text_data_preserved(self): - assert list(iter_sse_events(["event: error", "data: desktop stream failed", ""])) == [{"error": "desktop stream failed"}] - def test_error_json_data_parsed(self): assert list(iter_sse_events(["event: error", 'data: {"error":"boom"}', ""])) == [{"error": "boom"}]