Skip to content

Commit 48ee91b

Browse files
committed
fix: isolate custom and openai provider key/config handling
- Decouple custom and openai API key cache entries (no more custom->openai remap) - Store custom endpoint config under dedicated keys (ai_custom_base_url, ai_custom_no_auth) - Frontend uses provider 'custom' instead of aliasing to 'openai' in UI state - Key removal only clears active selection when the removed provider matches active - Backend readiness fallback for custom/openai no-auth configs without keyring keys - Legacy config migration support for existing openai-compatible setups - Add regression tests for cache isolation, no-auth toggle, and cross-provider removal
1 parent e5fa5d9 commit 48ee91b

File tree

5 files changed

+474
-70
lines changed

5 files changed

+474
-70
lines changed

src-tauri/src/commands/ai.rs

Lines changed: 165 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ use tauri_plugin_store::StoreExt;
1212
static API_KEY_CACHE: Lazy<Mutex<HashMap<String, String>>> =
1313
Lazy::new(|| Mutex::new(HashMap::new()));
1414

15+
const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
16+
const CUSTOM_BASE_URL_KEY: &str = "ai_custom_base_url";
17+
const CUSTOM_NO_AUTH_KEY: &str = "ai_custom_no_auth";
18+
const LEGACY_OPENAI_BASE_URL_KEY: &str = "ai_openai_base_url";
19+
const LEGACY_OPENAI_NO_AUTH_KEY: &str = "ai_openai_no_auth";
20+
1521
// Helper: determine if we should consider that the app "has an API key" for a provider
1622
// For OpenAI-compatible providers, a configured no_auth=true also counts as "has key"
1723
fn check_has_api_key<R: tauri::Runtime>(
@@ -20,7 +26,10 @@ fn check_has_api_key<R: tauri::Runtime>(
2026
cache: &HashMap<String, String>,
2127
) -> bool {
2228
if provider == "openai" {
23-
let configured_base = store.get("ai_openai_base_url").is_some();
29+
cache.contains_key("ai_api_key_openai") || store.get(LEGACY_OPENAI_BASE_URL_KEY).is_some()
30+
} else if provider == "custom" {
31+
let configured_base = store.get(CUSTOM_BASE_URL_KEY).is_some()
32+
|| store.get(LEGACY_OPENAI_BASE_URL_KEY).is_some();
2433
configured_base || cache.contains_key(&format!("ai_api_key_{}", provider))
2534
} else {
2635
cache.contains_key(&format!("ai_api_key_{}", provider))
@@ -50,7 +59,7 @@ lazy_static::lazy_static! {
5059
}
5160

5261
// Supported AI providers
53-
const ALLOWED_PROVIDERS: &[&str] = &["gemini", "openai", "anthropic"];
62+
const ALLOWED_PROVIDERS: &[&str] = &["gemini", "openai", "anthropic", "custom"];
5463

5564
fn validate_provider_name(provider: &str) -> Result<(), String> {
5665
// First check format
@@ -219,13 +228,21 @@ pub async fn validate_and_cache_api_key(
219228
let provided_key = api_key.clone().unwrap_or_default();
220229
let inferred_no_auth = no_auth.unwrap_or(false) || provided_key.trim().is_empty();
221230

222-
if provider == "openai" {
231+
if provider == "openai" || provider == "custom" {
223232
let store = app.store("settings").map_err(|e| e.to_string())?;
224233
if let Some(url) = base_url.clone() {
225-
store.set("ai_openai_base_url", serde_json::Value::String(url));
234+
if provider == "custom" {
235+
store.set(CUSTOM_BASE_URL_KEY, serde_json::Value::String(url));
236+
} else {
237+
store.set(LEGACY_OPENAI_BASE_URL_KEY, serde_json::Value::String(url));
238+
}
226239
}
227240
store.set(
228-
"ai_openai_no_auth",
241+
if provider == "custom" {
242+
CUSTOM_NO_AUTH_KEY
243+
} else {
244+
LEGACY_OPENAI_NO_AUTH_KEY
245+
},
229246
serde_json::Value::Bool(inferred_no_auth),
230247
);
231248
if let Some(m) = model.clone() {
@@ -236,10 +253,28 @@ pub async fn validate_and_cache_api_key(
236253
.map_err(|e| format!("Failed to save AI settings: {}", e))?;
237254
}
238255

239-
if provider == "openai" {
256+
if provider == "openai" || provider == "custom" {
257+
let store = app.store("settings").map_err(|e| e.to_string())?;
258+
240259
let base = base_url
241260
.clone()
242-
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
261+
.or_else(|| {
262+
if provider == "custom" {
263+
store
264+
.get(CUSTOM_BASE_URL_KEY)
265+
.and_then(|v| v.as_str().map(|s| s.to_string()))
266+
.or_else(|| {
267+
store
268+
.get(LEGACY_OPENAI_BASE_URL_KEY)
269+
.and_then(|v| v.as_str().map(|s| s.to_string()))
270+
})
271+
} else {
272+
store
273+
.get(LEGACY_OPENAI_BASE_URL_KEY)
274+
.and_then(|v| v.as_str().map(|s| s.to_string()))
275+
}
276+
})
277+
.unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
243278
let validate_url = normalize_chat_completions_url(&base);
244279

245280
let client = reqwest::Client::new();
@@ -404,15 +439,16 @@ pub async fn update_ai_settings(
404439

405440
// Check if API key exists when enabling
406441
if enabled {
407-
if provider == "openai" {
442+
if provider == "custom" {
408443
let store = app.store("settings").map_err(|e| e.to_string())?;
409444
let cache_has_key = {
410445
let cache = API_KEY_CACHE
411446
.lock()
412447
.map_err(|_| "Failed to access cache".to_string())?;
413-
cache.contains_key(&format!("ai_api_key_{}", provider))
448+
cache.contains_key("ai_api_key_custom")
414449
};
415-
let configured_base = store.get("ai_openai_base_url").is_some();
450+
let configured_base = store.get(CUSTOM_BASE_URL_KEY).is_some()
451+
|| store.get(LEGACY_OPENAI_BASE_URL_KEY).is_some();
416452

417453
if !(cache_has_key || configured_base) {
418454
log::warn!(
@@ -421,6 +457,23 @@ pub async fn update_ai_settings(
421457
);
422458
return Err("API key not found. Please add an API key first.".to_string());
423459
}
460+
} else if provider == "openai" {
461+
let store = app.store("settings").map_err(|e| e.to_string())?;
462+
let cache_has_key = {
463+
let cache = API_KEY_CACHE
464+
.lock()
465+
.map_err(|_| "Failed to access cache".to_string())?;
466+
cache.contains_key("ai_api_key_openai")
467+
};
468+
let legacy_custom_config = store.get(LEGACY_OPENAI_BASE_URL_KEY).is_some();
469+
470+
if !(cache_has_key || legacy_custom_config) {
471+
log::warn!(
472+
"Attempted to enable AI enhancement without cached API key for provider: {}",
473+
provider
474+
);
475+
return Err("API key not found. Please add an API key first.".to_string());
476+
}
424477
} else {
425478
let cache_has_key = {
426479
let cache = API_KEY_CACHE
@@ -554,25 +607,73 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
554607
}
555608

556609
// Determine provider-specific config
557-
let (api_key, options) = if provider == "openai" {
610+
let (factory_provider, api_key, options) = if provider == "openai" {
611+
let cache = API_KEY_CACHE.lock().map_err(|e| {
612+
log::error!("Failed to access API key cache: {}", e);
613+
"Failed to access cache".to_string()
614+
})?;
615+
616+
let openai_cached = cache.get("ai_api_key_openai").cloned();
617+
let custom_cached = cache.get("ai_api_key_custom").cloned();
618+
drop(cache);
619+
620+
if let Some(cached) = openai_cached {
621+
let mut opts = std::collections::HashMap::new();
622+
opts.insert(
623+
"base_url".into(),
624+
serde_json::Value::String(DEFAULT_OPENAI_BASE_URL.to_string()),
625+
);
626+
opts.insert("no_auth".into(), serde_json::Value::Bool(false));
627+
628+
("openai".to_string(), cached, opts)
629+
} else if let Some(legacy_base_url) = store
630+
.get(LEGACY_OPENAI_BASE_URL_KEY)
631+
.and_then(|v| v.as_str().map(|s| s.to_string()))
632+
{
633+
log::warn!("Using legacy OpenAI-compatible configuration for openai provider");
634+
let mut opts = std::collections::HashMap::new();
635+
opts.insert(
636+
"base_url".into(),
637+
serde_json::Value::String(legacy_base_url),
638+
);
639+
opts.insert(
640+
"no_auth".into(),
641+
serde_json::Value::Bool(custom_cached.is_none()),
642+
);
643+
644+
(
645+
"openai".to_string(),
646+
custom_cached.unwrap_or_default(),
647+
opts,
648+
)
649+
} else {
650+
log::error!(
651+
"API key not found in cache for OpenAI provider. Cache keys unavailable for OpenAI path"
652+
);
653+
return Err("API key not found in cache".to_string());
654+
}
655+
} else if provider == "custom" {
558656
let base_url = store
559-
.get("ai_openai_base_url")
657+
.get(CUSTOM_BASE_URL_KEY)
560658
.and_then(|v| v.as_str().map(|s| s.to_string()))
561-
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
659+
.or_else(|| {
660+
store
661+
.get(LEGACY_OPENAI_BASE_URL_KEY)
662+
.and_then(|v| v.as_str().map(|s| s.to_string()))
663+
})
664+
.unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
562665

563-
// Send Authorization only if a key is cached
564666
let cache = API_KEY_CACHE.lock().map_err(|e| {
565667
log::error!("Failed to access API key cache: {}", e);
566668
"Failed to access cache".to_string()
567669
})?;
568-
let key_name = format!("ai_api_key_{}", provider);
569-
let cached = cache.get(&key_name).cloned();
570670

571-
// Log detailed information about API key lookup
671+
let cached = cache.get("ai_api_key_custom").cloned();
672+
572673
if cached.is_some() {
573-
log::info!("Using cached API key for OpenAI provider");
674+
log::info!("Using cached API key for custom provider");
574675
} else {
575-
log::warn!("No cached API key found for OpenAI provider, using no-auth mode");
676+
log::warn!("No cached API key found for custom provider, using no-auth mode");
576677
log::debug!(
577678
"Available cache keys: {:?}",
578679
cache.keys().collect::<Vec<_>>()
@@ -584,7 +685,7 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
584685
opts.insert("base_url".into(), serde_json::Value::String(base_url));
585686
opts.insert("no_auth".into(), serde_json::Value::Bool(cached.is_none()));
586687

587-
(cached.unwrap_or_default(), opts)
688+
("openai".to_string(), cached.unwrap_or_default(), opts)
588689
} else if provider == "gemini" || provider == "anthropic" {
589690
// Require API key from in-memory cache
590691
let cache = API_KEY_CACHE
@@ -600,7 +701,7 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
600701
"API key not found in cache".to_string()
601702
})?;
602703

603-
(api_key, std::collections::HashMap::new())
704+
(provider.clone(), api_key, std::collections::HashMap::new())
604705
} else {
605706
return Err("Unsupported provider".to_string());
606707
};
@@ -629,7 +730,7 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
629730

630731
// Create provider config
631732
let config = AIProviderConfig {
632-
provider,
733+
provider: factory_provider,
633734
model,
634735
api_key,
635736
enabled: true,
@@ -688,12 +789,12 @@ pub async fn set_openai_config(
688789
) -> Result<(), String> {
689790
let store = app.store("settings").map_err(|e| e.to_string())?;
690791
store.set(
691-
"ai_openai_base_url",
792+
CUSTOM_BASE_URL_KEY,
692793
serde_json::Value::String(args.base_url),
693794
);
694795
if let Some(no_auth) = args.no_auth {
695796
// Backward-compatibility: accept but not required
696-
store.set("ai_openai_no_auth", serde_json::Value::Bool(no_auth));
797+
store.set(CUSTOM_NO_AUTH_KEY, serde_json::Value::Bool(no_auth));
697798
}
698799
store
699800
.save()
@@ -705,12 +806,22 @@ pub async fn set_openai_config(
705806
pub async fn get_openai_config(app: tauri::AppHandle) -> Result<OpenAIConfig, String> {
706807
let store = app.store("settings").map_err(|e| e.to_string())?;
707808
let base_url = store
708-
.get("ai_openai_base_url")
809+
.get(CUSTOM_BASE_URL_KEY)
709810
.and_then(|v| v.as_str().map(|s| s.to_string()))
710-
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
811+
.or_else(|| {
812+
store
813+
.get(LEGACY_OPENAI_BASE_URL_KEY)
814+
.and_then(|v| v.as_str().map(|s| s.to_string()))
815+
})
816+
.unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
711817
let no_auth = store
712-
.get("ai_openai_no_auth")
818+
.get(CUSTOM_NO_AUTH_KEY)
713819
.and_then(|v| v.as_bool())
820+
.or_else(|| {
821+
store
822+
.get(LEGACY_OPENAI_NO_AUTH_KEY)
823+
.and_then(|v| v.as_bool())
824+
})
714825
.unwrap_or(false);
715826
Ok(OpenAIConfig { base_url, no_auth })
716827
}
@@ -777,11 +888,18 @@ pub async fn list_provider_models(
777888
) -> Result<Vec<ProviderModel>, String> {
778889
// Validate provider
779890
if !["openai", "anthropic", "gemini"].contains(&provider.as_str()) {
780-
return Err(format!("Unsupported provider for model listing: {}", provider));
891+
return Err(format!(
892+
"Unsupported provider for model listing: {}",
893+
provider
894+
));
781895
}
782896

783897
let models = get_curated_models(&provider);
784-
log::info!("Returning {} curated models for provider {}", models.len(), provider);
898+
log::info!(
899+
"Returning {} curated models for provider {}",
900+
models.len(),
901+
provider
902+
);
785903

786904
Ok(models)
787905
}
@@ -796,10 +914,11 @@ mod tests {
796914
assert!(validate_provider_name("gemini").is_ok());
797915
assert!(validate_provider_name("openai").is_ok());
798916
assert!(validate_provider_name("anthropic").is_ok());
799-
917+
assert!(validate_provider_name("custom").is_ok());
918+
800919
// Groq is no longer supported
801920
assert!(validate_provider_name("groq").is_err());
802-
921+
803922
// Invalid formats
804923
assert!(validate_provider_name("test-provider").is_err());
805924
assert!(validate_provider_name("test_provider").is_err());
@@ -815,20 +934,28 @@ mod tests {
815934
assert_eq!(openai_models.len(), 2);
816935
assert!(openai_models.iter().any(|m| m.id == "gpt-5-nano"));
817936
assert!(openai_models.iter().any(|m| m.id == "gpt-5-mini"));
818-
937+
819938
// Anthropic models
820939
let anthropic_models = get_curated_models("anthropic");
821940
assert_eq!(anthropic_models.len(), 2);
822-
assert!(anthropic_models.iter().any(|m| m.id == "claude-haiku-4-5-latest"));
823-
assert!(anthropic_models.iter().any(|m| m.id == "claude-sonnet-4-5-latest"));
824-
941+
assert!(anthropic_models
942+
.iter()
943+
.any(|m| m.id == "claude-haiku-4-5-latest"));
944+
assert!(anthropic_models
945+
.iter()
946+
.any(|m| m.id == "claude-sonnet-4-5-latest"));
947+
825948
// Gemini models
826949
let gemini_models = get_curated_models("gemini");
827950
assert_eq!(gemini_models.len(), 3);
828-
assert!(gemini_models.iter().any(|m| m.id == "gemini-3-flash-preview"));
951+
assert!(gemini_models
952+
.iter()
953+
.any(|m| m.id == "gemini-3-flash-preview"));
829954
assert!(gemini_models.iter().any(|m| m.id == "gemini-2.5-flash"));
830-
assert!(gemini_models.iter().any(|m| m.id == "gemini-2.5-flash-lite"));
831-
955+
assert!(gemini_models
956+
.iter()
957+
.any(|m| m.id == "gemini-2.5-flash-lite"));
958+
832959
// Unknown provider returns empty list
833960
let unknown_models = get_curated_models("unknown");
834961
assert!(unknown_models.is_empty());

0 commit comments

Comments
 (0)