diff --git a/cmd/root/run_event_hooks.go b/cmd/root/run_event_hooks.go index 07df4be88..30b635efe 100644 --- a/cmd/root/run_event_hooks.go +++ b/cmd/root/run_event_hooks.go @@ -12,6 +12,7 @@ import ( tea "charm.land/bubbletea/v2" "github.com/docker/docker-agent/pkg/app" + "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/shellpath" ) @@ -80,7 +81,7 @@ func runEventHook(command string, payload []byte) { // it, and exits; the spawning goroutine ends with the subprocess. cmd := exec.CommandContext(context.Background(), shell, append(argsPrefix, command)...) cmd.Stdin = bytes.NewReader(payload) - var out boundedBuffer + var out boundedWriter cmd.Stdout = &out cmd.Stderr = &out if err := cmd.Run(); err != nil { @@ -88,25 +89,31 @@ func runEventHook(command string, payload []byte) { } } -// boundedBuffer captures up to maxHookOutput bytes from a hook subprocess -// and silently discards the rest. It implements only io.Writer so it can be -// assigned to exec.Cmd's Stdout/Stderr without forcing exec to buffer the -// full output internally. -type boundedBuffer struct { - buf bytes.Buffer +// boundedWriter captures up to maxHookOutput bytes from a hook subprocess +// and silently discards the rest. It satisfies io.Writer so it can be +// assigned to exec.Cmd's Stdout/Stderr. +// +// exec.Cmd spawns separate copy goroutines for Stdout and Stderr, so the +// underlying buffer must be safe for concurrent writes; that's what +// [concurrent.Buffer] gives us. The cap is enforced softly: between +// Len() and Write() another goroutine may slip in a chunk, so we may +// over-shoot maxHookOutput by at most one Write per concurrent stream +// (a few KB) — acceptable for diagnostic output. +type boundedWriter struct { + buf concurrent.Buffer } -func (b *boundedBuffer) Write(p []byte) (int, error) { +func (b *boundedWriter) Write(p []byte) (int, error) { if remaining := maxHookOutput - b.buf.Len(); remaining > 0 { - if len(p) > remaining { - b.buf.Write(p[:remaining]) - } else { - b.buf.Write(p) + chunk := p + if len(chunk) > remaining { + chunk = chunk[:remaining] } + _, _ = b.buf.Write(chunk) } return len(p), nil } -func (b *boundedBuffer) String() string { +func (b *boundedWriter) String() string { return b.buf.String() } diff --git a/cmd/root/run_event_hooks_test.go b/cmd/root/run_event_hooks_test.go index 559f02935..6959120cb 100644 --- a/cmd/root/run_event_hooks_test.go +++ b/cmd/root/run_event_hooks_test.go @@ -30,8 +30,8 @@ func TestParseOnEventFlags_BadFormat(t *testing.T) { } } -func TestBoundedBuffer_CapsAtMaxHookOutput(t *testing.T) { - var b boundedBuffer +func TestBoundedWriter_CapsAtMaxHookOutput(t *testing.T) { + var b boundedWriter n, err := b.Write(bytes.Repeat([]byte("a"), maxHookOutput-3)) require.NoError(t, err) diff --git a/pkg/chatserver/conversation_lock.go b/pkg/chatserver/conversation_lock.go index c5c9a1141..70a3488a3 100644 --- a/pkg/chatserver/conversation_lock.go +++ b/pkg/chatserver/conversation_lock.go @@ -1,6 +1,6 @@ package chatserver -import "sync" +import "github.com/docker/docker-agent/pkg/concurrent" // conversationLockSet ensures only one in-flight request at a time per // conversation id. Concurrent requests sharing an id would otherwise share @@ -12,12 +12,11 @@ import "sync" // for two reasons: it surfaces the misuse to the client immediately, and it // keeps the handler's resource cost bounded (no queue, no waiting goroutines). type conversationLockSet struct { - mu sync.Mutex - active map[string]struct{} + active concurrent.Map[string, struct{}] } func newConversationLockSet() *conversationLockSet { - return &conversationLockSet{active: make(map[string]struct{})} + return &conversationLockSet{} } // tryAcquire returns true when id was not already in flight. The caller @@ -27,13 +26,8 @@ func (l *conversationLockSet) tryAcquire(id string) bool { if l == nil || id == "" { return true } - l.mu.Lock() - defer l.mu.Unlock() - if _, ok := l.active[id]; ok { - return false - } - l.active[id] = struct{}{} - return true + _, loaded := l.active.LoadOrStore(id, struct{}{}) + return !loaded } // release marks id as no longer in flight. Safe to call when id is the @@ -42,7 +36,5 @@ func (l *conversationLockSet) release(id string) { if l == nil || id == "" { return } - l.mu.Lock() - delete(l.active, id) - l.mu.Unlock() + l.active.Delete(id) } diff --git a/pkg/concurrent/buffer.go b/pkg/concurrent/buffer.go index 44d8b9745..6efd1182e 100644 --- a/pkg/concurrent/buffer.go +++ b/pkg/concurrent/buffer.go @@ -2,6 +2,7 @@ package concurrent import ( "bytes" + "slices" "sync" ) @@ -27,6 +28,21 @@ func (b *Buffer) String() string { return b.buf.String() } +// Bytes returns a copy of the buffered content as a byte slice. +// The returned slice is safe to retain and modify. +func (b *Buffer) Bytes() []byte { + b.mu.Lock() + defer b.mu.Unlock() + return slices.Clone(b.buf.Bytes()) +} + +// Len returns the number of bytes currently buffered. +func (b *Buffer) Len() int { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Len() +} + // Reset clears the buffer. func (b *Buffer) Reset() { b.mu.Lock() diff --git a/pkg/concurrent/buffer_test.go b/pkg/concurrent/buffer_test.go new file mode 100644 index 000000000..7d827abd5 --- /dev/null +++ b/pkg/concurrent/buffer_test.go @@ -0,0 +1,90 @@ +package concurrent + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuffer_Write(t *testing.T) { + var b Buffer + + n, err := b.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + + n, err = b.Write([]byte(" world")) + require.NoError(t, err) + assert.Equal(t, 6, n) + + assert.Equal(t, "hello world", b.String()) +} + +func TestBuffer_Bytes(t *testing.T) { + var b Buffer + _, _ = b.Write([]byte("hello")) + + got := b.Bytes() + assert.Equal(t, []byte("hello"), got) + + // Mutating the returned slice must not affect the buffer. + got[0] = 'H' + assert.Equal(t, "hello", b.String()) +} + +func TestBuffer_Len(t *testing.T) { + var b Buffer + assert.Equal(t, 0, b.Len()) + + _, _ = b.Write([]byte("abc")) + assert.Equal(t, 3, b.Len()) + + _, _ = b.Write([]byte("de")) + assert.Equal(t, 5, b.Len()) +} + +func TestBuffer_Reset(t *testing.T) { + var b Buffer + _, _ = b.Write([]byte("hello")) + + b.Reset() + assert.Equal(t, 0, b.Len()) + assert.Empty(t, b.String()) +} + +func TestBuffer_Drain(t *testing.T) { + var b Buffer + _, _ = b.Write([]byte("hello")) + + got := b.Drain() + assert.Equal(t, "hello", got) + assert.Equal(t, 0, b.Len()) + assert.Empty(t, b.String()) +} + +func TestBuffer_Concurrent(t *testing.T) { + var b Buffer + var wg sync.WaitGroup + + const writers = 100 + for i := range writers { + wg.Go(func() { + _, _ = b.Write(fmt.Appendf(nil, "%03d", i)) + }) + } + + // Concurrent readers should not race with writers. + for range 50 { + wg.Go(func() { + _ = b.String() + _ = b.Len() + _ = b.Bytes() + }) + } + + wg.Wait() + assert.Equal(t, writers*3, b.Len()) +} diff --git a/pkg/concurrent/map.go b/pkg/concurrent/map.go index 9c69f27db..030d71df9 100644 --- a/pkg/concurrent/map.go +++ b/pkg/concurrent/map.go @@ -1,6 +1,9 @@ package concurrent -import "sync" +import ( + "maps" + "sync" +) type Map[K comparable, V any] struct { mu sync.RWMutex @@ -25,6 +28,9 @@ func (m *Map[K, V]) Store(key K, value V) { m.mu.Lock() defer m.mu.Unlock() + if m.values == nil { + m.values = make(map[K]V) + } m.values[key] = value } @@ -42,11 +48,54 @@ func (m *Map[K, V]) Length() int { return len(m.values) } +// LoadOrStore returns the existing value for key if present; otherwise it +// stores and returns value. The loaded result is true if the value was +// loaded, false if stored. +func (m *Map[K, V]) LoadOrStore(key K, value V) (V, bool) { + m.mu.RLock() + if existing, ok := m.values[key]; ok { + m.mu.RUnlock() + return existing, true + } + m.mu.RUnlock() + + m.mu.Lock() + defer m.mu.Unlock() + + // Re-check under the write lock: another goroutine may have stored + // the key between releasing the read lock and acquiring the write lock. + if existing, ok := m.values[key]; ok { + return existing, true + } + if m.values == nil { + m.values = make(map[K]V) + } + m.values[key] = value + return value, false +} + +// Clear removes all entries from the map. +func (m *Map[K, V]) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + + m.values = make(map[K]V) +} + +// Range calls f for every key/value pair in the map. Iteration stops early if +// f returns false. +// +// Range iterates over a snapshot of the map taken under a read lock; f is +// invoked without holding any lock, which means callbacks may safely call +// other methods on the same Map (including Store and Delete) without +// deadlocking. As a consequence, mutations performed during iteration are not +// reflected in the values seen by f. func (m *Map[K, V]) Range(f func(key K, value V) bool) { m.mu.RLock() - defer m.mu.RUnlock() + snapshot := maps.Clone(m.values) + m.mu.RUnlock() - for k, v := range m.values { + for k, v := range snapshot { if !f(k, v) { break } diff --git a/pkg/concurrent/map_test.go b/pkg/concurrent/map_test.go new file mode 100644 index 000000000..fe972bb44 --- /dev/null +++ b/pkg/concurrent/map_test.go @@ -0,0 +1,212 @@ +package concurrent + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMap_StoreLoad(t *testing.T) { + m := NewMap[string, int]() + + m.Store("a", 1) + m.Store("b", 2) + + val, ok := m.Load("a") + assert.True(t, ok) + assert.Equal(t, 1, val) + + val, ok = m.Load("b") + assert.True(t, ok) + assert.Equal(t, 2, val) + + _, ok = m.Load("missing") + assert.False(t, ok) +} + +func TestMap_StoreOverwrites(t *testing.T) { + m := NewMap[string, int]() + m.Store("k", 1) + m.Store("k", 2) + + val, ok := m.Load("k") + assert.True(t, ok) + assert.Equal(t, 2, val) + assert.Equal(t, 1, m.Length()) +} + +func TestMap_Delete(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + + m.Delete("a") + _, ok := m.Load("a") + assert.False(t, ok) + assert.Equal(t, 1, m.Length()) + + // Deleting a missing key is a no-op. + m.Delete("missing") + assert.Equal(t, 1, m.Length()) +} + +func TestMap_Length(t *testing.T) { + m := NewMap[string, int]() + assert.Equal(t, 0, m.Length()) + + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + assert.Equal(t, 3, m.Length()) +} + +func TestMap_Range(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + + collected := map[string]int{} + m.Range(func(k string, v int) bool { + collected[k] = v + return true + }) + assert.Equal(t, map[string]int{"a": 1, "b": 2, "c": 3}, collected) + + // Early termination: stop after the first element. + count := 0 + m.Range(func(_ string, _ int) bool { + count++ + return false + }) + assert.Equal(t, 1, count) +} + +func TestMap_RangeCallbackCanMutate(t *testing.T) { + // Range iterates over a snapshot, so callbacks may safely mutate the map + // without deadlocking. + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + + m.Range(func(k string, _ int) bool { + m.Store(k+"!", 0) + return true + }) + + assert.Equal(t, 4, m.Length()) +} + +func TestMap_ZeroValueStore(t *testing.T) { + // The zero value of Map must be usable: Store should lazily initialise + // the underlying map instead of panicking. + var m Map[string, int] + m.Store("a", 1) + + val, ok := m.Load("a") + assert.True(t, ok) + assert.Equal(t, 1, val) +} + +func TestMap_LoadOrStore(t *testing.T) { + m := NewMap[string, int]() + + val, loaded := m.LoadOrStore("a", 1) + assert.False(t, loaded) + assert.Equal(t, 1, val) + + val, loaded = m.LoadOrStore("a", 2) + assert.True(t, loaded) + assert.Equal(t, 1, val) + + // The original value is preserved even after a same-key LoadOrStore. + val, ok := m.Load("a") + assert.True(t, ok) + assert.Equal(t, 1, val) +} + +func TestMap_LoadOrStoreZeroValue(t *testing.T) { + // The zero value of Map must be usable for LoadOrStore as well. + var m Map[string, int] + val, loaded := m.LoadOrStore("a", 42) + assert.False(t, loaded) + assert.Equal(t, 42, val) +} + +func TestMap_LoadOrStoreConcurrent(t *testing.T) { + // Concurrent LoadOrStore calls for the same key must all return the + // same value, with exactly one of them reporting loaded == false. + m := NewMap[int, int]() + var wg sync.WaitGroup + const writers = 100 + + values := make([]int, writers) + loadedFlags := make([]bool, writers) + for i := range writers { + wg.Go(func() { + val, loaded := m.LoadOrStore(0, i) + values[i] = val + loadedFlags[i] = loaded + }) + } + wg.Wait() + + first := values[0] + newCount := 0 + for i := range writers { + require.Equal(t, first, values[i]) + if !loadedFlags[i] { + newCount++ + } + } + require.Equal(t, 1, newCount, "exactly one caller should report loaded == false") +} + +func TestMap_Concurrent(t *testing.T) { + m := NewMap[int, int]() + var wg sync.WaitGroup + + for i := range 100 { + wg.Add(1) + go func(n int) { + defer wg.Done() + m.Store(n, n*2) + }(i) + } + + wg.Wait() + require.Equal(t, 100, m.Length()) + + for i := range 100 { + val, ok := m.Load(i) + require.True(t, ok) + require.Equal(t, i*2, val) + } +} + +func TestMap_Clear(t *testing.T) { + m := NewMap[string, int]() + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + + m.Clear() + assert.Equal(t, 0, m.Length()) + + _, ok := m.Load("a") + assert.False(t, ok) + + // Map is usable after Clear. + m.Store("d", 4) + val, ok := m.Load("d") + assert.True(t, ok) + assert.Equal(t, 4, val) +} + +func TestMap_ClearZeroValue(t *testing.T) { + var m Map[string, int] + m.Clear() // should not panic on zero-value map + assert.Equal(t, 0, m.Length()) +} diff --git a/pkg/concurrent/slice.go b/pkg/concurrent/slice.go index 539c523ff..6e5631af0 100644 --- a/pkg/concurrent/slice.go +++ b/pkg/concurrent/slice.go @@ -11,7 +11,9 @@ type Slice[V any] struct { } func NewSlice[V any]() *Slice[V] { - return &Slice[V]{} + return &Slice[V]{ + values: []V{}, + } } func (s *Slice[V]) Append(value V) { @@ -57,6 +59,12 @@ func (s *Slice[V]) All() []V { return slices.Clone(s.values) } +// Range calls f for every element in the slice. Iteration stops early if f +// returns false. +// +// f is invoked while a read lock is held on the slice. Callbacks must not +// call methods that acquire the write lock (Append, Set, Update, Clear) on +// the same Slice, or a deadlock will occur. func (s *Slice[V]) Range(f func(index int, value V) bool) { s.mu.RLock() defer s.mu.RUnlock() @@ -68,6 +76,12 @@ func (s *Slice[V]) Range(f func(index int, value V) bool) { } } +// Find returns the first element for which predicate returns true, along with +// its index, or the zero value and -1 if no element matches. +// +// predicate is invoked while a read lock is held on the slice. It must not +// call methods that acquire the write lock (Append, Set, Update, Clear) on +// the same Slice, or a deadlock will occur. func (s *Slice[V]) Find(predicate func(V) bool) (V, int) { s.mu.RLock() defer s.mu.RUnlock() @@ -81,6 +95,12 @@ func (s *Slice[V]) Find(predicate func(V) bool) (V, int) { return zero, -1 } +// Update replaces the element at index with the result of f applied to the +// current value, returning true on success. If index is out of range, Update +// returns false and f is not called. +// +// f is invoked while the write lock is held on the slice. It must not call +// any other method on the same Slice, or a deadlock will occur. func (s *Slice[V]) Update(index int, f func(V) V) bool { s.mu.Lock() defer s.mu.Unlock() diff --git a/pkg/concurrent/slice_test.go b/pkg/concurrent/slice_test.go index 608f2dd66..bbdafffaf 100644 --- a/pkg/concurrent/slice_test.go +++ b/pkg/concurrent/slice_test.go @@ -71,6 +71,13 @@ func TestSlice_All(t *testing.T) { assert.Equal(t, 1, val) } +func TestSlice_AllEmpty(t *testing.T) { + s := NewSlice[int]() + all := s.All() + assert.NotNil(t, all) + assert.Empty(t, all) +} + func TestSlice_Range(t *testing.T) { s := NewSlice[int]() s.Append(10) diff --git a/pkg/evaluation/progress.go b/pkg/evaluation/progress.go index 8b124eca9..780f8dae6 100644 --- a/pkg/evaluation/progress.go +++ b/pkg/evaluation/progress.go @@ -9,6 +9,8 @@ import ( "time" "golang.org/x/term" + + "github.com/docker/docker-agent/pkg/concurrent" ) // progressBar provides a live-updating progress display for evaluation runs. @@ -20,10 +22,10 @@ type progressBar struct { completed atomic.Int32 passed atomic.Int32 failed atomic.Int32 - relevanceFailed atomic.Int32 // count of evals with relevance failures - sizeFailed atomic.Int32 // count of evals with size failures - toolCallsFailed atomic.Int32 // count of evals with tool call failures - running sync.Map // map[string]bool for currently running evals + relevanceFailed atomic.Int32 // count of evals with relevance failures + sizeFailed atomic.Int32 // count of evals with size failures + toolCallsFailed atomic.Int32 // count of evals with tool call failures + running concurrent.Map[string, struct{}] // titles of currently running evals done chan struct{} stopped chan struct{} // signals that the goroutine has finished ticker *time.Ticker @@ -67,7 +69,7 @@ func (p *progressBar) stop() { } func (p *progressBar) setRunning(title string) { - p.running.Store(title, true) + p.running.Store(title, struct{}{}) } func (p *progressBar) complete(title string, success bool) { @@ -180,10 +182,10 @@ func (p *progressBar) render(final bool) { // Count running evals runningCount := 0 var firstName string - p.running.Range(func(key, _ any) bool { + p.running.Range(func(key string, _ struct{}) bool { runningCount++ if firstName == "" { - firstName = key.(string) + firstName = key } return true }) diff --git a/pkg/hooks/handler.go b/pkg/hooks/handler.go index 2d5a2974a..eaf21fdc1 100644 --- a/pkg/hooks/handler.go +++ b/pkg/hooks/handler.go @@ -11,8 +11,8 @@ import ( "path/filepath" "slices" "strings" - "sync" + "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/shellpath" ) @@ -60,18 +60,14 @@ type BuiltinFunc func(ctx context.Context, in *Input, args []string) (*Output, e // Registry maps [HookType] to [HandlerFactory], plus a name → [BuiltinFunc] // table for [HookTypeBuiltin]. Safe for concurrent use. type Registry struct { - mu sync.RWMutex - factories map[HookType]HandlerFactory - builtins map[string]BuiltinFunc + factories concurrent.Map[HookType, HandlerFactory] + builtins concurrent.Map[string, BuiltinFunc] } // NewRegistry returns a registry pre-populated with [HookTypeCommand] // (shell command hooks) and [HookTypeBuiltin] (in-process functions). func NewRegistry() *Registry { - r := &Registry{ - factories: map[HookType]HandlerFactory{}, - builtins: map[string]BuiltinFunc{}, - } + r := &Registry{} r.Register(HookTypeCommand, newCommandFactory()) r.Register(HookTypeBuiltin, r.builtinFactory) return r @@ -79,17 +75,12 @@ func NewRegistry() *Registry { // Register associates a factory with a hook type, replacing any prior one. func (r *Registry) Register(t HookType, f HandlerFactory) { - r.mu.Lock() - defer r.mu.Unlock() - r.factories[t] = f + r.factories.Store(t, f) } // Lookup returns the factory registered for t, or (nil, false). func (r *Registry) Lookup(t HookType) (HandlerFactory, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - f, ok := r.factories[t] - return f, ok + return r.factories.Load(t) } // RegisterBuiltin makes fn callable as `{type: builtin, command: name}`. @@ -101,18 +92,13 @@ func (r *Registry) RegisterBuiltin(name string, fn BuiltinFunc) error { if fn == nil { return errors.New("builtin hook function must not be nil") } - r.mu.Lock() - defer r.mu.Unlock() - r.builtins[name] = fn + r.builtins.Store(name, fn) return nil } // LookupBuiltin returns the function registered as name, or (nil, false). func (r *Registry) LookupBuiltin(name string) (BuiltinFunc, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - fn, ok := r.builtins[name] - return fn, ok + return r.builtins.Load(name) } // DefaultRegistry is the process-wide registry used by [NewExecutor]. diff --git a/pkg/hooks/model_handler.go b/pkg/hooks/model_handler.go index f35057e86..c7388cfc4 100644 --- a/pkg/hooks/model_handler.go +++ b/pkg/hooks/model_handler.go @@ -7,9 +7,9 @@ import ( "errors" "fmt" "strings" - "sync" "text/template" + "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/config/latest" ) @@ -48,13 +48,9 @@ type ResponseShape func(raw string, in *Input) (*Output, error) // (or any other registry sharing the same package state) sees it. The // process-wide default is harmless because shapes are pure functions // of (raw, input). -var modelRegistry = struct { - mu sync.RWMutex - shapes map[string]ResponseShape - schemas map[string]*latest.StructuredOutput -}{ - shapes: map[string]ResponseShape{}, - schemas: map[string]*latest.StructuredOutput{}, +var modelRegistry struct { + shapes concurrent.Map[string, ResponseShape] + schemas concurrent.Map[string, *latest.StructuredOutput] } // RegisterResponseShape registers a [ResponseShape] under name. The @@ -68,9 +64,7 @@ func RegisterResponseShape(name string, shape ResponseShape) error { if shape == nil { return errors.New("response shape must not be nil") } - modelRegistry.mu.Lock() - defer modelRegistry.mu.Unlock() - modelRegistry.shapes[name] = shape + modelRegistry.shapes.Store(name, shape) return nil } @@ -81,9 +75,7 @@ func RegisterResponseSchema(name string, schema *latest.StructuredOutput) error if name == "" { return errors.New("response schema name must not be empty") } - modelRegistry.mu.Lock() - defer modelRegistry.mu.Unlock() - modelRegistry.schemas[name] = schema + modelRegistry.schemas.Store(name, schema) return nil } @@ -93,10 +85,7 @@ func lookupShape(name string) (ResponseShape, bool) { if name == "" { return defaultShape, true } - modelRegistry.mu.RLock() - defer modelRegistry.mu.RUnlock() - s, ok := modelRegistry.shapes[name] - return s, ok + return modelRegistry.shapes.Load(name) } // lookupSchema returns the structured-output schema for name, or nil @@ -105,9 +94,8 @@ func lookupSchema(name string) *latest.StructuredOutput { if name == "" { return nil } - modelRegistry.mu.RLock() - defer modelRegistry.mu.RUnlock() - return modelRegistry.schemas[name] + s, _ := modelRegistry.schemas.Load(name) + return s } // defaultShape passes the model's reply through as additional_context. diff --git a/pkg/snapshot/snapshot.go b/pkg/snapshot/snapshot.go index bfa5c64de..ad452c37b 100644 --- a/pkg/snapshot/snapshot.go +++ b/pkg/snapshot/snapshot.go @@ -17,6 +17,7 @@ import ( "strings" "sync" + "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/paths" ) @@ -52,8 +53,7 @@ type revertOp struct { // Manager opens per-worktree shadow repositories under a data directory. type Manager struct { dataDir string - mu sync.Mutex - locks map[string]*sync.Mutex + locks *concurrent.Map[string, *sync.Mutex] } // NewManager creates a snapshot manager rooted at dataDir. @@ -61,7 +61,7 @@ func NewManager(dataDir string) *Manager { if dataDir == "" { dataDir = paths.GetDataDir() } - return &Manager{dataDir: dataDir, locks: map[string]*sync.Mutex{}} + return &Manager{dataDir: dataDir, locks: concurrent.NewMap[string, *sync.Mutex]()} } // Open returns the shadow repository for the git worktree containing dir. @@ -96,14 +96,8 @@ func (m *Manager) Cleanup(ctx context.Context, dir string) error { } func (m *Manager) lock(key string) *sync.Mutex { - m.mu.Lock() - defer m.mu.Unlock() - if l := m.locks[key]; l != nil { - return l - } - l := &sync.Mutex{} - m.locks[key] = l - return l + lock, _ := m.locks.LoadOrStore(key, &sync.Mutex{}) + return lock } // Repo is a shadow git repository paired with a source worktree. diff --git a/pkg/tui/components/markdown/fast_renderer.go b/pkg/tui/components/markdown/fast_renderer.go index b0b836307..b53128582 100644 --- a/pkg/tui/components/markdown/fast_renderer.go +++ b/pkg/tui/components/markdown/fast_renderer.go @@ -17,6 +17,7 @@ import ( xansi "github.com/charmbracelet/x/ansi" runewidth "github.com/mattn/go-runewidth" + "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/tui/styles" ) @@ -160,9 +161,7 @@ func ResetStyles() { globalStylesMu.Unlock() // Also clear chroma syntax highlighting caches - chromaStyleCacheMu.Lock() - chromaStyleCache = make(map[chroma.TokenType]ansiStyle) - chromaStyleCacheMu.Unlock() + chromaStyleCache.Clear() syntaxHighlightCacheMu.Lock() syntaxHighlightCache.clear() @@ -2369,12 +2368,10 @@ type syntaxCacheKey struct { } var ( - lexerCache = make(map[string]chroma.Lexer) - lexerCacheMu sync.RWMutex + lexerCache concurrent.Map[string, chroma.Lexer] // Cache for chroma token type to ansiStyle conversion (with code bg) - chromaStyleCache = make(map[chroma.TokenType]ansiStyle) - chromaStyleCacheMu sync.RWMutex + chromaStyleCache concurrent.Map[chroma.TokenType, ansiStyle] // Cache for syntax highlighting results to avoid re-tokenizing unchanged code blocks. // Uses an LRU cache bounded to 128 entries to prevent unbounded memory growth @@ -2440,14 +2437,11 @@ func (p *parser) getLexer(lang string) chroma.Lexer { return nil } - lexerCacheMu.RLock() - lexer := lexerCache[lang] - lexerCacheMu.RUnlock() - if lexer != nil { + if lexer, ok := lexerCache.Load(lang); ok { return lexer } - lexer = lexers.Get(lang) + lexer := lexers.Get(lang) if lexer == nil { lexer = lexers.Match("file." + lang) } @@ -2456,27 +2450,20 @@ func (p *parser) getLexer(lang string) chroma.Lexer { } lexer = chroma.Coalesce(lexer) - lexerCacheMu.Lock() - lexerCache[lang] = lexer - lexerCacheMu.Unlock() + lexerCache.Store(lang, lexer) return lexer } func (p *parser) getCodeStyle(tokenType chroma.TokenType) ansiStyle { - chromaStyleCacheMu.RLock() - style, ok := chromaStyleCache[tokenType] - chromaStyleCacheMu.RUnlock() - if ok { + if style, ok := chromaStyleCache.Load(tokenType); ok { return style } // Build lipgloss style with code background inherited lipStyle := chromaToLipgloss(tokenType, p.styles.chromaStyle).Inherit(p.styles.styleCodeBg) - style = buildAnsiStyle(lipStyle) + style := buildAnsiStyle(lipStyle) - chromaStyleCacheMu.Lock() - chromaStyleCache[tokenType] = style - chromaStyleCacheMu.Unlock() + chromaStyleCache.Store(tokenType, style) return style } diff --git a/pkg/tui/components/tool/editfile/render.go b/pkg/tui/components/tool/editfile/render.go index e7d80fe28..f78680b65 100644 --- a/pkg/tui/components/tool/editfile/render.go +++ b/pkg/tui/components/tool/editfile/render.go @@ -14,6 +14,7 @@ import ( "github.com/aymanbagabas/go-udiff" "github.com/mattn/go-runewidth" + "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/builtin/filesystem" "github.com/docker/docker-agent/pkg/tui/styles" @@ -44,8 +45,7 @@ var ( cache = make(map[string]*toolRenderCache) // keyed by toolCallID cacheMu sync.RWMutex - lexerCache = make(map[string]chroma.Lexer) - lexerCacheMu sync.RWMutex + lexerCache concurrent.Map[string, chroma.Lexer] ) // InvalidateCaches clears all render caches. @@ -257,22 +257,14 @@ func normalizeDiff(diff []*udiff.Hunk) []*udiff.Hunk { func syntaxHighlight(code, filePath string) []chromaToken { ext := filepath.Ext(filePath) - // Try to get lexer from cache - lexerCacheMu.RLock() - lexer, ok := lexerCache[ext] - lexerCacheMu.RUnlock() - + lexer, ok := lexerCache.Load(ext) if !ok { - // Cache miss - compute and store lexer = lexers.Match(filePath) if lexer == nil { lexer = lexers.Fallback } lexer = chroma.Coalesce(lexer) - - lexerCacheMu.Lock() - lexerCache[ext] = lexer - lexerCacheMu.Unlock() + lexerCache.Store(ext, lexer) } style := styles.ChromaStyle() diff --git a/pkg/tui/styles/composite.go b/pkg/tui/styles/composite.go index 650dcf85c..74879d27a 100644 --- a/pkg/tui/styles/composite.go +++ b/pkg/tui/styles/composite.go @@ -2,9 +2,10 @@ package styles import ( "strings" - "sync" "charm.land/lipgloss/v2" + + "github.com/docker/docker-agent/pkg/concurrent" ) // ANSI reset sequences we need to handle @@ -15,17 +16,12 @@ const ( // styleSeqCache caches the style sequence for common styles. // The cache maps a style's string representation to its escape sequence. -var ( - styleSeqCache = make(map[string]string) - styleSeqCacheMu sync.RWMutex -) +var styleSeqCache concurrent.Map[string, string] // clearStyleSeqCache clears the style sequence cache. // Called when the theme changes to ensure styles are re-computed with new colors. func clearStyleSeqCache() { - styleSeqCacheMu.Lock() - styleSeqCache = make(map[string]string) - styleSeqCacheMu.Unlock() + styleSeqCache.Clear() } // getStyleSeq returns the ANSI escape sequence for a style's colors only. @@ -35,12 +31,9 @@ func getStyleSeq(style lipgloss.Style) string { // This is a simple way to identify the style cacheKey := style.Render("") - styleSeqCacheMu.RLock() - if seq, ok := styleSeqCache[cacheKey]; ok { - styleSeqCacheMu.RUnlock() + if seq, ok := styleSeqCache.Load(cacheKey); ok { return seq } - styleSeqCacheMu.RUnlock() // Compute the style sequence cleanStyle := style. @@ -66,9 +59,7 @@ func getStyleSeq(style lipgloss.Style) string { styleSeq = strings.TrimSuffix(styleSeq, resetFull) styleSeq = strings.TrimSuffix(styleSeq, resetShort) - styleSeqCacheMu.Lock() - styleSeqCache[cacheKey] = styleSeq - styleSeqCacheMu.Unlock() + styleSeqCache.Store(cacheKey, styleSeq) return styleSeq }