Skip to content

Commit dae8463

Browse files
author
chujian
committed
fix(registry): clone model snapshots and invalidate available-model cache
1 parent 5ebc58f commit dae8463

File tree

3 files changed

+288
-23
lines changed

3 files changed

+288
-23
lines changed

internal/registry/model_registry.go

Lines changed: 123 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ type ModelInfo struct {
6262
UserDefined bool `json:"-"`
6363
}
6464

65+
type availableModelsCacheEntry struct {
66+
models []map[string]any
67+
expiresAt time.Time
68+
}
69+
6570
// ThinkingSupport describes a model family's supported internal reasoning budget range.
6671
// Values are interpreted in provider-native token units.
6772
type ThinkingSupport struct {
@@ -116,6 +121,8 @@ type ModelRegistry struct {
116121
clientProviders map[string]string
117122
// mutex ensures thread-safe access to the registry
118123
mutex *sync.RWMutex
124+
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
125+
availableModelsCache map[string]availableModelsCacheEntry
119126
// hook is an optional callback sink for model registration changes
120127
hook ModelRegistryHook
121128
}
@@ -128,15 +135,28 @@ var registryOnce sync.Once
128135
func GetGlobalRegistry() *ModelRegistry {
129136
registryOnce.Do(func() {
130137
globalRegistry = &ModelRegistry{
131-
models: make(map[string]*ModelRegistration),
132-
clientModels: make(map[string][]string),
133-
clientModelInfos: make(map[string]map[string]*ModelInfo),
134-
clientProviders: make(map[string]string),
135-
mutex: &sync.RWMutex{},
138+
models: make(map[string]*ModelRegistration),
139+
clientModels: make(map[string][]string),
140+
clientModelInfos: make(map[string]map[string]*ModelInfo),
141+
clientProviders: make(map[string]string),
142+
availableModelsCache: make(map[string]availableModelsCacheEntry),
143+
mutex: &sync.RWMutex{},
136144
}
137145
})
138146
return globalRegistry
139147
}
148+
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
149+
if r.availableModelsCache == nil {
150+
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
151+
}
152+
}
153+
154+
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
155+
if len(r.availableModelsCache) == 0 {
156+
return
157+
}
158+
clear(r.availableModelsCache)
159+
}
140160

141161
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
142162
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
@@ -151,7 +171,7 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
151171
}
152172

153173
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
154-
return info
174+
return cloneModelInfo(info)
155175
}
156176
return LookupStaticModelInfo(modelID)
157177
}
@@ -211,6 +231,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
211231
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
212232
r.mutex.Lock()
213233
defer r.mutex.Unlock()
234+
r.ensureAvailableModelsCacheLocked()
214235

215236
provider := strings.ToLower(clientProvider)
216237
uniqueModelIDs := make([]string, 0, len(models))
@@ -236,6 +257,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
236257
delete(r.clientModels, clientID)
237258
delete(r.clientModelInfos, clientID)
238259
delete(r.clientProviders, clientID)
260+
r.invalidateAvailableModelsCacheLocked()
239261
misc.LogCredentialSeparator()
240262
return
241263
}
@@ -263,6 +285,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
263285
} else {
264286
delete(r.clientProviders, clientID)
265287
}
288+
r.invalidateAvailableModelsCacheLocked()
266289
r.triggerModelsRegistered(provider, clientID, models)
267290
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
268291
misc.LogCredentialSeparator()
@@ -406,6 +429,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
406429
delete(r.clientProviders, clientID)
407430
}
408431

432+
r.invalidateAvailableModelsCacheLocked()
409433
r.triggerModelsRegistered(provider, clientID, models)
410434
if len(added) == 0 && len(removed) == 0 && !providerChanged {
411435
// Only metadata (e.g., display name) changed; skip separator when no log output.
@@ -466,6 +490,7 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
466490
registration.LastUpdated = now
467491
if registration.QuotaExceededClients != nil {
468492
delete(registration.QuotaExceededClients, clientID)
493+
r.invalidateAvailableModelsCacheLocked()
469494
}
470495
if registration.SuspendedClients != nil {
471496
delete(registration.SuspendedClients, clientID)
@@ -509,6 +534,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
509534
if len(model.SupportedOutputModalities) > 0 {
510535
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
511536
}
537+
if model.Thinking != nil {
538+
copyThinking := *model.Thinking
539+
if len(model.Thinking.Levels) > 0 {
540+
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
541+
}
542+
copyModel.Thinking = &copyThinking
543+
}
512544
return &copyModel
513545
}
514546

@@ -538,6 +570,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
538570
r.mutex.Lock()
539571
defer r.mutex.Unlock()
540572
r.unregisterClientInternal(clientID)
573+
r.invalidateAvailableModelsCacheLocked()
541574
}
542575

543576
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
@@ -604,9 +637,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
604637
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
605638
r.mutex.Lock()
606639
defer r.mutex.Unlock()
640+
r.ensureAvailableModelsCacheLocked()
607641

608642
if registration, exists := r.models[modelID]; exists {
609-
registration.QuotaExceededClients[clientID] = new(time.Now())
643+
now := time.Now()
644+
registration.QuotaExceededClients[clientID] = &now
645+
r.invalidateAvailableModelsCacheLocked()
610646
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
611647
}
612648
}
@@ -618,9 +654,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
618654
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
619655
r.mutex.Lock()
620656
defer r.mutex.Unlock()
657+
r.ensureAvailableModelsCacheLocked()
621658

622659
if registration, exists := r.models[modelID]; exists {
623660
delete(registration.QuotaExceededClients, clientID)
661+
r.invalidateAvailableModelsCacheLocked()
624662
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
625663
}
626664
}
@@ -636,6 +674,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
636674
}
637675
r.mutex.Lock()
638676
defer r.mutex.Unlock()
677+
r.ensureAvailableModelsCacheLocked()
639678

640679
registration, exists := r.models[modelID]
641680
if !exists || registration == nil {
@@ -649,6 +688,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
649688
}
650689
registration.SuspendedClients[clientID] = reason
651690
registration.LastUpdated = time.Now()
691+
r.invalidateAvailableModelsCacheLocked()
652692
if reason != "" {
653693
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
654694
} else {
@@ -666,6 +706,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
666706
}
667707
r.mutex.Lock()
668708
defer r.mutex.Unlock()
709+
r.ensureAvailableModelsCacheLocked()
669710

670711
registration, exists := r.models[modelID]
671712
if !exists || registration == nil || registration.SuspendedClients == nil {
@@ -676,6 +717,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
676717
}
677718
delete(registration.SuspendedClients, clientID)
678719
registration.LastUpdated = time.Now()
720+
r.invalidateAvailableModelsCacheLocked()
679721
log.Debugf("Resumed client %s for model %s", clientID, modelID)
680722
}
681723

@@ -711,22 +753,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
711753
// Returns:
712754
// - []map[string]any: List of available models in the requested format
713755
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
756+
now := time.Now()
757+
714758
r.mutex.RLock()
715-
defer r.mutex.RUnlock()
759+
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
760+
models := cloneModelMaps(cache.models)
761+
r.mutex.RUnlock()
762+
return models
763+
}
764+
r.mutex.RUnlock()
765+
766+
r.mutex.Lock()
767+
defer r.mutex.Unlock()
768+
r.ensureAvailableModelsCacheLocked()
716769

717-
models := make([]map[string]any, 0)
770+
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
771+
return cloneModelMaps(cache.models)
772+
}
773+
774+
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
775+
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
776+
models: cloneModelMaps(models),
777+
expiresAt: expiresAt,
778+
}
779+
780+
return models
781+
}
782+
783+
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
784+
models := make([]map[string]any, 0, len(r.models))
718785
quotaExpiredDuration := 5 * time.Minute
786+
var expiresAt time.Time
719787

720788
for _, registration := range r.models {
721-
// Check if model has any non-quota-exceeded clients
722789
availableClients := registration.Count
723-
now := time.Now()
724790

725-
// Count clients that have exceeded quota but haven't recovered yet
726791
expiredClients := 0
727792
for _, quotaTime := range registration.QuotaExceededClients {
728-
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
793+
if quotaTime == nil {
794+
continue
795+
}
796+
recoveryAt := quotaTime.Add(quotaExpiredDuration)
797+
if now.Before(recoveryAt) {
729798
expiredClients++
799+
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
800+
expiresAt = recoveryAt
801+
}
730802
}
731803
}
732804

@@ -747,7 +819,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
747819
effectiveClients = 0
748820
}
749821

750-
// Include models that have available clients, or those solely cooling down.
751822
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
752823
model := r.convertModelToMap(registration.Info, handlerType)
753824
if model != nil {
@@ -756,7 +827,26 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
756827
}
757828
}
758829

759-
return models
830+
return models, expiresAt
831+
}
832+
833+
func cloneModelMaps(models []map[string]any) []map[string]any {
834+
if len(models) == 0 {
835+
return nil
836+
}
837+
cloned := make([]map[string]any, 0, len(models))
838+
for _, model := range models {
839+
if model == nil {
840+
cloned = append(cloned, nil)
841+
continue
842+
}
843+
copyModel := make(map[string]any, len(model))
844+
for key, value := range model {
845+
copyModel[key] = value
846+
}
847+
cloned = append(cloned, copyModel)
848+
}
849+
return cloned
760850
}
761851

762852
// GetAvailableModelsByProvider returns models available for the given provider identifier.
@@ -872,11 +962,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
872962

873963
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
874964
if entry.info != nil {
875-
result = append(result, entry.info)
965+
result = append(result, cloneModelInfo(entry.info))
876966
continue
877967
}
878968
if ok && registration != nil && registration.Info != nil {
879-
result = append(result, registration.Info)
969+
result = append(result, cloneModelInfo(registration.Info))
880970
}
881971
}
882972
}
@@ -985,13 +1075,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
9851075
if reg.Providers != nil {
9861076
if count, ok := reg.Providers[provider]; ok && count > 0 {
9871077
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
988-
return info
1078+
return cloneModelInfo(info)
9891079
}
9901080
}
9911081
}
9921082
}
9931083
// Fallback to global info (last registered)
994-
return reg.Info
1084+
return cloneModelInfo(reg.Info)
9951085
}
9961086
return nil
9971087
}
@@ -1111,15 +1201,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
11111201

11121202
now := time.Now()
11131203
quotaExpiredDuration := 5 * time.Minute
1204+
invalidated := false
11141205

11151206
for modelID, registration := range r.models {
11161207
for clientID, quotaTime := range registration.QuotaExceededClients {
11171208
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
11181209
delete(registration.QuotaExceededClients, clientID)
1210+
invalidated = true
11191211
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
11201212
}
11211213
}
11221214
}
1215+
if invalidated {
1216+
r.invalidateAvailableModelsCacheLocked()
1217+
}
11231218
}
11241219

11251220
// GetFirstAvailableModel returns the first available model for the given handler type.
@@ -1133,8 +1228,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
11331228
// - string: The model ID of the first available model, or empty string if none available
11341229
// - error: An error if no models are available
11351230
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
1136-
r.mutex.RLock()
1137-
defer r.mutex.RUnlock()
11381231

11391232
// Get all available models for this handler type
11401233
models := r.GetAvailableModels(handlerType)
@@ -1194,14 +1287,21 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
11941287
// Prefer client's own model info to preserve original type/owned_by
11951288
if clientInfos != nil {
11961289
if info, ok := clientInfos[modelID]; ok && info != nil {
1197-
result = append(result, info)
1290+
result = append(result, cloneModelInfo(info))
11981291
continue
11991292
}
12001293
}
12011294
// Fallback to global registry (for backwards compatibility)
12021295
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
1203-
result = append(result, reg.Info)
1296+
result = append(result, cloneModelInfo(reg.Info))
12041297
}
12051298
}
12061299
return result
12071300
}
1301+
1302+
1303+
1304+
1305+
1306+
1307+

0 commit comments

Comments
 (0)