Skip to content

Commit 053aec5

Browse files
committed
Centralize config
1 parent 3c22dab commit 053aec5

3 files changed

Lines changed: 146 additions & 34 deletions

File tree

src/config.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use anyhow::Result;
2+
use serde::{Deserialize, Serialize};
3+
use crate::touch::TriggerCorner;
4+
5+
#[derive(Serialize, Deserialize, Debug, Clone)]
6+
pub struct Config {
7+
// Direct mapping to CLI args - no arbitrary grouping
8+
pub engine: Option<String>,
9+
pub engine_base_url: Option<String>,
10+
pub engine_api_key: Option<String>,
11+
pub model: String,
12+
pub prompt: String,
13+
pub no_submit: bool,
14+
pub no_draw: bool,
15+
pub no_svg: bool,
16+
pub no_keyboard: bool,
17+
pub no_draw_progress: bool,
18+
pub input_png: Option<String>,
19+
pub output_file: Option<String>,
20+
pub model_output_file: Option<String>,
21+
pub save_screenshot: Option<String>,
22+
pub save_bitmap: Option<String>,
23+
pub no_loop: bool,
24+
pub no_trigger: bool,
25+
pub apply_segmentation: bool,
26+
pub web_search: bool,
27+
pub thinking: bool,
28+
pub thinking_tokens: u32,
29+
pub log_level: String,
30+
pub trigger_corner: String,
31+
}
32+
33+
impl Default for Config {
34+
fn default() -> Self {
35+
Self {
36+
engine: None,
37+
engine_base_url: None,
38+
engine_api_key: None,
39+
model: "claude-sonnet-4-0".to_string(),
40+
prompt: "general.json".to_string(),
41+
no_submit: false,
42+
no_draw: false,
43+
no_svg: false,
44+
no_keyboard: false,
45+
no_draw_progress: false,
46+
input_png: None,
47+
output_file: None,
48+
model_output_file: None,
49+
save_screenshot: None,
50+
save_bitmap: None,
51+
no_loop: false,
52+
no_trigger: false,
53+
apply_segmentation: false,
54+
web_search: false,
55+
thinking: false,
56+
thinking_tokens: 5000,
57+
log_level: "info".to_string(),
58+
trigger_corner: "UR".to_string(),
59+
}
60+
}
61+
}
62+
63+
impl Config {
64+
/// Validate the configuration and return any errors
65+
pub fn validate(&self) -> Result<()> {
66+
// Validate trigger corner
67+
TriggerCorner::from_string(&self.trigger_corner)?;
68+
69+
// Validate log level
70+
match self.log_level.as_str() {
71+
"error" | "warn" | "info" | "debug" | "trace" => {},
72+
_ => return Err(anyhow::anyhow!("Invalid log level: {}", self.log_level)),
73+
}
74+
75+
// Validate thinking tokens
76+
if self.thinking_tokens == 0 {
77+
return Err(anyhow::anyhow!("thinking_tokens must be greater than 0"));
78+
}
79+
80+
Ok(())
81+
}
82+
83+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod config;
12
pub mod device;
23
pub mod embedded_assets;
34
pub mod keyboard;

src/main.rs

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::thread::sleep;
1010
use std::time::Duration;
1111

1212
use ghostwriter::{
13+
config::Config,
1314
embedded_assets::load_config,
1415
keyboard::Keyboard,
1516
llm_engine::{anthropic::Anthropic, google::Google, openai::OpenAI, LLMEngine},
@@ -31,7 +32,7 @@ const VIRTUAL_HEIGHT: u32 = 1024;
3132
long_about = "Ghostwriter is an exploration of how to interact with vision-LLM through the handwritten medium of the reMarkable2. It is a pluggable system; you can provide a custom prompt and custom 'tools' that the agent can use."
3233
)]
3334
#[command(after_help = "See https://github.com/awwaiid/ghostwriter for updates!")]
34-
struct Args {
35+
pub struct Args {
3536
/// Sets the engine to use (openai, anthropic);
3637
/// Sometimes we can guess the engine from the model name
3738
#[arg(long)]
@@ -209,10 +210,37 @@ fn create_engine(engine_name: &str, engine_options: &OptionMap) -> Result<Box<dy
209210
}
210211

211212
fn ghostwriter(args: &Args) -> Result<()> {
212-
let trigger_corner = TriggerCorner::from_string(&args.trigger_corner)?;
213-
let keyboard = shared!(Keyboard::new(args.no_draw || args.no_keyboard, args.no_draw_progress,));
214-
let pen = shared!(Pen::new(args.no_draw));
215-
let touch = shared!(Touch::new(args.no_draw, trigger_corner));
213+
let config = Config {
214+
engine: args.engine.clone(),
215+
engine_base_url: args.engine_base_url.clone(),
216+
engine_api_key: args.engine_api_key.clone(),
217+
model: args.model.clone(),
218+
prompt: args.prompt.clone(),
219+
no_submit: args.no_submit,
220+
no_draw: args.no_draw,
221+
no_svg: args.no_svg,
222+
no_keyboard: args.no_keyboard,
223+
no_draw_progress: args.no_draw_progress,
224+
input_png: args.input_png.clone(),
225+
output_file: args.output_file.clone(),
226+
model_output_file: args.model_output_file.clone(),
227+
save_screenshot: args.save_screenshot.clone(),
228+
save_bitmap: args.save_bitmap.clone(),
229+
no_loop: args.no_loop,
230+
no_trigger: args.no_trigger,
231+
apply_segmentation: args.apply_segmentation,
232+
web_search: args.web_search,
233+
thinking: args.thinking,
234+
thinking_tokens: args.thinking_tokens,
235+
log_level: args.log_level.clone(),
236+
trigger_corner: args.trigger_corner.clone(),
237+
};
238+
config.validate()?;
239+
240+
let trigger_corner = TriggerCorner::from_string(&config.trigger_corner)?;
241+
let keyboard = shared!(Keyboard::new(config.no_draw || config.no_keyboard, config.no_draw_progress,));
242+
let pen = shared!(Pen::new(config.no_draw));
243+
let touch = shared!(Touch::new(config.no_draw, trigger_corner));
216244

217245
// Give time for the virtual keyboard to be plugged in
218246
sleep(Duration::from_millis(1000));
@@ -224,37 +252,37 @@ fn ghostwriter(args: &Args) -> Result<()> {
224252

225253
let mut engine_options = OptionMap::new();
226254

227-
let model = args.model.clone();
255+
let model = config.model.clone();
228256
engine_options.insert("model".to_string(), model.clone());
229257
debug!("Model: {}", model);
230258

231-
let engine_name = determine_engine_name(&args.engine, &model)?;
259+
let engine_name = determine_engine_name(&config.engine, &model)?;
232260
debug!("Engine: {}", engine_name);
233261

234-
if args.engine_base_url.is_some() {
235-
debug!("Engine base URL: {}", args.engine_base_url.clone().unwrap());
236-
engine_options.insert("base_url".to_string(), args.engine_base_url.clone().unwrap());
262+
if config.engine_base_url.is_some() {
263+
debug!("Engine base URL: {}", config.engine_base_url.clone().unwrap());
264+
engine_options.insert("base_url".to_string(), config.engine_base_url.clone().unwrap());
237265
}
238-
if args.engine_api_key.is_some() {
266+
if config.engine_api_key.is_some() {
239267
debug!("Using API key from CLI args");
240-
engine_options.insert("api_key".to_string(), args.engine_api_key.clone().unwrap());
268+
engine_options.insert("api_key".to_string(), config.engine_api_key.clone().unwrap());
241269
}
242270

243-
if args.web_search {
271+
if config.web_search {
244272
debug!("Web search tool enabled");
245273
engine_options.insert("web_search".to_string(), "true".to_string());
246274
}
247275

248-
if args.thinking {
249-
debug!("Thinking enabled with budget: {}", args.thinking_tokens);
276+
if config.thinking {
277+
debug!("Thinking enabled with budget: {}", config.thinking_tokens);
250278
engine_options.insert("thinking".to_string(), "true".to_string());
251-
engine_options.insert("thinking_tokens".to_string(), args.thinking_tokens.to_string());
279+
engine_options.insert("thinking_tokens".to_string(), config.thinking_tokens.to_string());
252280
}
253281

254282
let mut engine = create_engine(&engine_name, &engine_options)?;
255283

256-
let output_file = args.output_file.clone();
257-
let no_draw = args.no_draw;
284+
let output_file = config.output_file.clone();
285+
let no_draw = config.no_draw;
258286
let keyboard_clone = Arc::clone(&keyboard);
259287

260288
let tool_config_draw_text = load_config("tool_draw_text.json");
@@ -284,13 +312,13 @@ fn ghostwriter(args: &Args) -> Result<()> {
284312
}),
285313
);
286314

287-
let output_file = args.output_file.clone();
288-
let save_bitmap = args.save_bitmap.clone();
289-
let no_draw = args.no_draw;
315+
let output_file = config.output_file.clone();
316+
let save_bitmap = config.save_bitmap.clone();
317+
let no_draw = config.no_draw;
290318
let keyboard_clone = Arc::clone(&keyboard);
291319
let pen_clone = Arc::clone(&pen);
292320

293-
if !args.no_svg {
321+
if !config.no_svg {
294322
let tool_config_draw_svg = load_config("tool_draw_svg.json");
295323
engine.register_tool(
296324
"draw_svg",
@@ -323,12 +351,12 @@ fn ghostwriter(args: &Args) -> Result<()> {
323351
sleep(Duration::from_millis(1000));
324352

325353
loop {
326-
if args.no_trigger {
354+
if config.no_trigger {
327355
debug!("Skipping waiting for trigger");
328356
} else {
329357
info!(
330358
"Waiting for trigger (hand-touch in the {} corner)...",
331-
match TriggerCorner::from_string(&args.trigger_corner).unwrap() {
359+
match TriggerCorner::from_string(&config.trigger_corner).unwrap() {
332360
TriggerCorner::UpperRight => "upper-right",
333361
TriggerCorner::UpperLeft => "upper-left",
334362
TriggerCorner::LowerRight => "lower-right",
@@ -345,34 +373,34 @@ fn ghostwriter(args: &Args) -> Result<()> {
345373
// lock!(keyboard).progress("Taking screenshot...")?;
346374

347375
info!("Getting screenshot (or loading input image)");
348-
let base64_image = if let Some(input_png) = &args.input_png {
376+
let base64_image = if let Some(input_png) = &config.input_png {
349377
BASE64_STANDARD.encode(std::fs::read(input_png)?)
350378
} else {
351379
let mut screenshot = Screenshot::new()?;
352380
screenshot.take_screenshot()?;
353-
if let Some(save_screenshot) = &args.save_screenshot {
381+
if let Some(save_screenshot) = &config.save_screenshot {
354382
info!("Saving screenshot to {}", save_screenshot);
355383
screenshot.save_image(save_screenshot)?;
356384
}
357385
screenshot.base64()?
358386
};
359387

360-
if args.no_submit {
388+
if config.no_submit {
361389
info!("Image not submitted to model due to --no-submit flag");
362390
lock!(keyboard).progress_end()?;
363391
return Ok(());
364392
}
365393

366-
let prompt_general_raw = load_config(&args.prompt);
394+
let prompt_general_raw = load_config(&config.prompt);
367395
let prompt_general_json = serde_json::from_str::<serde_json::Value>(prompt_general_raw.as_str())?;
368396
let prompt = prompt_general_json["prompt"].as_str()
369-
.ok_or_else(|| anyhow::anyhow!("Prompt file '{}' missing required 'prompt' field", args.prompt))?;
397+
.ok_or_else(|| anyhow::anyhow!("Prompt file '{}' missing required 'prompt' field", config.prompt))?;
370398

371-
let segmentation_description = if args.apply_segmentation {
399+
let segmentation_description = if config.apply_segmentation {
372400
info!("Building image segmentation");
373401
lock!(keyboard).progress("segmenting...")?;
374-
let input_filename = args.input_png.clone()
375-
.or_else(|| args.save_screenshot.clone())
402+
let input_filename = config.input_png.clone()
403+
.or_else(|| config.save_screenshot.clone())
376404
.ok_or_else(|| anyhow::anyhow!("Segmentation requires either --input-png or --save-screenshot to be specified"))?;
377405
match analyze_image(input_filename.as_str()) {
378406
Ok(description) => description,
@@ -386,7 +414,7 @@ fn ghostwriter(args: &Args) -> Result<()> {
386414
engine.clear_content();
387415
engine.add_image_content(&base64_image);
388416

389-
if args.apply_segmentation {
417+
if config.apply_segmentation {
390418
engine.add_text_content(
391419
format!("Here are interesting regions based on an automatic segmentation algorithm. Use them to help identify the exact location of interesting features.\n\n{}", segmentation_description).as_str()
392420
);
@@ -400,7 +428,7 @@ fn ghostwriter(args: &Args) -> Result<()> {
400428
lock!(keyboard).progress(" model error. ")?;
401429
}
402430

403-
if args.no_loop {
431+
if config.no_loop {
404432
break Ok(());
405433
}
406434
}

0 commit comments

Comments
 (0)