@@ -12,6 +12,12 @@ use tauri_plugin_store::StoreExt;
1212static API_KEY_CACHE : Lazy < Mutex < HashMap < String , String > > > =
1313 Lazy :: new ( || Mutex :: new ( HashMap :: new ( ) ) ) ;
1414
15+ const DEFAULT_OPENAI_BASE_URL : & str = "https://api.openai.com/v1" ;
16+ const CUSTOM_BASE_URL_KEY : & str = "ai_custom_base_url" ;
17+ const CUSTOM_NO_AUTH_KEY : & str = "ai_custom_no_auth" ;
18+ const LEGACY_OPENAI_BASE_URL_KEY : & str = "ai_openai_base_url" ;
19+ const LEGACY_OPENAI_NO_AUTH_KEY : & str = "ai_openai_no_auth" ;
20+
1521// Helper: determine if we should consider that the app "has an API key" for a provider
1622// For OpenAI-compatible providers, a configured no_auth=true also counts as "has key"
1723fn check_has_api_key < R : tauri:: Runtime > (
@@ -20,7 +26,10 @@ fn check_has_api_key<R: tauri::Runtime>(
2026 cache : & HashMap < String , String > ,
2127) -> bool {
2228 if provider == "openai" {
23- let configured_base = store. get ( "ai_openai_base_url" ) . is_some ( ) ;
29+ cache. contains_key ( "ai_api_key_openai" ) || store. get ( LEGACY_OPENAI_BASE_URL_KEY ) . is_some ( )
30+ } else if provider == "custom" {
31+ let configured_base = store. get ( CUSTOM_BASE_URL_KEY ) . is_some ( )
32+ || store. get ( LEGACY_OPENAI_BASE_URL_KEY ) . is_some ( ) ;
2433 configured_base || cache. contains_key ( & format ! ( "ai_api_key_{}" , provider) )
2534 } else {
2635 cache. contains_key ( & format ! ( "ai_api_key_{}" , provider) )
@@ -50,7 +59,7 @@ lazy_static::lazy_static! {
5059}
5160
5261// Supported AI providers
53- const ALLOWED_PROVIDERS : & [ & str ] = & [ "gemini" , "openai" , "anthropic" ] ;
62+ const ALLOWED_PROVIDERS : & [ & str ] = & [ "gemini" , "openai" , "anthropic" , "custom" ] ;
5463
5564fn validate_provider_name ( provider : & str ) -> Result < ( ) , String > {
5665 // First check format
@@ -219,13 +228,21 @@ pub async fn validate_and_cache_api_key(
219228 let provided_key = api_key. clone ( ) . unwrap_or_default ( ) ;
220229 let inferred_no_auth = no_auth. unwrap_or ( false ) || provided_key. trim ( ) . is_empty ( ) ;
221230
222- if provider == "openai" {
231+ if provider == "openai" || provider == "custom" {
223232 let store = app. store ( "settings" ) . map_err ( |e| e. to_string ( ) ) ?;
224233 if let Some ( url) = base_url. clone ( ) {
225- store. set ( "ai_openai_base_url" , serde_json:: Value :: String ( url) ) ;
234+ if provider == "custom" {
235+ store. set ( CUSTOM_BASE_URL_KEY , serde_json:: Value :: String ( url) ) ;
236+ } else {
237+ store. set ( LEGACY_OPENAI_BASE_URL_KEY , serde_json:: Value :: String ( url) ) ;
238+ }
226239 }
227240 store. set (
228- "ai_openai_no_auth" ,
241+ if provider == "custom" {
242+ CUSTOM_NO_AUTH_KEY
243+ } else {
244+ LEGACY_OPENAI_NO_AUTH_KEY
245+ } ,
229246 serde_json:: Value :: Bool ( inferred_no_auth) ,
230247 ) ;
231248 if let Some ( m) = model. clone ( ) {
@@ -236,10 +253,28 @@ pub async fn validate_and_cache_api_key(
236253 . map_err ( |e| format ! ( "Failed to save AI settings: {}" , e) ) ?;
237254 }
238255
239- if provider == "openai" {
256+ if provider == "openai" || provider == "custom" {
257+ let store = app. store ( "settings" ) . map_err ( |e| e. to_string ( ) ) ?;
258+
240259 let base = base_url
241260 . clone ( )
242- . unwrap_or_else ( || "https://api.openai.com/v1" . to_string ( ) ) ;
261+ . or_else ( || {
262+ if provider == "custom" {
263+ store
264+ . get ( CUSTOM_BASE_URL_KEY )
265+ . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
266+ . or_else ( || {
267+ store
268+ . get ( LEGACY_OPENAI_BASE_URL_KEY )
269+ . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
270+ } )
271+ } else {
272+ store
273+ . get ( LEGACY_OPENAI_BASE_URL_KEY )
274+ . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
275+ }
276+ } )
277+ . unwrap_or_else ( || DEFAULT_OPENAI_BASE_URL . to_string ( ) ) ;
243278 let validate_url = normalize_chat_completions_url ( & base) ;
244279
245280 let client = reqwest:: Client :: new ( ) ;
@@ -404,15 +439,16 @@ pub async fn update_ai_settings(
404439
405440 // Check if API key exists when enabling
406441 if enabled {
407- if provider == "openai " {
442+ if provider == "custom " {
408443 let store = app. store ( "settings" ) . map_err ( |e| e. to_string ( ) ) ?;
409444 let cache_has_key = {
410445 let cache = API_KEY_CACHE
411446 . lock ( )
412447 . map_err ( |_| "Failed to access cache" . to_string ( ) ) ?;
413- cache. contains_key ( & format ! ( "ai_api_key_{}" , provider ) )
448+ cache. contains_key ( "ai_api_key_custom" )
414449 } ;
415- let configured_base = store. get ( "ai_openai_base_url" ) . is_some ( ) ;
450+ let configured_base = store. get ( CUSTOM_BASE_URL_KEY ) . is_some ( )
451+ || store. get ( LEGACY_OPENAI_BASE_URL_KEY ) . is_some ( ) ;
416452
417453 if !( cache_has_key || configured_base) {
418454 log:: warn!(
@@ -421,6 +457,23 @@ pub async fn update_ai_settings(
421457 ) ;
422458 return Err ( "API key not found. Please add an API key first." . to_string ( ) ) ;
423459 }
460+ } else if provider == "openai" {
461+ let store = app. store ( "settings" ) . map_err ( |e| e. to_string ( ) ) ?;
462+ let cache_has_key = {
463+ let cache = API_KEY_CACHE
464+ . lock ( )
465+ . map_err ( |_| "Failed to access cache" . to_string ( ) ) ?;
466+ cache. contains_key ( "ai_api_key_openai" )
467+ } ;
468+ let legacy_custom_config = store. get ( LEGACY_OPENAI_BASE_URL_KEY ) . is_some ( ) ;
469+
470+ if !( cache_has_key || legacy_custom_config) {
471+ log:: warn!(
472+ "Attempted to enable AI enhancement without cached API key for provider: {}" ,
473+ provider
474+ ) ;
475+ return Err ( "API key not found. Please add an API key first." . to_string ( ) ) ;
476+ }
424477 } else {
425478 let cache_has_key = {
426479 let cache = API_KEY_CACHE
@@ -554,25 +607,73 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
554607 }
555608
556609 // Determine provider-specific config
557- let ( api_key, options) = if provider == "openai" {
610+ let ( factory_provider, api_key, options) = if provider == "openai" {
611+ let cache = API_KEY_CACHE . lock ( ) . map_err ( |e| {
612+ log:: error!( "Failed to access API key cache: {}" , e) ;
613+ "Failed to access cache" . to_string ( )
614+ } ) ?;
615+
616+ let openai_cached = cache. get ( "ai_api_key_openai" ) . cloned ( ) ;
617+ let custom_cached = cache. get ( "ai_api_key_custom" ) . cloned ( ) ;
618+ drop ( cache) ;
619+
620+ if let Some ( cached) = openai_cached {
621+ let mut opts = std:: collections:: HashMap :: new ( ) ;
622+ opts. insert (
623+ "base_url" . into ( ) ,
624+ serde_json:: Value :: String ( DEFAULT_OPENAI_BASE_URL . to_string ( ) ) ,
625+ ) ;
626+ opts. insert ( "no_auth" . into ( ) , serde_json:: Value :: Bool ( false ) ) ;
627+
628+ ( "openai" . to_string ( ) , cached, opts)
629+ } else if let Some ( legacy_base_url) = store
630+ . get ( LEGACY_OPENAI_BASE_URL_KEY )
631+ . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
632+ {
633+ log:: warn!( "Using legacy OpenAI-compatible configuration for openai provider" ) ;
634+ let mut opts = std:: collections:: HashMap :: new ( ) ;
635+ opts. insert (
636+ "base_url" . into ( ) ,
637+ serde_json:: Value :: String ( legacy_base_url) ,
638+ ) ;
639+ opts. insert (
640+ "no_auth" . into ( ) ,
641+ serde_json:: Value :: Bool ( custom_cached. is_none ( ) ) ,
642+ ) ;
643+
644+ (
645+ "openai" . to_string ( ) ,
646+ custom_cached. unwrap_or_default ( ) ,
647+ opts,
648+ )
649+ } else {
650+ log:: error!(
651+ "API key not found in cache for OpenAI provider. Cache keys unavailable for OpenAI path"
652+ ) ;
653+ return Err ( "API key not found in cache" . to_string ( ) ) ;
654+ }
655+ } else if provider == "custom" {
558656 let base_url = store
559- . get ( "ai_openai_base_url" )
657+ . get ( CUSTOM_BASE_URL_KEY )
560658 . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
561- . unwrap_or_else ( || "https://api.openai.com/v1" . to_string ( ) ) ;
659+ . or_else ( || {
660+ store
661+ . get ( LEGACY_OPENAI_BASE_URL_KEY )
662+ . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
663+ } )
664+ . unwrap_or_else ( || DEFAULT_OPENAI_BASE_URL . to_string ( ) ) ;
562665
563- // Send Authorization only if a key is cached
564666 let cache = API_KEY_CACHE . lock ( ) . map_err ( |e| {
565667 log:: error!( "Failed to access API key cache: {}" , e) ;
566668 "Failed to access cache" . to_string ( )
567669 } ) ?;
568- let key_name = format ! ( "ai_api_key_{}" , provider) ;
569- let cached = cache. get ( & key_name) . cloned ( ) ;
570670
571- // Log detailed information about API key lookup
671+ let cached = cache. get ( "ai_api_key_custom" ) . cloned ( ) ;
672+
572673 if cached. is_some ( ) {
573- log:: info!( "Using cached API key for OpenAI provider" ) ;
674+ log:: info!( "Using cached API key for custom provider" ) ;
574675 } else {
575- log:: warn!( "No cached API key found for OpenAI provider, using no-auth mode" ) ;
676+ log:: warn!( "No cached API key found for custom provider, using no-auth mode" ) ;
576677 log:: debug!(
577678 "Available cache keys: {:?}" ,
578679 cache. keys( ) . collect:: <Vec <_>>( )
@@ -584,7 +685,7 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
584685 opts. insert ( "base_url" . into ( ) , serde_json:: Value :: String ( base_url) ) ;
585686 opts. insert ( "no_auth" . into ( ) , serde_json:: Value :: Bool ( cached. is_none ( ) ) ) ;
586687
587- ( cached. unwrap_or_default ( ) , opts)
688+ ( "openai" . to_string ( ) , cached. unwrap_or_default ( ) , opts)
588689 } else if provider == "gemini" || provider == "anthropic" {
589690 // Require API key from in-memory cache
590691 let cache = API_KEY_CACHE
@@ -600,7 +701,7 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
600701 "API key not found in cache" . to_string ( )
601702 } ) ?;
602703
603- ( api_key, std:: collections:: HashMap :: new ( ) )
704+ ( provider . clone ( ) , api_key, std:: collections:: HashMap :: new ( ) )
604705 } else {
605706 return Err ( "Unsupported provider" . to_string ( ) ) ;
606707 } ;
@@ -629,7 +730,7 @@ pub async fn enhance_transcription(text: String, app: tauri::AppHandle) -> Resul
629730
630731 // Create provider config
631732 let config = AIProviderConfig {
632- provider,
733+ provider : factory_provider ,
633734 model,
634735 api_key,
635736 enabled : true ,
@@ -688,12 +789,12 @@ pub async fn set_openai_config(
688789) -> Result < ( ) , String > {
689790 let store = app. store ( "settings" ) . map_err ( |e| e. to_string ( ) ) ?;
690791 store. set (
691- "ai_openai_base_url" ,
792+ CUSTOM_BASE_URL_KEY ,
692793 serde_json:: Value :: String ( args. base_url ) ,
693794 ) ;
694795 if let Some ( no_auth) = args. no_auth {
695796 // Backward-compatibility: accept but not required
696- store. set ( "ai_openai_no_auth" , serde_json:: Value :: Bool ( no_auth) ) ;
797+ store. set ( CUSTOM_NO_AUTH_KEY , serde_json:: Value :: Bool ( no_auth) ) ;
697798 }
698799 store
699800 . save ( )
@@ -705,12 +806,22 @@ pub async fn set_openai_config(
705806pub async fn get_openai_config ( app : tauri:: AppHandle ) -> Result < OpenAIConfig , String > {
706807 let store = app. store ( "settings" ) . map_err ( |e| e. to_string ( ) ) ?;
707808 let base_url = store
708- . get ( "ai_openai_base_url" )
809+ . get ( CUSTOM_BASE_URL_KEY )
709810 . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
710- . unwrap_or_else ( || "https://api.openai.com/v1" . to_string ( ) ) ;
811+ . or_else ( || {
812+ store
813+ . get ( LEGACY_OPENAI_BASE_URL_KEY )
814+ . and_then ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
815+ } )
816+ . unwrap_or_else ( || DEFAULT_OPENAI_BASE_URL . to_string ( ) ) ;
711817 let no_auth = store
712- . get ( "ai_openai_no_auth" )
818+ . get ( CUSTOM_NO_AUTH_KEY )
713819 . and_then ( |v| v. as_bool ( ) )
820+ . or_else ( || {
821+ store
822+ . get ( LEGACY_OPENAI_NO_AUTH_KEY )
823+ . and_then ( |v| v. as_bool ( ) )
824+ } )
714825 . unwrap_or ( false ) ;
715826 Ok ( OpenAIConfig { base_url, no_auth } )
716827}
@@ -777,11 +888,18 @@ pub async fn list_provider_models(
777888) -> Result < Vec < ProviderModel > , String > {
778889 // Validate provider
779890 if ![ "openai" , "anthropic" , "gemini" ] . contains ( & provider. as_str ( ) ) {
780- return Err ( format ! ( "Unsupported provider for model listing: {}" , provider) ) ;
891+ return Err ( format ! (
892+ "Unsupported provider for model listing: {}" ,
893+ provider
894+ ) ) ;
781895 }
782896
783897 let models = get_curated_models ( & provider) ;
784- log:: info!( "Returning {} curated models for provider {}" , models. len( ) , provider) ;
898+ log:: info!(
899+ "Returning {} curated models for provider {}" ,
900+ models. len( ) ,
901+ provider
902+ ) ;
785903
786904 Ok ( models)
787905}
@@ -796,10 +914,11 @@ mod tests {
796914 assert ! ( validate_provider_name( "gemini" ) . is_ok( ) ) ;
797915 assert ! ( validate_provider_name( "openai" ) . is_ok( ) ) ;
798916 assert ! ( validate_provider_name( "anthropic" ) . is_ok( ) ) ;
799-
917+ assert ! ( validate_provider_name( "custom" ) . is_ok( ) ) ;
918+
800919 // Groq is no longer supported
801920 assert ! ( validate_provider_name( "groq" ) . is_err( ) ) ;
802-
921+
803922 // Invalid formats
804923 assert ! ( validate_provider_name( "test-provider" ) . is_err( ) ) ;
805924 assert ! ( validate_provider_name( "test_provider" ) . is_err( ) ) ;
@@ -815,20 +934,28 @@ mod tests {
815934 assert_eq ! ( openai_models. len( ) , 2 ) ;
816935 assert ! ( openai_models. iter( ) . any( |m| m. id == "gpt-5-nano" ) ) ;
817936 assert ! ( openai_models. iter( ) . any( |m| m. id == "gpt-5-mini" ) ) ;
818-
937+
819938 // Anthropic models
820939 let anthropic_models = get_curated_models ( "anthropic" ) ;
821940 assert_eq ! ( anthropic_models. len( ) , 2 ) ;
822- assert ! ( anthropic_models. iter( ) . any( |m| m. id == "claude-haiku-4-5-latest" ) ) ;
823- assert ! ( anthropic_models. iter( ) . any( |m| m. id == "claude-sonnet-4-5-latest" ) ) ;
824-
941+ assert ! ( anthropic_models
942+ . iter( )
943+ . any( |m| m. id == "claude-haiku-4-5-latest" ) ) ;
944+ assert ! ( anthropic_models
945+ . iter( )
946+ . any( |m| m. id == "claude-sonnet-4-5-latest" ) ) ;
947+
825948 // Gemini models
826949 let gemini_models = get_curated_models ( "gemini" ) ;
827950 assert_eq ! ( gemini_models. len( ) , 3 ) ;
828- assert ! ( gemini_models. iter( ) . any( |m| m. id == "gemini-3-flash-preview" ) ) ;
951+ assert ! ( gemini_models
952+ . iter( )
953+ . any( |m| m. id == "gemini-3-flash-preview" ) ) ;
829954 assert ! ( gemini_models. iter( ) . any( |m| m. id == "gemini-2.5-flash" ) ) ;
830- assert ! ( gemini_models. iter( ) . any( |m| m. id == "gemini-2.5-flash-lite" ) ) ;
831-
955+ assert ! ( gemini_models
956+ . iter( )
957+ . any( |m| m. id == "gemini-2.5-flash-lite" ) ) ;
958+
832959 // Unknown provider returns empty list
833960 let unknown_models = get_curated_models ( "unknown" ) ;
834961 assert ! ( unknown_models. is_empty( ) ) ;
0 commit comments