Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions brain/scripts/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Shared helpers for tracked brain/scripts/* operational scripts.

Consolidates two duplicated bits of plumbing:

1. ``ensure_sdk_on_path()`` — wires ``<repo>/sdk/src`` into ``sys.path`` so
``from gradata...`` imports resolve when a script is invoked directly.
2. ``ollama_generate()`` — a thin POST wrapper around Ollama's
``/api/generate`` endpoint with sensible timeouts and a uniform
error-marker return value.

Keep this module dependency-free (stdlib only) — these scripts run in
minimal environments (cron, worktrees, one-off invocations).
"""

from __future__ import annotations

import json
import logging
import sys
import urllib.error
import urllib.request
from pathlib import Path

log = logging.getLogger(__name__)

DEFAULT_OLLAMA_URL = "http://localhost:11434/api/generate"
DEFAULT_OLLAMA_MODEL = "gemma4:e4b"


def ensure_sdk_on_path() -> Path:
"""Insert the repo's SDK source root into ``sys.path`` so ``from gradata...``
resolves when a script is invoked directly.

Prefers ``src/`` (legacy root layout) then falls back to ``sdk/src/``.
Returns the resolved SDK root so callers can log it if they want.
"""
repo_root = Path(__file__).resolve().parent.parent.parent
for candidate in (repo_root / "src", repo_root / "sdk" / "src"):
if (candidate / "gradata").is_dir():
sdk_root = candidate
break
else:
sdk_root = repo_root / "src"
sdk_root_str = str(sdk_root)
if sdk_root_str not in sys.path:
sys.path.insert(0, sdk_root_str)
return sdk_root


def ollama_generate(
prompt: str,
*,
system: str = "",
model: str = DEFAULT_OLLAMA_MODEL,
url: str = DEFAULT_OLLAMA_URL,
timeout: int = 120,
num_predict: int = 500,
temperature: float = 0.7,
) -> str:
"""Call Ollama ``/api/generate``. Returns response text or an error marker.

The error-marker shape (``"[Generation failed: ...]"``) is load-bearing —
callers downstream (A/B judge parse, MiroFish post body) tolerate it as
a regular string rather than raising.
"""
payload = json.dumps(
{
"model": model,
"prompt": prompt,
"system": system,
"stream": False,
"options": {
"temperature": temperature,
"num_predict": num_predict,
},
}
).encode("utf-8")

req = urllib.request.Request(
url, data=payload, headers={"Content-Type": "application/json"}
)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data.get("response", "").strip()
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as e:
log.warning("Ollama call failed: %s", e)
return f"[Generation failed: {e}]"
44 changes: 13 additions & 31 deletions brain/scripts/ab_test_constitutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,15 @@
import random
import re
import sys
import urllib.error
import urllib.request
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean, pstdev
from typing import Any

# Wire the SDK into sys.path so `from gradata...` works when invoked
# directly (same pattern as brain_benchmark.py).
_SDK_ROOT = Path(__file__).resolve().parent.parent.parent / "src"
if str(_SDK_ROOT) not in sys.path:
sys.path.insert(0, str(_SDK_ROOT))
from _common import DEFAULT_OLLAMA_MODEL, DEFAULT_OLLAMA_URL, ensure_sdk_on_path, ollama_generate

ensure_sdk_on_path()

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("ab_test_constitutional")
Expand All @@ -64,8 +60,8 @@
# Constants
# ---------------------------------------------------------------------------

DEFAULT_MODEL = "gemma4:e4b"
OLLAMA_URL = "http://localhost:11434/api/generate"
DEFAULT_MODEL = DEFAULT_OLLAMA_MODEL
OLLAMA_URL = DEFAULT_OLLAMA_URL
DEFAULT_BRAIN_DIR = r"C:/Users/olive/SpritesWork/brain"
DEFAULT_OUTPUT = ".tmp/ab_test_results"

Expand Down Expand Up @@ -221,29 +217,15 @@ def _ollama_generate(
temperature: float = 0.7,
) -> str:
"""Call Ollama /api/generate. Returns response text or an error marker."""
payload = json.dumps(
{
"model": model,
"prompt": prompt,
"system": system,
"stream": False,
"options": {
"temperature": temperature,
"num_predict": num_predict,
},
}
).encode("utf-8")

req = urllib.request.Request(
OLLAMA_URL, data=payload, headers={"Content-Type": "application/json"}
return ollama_generate(
prompt,
system=system,
model=model,
url=OLLAMA_URL,
timeout=timeout,
num_predict=num_predict,
temperature=temperature,
)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data.get("response", "").strip()
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as e:
log.warning("Ollama call failed: %s", e)
return f"[Generation failed: {e}]"


# ---------------------------------------------------------------------------
Expand Down
20 changes: 4 additions & 16 deletions brain/scripts/brain_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
import argparse
import json
import sqlite3
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any

# Add SDK to path for graduation simulation
_sdk_root = Path(__file__).resolve().parent.parent.parent / "src"
if str(_sdk_root) not in sys.path:
sys.path.insert(0, str(_sdk_root))
from _common import ensure_sdk_on_path

ensure_sdk_on_path()


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -339,18 +337,8 @@ def _simulate_graduation(events: list[dict]) -> list[dict]:
except ImportError:
return [] # SDK not available, skip simulation

# Group corrections by session
# Group corrections by session.
corrections_by_session: dict[int, list[dict]] = defaultdict(list)
for ev in events:
if ev.get("type") == "CORRECTION":
sess = ev.get("session")
if sess is not None:
try:
corrections_by_session[int(sess)] += 1 # type: ignore[assignment]
except (ValueError, TypeError):
pass
# Fix: actually append the event dict, not increment
corrections_by_session = defaultdict(list)
for ev in events:
if ev.get("type") == "CORRECTION":
sess = ev.get("session")
Expand Down
46 changes: 12 additions & 34 deletions brain/scripts/mirofish_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import random
import sys
import uuid
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import urllib.request
import urllib.error
from _common import DEFAULT_OLLAMA_MODEL, DEFAULT_OLLAMA_URL, ollama_generate

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("mirofish")
Expand Down Expand Up @@ -146,11 +145,6 @@ def add_like(self, post_id: str, agent_name: str) -> None:
p.likes = len(self._likes[post_id])
break

def add_disagree(self, post_id: str, agent_name: str) -> None:
"""Register a disagreement (tracked but doesn't affect like count)."""
# Could be extended with a separate disagree counter
pass

def get_likes(self, post_id: str) -> int:
return len(self._likes.get(post_id, set()))

Expand Down Expand Up @@ -185,37 +179,21 @@ def summary_context(self, max_posts: int = 10) -> str:
# ---------------------------------------------------------------------------
# Ollama generation
# ---------------------------------------------------------------------------
DEFAULT_MODEL = "gemma4:e4b"
OLLAMA_URL = "http://localhost:11434/api/generate"
DEFAULT_MODEL = DEFAULT_OLLAMA_MODEL
OLLAMA_URL = DEFAULT_OLLAMA_URL


def _generate(prompt: str, system: str = "", model: str = DEFAULT_MODEL) -> str:
"""Call Ollama generate endpoint. Returns response text."""
payload = json.dumps(
{
"model": model,
"prompt": prompt,
"system": system,
"stream": False,
"options": {
"temperature": 0.8,
"num_predict": 500,
},
}
).encode("utf-8")

req = urllib.request.Request(
OLLAMA_URL,
data=payload,
headers={"Content-Type": "application/json"},
return ollama_generate(
prompt,
system=system,
model=model,
url=OLLAMA_URL,
timeout=120,
num_predict=500,
temperature=0.8,
)
try:
with urllib.request.urlopen(req, timeout=120) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data.get("response", "").strip()
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as e:
log.warning("Ollama generation failed: %s", e)
return f"[Generation failed: {e}]"


# ---------------------------------------------------------------------------
Expand Down
Loading