Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions pkg/runtime/compaction_context_limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package runtime

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/agent"
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/model/provider/base"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/team"
"github.com/docker/docker-agent/pkg/tools"
)

// providerOptsProvider is a minimal provider used to test that
// [providerContextLimit] reads the user-supplied context_size from
// the resolved [latest.ModelConfig.ProviderOpts] map.
type providerOptsProvider struct {
id string
opts map[string]any
}

func (p *providerOptsProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) }

func (p *providerOptsProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) {
return &mockStream{}, nil
}

func (p *providerOptsProvider) BaseConfig() base.Config {
return base.Config{
ModelConfig: latest.ModelConfig{ProviderOpts: p.opts},
}
}

func (p *providerOptsProvider) MaxTokens() int { return 0 }

// TestProviderContextLimit covers the fallback that lets compaction
// trigger for local models that aren't catalogued in models.dev. The
// helper accepts the various scalar shapes that YAML/JSON decoders
// produce ("32768", 32768, 32768.0) and rejects junk.
func TestProviderContextLimit(t *testing.T) {
t.Parallel()

tests := []struct {
name string
opts map[string]any
want int64
}{
{name: "nil opts", opts: nil, want: 0},
{name: "empty opts", opts: map[string]any{}, want: 0},
{name: "missing key", opts: map[string]any{"other": 123}, want: 0},
{name: "int", opts: map[string]any{"context_size": 32768}, want: 32768},
{name: "int64", opts: map[string]any{"context_size": int64(65536)}, want: 65536},
{name: "float64 (json)", opts: map[string]any{"context_size": float64(8192)}, want: 8192},
{name: "string decimal", opts: map[string]any{"context_size": "16384"}, want: 16384},
{name: "string with whitespace", opts: map[string]any{"context_size": " 4096 "}, want: 4096},
{name: "non-numeric string", opts: map[string]any{"context_size": "lots"}, want: 0},
{name: "negative", opts: map[string]any{"context_size": -1}, want: 0},
{name: "zero", opts: map[string]any{"context_size": 0}, want: 0},
{name: "bool", opts: map[string]any{"context_size": true}, want: 0},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := &providerOptsProvider{id: "dmr/test-model", opts: tt.opts}
assert.Equal(t, tt.want, providerContextLimit(p))
})
}
}

// TestProviderContextLimit_NilProvider verifies the helper handles a
// nil provider safely (returns 0). Belt-and-braces for callers that
// can't statically prove non-nil.
func TestProviderContextLimit_NilProvider(t *testing.T) {
t.Parallel()
assert.Equal(t, int64(0), providerContextLimit(nil))
}

// errorModelStore returns a "not found" error from GetModel, simulating
// a models.dev catalogue that doesn't have an entry for the configured
// model (the exact case reported for DMR + HuggingFace GGUF models).
type errorModelStore struct {
ModelStore

err error
}

func (s errorModelStore) GetModel(_ context.Context, _ modelsdev.ID) (*modelsdev.Model, error) {
return nil, s.err
}

// TestCompactionContextLimit_FallsBackToProviderOpts verifies that the
// runtime resolves a usable context limit from provider_opts.context_size
// when the models.dev catalogue lookup fails.
//
// This is the core of the fix for the reported bug: DMR users with a
// model not catalogued in models.dev (e.g. a HuggingFace GGUF) could
// supply context_size via provider_opts but compaction silently became
// a no-op, eventually surfacing as "Failed to get model definition"
// when overflow recovery was attempted.
func TestCompactionContextLimit_FallsBackToProviderOpts(t *testing.T) {
t.Parallel()

prov := &providerOptsProvider{
id: "dmr/hf.co/unsloth/qwen3-4b-gguf:Q4_K_M",
opts: map[string]any{"context_size": 32768},
}
root := agent.New("root", "test", agent.WithModel(prov))
tm := team.New(team.WithAgents(root))

rt, err := NewLocalRuntime(tm, WithModelStore(errorModelStore{err: errors.New("not in catalogue")}))
require.NoError(t, err)

got := rt.compactionContextLimit(t.Context(), root)
assert.Equal(t, int64(32768), got,
"context limit must fall back to provider_opts.context_size when models.dev has no entry")
}

// TestCompactionContextLimit_PrefersProviderOpts verifies that an explicit
// user-supplied provider_opts.context_size is the authoritative limit, even
// when the models.dev catalogue has its own entry. This is what the user is
// asking for — DMR allocates exactly context_size bytes for the inference
// context, and a user setting a smaller-than-catalogue value (cost / memory
// tuning) wants compaction to respect that.
func TestCompactionContextLimit_PrefersProviderOpts(t *testing.T) {
t.Parallel()

prov := &providerOptsProvider{
id: "openai/gpt-5",
opts: map[string]any{"context_size": 8192},
}
root := agent.New("root", "test", agent.WithModel(prov))
tm := team.New(team.WithAgents(root))

rt, err := NewLocalRuntime(tm, WithModelStore(mockModelStoreWithLimit{limit: 200_000}))
require.NoError(t, err)

got := rt.compactionContextLimit(t.Context(), root)
assert.Equal(t, int64(8192), got,
"explicit provider_opts.context_size must take precedence over the catalogue")
}

// TestCompactionContextLimit_FallsBackToCatalogue verifies that when the
// user has not supplied context_size, the runtime uses the models.dev
// catalogue limit. This is the path most hosted-model users hit.
func TestCompactionContextLimit_FallsBackToCatalogue(t *testing.T) {
t.Parallel()

prov := &providerOptsProvider{id: "openai/gpt-5"} // no opts
root := agent.New("root", "test", agent.WithModel(prov))
tm := team.New(team.WithAgents(root))

rt, err := NewLocalRuntime(tm, WithModelStore(mockModelStoreWithLimit{limit: 200_000}))
require.NoError(t, err)

got := rt.compactionContextLimit(t.Context(), root)
assert.Equal(t, int64(200_000), got)
}

// TestCompactionContextLimit_NoSourcesYieldsZero verifies the legacy
// behaviour: when neither models.dev nor provider_opts provides a
// limit, the function returns 0 (callers treat this as "can't
// compact"; the LLM strategy enforces ContextLimit > 0).
func TestCompactionContextLimit_NoSourcesYieldsZero(t *testing.T) {
t.Parallel()

prov := &providerOptsProvider{id: "unknown/model"} // no opts
root := agent.New("root", "test", agent.WithModel(prov))
tm := team.New(team.WithAgents(root))

rt, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{}))
require.NoError(t, err)

got := rt.compactionContextLimit(t.Context(), root)
assert.Equal(t, int64(0), got)
}
32 changes: 17 additions & 15 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,20 +362,23 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session,
if err != nil {
slog.DebugContext(ctx, "Failed to get model definition", "error", err)
}
// We can only compact if we know the limit.
var contextLimit int64
if m != nil {
contextLimit = int64(m.Limit.Context)

if r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) {
r.compactWithReason(ctx, sess, "", compactionReasonThreshold, sink)
}
// We can only compact if we know the context limit.
// resolveContextLimit prefers provider_opts.context_size when set
// (some providers — notably Docker Model Runner — use it to size
// the actual inference context), then falls back to the models.dev
// catalogue. The lookup above is reused inside resolveContextLimit
// only when context_size isn't supplied; we keep the explicit call
// here because m is also threaded into [recordAssistantMessage] for
// per-message cost computation.
contextLimit := r.resolveContextLimit(ctx, model, modelID)
if contextLimit > 0 && r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) {
r.compactWithReason(ctx, sess, "", compactionReasonThreshold, sink)
}

// Drain steer messages queued while idle or before the first model call
// (covers idle-window and first-turn-miss races).
if drained, messageCountBeforeSteer := r.drainAndEmitSteered(ctx, sess, sink); drained {
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeSteer, sink)
r.compactIfNeeded(ctx, sess, a, contextLimit, messageCountBeforeSteer, sink)
}

// Everything from turn_start onwards is wrapped in a closure so a
Expand Down Expand Up @@ -640,7 +643,7 @@ func (r *LocalRuntime) runTurn(

// Drain steer messages that arrived during tool calls.
if drained, _ := r.drainAndEmitSteered(ctx, sess, events); drained {
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
r.compactIfNeeded(ctx, sess, a, contextLimit, messageCountBeforeTools, events)
endReason = turnEndReasonSteered
return turnContinue
}
Expand All @@ -651,7 +654,7 @@ func (r *LocalRuntime) runTurn(

// Re-check steer queue: closes the race between the mid-loop drain and this stop.
if drained, _ := r.drainAndEmitSteered(ctx, sess, events); drained {
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
r.compactIfNeeded(ctx, sess, a, contextLimit, messageCountBeforeTools, events)
endReason = turnEndReasonSteered
return turnContinue
}
Expand All @@ -666,7 +669,7 @@ func (r *LocalRuntime) runTurn(
userMsg := session.UserMessage(followUp.Content, followUp.MultiContent...)
sess.AddMessage(userMsg)
events.Emit(UserMessage(followUp.Content, sess.ID, followUp.MultiContent, len(sess.Messages)-1))
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
r.compactIfNeeded(ctx, sess, a, contextLimit, messageCountBeforeTools, events)
endReason = turnEndReasonContinue
return turnContinue // re-enter the loop for a new turn
}
Expand All @@ -675,7 +678,7 @@ func (r *LocalRuntime) runTurn(
return turnExit
}

r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
r.compactIfNeeded(ctx, sess, a, contextLimit, messageCountBeforeTools, events)
endReason = turnEndReasonContinue
return turnContinue
}
Expand Down Expand Up @@ -774,12 +777,11 @@ func (r *LocalRuntime) compactIfNeeded(
ctx context.Context,
sess *session.Session,
a *agent.Agent,
m *modelsdev.Model,
contextLimit int64,
messageCountBefore int,
events EventSink,
) {
if m == nil || !r.sessionCompaction || contextLimit <= 0 {
if !r.sessionCompaction || contextLimit <= 0 {
return
}

Expand Down
10 changes: 2 additions & 8 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -1029,10 +1029,7 @@ func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, sess *session.Sessio
// sub-sessions won't emit their own events, so the parent must include
// their costs.
if sess != nil && (sess.InputTokens > 0 || sess.OutputTokens > 0) {
var contextLimit int64
if m, err := r.modelsStore.GetModel(ctx, modelID); err == nil && m != nil {
contextLimit = int64(m.Limit.Context)
}
contextLimit := r.resolveContextLimit(ctx, a.Model(ctx), modelID)
usage := SessionUsage(sess, contextLimit)
usage.Cost = sess.TotalCost()

Expand Down Expand Up @@ -1301,10 +1298,7 @@ func (r *LocalRuntime) compactWithReason(ctx context.Context, sess *session.Sess
// compaction: tokens drop to the summary size, context % drops, and
// cost increases by the summary generation cost.
modelID := r.getEffectiveModelID(a)
var contextLimit int64
if m, err := r.modelsStore.GetModel(ctx, modelID); err == nil && m != nil {
contextLimit = int64(m.Limit.Context)
}
contextLimit := r.resolveContextLimit(ctx, a.Model(ctx), modelID)
events.Emit(NewTokenUsageEvent(sess.ID, a.Name(), SessionUsage(sess, contextLimit)))
}

Expand Down
77 changes: 74 additions & 3 deletions pkg/runtime/session_compaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package runtime
import (
"context"
"log/slog"
"strconv"
"strings"

"github.com/docker/docker-agent/pkg/agent"
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/compaction"
"github.com/docker/docker-agent/pkg/hooks"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/runtime/compactor"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/team"
Expand Down Expand Up @@ -162,6 +165,10 @@ func summaryFromHook(sess *session.Session, a *agent.Agent, pre *hooks.Result) *
// when it can't be resolved. Failure is non-fatal: a before_compaction
// hook may supply its own summary and never need the model definition.
// The LLM strategy itself enforces ContextLimit > 0.
//
// See [LocalRuntime.resolveContextLimit] for the resolution order; we
// pass the cloned summary-call provider so its provider_opts (which
// match the underlying model) are considered.
func (r *LocalRuntime) compactionContextLimit(ctx context.Context, a *agent.Agent) int64 {
if a == nil || a.Model(ctx) == nil {
return 0
Expand All @@ -170,11 +177,75 @@ func (r *LocalRuntime) compactionContextLimit(ctx context.Context, a *agent.Agen
options.WithStructuredOutput(nil),
options.WithMaxTokens(compactor.MaxSummaryTokens),
)
m, err := r.modelsStore.GetModel(ctx, summaryModel.ID())
if err != nil || m == nil {
return r.resolveContextLimit(ctx, summaryModel, summaryModel.ID())
}

// resolveContextLimit resolves the effective context window size for a
// model. Resolution order:
//
// 1. The user-supplied [provider_opts.context_size], when set, is
// authoritative. Some providers (notably Docker Model Runner) use
// it to size the actual inference context, so we plan against the
// same number the engine will enforce. This also makes compaction
// work for local models that aren't catalogued in models.dev (e.g.
// a HuggingFace GGUF).
// 2. Otherwise, the models.dev catalogue limit looked up by id.
// 3. Otherwise, 0 (caller treats this as "can't compact").
func (r *LocalRuntime) resolveContextLimit(ctx context.Context, p provider.Provider, id modelsdev.ID) int64 {
if n := providerContextLimit(p); n > 0 {
return n
}
m, err := r.modelsStore.GetModel(ctx, id)
if err == nil && m != nil && m.Limit.Context > 0 {
return int64(m.Limit.Context)
}
return 0
}

// providerContextLimit reads [provider_opts.context_size] from a
// provider's resolved [latest.ModelConfig], returning 0 when unset or
// not parseable as an integer. This is the fallback used when the
// models.dev catalogue does not have an entry for the configured
// model (typically Docker Model Runner with a HuggingFace GGUF model).
//
// Accepted shapes mirror what YAML/JSON decoders may produce: int,
// int64, float64, and decimal strings. Negative or zero values are
// treated as "unset" so callers don't accidentally trigger
// compaction with a degenerate limit.
func providerContextLimit(p provider.Provider) int64 {
if p == nil {
return 0
}
opts := p.BaseConfig().ModelConfig.ProviderOpts
v, ok := opts["context_size"]
if !ok {
return 0
}
var n int64
switch t := v.(type) {
case int64:
n = t
case int:
n = int64(t)
case int32:
n = int64(t)
case float64:
n = int64(t)
case float32:
n = int64(t)
case string:
parsed, err := strconv.ParseInt(strings.TrimSpace(t), 10, 64)
if err != nil {
return 0
}
n = parsed
default:
return 0
}
if n <= 0 {
return 0
}
return int64(m.Limit.Context)
return n
}

// runCompactionAgent runs an agent against a sub-session for compaction.
Expand Down
Loading