From 29bf7fc6da09e96abe614ac17617a5e42e4e688d Mon Sep 17 00:00:00 2001 From: xiami762 <> Date: Tue, 26 May 2026 17:31:43 +0800 Subject: [PATCH] feat(web2cli,session): multipart payloads, manual auth headers, tool-loop exit fix Extend web2cli spec/CLI generation for multipart uploads, manual HEADER auth, and time-range params; keep the session loop running when assistant messages still contain tool parts so results can be fed back to the model. Co-authored-by: Cursor --- .flocks/plugins/skills/onesec-use/SKILL.md | 2 +- .flocks/plugins/skills/web2cli/SKILL.md | 25 ++- .../skills/web2cli/scripts/generate-cli.py | 148 +++++++++++++++-- .../skills/web2cli/scripts/generate-spec.py | 105 +++++++++--- flocks/browser/admin.py | 2 +- flocks/session/session_loop.py | 19 ++- tests/session/test_session_abort_inject.py | 129 +++++++++++++++ tests/tool/test_web2cli_generate_cli.py | 155 +++++++++++++++++- tests/tool/test_web2cli_generate_spec.py | 115 +++++++++++++ 9 files changed, 646 insertions(+), 54 deletions(-) diff --git a/.flocks/plugins/skills/onesec-use/SKILL.md b/.flocks/plugins/skills/onesec-use/SKILL.md index 408be8dc4..968ba0725 100644 --- a/.flocks/plugins/skills/onesec-use/SKILL.md +++ b/.flocks/plugins/skills/onesec-use/SKILL.md @@ -1,6 +1,6 @@ --- name: onesec-use -description: 用于处理 OneSEC 终端安全平台相关任务,适合通过API或者结合浏览器进行以下任务: 终端安全调查、威胁事件分析、终端告警检索、行为日志排查、IOC 查询、恶意文件分析、DNS 威胁排查、软件与终端资产查询、任务进度查看、审计日志分析、病毒扫描和常见终端处置场景。只要用户提到 OneSEC、微步 EDR等相关操纵需求时,必须先加载本 skill。本 skill 是 OneSEC 平台操作的唯一决策入口:在未阅读本 skill 并完成模式判断前,不要直接调用任何 `onesec_*` tool。 +description: 用于处理 OneSEC/OneDNS 终端安全平台相关任务,适合通过API或者结合浏览器进行以下任务: 终端安全调查、威胁事件分析、终端告警检索、行为日志排查、IOC 查询、恶意文件分析、DNS 威胁排查、软件与终端资产查询、任务进度查看、审计日志分析、病毒扫描和常见终端处置场景。只要用户提到 OneSEC、微步 EDR等相关操纵需求时,必须先加载本 skill。本 skill 是 OneSEC 平台操作的唯一决策入口:在未阅读本 skill 并完成模式判断前,不要直接调用任何 `onesec_*` tool。 --- # OneSEC Use diff --git a/.flocks/plugins/skills/web2cli/SKILL.md b/.flocks/plugins/skills/web2cli/SKILL.md index 0f0311d83..27fe5685f 100644 --- a/.flocks/plugins/skills/web2cli/SKILL.md +++ b/.flocks/plugins/skills/web2cli/SKILL.md @@ -270,7 +270,7 @@ uv run python .flocks/plugins/skills/web2cli/scripts/generate-spec.py \ - 目标站点与命令名 - 鉴权策略(如 `PUBLIC` / `COOKIE` / `HEADER`) -- 主请求的 method、endpoint、query/body 模板 +- 主请求的 method、endpoint、query/body/payload 模板 - CLI 参数定义 - 固定输出列定义 - 验证材料初稿 @@ -314,16 +314,24 @@ uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ --output "$CAPTURE_ROOT/${CAPTURE_NAME}_api.md" ``` -### 10. CLI工具验证 和浏览器关闭 +### 10. CLI工具验证与修改 根据生成的 CLI ,任意选择一个接口调用测试可用性 - CLI 工具可用性 - 认证状态可用性 - `verify.json` 的输出约束是否满足 +- method、endpoint、query/body/payload 的一致性,必要时根据${CAPTURE_NAME}_api.json调整 推荐先查看 `"$CAPTURE_ROOT/${CAPTURE_NAME}_verify.json"`,再用生成的 CLI 以默认参数执行一次,确认固定输出列与认证状态都正确。 -当验证完成,确保 CLI 可用后关闭浏览器或 Tab +### 11. CLI 工具集成到skill + +将 CLI 按 `references/cli-in-skill.md` 集成为 skill; + +### 12. summary并关闭浏览器 tab + +1. 总结当前生成的 CLI 工具有哪些接口/能力 +2. 确保 CLI 可用后关闭浏览器或 Tab #### 关闭浏览器或 Tab @@ -343,17 +351,6 @@ else: 必须保留用户原有的 tab 不受影响。 -### 11. CLI 工具集成到skill - -将 CLI 按 `references/cli-in-skill.md` 集成为 skill; - -### 12. summary - -总结当前生成的 CLI 工具有哪些能力,然后可提示用户下一步操作: - -- 精简或修正 CLI -- 若仍需扩展能力或沉淀为 skill,回到步骤 11 - ## 故障处理 ### Hook 注入报错 diff --git a/.flocks/plugins/skills/web2cli/scripts/generate-cli.py b/.flocks/plugins/skills/web2cli/scripts/generate-cli.py index 67295d931..c2bed6901 100644 --- a/.flocks/plugins/skills/web2cli/scripts/generate-cli.py +++ b/.flocks/plugins/skills/web2cli/scripts/generate-cli.py @@ -711,6 +711,9 @@ def generate_postman_collection_from_spec(spec: Dict[str, Any]) -> Dict[str, Any def generate_python_cli_from_spec(spec: Dict[str, Any]) -> str: """Generate a fixed command CLI script from a web2cli spec.""" spec_json = json.dumps(spec, indent=2, ensure_ascii=False) + spec_json = re.sub(r'\btrue\b', 'True', spec_json) + spec_json = re.sub(r'\bfalse\b', 'False', spec_json) + spec_json = re.sub(r'\bnull\b', 'None', spec_json) return '''#!/usr/bin/env python3 """ Auto-generated Web2CLI command script. @@ -720,6 +723,7 @@ def generate_python_cli_from_spec(spec: Dict[str, Any]) -> str: import argparse import csv import json +from pathlib import Path import re import sys from typing import Any, Dict, List @@ -764,6 +768,25 @@ def _coerce_bool(value: str) -> bool: raise argparse.ArgumentTypeError(f"invalid boolean value: {value}") +def _auth_header_dest(header_name: str) -> str: + normalized = re.sub(r"[^A-Za-z0-9]+", "_", str(header_name or "")).strip("_").lower() + return f"auth_header_{normalized or 'value'}" + + +def _manual_auth_rules() -> List[Dict[str, Any]]: + auth = SPEC.get("auth", {}) + if not isinstance(auth, dict): + return [] + rules = auth.get("requiredHeaders", []) + if not isinstance(rules, list): + return [] + return [ + rule + for rule in rules + if isinstance(rule, dict) and rule.get("source") == "manual" and rule.get("name") + ] + + def _type_name(value: Any) -> str: if value is None: return "null" @@ -992,12 +1015,55 @@ def _extract_first(cls, value: Any, path: str) -> Any: values = cls._extract_many(value, path) return values[0] if values else None - def __init__(self, base_url: str = SPEC.get("baseUrl", ""), auth_state: str = "auth-state.json"): + @staticmethod + def _stringify_multipart_value(value: Any) -> str: + if value is None: + return "" + if isinstance(value, (dict, list)): + return json.dumps(value, ensure_ascii=False) + return str(value) + + @classmethod + def _build_multipart_files( + cls, + body: Dict[str, Any], + file_fields: List[str], + ) -> tuple[List[Any], List[Any]]: + files = [] + opened_files = [] + target_fields = {str(item) for item in file_fields if item} + for key, value in (body or {}).items(): + if key in target_fields: + file_path = Path(str(value or "")) + if not str(value or "").strip(): + raise SystemExit(f"missing required multipart file path: {key}") + try: + handle = file_path.open("rb") + except OSError as error: + raise SystemExit(f"failed to open multipart file for {key}: {error}") from error + opened_files.append(handle) + files.append((key, (file_path.name, handle))) + else: + files.append((key, (None, cls._stringify_multipart_value(value)))) + return files, opened_files + + def __init__( + self, + base_url: str = SPEC.get("baseUrl", ""), + auth_state: str = "auth-state.json", + manual_headers: Dict[str, str] | None = None, + ): self.base_url = (base_url or SPEC.get("baseUrl", "")).rstrip("/") self.auth_state_path = auth_state self.auth_state = _load_json(auth_state) if auth_state else {} if not isinstance(self.auth_state, dict): self.auth_state = {} + raw_manual_headers = manual_headers if isinstance(manual_headers, dict) else {} + self.manual_headers = { + str(key): str(value) + for key, value in raw_manual_headers.items() + if value not in (None, "") + } self.session = requests.Session() self._apply_auth_state() @@ -1009,16 +1075,25 @@ def _apply_auth_state(self) -> None: self.session.headers.update(headers) if strategy == "HEADER": + missing_manual_headers = [] for rule in auth.get("requiredHeaders", []): if not isinstance(rule, dict) or not rule.get("name"): continue source = rule.get("source") if source == "cookie": value = self._resolve_cookie_value(rule.get("key")) + elif source == "manual": + value = self.manual_headers.get(str(rule["name"])) else: value = self._resolve_header_value(self.auth_state, rule) if value is not None: self.session.headers[str(rule["name"])] = value + elif source == "manual": + missing_manual_headers.append(str(rule["name"])) + if missing_manual_headers: + raise SystemExit( + "missing required auth headers: " + ", ".join(sorted(missing_manual_headers)) + ) def build_request(self, args: Dict[str, Any], entry: Dict[str, Any]) -> Dict[str, Any]: operation = entry.get("operation", {}) @@ -1035,12 +1110,24 @@ def build_request(self, args: Dict[str, Any], entry: Dict[str, Any]) -> Dict[str "params": query or None, "json": None, "data": None, + "files": None, + "opened_files": [], "headers": headers or None, } if payload_mode == "json": request_options["json"] = body or None elif payload_mode == "form": request_options["data"] = body or None + elif payload_mode == "multipart": + multipart_body = body if isinstance(body, dict) else {} + multipart_files, opened_files = self._build_multipart_files( + multipart_body, + operation.get("multipartFileFields", []), + ) + headers.pop("Content-Type", None) + headers.pop("content-type", None) + request_options["files"] = multipart_files or None + request_options["opened_files"] = opened_files elif payload_mode == "raw": request_options["data"] = raw_body or None if cookie_strategy in {"COOKIE", "HEADER"}: @@ -1053,7 +1140,9 @@ def build_request(self, args: Dict[str, Any], entry: Dict[str, Any]) -> Dict[str "params": request_options["params"], "json": request_options["json"], "data": request_options["data"], - "headers": request_options["headers"], + "files": request_options["files"], + "opened_files": request_options["opened_files"], + "headers": headers or None, } def _project_rows(self, payload: Any, entry: Dict[str, Any]) -> List[Dict[str, Any]]: @@ -1084,16 +1173,28 @@ def _project_rows(self, payload: Any, entry: Dict[str, Any]) -> List[Dict[str, A def run(self, args: Dict[str, Any], entry: Dict[str, Any] | None = None) -> List[Dict[str, Any]]: operation_entry = entry or _operation_entries()[0] request_options = self.build_request(args, operation_entry) - response = self.session.request( - request_options["method"], - request_options["url"], - params=request_options["params"], - json=request_options["json"], - data=request_options["data"], - headers=request_options["headers"], - ) - response.raise_for_status() - return self._project_rows(response.json(), operation_entry) + request_kwargs = { + "params": request_options["params"], + "json": request_options["json"], + "data": request_options["data"], + "headers": request_options["headers"], + } + if request_options["files"] is not None: + request_kwargs["files"] = request_options["files"] + try: + response = self.session.request( + request_options["method"], + request_options["url"], + **request_kwargs, + ) + response.raise_for_status() + return self._project_rows(response.json(), operation_entry) + finally: + for handle in request_options.get("opened_files", []): + try: + handle.close() + except OSError: + pass def verify_rows(rows: List[Dict[str, Any]], verify_spec: Dict[str, Any]) -> List[str]: @@ -1157,6 +1258,17 @@ def _add_output_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument("--verify-spec", help="Optional verify JSON path") +def _add_manual_auth_arguments(parser: argparse.ArgumentParser) -> None: + for rule in _manual_auth_rules(): + header_name = str(rule["name"]) + option_name = re.sub(r"[^a-z0-9]+", "-", header_name.lower()).strip("-") + parser.add_argument( + f"--auth-header-{option_name}", + dest=_auth_header_dest(header_name), + help=f"Value for required header {header_name}", + ) + + def _add_operation_arguments(parser: argparse.ArgumentParser, entry: Dict[str, Any]) -> None: for arg in entry.get("args", []): if not isinstance(arg, dict) or not arg.get("name"): @@ -1187,6 +1299,7 @@ def build_parser() -> argparse.ArgumentParser: default=(SPEC.get("auth", {}) or {}).get("stateFile", "auth-state.json"), help="Path to auth state JSON", ) + _add_manual_auth_arguments(parser) entries = _operation_entries() if _uses_subcommands(): subparsers = parser.add_subparsers(dest="command", required=True) @@ -1213,7 +1326,16 @@ def main() -> None: for item in entry.get("args", []) if isinstance(item, dict) and item.get("name") } - client = APIClient(base_url=parsed.base_url, auth_state=parsed.auth_state) + manual_headers = { + str(rule["name"]): getattr(parsed, _auth_header_dest(str(rule["name"]))) + for rule in _manual_auth_rules() + if getattr(parsed, _auth_header_dest(str(rule["name"])), None) not in (None, "") + } + client = APIClient( + base_url=parsed.base_url, + auth_state=parsed.auth_state, + manual_headers=manual_headers, + ) rows = client.run(runtime_args, entry) if parsed.verify: diff --git a/.flocks/plugins/skills/web2cli/scripts/generate-spec.py b/.flocks/plugins/skills/web2cli/scripts/generate-spec.py index 0f7e122d4..682aac5e5 100644 --- a/.flocks/plugins/skills/web2cli/scripts/generate-spec.py +++ b/.flocks/plugins/skills/web2cli/scripts/generate-spec.py @@ -13,8 +13,15 @@ from urllib.parse import parse_qsl, urlparse -PAGE_PARAM_NAMES = {"page", "pageNo", "pageNum", "current", "pageIndex", "curPage"} +PAGE_PARAM_NAMES = {"page", "pageNo", "pageNum", "current", "pageIndex", "curPage", "cur_page"} LIMIT_PARAM_NAMES = {"limit", "size", "pageSize", "page_size", "page_limit", "rows"} +TIME_PARAM_NAMES = { + "time_from", "time_to", "start_time", "end_time", + "create_time_from", "create_time_to", "update_time_from", "update_time_to", + "last_update_time_from", "last_update_time_to", "begin_time", "endtime", + "timeRange", "startTime", "endTime", +} +MIN_OPERATION_SCORE = 40 def sanitize_name(name: str) -> str: @@ -252,10 +259,11 @@ def collect_columns(item: Any) -> list[dict[str, Any]]: def build_templates( request: dict[str, Any], url_info: dict[str, Any] -) -> tuple[dict[str, Any], dict[str, Any], list[dict[str, Any]], str, str]: +) -> tuple[dict[str, Any], dict[str, Any], list[dict[str, Any]], str, str, list[str]]: """Build query/body templates, payload mode, and CLI arg definitions.""" args: list[dict[str, Any]] = [] seen_args: set[str] = set() + multipart_file_fields: list[str] = [] def add_arg(name: str, default: Any, help_text: str) -> None: if name in seen_args: @@ -264,17 +272,60 @@ def add_arg(name: str, default: Any, help_text: str) -> None: arg_type = "int" if isinstance(default, int) else "string" args.append({"name": name, "type": arg_type, "default": default, "help": help_text}) - def transform_mapping(data: dict[str, Any]) -> dict[str, Any]: + def _is_special_param_key(key: str) -> bool: + return key in PAGE_PARAM_NAMES or key in LIMIT_PARAM_NAMES or key in TIME_PARAM_NAMES + + def _template_scalar(key: str, value: Any) -> tuple[str, Any] | None: + if key in PAGE_PARAM_NAMES: + default = int(value) if str(value).isdigit() else 1 + add_arg("page", default, "Page number") + return "${page}", default + if key in LIMIT_PARAM_NAMES: + default = int(value) if str(value).isdigit() else 20 + add_arg("limit", default, "Page size") + return "${limit}", default + if key in TIME_PARAM_NAMES: + default = int(value) if str(value).isdigit() else value + add_arg(key, default, "Time parameter (Unix timestamp)") + return "${" + key + "}", default + return None + + def _multipart_file_template(path_tokens: tuple[str, ...], value: Any) -> str | None: + if payload_mode != "multipart" or value != "[file]": + return None + field_name = ".".join(path_tokens) + arg_name = sanitize_name("_".join(path_tokens) + "_file") + add_arg(arg_name, "", f"File path for multipart field {field_name}") + if field_name not in multipart_file_fields: + multipart_file_fields.append(field_name) + return "${" + arg_name + "}" + + def transform_mapping(data: dict[str, Any], path_tokens: tuple[str, ...] = ()) -> dict[str, Any]: result: dict[str, Any] = {} for key, value in data.items(): - if key in PAGE_PARAM_NAMES: - default = int(value) if str(value).isdigit() else 1 - result[key] = "${page}" - add_arg("page", default, "Page number") - elif key in LIMIT_PARAM_NAMES: - default = int(value) if str(value).isdigit() else 20 - result[key] = "${limit}" - add_arg("limit", default, "Page size") + field_path = path_tokens + (str(key),) + multipart_template = _multipart_file_template(field_path, value) + if multipart_template is not None: + result[key] = multipart_template + continue + if _is_special_param_key(key): + if isinstance(value, dict): + result[key] = transform_mapping(value, field_path) + elif isinstance(value, list): + result[key] = [ + transform_mapping(item, field_path) if isinstance(item, dict) else item + for item in value + ] + else: + template_result = _template_scalar(key, value) + result[key] = template_result[0] if template_result else value + elif isinstance(value, dict): + result[key] = transform_mapping(value, field_path) + elif isinstance(value, list): + result[key] = [ + transform_mapping(item, field_path) if isinstance(item, dict) else item + for item in value + ] else: result[key] = value return result @@ -290,9 +341,11 @@ def transform_mapping(data: dict[str, Any]) -> dict[str, Any]: if isinstance(parsed_body, dict) and "raw" not in parsed_body: body = parsed_body if body: - if body_kind in {"urlencoded", "formdata"}: + if body_kind == "formdata" or "multipart/form-data" in content_type: + payload_mode = "multipart" + elif body_kind == "urlencoded": payload_mode = "form" - elif "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type: + elif "application/x-www-form-urlencoded" in content_type: payload_mode = "form" else: payload_mode = "json" @@ -307,8 +360,13 @@ def transform_mapping(data: dict[str, Any]) -> dict[str, Any]: query_template = transform_mapping(url_info["query"]) body_template = transform_mapping(body) - args.sort(key=lambda item: (0 if item["name"] == "page" else 1 if item["name"] == "limit" else 2, item["name"])) - return query_template, body_template, args, payload_mode, raw_body_template + args.sort(key=lambda item: ( + 0 if item["name"] == "page" else + 1 if item["name"] == "limit" else + 2 if item["name"] in TIME_PARAM_NAMES else 3, + item["name"], + )) + return query_template, body_template, args, payload_mode, raw_body_template, multipart_file_fields def build_strategy(request: dict[str, Any]) -> tuple[str, dict[str, Any]]: @@ -332,12 +390,15 @@ def build_strategy(request: dict[str, Any]) -> tuple[str, dict[str, Any]]: return strategy, {"stateFile": "auth-state.json", "requiredCookies": [], "requiredHeaders": required_headers} -def safe_headers(request: dict[str, Any]) -> dict[str, Any]: +def safe_headers(request: dict[str, Any], payload_mode: str = "") -> dict[str, Any]: """Return non-sensitive request headers that can be replayed safely.""" headers = request.get("requestHeaders", {}) or request.get("request_headers", {}) result = {} for key, value in headers.items(): - if str(key).lower() in {"cookie", "authorization", "x-csrf-token", "x-xsrf-token", "x-auth-token"}: + key_name = str(key).lower() + if key_name in {"cookie", "authorization", "x-csrf-token", "x-xsrf-token", "x-auth-token"}: + continue + if payload_mode == "multipart" and key_name == "content-type": continue result[key] = value return result @@ -380,7 +441,7 @@ def select_operation_requests(requests: list[dict[str, Any]]) -> list[dict[str, selected[key] = {"index": index, "score": score, "request": request} candidates = sorted(selected.values(), key=lambda item: item["index"]) - filtered = [item for item in candidates if item["score"] >= 60] + filtered = [item for item in candidates if item["score"] >= MIN_OPERATION_SCORE] return [item["request"] for item in (filtered or candidates[:1])] @@ -400,7 +461,10 @@ def build_operation_entry(request: dict[str, Any]) -> dict[str, Any]: response = parse_json_text(str(request.get("response", ""))) collection = find_best_collection(response) row_item = collection["item"] if collection is not None else response - query_template, body_template, args, payload_mode, raw_body_template = build_templates(request, url_info) + query_template, body_template, args, payload_mode, raw_body_template, multipart_file_fields = build_templates( + request, + url_info, + ) columns = collect_columns(row_item) defaults = {item["name"]: item["default"] for item in args} @@ -424,7 +488,8 @@ def build_operation_entry(request: dict[str, Any]) -> dict[str, Any]: "bodyTemplate": body_template, "payloadMode": payload_mode, "rawBodyTemplate": raw_body_template, - "headers": safe_headers(request), + "multipartFileFields": multipart_file_fields, + "headers": safe_headers(request, payload_mode), "captureSource": request.get("captureSource", "pageHook"), "captureReason": request.get("captureReason", ""), "sourceRequestId": request.get("timestamp", ""), diff --git a/flocks/browser/admin.py b/flocks/browser/admin.py index 224dfeeeb..31b96020f 100644 --- a/flocks/browser/admin.py +++ b/flocks/browser/admin.py @@ -487,7 +487,7 @@ def row(label: str, ok: bool, detail: str = "") -> None: browser_running, "" if browser_running else "start Chrome, Chromium, or Edge and rerun `flocks browser --setup`", ) - row("daemon alive", daemon, "" if daemon else "not running; wait user open browser inspect page then run `flocks browser --setup` to attach") + row("daemon alive", daemon, "" if daemon else "not running; wait user open browser inspect page and run `flocks browser --setup` to attach") row("active browser connections", bool(connections), str(len(connections))) for conn in connections: page = conn.get("page") diff --git a/flocks/session/session_loop.py b/flocks/session/session_loop.py index 48d7e1bf1..6b2e1f455 100644 --- a/flocks/session/session_loop.py +++ b/flocks/session/session_loop.py @@ -662,13 +662,23 @@ async def _run_loop( log.error("loop.no_user_message", {"session_id": ctx.session.id}) break + last_assistant_parts = ( + await Message.parts(last_assistant.id, ctx.session.id) + if last_assistant + else [] + ) + # Check exit conditions (matching TUI lines 295-302) - if cls._should_exit(last_user, last_assistant): + if cls._should_exit(last_user, last_assistant, last_assistant_parts): log.info("loop.exit_condition", { "session_id": ctx.session.id, "last_user_id": last_user.id, "last_assistant_id": last_assistant.id if last_assistant else None, "finish": last_assistant.finish if last_assistant else None, + "has_tool_parts": any( + getattr(part, "type", None) == "tool" + for part in last_assistant_parts + ), }) last_message = last_assistant break @@ -1346,6 +1356,7 @@ def _should_exit( cls, last_user: MessageInfo, last_assistant: Optional[MessageInfo], + last_assistant_parts: Optional[List[Any]] = None, ) -> bool: """ Check if loop should exit @@ -1356,6 +1367,12 @@ def _should_exit( """ if not last_assistant: return False + + if any( + getattr(part, "type", None) == "tool" + for part in (last_assistant_parts or []) + ): + return False # Check finish reason if last_assistant.finish: diff --git a/tests/session/test_session_abort_inject.py b/tests/session/test_session_abort_inject.py index 90d403573..cdd9f5b7c 100644 --- a/tests/session/test_session_abort_inject.py +++ b/tests/session/test_session_abort_inject.py @@ -14,6 +14,7 @@ import pytest +from flocks.session.message import ToolPart, ToolStateCompleted from flocks.session.session_loop import SessionLoop, LoopCallbacks, LoopContext, LoopResult from flocks.session.runner import SessionRunner, StepResult from flocks.session.session import SessionInfo @@ -31,6 +32,23 @@ def _make_session_info(session_id: str = "test_session") -> SessionInfo: ) +def _make_completed_tool_part(message_id: str) -> ToolPart: + """Create a completed tool part that should force one more loop iteration.""" + return ToolPart( + sessionID="test_session", + messageID=message_id, + callID="call_001", + tool="bash", + state=ToolStateCompleted( + input={"command": "echo hi"}, + output="hi", + title="bash", + metadata={}, + time={"start": 1, "end": 2}, + ), + ) + + # --------------------------------------------------------------------------- # Abort propagation tests # --------------------------------------------------------------------------- @@ -234,6 +252,18 @@ def test_no_exit_when_assistant_not_finished(self): assert SessionLoop._should_exit(last_user, last_assistant) is False + def test_no_exit_when_assistant_has_completed_tool_parts(self): + """Should continue so completed tool results can be fed back to the model.""" + last_user = self._make_msg("msg_001", "user") + last_assistant = self._make_msg("msg_002", "assistant", finish="stop") + last_assistant_parts = [_make_completed_tool_part(last_assistant.id)] + + assert SessionLoop._should_exit( + last_user, + last_assistant, + last_assistant_parts, + ) is False + class TestQueuedUserDetection: @staticmethod @@ -348,6 +378,105 @@ async def test_pre_compact_cleanup_emits_turn_continued_before_next_iteration(se assert cleanup_turn["continue_reason"] == "pre_compact_cleanup" assert cleanup_turn["status"] == "continued" + @pytest.mark.asyncio + async def test_run_loop_skips_exit_condition_when_assistant_has_tool_parts(self): + session = SimpleNamespace( + id="loop_tool_part_session", + agent="rex", + directory="/tmp", + memory_enabled=False, + ) + ctx = LoopContext( + session=session, + provider_id="test-provider", + model_id="test-model", + agent_name="rex", + ) + messages = [ + self._make_msg("msg_001", "user"), + self._make_msg("msg_002", "assistant", finish="stop"), + ] + ctx.session_ctx = SimpleNamespace( + get_messages=AsyncMock(side_effect=[messages, messages]) + ) + event_callback = AsyncMock() + callbacks = LoopCallbacks(event_publish_callback=event_callback) + process_step = AsyncMock(return_value=StepResult(action="stop")) + log_info = MagicMock() + + with patch( + "flocks.session.session_loop.Message.parts", + AsyncMock(return_value=[_make_completed_tool_part("msg_002")]), + ), patch( + "flocks.session.session_loop.Provider.resolve_model_info", + return_value=(0, 0, None), + ), patch( + "flocks.session.lifecycle.title.SessionTitle.ensure_title", + MagicMock(return_value=None), + ), patch( + "flocks.session.session_loop.fire_and_forget", + MagicMock(), + ), patch( + "flocks.session.runner.SessionRunner._process_step", + process_step, + ), patch( + "flocks.session.session_loop.log.info", + log_info, + ): + result = await SessionLoop._run_loop(ctx, callbacks) + + assert result.action == "stop" + assert result.last_message is messages[1] + assert process_step.await_count == 1 + assert not any(call.args and call.args[0] == "loop.exit_condition" for call in log_info.call_args_list) + event_names = [call.args[0] for call in event_callback.await_args_list] + assert event_names == ["turn.started", "turn.stopped"] + + @pytest.mark.asyncio + async def test_run_loop_breaks_on_exit_condition_without_tool_parts(self): + session = SimpleNamespace( + id="loop_exit_condition_session", + agent="rex", + directory="/tmp", + memory_enabled=False, + ) + ctx = LoopContext( + session=session, + provider_id="test-provider", + model_id="test-model", + agent_name="rex", + ) + messages = [ + self._make_msg("msg_001", "user"), + self._make_msg("msg_002", "assistant", finish="stop"), + ] + ctx.session_ctx = SimpleNamespace( + get_messages=AsyncMock(return_value=messages) + ) + event_callback = AsyncMock() + callbacks = LoopCallbacks(event_publish_callback=event_callback) + process_step = AsyncMock(return_value=StepResult(action="stop")) + log_info = MagicMock() + + with patch( + "flocks.session.session_loop.Message.parts", + AsyncMock(return_value=[]), + ), patch( + "flocks.session.session_loop.log.info", + log_info, + ), patch( + "flocks.session.runner.SessionRunner._process_step", + process_step, + ): + result = await SessionLoop._run_loop(ctx, callbacks) + + assert result.action == "stop" + assert result.last_message is messages[1] + assert process_step.await_count == 0 + assert any(call.args and call.args[0] == "loop.exit_condition" for call in log_info.call_args_list) + event_names = [call.args[0] for call in event_callback.await_args_list] + assert event_names == ["turn.started"] + class TestExecuteSubtask: @pytest.mark.asyncio diff --git a/tests/tool/test_web2cli_generate_cli.py b/tests/tool/test_web2cli_generate_cli.py index 294dc7ed5..27984a95f 100644 --- a/tests/tool/test_web2cli_generate_cli.py +++ b/tests/tool/test_web2cli_generate_cli.py @@ -383,6 +383,47 @@ def _cookie_source_header_auth_spec(): return spec +def _manual_header_auth_spec(): + spec = _sample_spec() + spec["strategy"] = "HEADER" + spec["auth"] = { + "stateFile": "auth-state.json", + "requiredCookies": [], + "requiredHeaders": [ + {"name": "Authorization", "source": "manual", "key": "authorization"}, + {"name": "X-CSRF-Token", "source": "manual", "key": "x-csrf-token"}, + ], + } + return spec + + +def _multipart_spec(): + spec = _sample_spec() + spec["strategy"] = "PUBLIC" + spec["auth"] = {"stateFile": "auth-state.json", "requiredCookies": [], "requiredHeaders": []} + spec["command"] = "upload_items" + spec["description"] = "Upload items with multipart payload" + spec["operation"] = { + "method": "POST", + "endpoint": "/api/upload", + "queryTemplate": {"page": "${page}"}, + "bodyTemplate": {"note": "alpha", "upload": "${upload_file}"}, + "payloadMode": "multipart", + "rawBodyTemplate": "", + "multipartFileFields": ["upload"], + "headers": { + "Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary123", + "X-Requested-With": "XMLHttpRequest", + }, + } + spec["args"] = [ + {"name": "page", "type": "int", "default": 1, "help": "Page number"}, + {"name": "upload_file", "type": "string", "default": "", "help": "File path for multipart field upload"}, + ] + spec["verify"]["args"] = {"page": 1, "upload_file": "sample.txt"} + return spec + + class _FakeResponse: def __init__(self, payload): self._payload = payload @@ -400,10 +441,29 @@ def __init__(self, payload) -> None: self._payload = payload self.request_calls = [] - def request(self, method, url, json=None, params=None, data=None, headers=None): - self.request_calls.append( - {"method": method, "url": url, "json": json, "params": params, "data": data, "headers": headers} - ) + @staticmethod + def _snapshot_files(files): + if files is None: + return None + snapshot = [] + for key, value in files: + if ( + isinstance(value, tuple) + and len(value) >= 2 + and hasattr(value[1], "read") + ): + content = value[1].read() + value[1].seek(0) + snapshot.append((key, (value[0], content))) + else: + snapshot.append((key, value)) + return snapshot + + def request(self, method, url, json=None, params=None, data=None, headers=None, files=None): + record = {"method": method, "url": url, "json": json, "params": params, "data": data, "headers": headers} + if files is not None: + record["files"] = self._snapshot_files(files) + self.request_calls.append(record) return _FakeResponse(self._payload) @@ -697,6 +757,93 @@ def test_generated_header_strategy_accepts_empty_cookie_values(tmp_path, monkeyp assert fake_session.request_calls[0]["headers"]["Cookie"] == "flag=; sid=cookie-123" +def test_generated_manual_header_strategy_requires_values(monkeypatch): + module = _load_module() + fake_session = _FakeRequestSession({"data": {"items": [{"id": "1", "title": "Alpha"}]}}) + fake_requests = types.SimpleNamespace(Session=lambda: fake_session) + monkeypatch.setitem(sys.modules, "requests", fake_requests) + + namespace = {"__name__": "generated_manual_header_cli"} + exec(module.generate_python_cli_from_spec(_manual_header_auth_spec()), namespace) + + try: + namespace["APIClient"](auth_state="auth-state.json") + except SystemExit as error: + assert str(error) == "missing required auth headers: Authorization, X-CSRF-Token" + else: + raise AssertionError("expected missing manual auth headers to exit") + + +def test_generated_manual_header_strategy_accepts_cli_values(monkeypatch): + module = _load_module() + fake_session = _FakeRequestSession({"data": {"items": [{"id": "1", "title": "Alpha"}]}}) + fake_requests = types.SimpleNamespace(Session=lambda: fake_session) + monkeypatch.setitem(sys.modules, "requests", fake_requests) + + namespace = {"__name__": "generated_manual_header_cli"} + exec(module.generate_python_cli_from_spec(_manual_header_auth_spec()), namespace) + + parser = namespace["build_parser"]() + parsed = parser.parse_args( + [ + "--auth-header-authorization", + "Bearer token-123", + "--auth-header-x-csrf-token", + "csrf-abc", + "--page", + "2", + "--limit", + "5", + ] + ) + runtime_args = { + item["name"]: getattr(parsed, item["name"]) + for item in _manual_header_auth_spec()["args"] + } + manual_headers = { + "Authorization": parsed.auth_header_authorization, + "X-CSRF-Token": parsed.auth_header_x_csrf_token, + } + + client = namespace["APIClient"](auth_state="auth-state.json", manual_headers=manual_headers) + rows = client.run(runtime_args) + + assert rows == [{"id": "1", "title": "Alpha"}] + assert fake_session.headers["Authorization"] == "Bearer token-123" + assert fake_session.headers["X-CSRF-Token"] == "csrf-abc" + + +def test_generated_multipart_spec_cli_sends_files_and_strips_content_type(tmp_path, monkeypatch): + module = _load_module() + upload_path = tmp_path / "sample.txt" + upload_path.write_text("payload-data", encoding="utf-8") + fake_session = _FakeRequestSession({"data": {"items": [{"id": "1", "title": "Alpha"}]}}) + fake_requests = types.SimpleNamespace(Session=lambda: fake_session) + monkeypatch.setitem(sys.modules, "requests", fake_requests) + + namespace = {"__name__": "generated_multipart_cli"} + exec(module.generate_python_cli_from_spec(_multipart_spec()), namespace) + + client = namespace["APIClient"]() + rows = client.run({"page": 4, "upload_file": str(upload_path)}) + + assert rows == [{"id": "1", "title": "Alpha"}] + assert fake_session.request_calls == [ + { + "method": "POST", + "url": "https://example.com/api/upload", + "json": None, + "params": {"page": 4}, + "data": None, + "headers": {"X-Requested-With": "XMLHttpRequest"}, + "files": [ + ("note", (None, "alpha")), + ("upload", ("sample.txt", b"payload-data")), + ], + } + ] + + def test_generated_cli_normalizes_non_dict_auth_state_for_headers(tmp_path, monkeypatch): module = _load_module() auth_state = tmp_path / "auth-state.json" diff --git a/tests/tool/test_web2cli_generate_spec.py b/tests/tool/test_web2cli_generate_spec.py index bbe9329ad..cddf0ee51 100644 --- a/tests/tool/test_web2cli_generate_spec.py +++ b/tests/tool/test_web2cli_generate_spec.py @@ -131,6 +131,79 @@ def _raw_request(): ] +def _multipart_request(): + return [ + { + "type": "Fetch", + "method": "POST", + "url": "https://example.com/api/upload?page=1", + "origin": "https://example.com", + "pathname": "/api/upload", + "query": {"page": "1"}, + "status": 200, + "captureReason": "nonGet", + "requestContentType": "multipart/form-data; boundary=----WebKitFormBoundary123", + "requestHeaders": { + "Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary123", + "X-Requested-With": "XMLHttpRequest", + }, + "requestBodyKind": "formData", + "requestBody": '{"page":"1","note":"alpha","upload":"[file]"}', + "response": '{"data":{"items":[{"id":"1","title":"Alpha"}]}}', + } + ] + + +def _header_auth_request(): + return [ + { + "type": "Fetch", + "method": "GET", + "url": "https://example.com/api/items/list?page=1", + "origin": "https://example.com", + "pathname": "/api/items/list", + "query": {"page": "1"}, + "status": 200, + "captureReason": "includePattern", + "requestHeaders": { + "Authorization": "Bearer secret", + "X-CSRF-Token": "csrf-abc", + "Accept": "application/json", + }, + "response": '{"data":{"items":[{"id":"1","title":"Alpha"}]}}', + } + ] + + +def _mixed_score_requests(): + return [ + { + "type": "Fetch", + "method": "POST", + "url": "https://example.com/api/items/list?page=1", + "origin": "https://example.com", + "pathname": "/api/items/list", + "query": {"page": "1"}, + "status": 200, + "captureReason": "nonGet", + "requestHeaders": {"Content-Type": "application/json"}, + "requestBody": '{"page": 1, "size": 20}', + "response": '{"data":{"items":[{"id":"1","title":"Alpha"}]}}', + }, + { + "type": "Fetch", + "method": "GET", + "url": "https://example.com/api/items/detail?id=1", + "origin": "https://example.com", + "pathname": "/api/items/detail", + "query": {"id": "1"}, + "status": 200, + "requestHeaders": {"Accept": "application/json"}, + "response": '{"id":"1","title":"Alpha"}', + }, + ] + + def test_generate_spec_from_requests_picks_primary_collection_endpoint(): module = _load_module() @@ -168,6 +241,17 @@ def test_generate_spec_from_requests_includes_multi_operation_entries(): ] +def test_generate_spec_from_requests_keeps_mid_score_object_endpoints(): + module = _load_module() + + spec = module.generate_spec_from_requests(_mixed_score_requests()) + + assert [entry["operation"]["endpoint"] for entry in spec["operations"]] == [ + "/api/items/list", + "/api/items/detail", + ] + + def test_generate_spec_from_requests_preserves_form_payload_mode(): module = _load_module() @@ -189,6 +273,37 @@ def test_generate_spec_from_requests_preserves_raw_payload_mode(): assert spec["args"] == [{"name": "page", "type": "int", "default": 1, "help": "Page number"}] +def test_generate_spec_from_requests_preserves_multipart_payload_mode(): + module = _load_module() + + spec = module.generate_spec_from_requests(_multipart_request()) + + assert spec["operation"]["payloadMode"] == "multipart" + assert spec["operation"]["bodyTemplate"] == { + "page": "${page}", + "note": "alpha", + "upload": "${upload_file}", + } + assert spec["operation"]["multipartFileFields"] == ["upload"] + assert spec["args"] == [ + {"name": "page", "type": "int", "default": 1, "help": "Page number"}, + {"name": "upload_file", "type": "string", "default": "", "help": "File path for multipart field upload"}, + ] + assert spec["operation"]["headers"] == {"X-Requested-With": "XMLHttpRequest"} + + +def test_generate_spec_from_requests_marks_manual_header_auth(): + module = _load_module() + + spec = module.generate_spec_from_requests(_header_auth_request()) + + assert spec["strategy"] == "HEADER" + assert spec["auth"]["requiredHeaders"] == [ + {"name": "Authorization", "source": "manual", "key": "authorization"}, + {"name": "x-csrf-token", "source": "manual", "key": "x-csrf-token"}, + ] + + def test_main_writes_spec_file(tmp_path, monkeypatch, capsys): module = _load_module() input_path = tmp_path / "captured.json"