Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 84 additions & 9 deletions src/main/presenter/configPresenter/modelConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,52 @@ const SPECIAL_CONCAT_CHAR = '-_-'

export class ModelConfigHelper {
private modelConfigStore: ElectronStore<Record<string, IModelConfig>>
private memoryCache: Map<string, IModelConfig> = new Map()
private cacheInitialized: boolean = false

constructor() {
this.modelConfigStore = new ElectronStore<Record<string, IModelConfig>>({
name: 'model-config'
})
}

/**
* Initialize memory cache by loading all data from store
* This is called lazily on first access
*/
private initializeCache(): void {
if (this.cacheInitialized) return

const allConfigs = this.modelConfigStore.store
Object.entries(allConfigs).forEach(([key, value]) => {
this.memoryCache.set(key, value)
})
this.cacheInitialized = true
}

/**
* Get model configuration with priority: user config > provider config > default config
* @param modelId - The model ID
* @param providerId - Optional provider ID
* @returns ModelConfig
*/
getModelConfig(modelId: string, providerId?: string): ModelConfig {
// Initialize cache if not already done
this.initializeCache()

// 1. First try to get user-defined config for this specific provider + model
if (providerId) {
const userConfig = this.modelConfigStore.get(providerId + SPECIAL_CONCAT_CHAR + modelId)
const cacheKey = providerId + SPECIAL_CONCAT_CHAR + modelId
let userConfig = this.memoryCache.get(cacheKey)

// If not in cache, try to load from store and cache it
if (!userConfig) {
userConfig = this.modelConfigStore.get(cacheKey)
if (userConfig) {
this.memoryCache.set(cacheKey, userConfig)
}
}

if (userConfig?.config) {
return userConfig.config
}
Expand Down Expand Up @@ -73,11 +102,16 @@ export class ModelConfigHelper {
* @param config - The model configuration
*/
setModelConfig(modelId: string, providerId: string, config: ModelConfig): void {
this.modelConfigStore.set(providerId + SPECIAL_CONCAT_CHAR + modelId, {
const cacheKey = providerId + SPECIAL_CONCAT_CHAR + modelId
const configData: IModelConfig = {
id: modelId,
providerId: providerId,
config: config
})
}

// Update both store and cache
this.modelConfigStore.set(cacheKey, configData)
this.memoryCache.set(cacheKey, configData)
}

/**
Expand All @@ -86,15 +120,27 @@ export class ModelConfigHelper {
* @param providerId - The provider ID
*/
resetModelConfig(modelId: string, providerId: string): void {
this.modelConfigStore.delete(providerId + SPECIAL_CONCAT_CHAR + modelId)
const cacheKey = providerId + SPECIAL_CONCAT_CHAR + modelId

// Remove from both store and cache
this.modelConfigStore.delete(cacheKey)
this.memoryCache.delete(cacheKey)
}

/**
* Get all user-defined model configurations
* @returns Record of all configurations
*/
getAllModelConfigs(): Record<string, IModelConfig> {
return this.modelConfigStore.store
// Initialize cache if not already done
this.initializeCache()

// Return data from cache for better performance
const result: Record<string, IModelConfig> = {}
this.memoryCache.forEach((value, key) => {
result[key] = value
})
return result
}

/**
Expand Down Expand Up @@ -126,8 +172,24 @@ export class ModelConfigHelper {
* @returns boolean
*/
hasUserConfig(modelId: string, providerId: string): boolean {
const userConfig = this.modelConfigStore.get(providerId + SPECIAL_CONCAT_CHAR + modelId)
return !!userConfig
// Initialize cache if not already done
this.initializeCache()

const cacheKey = providerId + SPECIAL_CONCAT_CHAR + modelId

// Check cache first
if (this.memoryCache.has(cacheKey)) {
return true
}

// If not in cache, check store and update cache if found
const userConfig = this.modelConfigStore.get(cacheKey)
if (userConfig) {
this.memoryCache.set(cacheKey, userConfig)
return true
}

return false
}

/**
Expand All @@ -137,16 +199,20 @@ export class ModelConfigHelper {
*/
importConfigs(configs: Record<string, IModelConfig>, overwrite: boolean = false): void {
if (overwrite) {
// Clear existing configs
// Clear existing configs from both store and cache
this.modelConfigStore.clear()
this.memoryCache.clear()
}

// Import configs
// Import configs to both store and cache
Object.entries(configs).forEach(([key, value]) => {
if (overwrite || !this.modelConfigStore.has(key)) {
this.modelConfigStore.set(key, value)
this.memoryCache.set(key, value)
}
})

this.cacheInitialized = true
}

/**
Expand All @@ -162,6 +228,7 @@ export class ModelConfigHelper {
*/
clearAllConfigs(): void {
this.modelConfigStore.clear()
this.memoryCache.clear()
}

/**
Expand All @@ -171,4 +238,12 @@ export class ModelConfigHelper {
getStorePath(): string {
return this.modelConfigStore.path
}

/**
* Clear memory cache (useful for testing or memory management)
*/
clearMemoryCache(): void {
this.memoryCache.clear()
this.cacheInitialized = false
}
}
27 changes: 13 additions & 14 deletions src/main/presenter/githubCopilotDeviceFlow.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import { BrowserWindow, shell } from 'electron'
import { exec } from 'child_process';
import { exec } from 'child_process'
import { presenter } from '@/presenter'


const GITHUB_DEVICE_URL = 'https://github.com/login/device';
const GITHUB_DEVICE_URL = 'https://github.com/login/device'

export interface DeviceFlowConfig {
clientId: string
Expand Down Expand Up @@ -303,18 +302,18 @@ export class GitHubCopilotDeviceFlow {
setTimeout(async () => {
try {
// 使用固定的GitHub设备激活页面
const url = GITHUB_DEVICE_URL;
console.log('Attempting to open URL:', url);
const url = GITHUB_DEVICE_URL
console.log('Attempting to open URL:', url)

if (process.platform === 'win32') {
// 先尝试使用explorer命令
exec(`explorer "${url}"`, (error) => {
if (error) {
console.error('Explorer command failed:', error);
console.error('Explorer command failed:', error)
// 如果explorer失败,尝试使用start命令
exec(`start "" "${url}"`, (startError) => {
if (startError) {
console.error('Start command failed:', startError);
console.error('Start command failed:', startError)
// 使用更安全的方式处理剪贴板操作
instructionWindow.webContents.executeJavaScript(`
const shouldCopy = confirm('无法自动打开浏览器。是否复制链接到剪贴板?');
Expand All @@ -325,20 +324,20 @@ export class GitHubCopilotDeviceFlow {
} else {
alert('请手动访问: ${url}');
}
`);
`)
}
});
})
}
});
})
} else {
// 非Windows系统使用默认的shell.openExternal
await shell.openExternal(url);
await shell.openExternal(url)
}
} catch (error) {
console.error('Failed to open browser:', error);
console.error('Failed to open browser:', error)
instructionWindow.webContents.executeJavaScript(`
alert('无法自动打开浏览器,请手动访问: ${GITHUB_DEVICE_URL}');
`);
`)
}
}, 1000)

Expand Down
64 changes: 61 additions & 3 deletions src/main/presenter/llmProviderPresenter/baseProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import {
import { ConfigPresenter } from '../configPresenter'
import { DevicePresenter } from '../devicePresenter'
import { jsonrepair } from 'jsonrepair'
import { eventBus, SendTarget } from '@/eventbus'
import { CONFIG_EVENTS } from '@/events'

/**
* 基础LLM提供商抽象类
Expand Down Expand Up @@ -43,6 +45,9 @@ export abstract class BaseLLMProvider {
this.provider = provider
this.configPresenter = configPresenter
this.defaultHeaders = DevicePresenter.getDefaultHeaders()

// Initialize models and customModels from cached config data
this.loadCachedModels()
}

/**
Expand All @@ -53,6 +58,37 @@ export abstract class BaseLLMProvider {
return BaseLLMProvider.MAX_TOOL_CALLS
}

/**
* 从配置中加载缓存的模型数据
* 在构造函数中调用,避免每次都需要重新获取模型列表
*/
private loadCachedModels(): void {
try {
// Load cached provider models from config
const cachedModels = this.configPresenter.getProviderModels(this.provider.id)
if (cachedModels && cachedModels.length > 0) {
this.models = cachedModels
console.info(
`Loaded ${cachedModels.length} cached models for provider: ${this.provider.name}`
)
}

// Load cached custom models from config
const cachedCustomModels = this.configPresenter.getCustomModels(this.provider.id)
if (cachedCustomModels && cachedCustomModels.length > 0) {
this.customModels = cachedCustomModels
console.info(
`Loaded ${cachedCustomModels.length} cached custom models for provider: ${this.provider.name}`
)
}
} catch (error) {
console.warn(`Failed to load cached models for provider: ${this.provider.name}`, error)
// Keep default empty arrays if loading fails
this.models = []
this.customModels = []
}
}

/**
* 初始化提供商
* 包括获取模型列表、配置代理等
Expand All @@ -79,9 +115,8 @@ export abstract class BaseLLMProvider {
if (!this.models || this.models.length === 0) return
const providerId = this.provider.id

// 检查是否有自定义模型
const customModels = this.configPresenter.getCustomModels(providerId)
if (customModels && customModels.length > 0) return
// 检查是否有自定义模型 (use cached customModels)
if (this.customModels && this.customModels.length > 0) return

// 检查是否有任何模型的状态被手动修改过
const hasManuallyModifiedModels = this.models.some((model) =>
Expand Down Expand Up @@ -124,6 +159,22 @@ export abstract class BaseLLMProvider {
}
}

/**
* 强制刷新模型数据
* 忽略缓存,重新从网络获取最新的模型列表
* @returns 模型列表
*/
public async refreshModels(): Promise<void> {
console.info(`Force refreshing models for provider: ${this.provider.name}`)
await this.fetchModels()
await this.autoEnableModelsIfNeeded()
eventBus.sendToRenderer(
CONFIG_EVENTS.MODEL_LIST_CHANGED,
SendTarget.ALL_WINDOWS,
this.provider.id
)
}

/**
* 获取特定提供商的模型
* 此方法由具体的提供商子类实现
Expand Down Expand Up @@ -160,6 +211,9 @@ export abstract class BaseLLMProvider {
this.customModels.push(newModel)
}

// Sync with config
this.configPresenter.addCustomModel(this.provider.id, newModel)

return newModel
}

Expand All @@ -172,6 +226,8 @@ export abstract class BaseLLMProvider {
const index = this.customModels.findIndex((model) => model.id === modelId)
if (index !== -1) {
this.customModels.splice(index, 1)
// Sync with config
this.configPresenter.removeCustomModel(this.provider.id, modelId)
return true
}
return false
Expand All @@ -188,6 +244,8 @@ export abstract class BaseLLMProvider {
if (model) {
// 应用更新
Object.assign(model, updates)
// Sync with config
this.configPresenter.updateCustomModel(this.provider.id, modelId, updates)
return true
}
return false
Expand Down
11 changes: 11 additions & 0 deletions src/main/presenter/llmProviderPresenter/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,17 @@ export class LLMProviderPresenter implements ILlmProviderPresenter {
return provider.getKeyStatus()
}

async refreshModels(providerId: string): Promise<void> {
try {
const provider = this.getProviderInstance(providerId)
await provider.refreshModels()
} catch (error) {
console.error(`Failed to refresh models for provider ${providerId}:`, error)
const errorMessage = error instanceof Error ? error.message : String(error)
throw new Error(`Model refresh failed: ${errorMessage}`)
}
}

async addCustomModel(
providerId: string,
model: Omit<MODEL_META, 'providerId' | 'isCustom' | 'group'>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import {
} from '@shared/presenter'
import { ConfigPresenter } from '../../configPresenter'
import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider'
import { eventBus, SendTarget } from '@/eventbus'
import { CONFIG_EVENTS } from '@/events'

// Mapping from simple keys to API HarmCategory constants
const keyToHarmCategoryMap: Record<string, HarmCategory> = {
Expand Down Expand Up @@ -295,6 +297,12 @@ export class GeminiProvider extends BaseLLMProvider {
// 使用API获取模型列表,如果失败则回退到静态列表
this.models = await this.fetchProviderModels()
await this.autoEnableModelsIfNeeded()
// gemini 比较慢,特殊补偿一下
eventBus.sendToRenderer(
CONFIG_EVENTS.MODEL_LIST_CHANGED,
SendTarget.ALL_WINDOWS,
this.provider.id
)
console.info('Provider initialized successfully:', this.provider.name)
} catch (error) {
console.warn('Provider initialization failed:', this.provider.name, error)
Expand Down
Loading