From b30f2968b71efa4c3d81490fe2b9da4ab1ec2143 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 21:25:14 -0700 Subject: [PATCH] refactor(brain): simplify pass on tracked ops scripts - Extract shared Ollama /api/generate wrapper + SDK-path wiring into brain/scripts/_common.py (only consolidation with >=2 call sites). - ab_test_constitutional + mirofish_sim: delegate to ollama_generate, drop duplicated urllib plumbing. - brain_benchmark: use ensure_sdk_on_path(); delete dead first-pass defaultdict loop in _simulate_graduation (was a no-op overwritten by the correct second loop). - mirofish_sim: remove dead Forum.add_disagree() no-op and the unused ``field`` import. CLI signatures preserved. All 2060 tests still pass. --- brain/scripts/_common.py | 88 +++++++++++++++++++++++++ brain/scripts/ab_test_constitutional.py | 44 ++++--------- brain/scripts/brain_benchmark.py | 20 ++---- brain/scripts/mirofish_sim.py | 46 ++++--------- 4 files changed, 117 insertions(+), 81 deletions(-) create mode 100644 brain/scripts/_common.py diff --git a/brain/scripts/_common.py b/brain/scripts/_common.py new file mode 100644 index 00000000..1afba890 --- /dev/null +++ b/brain/scripts/_common.py @@ -0,0 +1,88 @@ +"""Shared helpers for tracked brain/scripts/* operational scripts. + +Consolidates two duplicated bits of plumbing: + +1. ``ensure_sdk_on_path()`` — wires ``/sdk/src`` into ``sys.path`` so + ``from gradata...`` imports resolve when a script is invoked directly. +2. ``ollama_generate()`` — a thin POST wrapper around Ollama's + ``/api/generate`` endpoint with sensible timeouts and a uniform + error-marker return value. + +Keep this module dependency-free (stdlib only) — these scripts run in +minimal environments (cron, worktrees, one-off invocations). +""" + +from __future__ import annotations + +import json +import logging +import sys +import urllib.error +import urllib.request +from pathlib import Path + +log = logging.getLogger(__name__) + +DEFAULT_OLLAMA_URL = "http://localhost:11434/api/generate" +DEFAULT_OLLAMA_MODEL = "gemma4:e4b" + + +def ensure_sdk_on_path() -> Path: + """Insert the repo's SDK source root into ``sys.path`` so ``from gradata...`` + resolves when a script is invoked directly. + + Prefers ``src/`` (legacy root layout) then falls back to ``sdk/src/``. + Returns the resolved SDK root so callers can log it if they want. + """ + repo_root = Path(__file__).resolve().parent.parent.parent + for candidate in (repo_root / "src", repo_root / "sdk" / "src"): + if (candidate / "gradata").is_dir(): + sdk_root = candidate + break + else: + sdk_root = repo_root / "src" + sdk_root_str = str(sdk_root) + if sdk_root_str not in sys.path: + sys.path.insert(0, sdk_root_str) + return sdk_root + + +def ollama_generate( + prompt: str, + *, + system: str = "", + model: str = DEFAULT_OLLAMA_MODEL, + url: str = DEFAULT_OLLAMA_URL, + timeout: int = 120, + num_predict: int = 500, + temperature: float = 0.7, +) -> str: + """Call Ollama ``/api/generate``. Returns response text or an error marker. + + The error-marker shape (``"[Generation failed: ...]"``) is load-bearing — + callers downstream (A/B judge parse, MiroFish post body) tolerate it as + a regular string rather than raising. + """ + payload = json.dumps( + { + "model": model, + "prompt": prompt, + "system": system, + "stream": False, + "options": { + "temperature": temperature, + "num_predict": num_predict, + }, + } + ).encode("utf-8") + + req = urllib.request.Request( + url, data=payload, headers={"Content-Type": "application/json"} + ) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + data = json.loads(resp.read().decode("utf-8")) + return data.get("response", "").strip() + except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as e: + log.warning("Ollama call failed: %s", e) + return f"[Generation failed: {e}]" diff --git a/brain/scripts/ab_test_constitutional.py b/brain/scripts/ab_test_constitutional.py index a1a377c6..9b088e9a 100644 --- a/brain/scripts/ab_test_constitutional.py +++ b/brain/scripts/ab_test_constitutional.py @@ -43,19 +43,15 @@ import random import re import sys -import urllib.error -import urllib.request from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from pathlib import Path from statistics import mean, pstdev from typing import Any -# Wire the SDK into sys.path so `from gradata...` works when invoked -# directly (same pattern as brain_benchmark.py). -_SDK_ROOT = Path(__file__).resolve().parent.parent.parent / "src" -if str(_SDK_ROOT) not in sys.path: - sys.path.insert(0, str(_SDK_ROOT)) +from _common import DEFAULT_OLLAMA_MODEL, DEFAULT_OLLAMA_URL, ensure_sdk_on_path, ollama_generate + +ensure_sdk_on_path() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger("ab_test_constitutional") @@ -64,8 +60,8 @@ # Constants # --------------------------------------------------------------------------- -DEFAULT_MODEL = "gemma4:e4b" -OLLAMA_URL = "http://localhost:11434/api/generate" +DEFAULT_MODEL = DEFAULT_OLLAMA_MODEL +OLLAMA_URL = DEFAULT_OLLAMA_URL DEFAULT_BRAIN_DIR = r"C:/Users/olive/SpritesWork/brain" DEFAULT_OUTPUT = ".tmp/ab_test_results" @@ -221,29 +217,15 @@ def _ollama_generate( temperature: float = 0.7, ) -> str: """Call Ollama /api/generate. Returns response text or an error marker.""" - payload = json.dumps( - { - "model": model, - "prompt": prompt, - "system": system, - "stream": False, - "options": { - "temperature": temperature, - "num_predict": num_predict, - }, - } - ).encode("utf-8") - - req = urllib.request.Request( - OLLAMA_URL, data=payload, headers={"Content-Type": "application/json"} + return ollama_generate( + prompt, + system=system, + model=model, + url=OLLAMA_URL, + timeout=timeout, + num_predict=num_predict, + temperature=temperature, ) - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - data = json.loads(resp.read().decode("utf-8")) - return data.get("response", "").strip() - except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as e: - log.warning("Ollama call failed: %s", e) - return f"[Generation failed: {e}]" # --------------------------------------------------------------------------- diff --git a/brain/scripts/brain_benchmark.py b/brain/scripts/brain_benchmark.py index 2bf93d37..a9e382c8 100644 --- a/brain/scripts/brain_benchmark.py +++ b/brain/scripts/brain_benchmark.py @@ -18,15 +18,13 @@ import argparse import json import sqlite3 -import sys from collections import defaultdict from pathlib import Path from typing import Any -# Add SDK to path for graduation simulation -_sdk_root = Path(__file__).resolve().parent.parent.parent / "src" -if str(_sdk_root) not in sys.path: - sys.path.insert(0, str(_sdk_root)) +from _common import ensure_sdk_on_path + +ensure_sdk_on_path() # --------------------------------------------------------------------------- @@ -339,18 +337,8 @@ def _simulate_graduation(events: list[dict]) -> list[dict]: except ImportError: return [] # SDK not available, skip simulation - # Group corrections by session + # Group corrections by session. corrections_by_session: dict[int, list[dict]] = defaultdict(list) - for ev in events: - if ev.get("type") == "CORRECTION": - sess = ev.get("session") - if sess is not None: - try: - corrections_by_session[int(sess)] += 1 # type: ignore[assignment] - except (ValueError, TypeError): - pass - # Fix: actually append the event dict, not increment - corrections_by_session = defaultdict(list) for ev in events: if ev.get("type") == "CORRECTION": sess = ev.get("session") diff --git a/brain/scripts/mirofish_sim.py b/brain/scripts/mirofish_sim.py index 28bda93d..1df0f370 100644 --- a/brain/scripts/mirofish_sim.py +++ b/brain/scripts/mirofish_sim.py @@ -15,13 +15,12 @@ import random import sys import uuid -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any -import urllib.request -import urllib.error +from _common import DEFAULT_OLLAMA_MODEL, DEFAULT_OLLAMA_URL, ollama_generate logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger("mirofish") @@ -146,11 +145,6 @@ def add_like(self, post_id: str, agent_name: str) -> None: p.likes = len(self._likes[post_id]) break - def add_disagree(self, post_id: str, agent_name: str) -> None: - """Register a disagreement (tracked but doesn't affect like count).""" - # Could be extended with a separate disagree counter - pass - def get_likes(self, post_id: str) -> int: return len(self._likes.get(post_id, set())) @@ -185,37 +179,21 @@ def summary_context(self, max_posts: int = 10) -> str: # --------------------------------------------------------------------------- # Ollama generation # --------------------------------------------------------------------------- -DEFAULT_MODEL = "gemma4:e4b" -OLLAMA_URL = "http://localhost:11434/api/generate" +DEFAULT_MODEL = DEFAULT_OLLAMA_MODEL +OLLAMA_URL = DEFAULT_OLLAMA_URL def _generate(prompt: str, system: str = "", model: str = DEFAULT_MODEL) -> str: """Call Ollama generate endpoint. Returns response text.""" - payload = json.dumps( - { - "model": model, - "prompt": prompt, - "system": system, - "stream": False, - "options": { - "temperature": 0.8, - "num_predict": 500, - }, - } - ).encode("utf-8") - - req = urllib.request.Request( - OLLAMA_URL, - data=payload, - headers={"Content-Type": "application/json"}, + return ollama_generate( + prompt, + system=system, + model=model, + url=OLLAMA_URL, + timeout=120, + num_predict=500, + temperature=0.8, ) - try: - with urllib.request.urlopen(req, timeout=120) as resp: - data = json.loads(resp.read().decode("utf-8")) - return data.get("response", "").strip() - except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as e: - log.warning("Ollama generation failed: %s", e) - return f"[Generation failed: {e}]" # ---------------------------------------------------------------------------