Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions src/balatrollm/bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Core LLM-powered Balatro bot implementation."""

import asyncio
import base64
import json
import logging
import time
Expand All @@ -25,6 +26,23 @@
logger = logging.getLogger(__name__)


def _to_wine_path(path: Path) -> str:
"""Convert a Linux path under Wine's drive_c to a Windows-style path.

Balatro (via BalatroBot Lua mod) runs inside Wine/Proton and can only write
to Windows-style paths. This converts e.g.:
/home/user/.../drive_c/users/foo/bar.png
-> C:\\users\\foo\\bar.png
Falls back to the original string if drive_c is not in the path.
"""
parts = path.parts
try:
idx = next(i for i, p in enumerate(parts) if p == "drive_c")
return "C:\\" + "\\".join(parts[idx + 1 :])
except StopIteration:
return str(path)


class BotError(Exception):
"""Base exception for bot errors."""

Expand Down Expand Up @@ -66,6 +84,7 @@ async def __aenter__(self) -> "Bot":
self._llm = LLMClient(
base_url=self.config.base_url,
api_key=self.config.api_key or "",
vision=self.config.vision,
)
await self._llm.__aenter__()

Expand Down Expand Up @@ -198,14 +217,11 @@ async def _run_game_loop(self, gamestate: dict[str, Any]) -> None:
await self._balatro.call("gamestate")

match current_state:
case "SELECTING_HAND" | "SHOP" | "SMODS_BOOSTER_OPENED":
case "SELECTING_HAND" | "SHOP" | "SMODS_BOOSTER_OPENED" | "BLIND_SELECT":
response = await self._get_llm_response(gamestate)
gamestate = await self._execute_tool_call(response)
case "ROUND_EVAL":
gamestate = await self._balatro.call("cash_out")
case "BLIND_SELECT":
# NOTE: This bot always selects and never skips blinds
gamestate = await self._balatro.call("select")
case "GAME_OVER":
self._finish_reason = "lost"
logger.info("Game over!")
Expand All @@ -220,6 +236,20 @@ async def _get_llm_response(self, gamestate: dict[str, Any]) -> ChatCompletion:
assert self._llm is not None
assert self._collector is not None

# Take screenshot BEFORE building the request so it can be included in the prompt
next_custom_id = self._collector.peek_next_custom_id()
screenshot_path = self._collector.screenshot_dir / f"{next_custom_id}.png"
screenshot_b64: str | None = None
try:
await self._balatro.call(
"screenshot", {"path": _to_wine_path(screenshot_path)}
)
screenshot_b64 = base64.b64encode(screenshot_path.read_bytes()).decode()
except BalatroError as e:
logger.warning(f"Screenshot failed: {e}")
except Exception as e:
logger.warning(f"Screenshot read failed: {e}")

strategy_content = self.strategy.render_strategy(gamestate)
gamestate_content = self.strategy.render_gamestate(gamestate)
memory_content = self.strategy.render_memory(
Expand All @@ -228,20 +258,28 @@ async def _get_llm_response(self, gamestate: dict[str, Any]) -> ChatCompletion:
last_failure=self._last_failed_msg,
)

messages = [
content: list[dict[str, Any]] = [
{
"role": "user",
"content": [
{
"type": "text",
"text": strategy_content,
"cache_control": {"type": "ephemeral"},
},
{"type": "text", "text": gamestate_content},
{"type": "text", "text": memory_content},
],
}
"type": "text",
"text": strategy_content,
"cache_control": {"type": "ephemeral"},
},
]
if screenshot_b64:
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{screenshot_b64}"},
}
)
content.extend(
[
{"type": "text", "text": gamestate_content},
{"type": "text", "text": memory_content},
]
)

messages = [{"role": "user", "content": content}]

tools = self.strategy.get_tools(gamestate["state"])

Expand All @@ -263,14 +301,6 @@ async def _get_llm_response(self, gamestate: dict[str, Any]) -> ChatCompletion:
model_config=self.model_config,
)

try:
await self._balatro.call(
"screenshot",
{"path": str(self._collector.screenshot_dir / f"{custom_id}.png")},
)
except BalatroError as e:
logger.warning(f"Screenshot failed: {e}")

self._collector.write_response(
id=str(time.time_ns() // 1_000_000),
custom_id=custom_id,
Expand Down
4 changes: 4 additions & 0 deletions src/balatrollm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--views", action="store_true", help="Start HTTP server on port 12345 for views"
)
parser.add_argument(
"--no-vision", dest="vision", action="store_false", default=None,
help="Disable screenshots (required for non-vision models via Ollama)",
)

return parser

Expand Down
4 changes: 4 additions & 0 deletions src/balatrollm/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def _write_latest_json(self) -> None:
f,
)

def peek_next_custom_id(self) -> str:
"""Return the custom_id that the next write_request call will use, without advancing the counter."""
return f"request-{self._request_count + 1:05}"

def write_request(self, body: dict[str, Any]) -> str:
"""Write request to requests.jsonl. Returns custom_id."""
self._request_count += 1
Expand Down
6 changes: 5 additions & 1 deletion src/balatrollm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@
"base_url": "BALATROLLM_BASE_URL",
"api_key": "BALATROLLM_API_KEY",
"views": "BALATROLLM_VIEWS",
"vision": "BALATROLLM_VISION",
}

################################################################################
# Types for config conversion
################################################################################

BOOL_FIELDS: frozenset[str] = frozenset({"views"})
BOOL_FIELDS: frozenset[str] = frozenset({"views", "vision"})
LIST_FIELDS: frozenset[str] = frozenset({"model", "seed", "deck", "stake", "strategy"})
STRING_FIELDS: frozenset[str] = frozenset({"host", "base_url", "api_key"})
INT_FIELDS: frozenset[str] = frozenset({"parallel", "port"})
Expand Down Expand Up @@ -149,6 +150,8 @@ def _load_from_args(args: Namespace) -> dict[str, Any]:
for field_name in BOOL_FIELDS:
if getattr(args, field_name, False):
result[field_name] = True
if getattr(args, "vision", None) is False:
result["vision"] = False
return result


Expand Down Expand Up @@ -183,6 +186,7 @@ class Config:
# Execution
parallel: int = 1
views: bool = False
vision: bool = True

# Connection
host: str = "127.0.0.1"
Expand Down
46 changes: 45 additions & 1 deletion src/balatrollm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ class LLMRetryExhaustedError(LLMClientError):
pass


def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Remove image_url blocks from message content lists."""
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
filtered = [b for b in content if b.get("type") != "image_url"]
result.append({**msg, "content": filtered})
else:
result.append(msg)
return result


@dataclass
class LLMClient:
"""Async OpenAI client wrapper with retry logic."""
Expand All @@ -38,9 +51,11 @@ class LLMClient:
api_key: str
timeout: float = 240.0 # We assume that LLMs respond in 240s
max_retries: int = 3
vision: bool = True

_client: openai.AsyncOpenAI | None = field(default=None, init=False, repr=False)
_consecutive_timeouts: int = field(default=0, init=False, repr=False)
_vision_supported: bool = field(default=True, init=False, repr=False)

async def __aenter__(self) -> "LLMClient":
"""Create the async OpenAI client."""
Expand All @@ -50,6 +65,7 @@ async def __aenter__(self) -> "LLMClient":
timeout=self.timeout,
)
self._consecutive_timeouts = 0
self._vision_supported = self.vision
return self

async def __aexit__(self, *_: Any) -> None:
Expand All @@ -71,10 +87,15 @@ async def call(
"Client not connected. Use 'async with LLMClient() as client:'"
)

effective_messages = (
_strip_image_content(messages) if not self._vision_supported else messages
)

request_data: dict[str, Any] = {
"model": model,
"messages": messages,
"messages": effective_messages,
"tools": tools,
"tool_choice": "required",
}

if model_config:
Expand All @@ -83,6 +104,7 @@ async def call(

retry_delay = 1.0
last_exception: Exception | None = None
vision_stripped = False

for attempt in range(self.max_retries):
try:
Expand All @@ -108,6 +130,23 @@ async def call(
last_exception = e

except openai.APIStatusError as e:
if e.status_code == 404 and "image input" in str(e).lower() and not vision_stripped:
logger.warning("Model does not support vision — disabling screenshots for this session")
self._vision_supported = False
vision_stripped = True
request_data["messages"] = _strip_image_content(
request_data["messages"]
)
# Retry immediately without consuming a retry slot
try:
response = await self._client.chat.completions.create(**request_data)
self._consecutive_timeouts = 0
if not response.choices:
raise LLMClientError("API returned empty response (no choices)")
return response
except Exception as inner_e:
last_exception = inner_e
continue
logger.error(f"LLM status error ({e.status_code}): {e}")
last_exception = e

Expand All @@ -134,6 +173,11 @@ async def call(
f"All {self.max_retries} retry attempts exhausted"
) from last_exception

@property
def vision_supported(self) -> bool:
"""False after first vision-unsupported 404; screenshots skipped for session."""
return self._vision_supported

@property
def consecutive_timeouts(self) -> int:
"""Get current consecutive timeout count."""
Expand Down
2 changes: 1 addition & 1 deletion src/balatrollm/strategies/aggressive/TOOLS.json
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"function": {
"name": "next_round",
"strict": false,
"description": "Leave the shop and advance to blind selection.",
"description": "Exit the shop and advance to blind selection. You MUST invoke this as a function tool call — never output raw JSON. This is the only valid way to leave the shop; omitting the call or describing it in text will be treated as an error.",
"parameters": {
"type": "object",
"properties": {
Expand Down
2 changes: 1 addition & 1 deletion src/balatrollm/strategies/conservative/TOOLS.json
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"function": {
"name": "next_round",
"strict": false,
"description": "Leave the shop and advance to blind selection.",
"description": "Exit the shop and advance to blind selection. You MUST invoke this as a function tool call — never output raw JSON. This is the only valid way to leave the shop; omitting the call or describing it in text will be treated as an error.",
"parameters": {
"type": "object",
"properties": {
Expand Down
2 changes: 1 addition & 1 deletion src/balatrollm/strategies/default/TOOLS.json
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"function": {
"name": "next_round",
"strict": false,
"description": "Leave the shop and advance to blind selection.",
"description": "Exit the shop and advance to blind selection. You MUST invoke this as a function tool call — never output raw JSON. This is the only valid way to leave the shop; omitting the call or describing it in text will be treated as an error.",
"parameters": {
"type": "object",
"properties": {
Expand Down
Loading