Skip to content

Commit a28030b

Browse files
committed
feat(ai): add OpenAI-compatible provider and endpoint normalization; fix pill feedback timing; cleanups
- Always call baseURL + /v1/chat/completions for OpenAI-compatible - Add OpenAI provider, config modal, and backend validation incl. no-auth - Show short "Formatting failed" on pill for ~2s on enhancement errors - Remove unused validate_api_key; drop unnecessary mut; adjust gating logic - Wire keyring helpers and UI to new flow
1 parent 937bbb3 commit a28030b

File tree

11 files changed

+865
-163
lines changed

11 files changed

+865
-163
lines changed

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,6 @@
7070
"typescript": "~5.6.2",
7171
"vite": "^6.3.6",
7272
"vitest": "^3.2.4"
73-
}
73+
},
74+
"packageManager": "pnpm@10.13.1+sha512.37ebf1a5c7a30d5fabe0c5df44ee8da4c965ca0c5af3dbab28c3a1681b70a256218d05c81c9c0dcf767ef6b8551eb5b960042b9ed4300c59242336377e01cfad"
7475
}

src-tauri/src/ai/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::collections::HashMap;
55
pub mod config;
66
pub mod gemini;
77
pub mod groq;
8+
pub mod openai;
89
pub mod prompts;
910

1011
pub use config::{MAX_CUSTOM_VOCABULARY, MAX_TEXT_LENGTH, MAX_VOCABULARY_TERM_LENGTH};
@@ -115,11 +116,16 @@ impl AIProviderFactory {
115116
config.model.clone(),
116117
config.options.clone(),
117118
)?)),
119+
"openai" => Ok(Box::new(openai::OpenAIProvider::new(
120+
config.api_key.clone(),
121+
config.model.clone(),
122+
config.options.clone(),
123+
)?)),
118124
provider => Err(AIError::ProviderNotFound(provider.to_string())),
119125
}
120126
}
121127

122128
fn is_valid_provider(provider: &str) -> bool {
123-
matches!(provider, "groq" | "gemini")
129+
matches!(provider, "groq" | "gemini" | "openai")
124130
}
125131
}

src-tauri/src/ai/openai.rs

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
use super::config::*;
2+
use super::{prompts, AIEnhancementRequest, AIEnhancementResponse, AIError, AIProvider};
3+
use async_trait::async_trait;
4+
use reqwest::Client;
5+
use serde::{Deserialize, Serialize};
6+
use std::collections::HashMap;
7+
use std::time::Duration;
8+
9+
// Accept common OpenAI chat models. Keep minimal allowlist but include the requested one.
10+
const SUPPORTED_MODELS: &[&str] = &["gpt-5-nano", "gpt-5", "gpt-4.1-2025-04-14"];
11+
12+
pub struct OpenAIProvider {
13+
#[allow(dead_code)]
14+
api_key: String,
15+
model: String,
16+
client: Client,
17+
base_url: String,
18+
options: HashMap<String, serde_json::Value>,
19+
}
20+
21+
impl OpenAIProvider {
22+
pub fn new(
23+
api_key: String,
24+
model: String,
25+
mut options: HashMap<String, serde_json::Value>,
26+
) -> Result<Self, AIError> {
27+
// Validate model (allow if in list; otherwise accept for forward compatibility)
28+
if !SUPPORTED_MODELS.contains(&model.as_str()) {
29+
// Don’t hard fail; just log a warning and continue
30+
log::warn!("OpenAI model not in local allowlist: {}", model);
31+
}
32+
33+
// Determine if auth is required
34+
let no_auth = options
35+
.get("no_auth")
36+
.and_then(|v| v.as_bool())
37+
.unwrap_or(false);
38+
39+
// Validate API key format (basic check) only if auth is required
40+
if !no_auth {
41+
if api_key.trim().is_empty() || api_key.len() < MIN_API_KEY_LENGTH {
42+
return Err(AIError::ValidationError(
43+
"Invalid API key format".to_string(),
44+
));
45+
}
46+
}
47+
48+
let client = Client::builder()
49+
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
50+
.build()
51+
.map_err(|e| AIError::NetworkError(format!("Failed to create HTTP client: {}", e)))?;
52+
53+
// Resolve base URL: always map base to /v1/chat/completions
54+
let base_root = options
55+
.get("base_url")
56+
.and_then(|v| v.as_str())
57+
.unwrap_or("https://api.openai.com");
58+
let base_trim = base_root.trim_end_matches('/');
59+
let base_url = format!("{}/v1/chat/completions", base_trim);
60+
61+
// Ensure the normalized values are kept in options for downstream if needed
62+
options.insert(
63+
"base_url".into(),
64+
serde_json::Value::String(base_root.to_string()),
65+
);
66+
67+
Ok(Self { api_key, model, client, base_url, options })
68+
}
69+
70+
async fn make_request_with_retry(
71+
&self,
72+
request: &OpenAIRequest,
73+
) -> Result<OpenAIResponse, AIError> {
74+
let mut last_error = None;
75+
76+
for attempt in 1..=MAX_RETRIES {
77+
match self.make_single_request(request).await {
78+
Ok(response) => return Ok(response),
79+
Err(e) => {
80+
log::warn!("API request attempt {} failed: {}", attempt, e);
81+
last_error = Some(e);
82+
83+
if attempt < MAX_RETRIES {
84+
tokio::time::sleep(Duration::from_millis(
85+
RETRY_BASE_DELAY_MS * attempt as u64,
86+
))
87+
.await;
88+
}
89+
}
90+
}
91+
}
92+
93+
Err(last_error.unwrap_or_else(|| AIError::NetworkError("Unknown error".to_string())))
94+
}
95+
96+
async fn make_single_request(&self, request: &OpenAIRequest) -> Result<OpenAIResponse, AIError> {
97+
// Determine if auth header should be sent
98+
let no_auth = self
99+
.options
100+
.get("no_auth")
101+
.and_then(|v| v.as_bool())
102+
.unwrap_or(false);
103+
104+
let mut req = self
105+
.client
106+
.post(&self.base_url)
107+
.header("Content-Type", "application/json")
108+
.json(request);
109+
110+
if !no_auth {
111+
req = req.header("Authorization", format!("Bearer {}", self.api_key));
112+
}
113+
114+
let response = req
115+
.send()
116+
.await
117+
.map_err(|e| AIError::NetworkError(e.to_string()))?;
118+
119+
let status = response.status();
120+
121+
if status.as_u16() == 429 {
122+
return Err(AIError::RateLimitExceeded);
123+
}
124+
125+
if !status.is_success() {
126+
let error_text = response
127+
.text()
128+
.await
129+
.unwrap_or_else(|_| "Unknown error".to_string());
130+
return Err(AIError::ApiError(format!(
131+
"API returned {}: {}",
132+
status, error_text
133+
)));
134+
}
135+
136+
response
137+
.json()
138+
.await
139+
.map_err(|e| AIError::InvalidResponse(e.to_string()))
140+
}
141+
}
142+
143+
#[derive(Serialize)]
144+
struct OpenAIRequest {
145+
model: String,
146+
messages: Vec<Message>,
147+
#[serde(skip_serializing_if = "Option::is_none")]
148+
temperature: Option<f32>,
149+
#[serde(skip_serializing_if = "Option::is_none")]
150+
max_tokens: Option<u32>,
151+
}
152+
153+
#[derive(Serialize, Deserialize)]
154+
struct Message {
155+
role: String,
156+
content: String,
157+
}
158+
159+
#[derive(Deserialize)]
160+
struct OpenAIResponse {
161+
choices: Vec<Choice>,
162+
}
163+
164+
#[derive(Deserialize)]
165+
struct Choice {
166+
message: Message,
167+
}
168+
169+
#[async_trait]
170+
impl AIProvider for OpenAIProvider {
171+
async fn enhance_text(
172+
&self,
173+
request: AIEnhancementRequest,
174+
) -> Result<AIEnhancementResponse, AIError> {
175+
request.validate()?;
176+
177+
let prompt = prompts::build_enhancement_prompt(
178+
&request.text,
179+
request.context.as_deref(),
180+
&request.options.unwrap_or_default(),
181+
);
182+
183+
let temperature = self
184+
.options
185+
.get("temperature")
186+
.and_then(|v| v.as_f64())
187+
.map(|v| v as f32)
188+
.unwrap_or(DEFAULT_TEMPERATURE);
189+
190+
let max_tokens = self
191+
.options
192+
.get("max_tokens")
193+
.and_then(|v| v.as_u64())
194+
.map(|v| v as u32);
195+
196+
let request_body = OpenAIRequest {
197+
model: self.model.clone(),
198+
messages: vec![
199+
Message {
200+
role: "system".to_string(),
201+
content: "You are a careful text formatter that only returns the cleaned text per the provided rules.".to_string(),
202+
},
203+
Message {
204+
role: "user".to_string(),
205+
content: prompt,
206+
},
207+
],
208+
temperature: Some(temperature.clamp(0.0, 2.0)),
209+
max_tokens,
210+
};
211+
212+
let api_response = self.make_request_with_retry(&request_body).await?;
213+
214+
let enhanced_text = api_response
215+
.choices
216+
.first()
217+
.ok_or_else(|| AIError::InvalidResponse("No choices in response".to_string()))?
218+
.message
219+
.content
220+
.trim()
221+
.to_string();
222+
223+
if enhanced_text.is_empty() {
224+
return Err(AIError::InvalidResponse("Empty response from API".to_string()));
225+
}
226+
227+
Ok(AIEnhancementResponse {
228+
enhanced_text,
229+
original_text: request.text,
230+
provider: self.name().to_string(),
231+
model: self.model.clone(),
232+
})
233+
}
234+
235+
fn name(&self) -> &str {
236+
"openai"
237+
}
238+
}
239+
240+
#[cfg(test)]
241+
mod tests {
242+
use super::*;
243+
244+
#[test]
245+
fn test_provider_creation() {
246+
let result = OpenAIProvider::new("".to_string(), "gpt-5-nano".to_string(), HashMap::new());
247+
assert!(result.is_err());
248+
249+
let result = OpenAIProvider::new(
250+
"test_key_12345".to_string(),
251+
"gpt-unknown".to_string(),
252+
HashMap::new(),
253+
);
254+
// Unknown model should be allowed with a warning (not hard error)
255+
assert!(result.is_ok());
256+
257+
let result = OpenAIProvider::new(
258+
"test_key_12345".to_string(),
259+
"gpt-5-nano".to_string(),
260+
HashMap::new(),
261+
);
262+
assert!(result.is_ok());
263+
}
264+
}

0 commit comments

Comments
 (0)