Skip to content

Commit 97ef633

Browse files
author
chujian
committed
fix(registry): address review feedback
1 parent dae8463 commit 97ef633

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

internal/registry/model_registry.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
173173
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
174174
return cloneModelInfo(info)
175175
}
176-
return LookupStaticModelInfo(modelID)
176+
return cloneModelInfo(LookupStaticModelInfo(modelID))
177177
}
178178

179179
// SetHook sets an optional hook for observing model registration changes.
@@ -490,7 +490,6 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
490490
registration.LastUpdated = now
491491
if registration.QuotaExceededClients != nil {
492492
delete(registration.QuotaExceededClients, clientID)
493-
r.invalidateAvailableModelsCacheLocked()
494493
}
495494
if registration.SuspendedClients != nil {
496495
delete(registration.SuspendedClients, clientID)
@@ -842,13 +841,34 @@ func cloneModelMaps(models []map[string]any) []map[string]any {
842841
}
843842
copyModel := make(map[string]any, len(model))
844843
for key, value := range model {
845-
copyModel[key] = value
844+
copyModel[key] = cloneModelMapValue(value)
846845
}
847846
cloned = append(cloned, copyModel)
848847
}
849848
return cloned
850849
}
851850

851+
func cloneModelMapValue(value any) any {
852+
switch typed := value.(type) {
853+
case map[string]any:
854+
copyMap := make(map[string]any, len(typed))
855+
for key, entry := range typed {
856+
copyMap[key] = cloneModelMapValue(entry)
857+
}
858+
return copyMap
859+
case []any:
860+
copySlice := make([]any, len(typed))
861+
for i, entry := range typed {
862+
copySlice[i] = cloneModelMapValue(entry)
863+
}
864+
return copySlice
865+
case []string:
866+
return append([]string(nil), typed...)
867+
default:
868+
return value
869+
}
870+
}
871+
852872
// GetAvailableModelsByProvider returns models available for the given provider identifier.
853873
// Parameters:
854874
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
@@ -1298,10 +1318,3 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
12981318
}
12991319
return result
13001320
}
1301-
1302-
1303-
1304-
1305-
1306-
1307-

internal/registry/model_registry_safety_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,41 @@ func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
109109
t.Fatalf("expected model id m1, got %v", got)
110110
}
111111
}
112+
113+
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
114+
r := newTestModelRegistry()
115+
r.RegisterClient("client-1", "openai", []*ModelInfo{{
116+
ID: "m1",
117+
DisplayName: "Model One",
118+
SupportedParameters: []string{"temperature", "top_p"},
119+
}})
120+
121+
first := r.GetAvailableModels("openai")
122+
if len(first) != 1 {
123+
t.Fatalf("expected one model, got %d", len(first))
124+
}
125+
params, ok := first[0]["supported_parameters"].([]string)
126+
if !ok || len(params) != 2 {
127+
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
128+
}
129+
params[0] = "mutated"
130+
131+
second := r.GetAvailableModels("openai")
132+
params, ok = second[0]["supported_parameters"].([]string)
133+
if !ok || len(params) != 2 || params[0] != "temperature" {
134+
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
135+
}
136+
}
137+
138+
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
139+
first := LookupModelInfo("glm-4.6")
140+
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
141+
t.Fatalf("expected static model with thinking levels, got %+v", first)
142+
}
143+
first.Thinking.Levels[0] = "mutated"
144+
145+
second := LookupModelInfo("glm-4.6")
146+
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
147+
t.Fatalf("expected static lookup clone, got %+v", second)
148+
}
149+
}

0 commit comments

Comments
 (0)