diff --git a/pkg/runtime/loop_steps.go b/pkg/runtime/loop_steps.go index 7ab158a8b..7be44cf93 100644 --- a/pkg/runtime/loop_steps.go +++ b/pkg/runtime/loop_steps.go @@ -156,27 +156,53 @@ func (r *LocalRuntime) handleStreamError( return streamErrorFatal } - // Auto-recovery: if the error is a context overflow and session - // compaction is enabled, compact the conversation and retry the - // request instead of surfacing raw errors. We allow at most - // r.maxOverflowCompactions consecutive attempts to avoid an infinite - // loop when compaction cannot reduce the context enough. - if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction && *overflowCompactions < r.maxOverflowCompactions { - *overflowCompactions++ - slog.WarnContext(ctx, "Context window overflow detected, attempting auto-compaction", - "agent", a.Name(), - "session_id", sess.ID, - "input_tokens", sess.InputTokens, - "output_tokens", sess.OutputTokens, - "context_limit", contextLimit, - "attempt", *overflowCompactions, - ) - events.Emit(Warning( - "The conversation has exceeded the model's context window. Automatically compacting the conversation history...", - a.Name(), - )) - r.compactWithReason(ctx, sess, "", compactionReasonOverflow, events) - return streamErrorRetry + // Overflow handling has two independent concerns: + // + // 1. Auto-compaction (token overflow only): summarise older + // turns to fit the context window, then retry. Gated by + // r.sessionCompaction and the per-run attempt cap. + // + // 2. Session hygiene (wire/media overflow): rewrite the + // offending user message so the same oversized payload + // cannot reload on the next call and re-poison the session. + // Always runs when the kind warrants it, independent of + // the compaction config — the hygiene step does not retry + // and is correct even when compaction is disabled. + if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok { + kind := modelerrors.OverflowKindOf(err) + + // Token overflow: compaction is the right recovery — older + // turns can be summarised to free up context. Wire/media do + // not benefit from compaction (the latest turn alone is + // over the cap; resending it during compaction would just + // fail again), so we fall through to the hygiene step + // below for those. + if kind == modelerrors.OverflowKindTokens && r.sessionCompaction && *overflowCompactions < r.maxOverflowCompactions { + *overflowCompactions++ + slog.WarnContext(ctx, "Context window overflow detected, attempting auto-compaction", + "agent", a.Name(), + "session_id", sess.ID, + "input_tokens", sess.InputTokens, + "output_tokens", sess.OutputTokens, + "context_limit", contextLimit, + "attempt", *overflowCompactions, + ) + events.Emit(Warning( + "The conversation has exceeded the model's context window. Automatically compacting the conversation history...", + a.Name(), + )) + r.compactWithReason(ctx, sess, "", compactionReasonOverflow, events) + return streamErrorRetry + } + + // Hygiene scrub for wire/media overflow. Runs independently + // of r.sessionCompaction: this rewrites a single message in + // place, it does not retry, and the same-process + // session-poisoning bug it fixes occurs regardless of + // whether the user opted into auto-compaction. + if kind == modelerrors.OverflowKindWire || kind == modelerrors.OverflowKindMedia { + r.recoverFromOversizedTurn(ctx, sess, kind, events) + } } streamSpan.RecordError(err) diff --git a/pkg/runtime/loop_steps_test.go b/pkg/runtime/loop_steps_test.go index 8c60a2c25..f2115acb6 100644 --- a/pkg/runtime/loop_steps_test.go +++ b/pkg/runtime/loop_steps_test.go @@ -292,3 +292,75 @@ func TestHandleStreamError_GenericError_FatalAndEmitsError(t *testing.T) { } assert.True(t, sawError, "generic error should emit ErrorEvent") } + +// TestHandleStreamError_WireOverflowSkipsCompaction verifies that wire-level +// overflow does not trigger auto-compaction. Compaction would resend the same +// oversized request that just got rejected, so it is guaranteed to fail; we +// surface the error directly instead. +func TestHandleStreamError_WireOverflowSkipsCompaction(t *testing.T) { + t.Parallel() + + rt, a := newTestRuntime(t) + sess := session.New() + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("HTTP 413: Payload Too Large"), + Kind: modelerrors.OverflowKindWire, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, a, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + + assert.Equal(t, streamErrorFatal, outcome, "wire overflow must not trigger compaction retry") + assert.Equal(t, 0, overflowCount, "wire overflow must not bump the compaction counter") + + got := drainEvents(events) + var sawError bool + var sawWarning bool + var errCode string + for _, ev := range got { + switch e := ev.(type) { + case *ErrorEvent: + sawError = true + errCode = e.Code + case *WarningEvent: + sawWarning = true + } + } + assert.True(t, sawError, "wire overflow should emit an ErrorEvent") + assert.False(t, sawWarning, "wire overflow should not emit the compaction warning") + assert.Equal(t, ErrorCodeRequestTooLarge, errCode, "ErrorEvent.Code should distinguish wire overflow") +} + +// TestHandleStreamError_MediaOverflowSkipsCompaction verifies the same skip +// behaviour for media-size rejections. Without media-stripping during +// compaction, the offending attachment would be resent and fail again. +func TestHandleStreamError_MediaOverflowSkipsCompaction(t *testing.T) { + t.Parallel() + + rt, a := newTestRuntime(t) + sess := session.New() + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("image exceeds 5 MB maximum"), + Kind: modelerrors.OverflowKindMedia, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, a, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + + assert.Equal(t, streamErrorFatal, outcome, "media overflow must not trigger compaction retry") + assert.Equal(t, 0, overflowCount, "media overflow must not bump the compaction counter") + + var errCode string + for _, ev := range drainEvents(events) { + if e, ok := ev.(*ErrorEvent); ok { + errCode = e.Code + } + } + assert.Equal(t, ErrorCodeMediaTooLarge, errCode, "ErrorEvent.Code should distinguish media overflow") +} diff --git a/pkg/runtime/overflow_recovery.go b/pkg/runtime/overflow_recovery.go new file mode 100644 index 000000000..97475083a --- /dev/null +++ b/pkg/runtime/overflow_recovery.go @@ -0,0 +1,289 @@ +package runtime + +import ( + "context" + "fmt" + "log/slog" + "path/filepath" + "strings" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/session" +) + +// maxScrubbedTextBytes is the threshold at and below which a user +// message's plain-text content is preserved verbatim during scrubbing. +// Above the threshold the content is replaced with a placeholder that +// records the original size — keeping it verbatim would re-poison the +// session on the next turn by carrying the offending payload back into +// the provider request. +// +// The value is intentionally well below any major provider's wire-size +// cap. Smaller payloads pass through unchanged. +const maxScrubbedTextBytes = 1 << 20 // 1 MiB + +// scrubReport summarises what scrubMessage rewrote, for observability. +// All counters are zero on a no-op scrub. +type scrubReport struct { + // textReplaced is true when [chat.Message.Content] was over + // [maxScrubbedTextBytes] and has been replaced. + textReplaced bool + // originalBytes is the size of the original plain text, set + // only when textReplaced is true. + originalBytes int64 + // partsReplaced counts how many MultiContent parts were + // rewritten: media parts (image/file/document) are always + // counted; oversized text parts are also counted when they + // exceeded [maxScrubbedTextBytes]. + partsReplaced int +} + +func (r scrubReport) didScrub() bool { + return r.textReplaced || r.partsReplaced > 0 +} + +// scrubMessage returns a copy of msg in which media parts (image_url, +// file, document) are replaced with text placeholders and oversized +// plain-text content is replaced with a size-noting placeholder. The +// returned scrubReport describes what changed; when it reports +// [scrubReport.didScrub] false the message is byte-identical to msg. +// +// scrubMessage is pure: it does not consult the session, the model, +// or any context. Callers decide *when* to apply it (post-failure +// recovery, manual cleanup, etc.); this function only describes +// *how*. +func scrubMessage(msg chat.Message) (chat.Message, scrubReport) { + var report scrubReport + + out := msg + + // Plain text: replace only when oversized so we don't lose the + // user's intent for normal-sized messages. + if len(out.Content) > maxScrubbedTextBytes { + report.textReplaced = true + report.originalBytes = int64(len(out.Content)) + out.Content = oversizedTextPlaceholder(int64(len(out.Content))) + } + + // Multi-content parts: rewrite each media part in place. + if len(out.MultiContent) > 0 { + newParts := make([]chat.MessagePart, len(out.MultiContent)) + for i, part := range out.MultiContent { + rewritten, replaced := scrubMessagePart(part) + if replaced { + report.partsReplaced++ + } + newParts[i] = rewritten + } + out.MultiContent = newParts + } + + return out, report +} + +// scrubMessagePart replaces a single attachment part with a text +// placeholder. Returns the rewritten part and whether anything +// changed; small text parts and unrecognised types pass through +// unchanged. +// +// Text parts over [maxScrubbedTextBytes] are themselves rewritten to +// a size-noting placeholder — a single text part inside MultiContent +// can be just as poisoning as oversized [chat.Message.Content]. +// +// The placeholder describes what was attached (kind, name, size) +// without preserving any of the content, so the rewritten message +// can never re-trip the provider's media-size limits. +func scrubMessagePart(part chat.MessagePart) (chat.MessagePart, bool) { + switch part.Type { + case chat.MessagePartTypeText: + if int64(len(part.Text)) > maxScrubbedTextBytes { + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: oversizedTextPlaceholder(int64(len(part.Text))), + }, true + } + return part, false + + case chat.MessagePartTypeImageURL: + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: imagePlaceholder(part), + }, true + + case chat.MessagePartTypeFile: + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: filePlaceholder(part), + }, true + + case chat.MessagePartTypeDocument: + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: documentPlaceholder(part), + }, true + } + // Unknown part type — leave it alone rather than risk dropping + // data we don't recognise. + return part, false +} + +func oversizedTextPlaceholder(originalBytes int64) string { + return fmt.Sprintf( + "[previous message was %s of text — too large for the AI provider; "+ + "content was removed from the session so the conversation can continue]", + humanByteSize(originalBytes), + ) +} + +func imagePlaceholder(part chat.MessagePart) string { + if part.ImageURL == nil { + return "[image attachment removed: too large for the AI provider]" + } + if name := imageDisplayName(part.ImageURL.URL); name != "" { + return fmt.Sprintf("[image %q removed: too large for the AI provider]", name) + } + return "[image attachment removed: too large for the AI provider]" +} + +func filePlaceholder(part chat.MessagePart) string { + if part.File == nil { + return "[file attachment removed: too large for the AI provider]" + } + name := part.File.Path + if name != "" { + name = filepath.Base(name) + } + if name == "" { + name = part.File.FileID + } + if name == "" { + return "[file attachment removed: too large for the AI provider]" + } + return fmt.Sprintf("[file %q removed: too large for the AI provider]", name) +} + +func documentPlaceholder(part chat.MessagePart) string { + if part.Document == nil { + return "[document attachment removed: too large for the AI provider]" + } + doc := part.Document + if doc.Name != "" && doc.Size > 0 { + return fmt.Sprintf("[document %q (%s) removed: too large for the AI provider]", + doc.Name, humanByteSize(doc.Size)) + } + if doc.Name != "" { + return fmt.Sprintf("[document %q removed: too large for the AI provider]", doc.Name) + } + if doc.MimeType != "" { + return fmt.Sprintf("[%s attachment removed: too large for the AI provider]", doc.MimeType) + } + return "[document attachment removed: too large for the AI provider]" +} + +// imageDisplayName extracts a short display name from an image URL. +// Returns "" for data: URIs (where the URL itself is the payload and +// not user-meaningful) and falls back to the URL path's basename +// otherwise. +func imageDisplayName(url string) string { + if url == "" || strings.HasPrefix(url, "data:") { + return "" + } + // Strip query / fragment. + if i := strings.IndexAny(url, "?#"); i >= 0 { + url = url[:i] + } + base := filepath.Base(url) + if base == "/" || base == "." { + return "" + } + return base +} + +// humanByteSize renders n bytes as a short decimal string with binary +// units (KiB, MiB, GiB). Used for placeholder text only; precision is +// limited to one decimal place since this is informational. +func humanByteSize(n int64) string { + const ( + kib = 1 << 10 + mib = 1 << 20 + gib = 1 << 30 + ) + switch { + case n >= gib: + return fmt.Sprintf("%.1f GiB", float64(n)/float64(gib)) + case n >= mib: + return fmt.Sprintf("%.1f MiB", float64(n)/float64(mib)) + case n >= kib: + return fmt.Sprintf("%.1f KiB", float64(n)/float64(kib)) + } + return fmt.Sprintf("%d B", n) +} + +// recoverFromOversizedTurn rewrites the latest user message in sess so +// that the offending content (oversized text, media attachments) is +// neutralised. This is the runtime's in-memory hygiene step after a +// wire- or media-overflow rejection: without it, the same oversized +// turn re-sends on every subsequent call within this process and the +// conversation cannot continue. +// +// Scope: +// - In-memory only. The session store row is NOT updated; a +// docker-agent restart mid-session will reload the original +// oversized payload from disk. Mirroring the rewrite to the +// store requires propagating Message.ID from Store.AddMessage +// back into the in-memory session, which is an independent +// persistence-layer fix tracked as a separate change. +// - Only called for [modelerrors.OverflowKindWire] and +// [modelerrors.OverflowKindMedia]. Token overflow is handled by +// auto-compaction (a different mechanism). +// - Mutates only the most recent user message. Earlier turns are +// left alone — the heuristic is that the latest turn is the one +// that just tripped the provider; older turns must have been +// accepted at some point. +func (r *LocalRuntime) recoverFromOversizedTurn( + ctx context.Context, + sess *session.Session, + kind modelerrors.OverflowKind, + events EventSink, +) { + var report scrubReport + rewrote := sess.RewriteLatestUserMessage(func(msg chat.Message) (chat.Message, bool) { + rewritten, r := scrubMessage(msg) + if !r.didScrub() { + return msg, false + } + report = r + return rewritten, true + }) + if !rewrote { + // Nothing oversized to scrub (e.g. the offending content was + // already small, or the session has no user message yet). + return + } + + slog.InfoContext(ctx, "Scrubbed oversized user message after overflow", + "session_id", sess.ID, + "overflow_kind", string(kind), + "text_replaced", report.textReplaced, + "original_text_bytes", report.originalBytes, + "parts_replaced", report.partsReplaced, + ) + emitScrubNotice(events, report) +} + +// emitScrubNotice surfaces an informational warning so the UI can show +// the user that their message was rewritten in place. Without this the +// recovery is silent and the user sees only "Your message is too +// large" — they wouldn't know that the offending content has been +// removed from the conversation history. +func emitScrubNotice(events EventSink, report scrubReport) { + if events == nil || !report.didScrub() { + return + } + events.Emit(Warning( + "Your previous message was too large and has been rewritten in the "+ + "conversation history. Send a smaller message to continue.", + "", + )) +} diff --git a/pkg/runtime/overflow_recovery_test.go b/pkg/runtime/overflow_recovery_test.go new file mode 100644 index 000000000..3bad3ec46 --- /dev/null +++ b/pkg/runtime/overflow_recovery_test.go @@ -0,0 +1,390 @@ +package runtime + +import ( + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" +) + +// --- scrubMessage / scrubMessagePart unit tests --- + +func TestScrubMessage_TextBelowThreshold_PassesThrough(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + Content: strings.Repeat("a", 1024), + } + out, report := scrubMessage(msg) + assert.False(t, report.didScrub(), "small text must not be scrubbed") + assert.Equal(t, msg, out, "small text must pass through byte-identical") +} + +func TestScrubMessage_OversizedText_Replaced(t *testing.T) { + t.Parallel() + + original := strings.Repeat("z", maxScrubbedTextBytes+1) + msg := chat.Message{Role: chat.MessageRoleUser, Content: original} + + out, report := scrubMessage(msg) + assert.True(t, report.textReplaced) + assert.Equal(t, int64(len(original)), report.originalBytes) + assert.NotEqual(t, original, out.Content, "oversized text must be rewritten") + assert.Contains(t, out.Content, "too large", "placeholder must signal the cause") +} + +func TestScrubMessage_ImageURLPart_BecomesTextPlaceholder(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "look at this"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "https://example.com/foo.png"}}, + }, + } + + out, report := scrubMessage(msg) + require.True(t, report.didScrub()) + assert.Equal(t, 1, report.partsReplaced) + require.Len(t, out.MultiContent, 2) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[0].Type, "text part untouched") + assert.Equal(t, "look at this", out.MultiContent[0].Text) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[1].Type, "media replaced by text") + assert.Contains(t, out.MultiContent[1].Text, "foo.png", "placeholder must include the name when available") +} + +func TestScrubMessage_DataURLImage_NameOmitted(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,AAAA"}}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[0].Type) + // data: URI is not user-meaningful — the placeholder should not + // leak the base64 payload and should still describe what was + // removed. + assert.NotContains(t, out.MultiContent[0].Text, "AAAA") + assert.Contains(t, out.MultiContent[0].Text, "image attachment") +} + +func TestScrubMessage_FilePart_BecomesPlaceholder(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeFile, File: &chat.MessageFile{Path: "/tmp/big.log"}}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Contains(t, out.MultiContent[0].Text, "big.log") +} + +func TestScrubMessage_DocumentPart_IncludesSize(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeDocument, Document: &chat.Document{ + Name: "report.pdf", MimeType: "application/pdf", Size: 3 * 1024 * 1024, + }}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Contains(t, out.MultiContent[0].Text, "report.pdf") + assert.Contains(t, out.MultiContent[0].Text, "MiB", "size should be human-readable") +} + +func TestScrubMessage_SmallTextPart_PassesThrough(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "just text"}, + }, + } + out, report := scrubMessage(msg) + assert.False(t, report.didScrub(), "small text in multi-content must not be scrubbed") + assert.Equal(t, msg, out) +} + +// TestScrubMessage_OversizedTextPart_Replaced verifies that an +// oversized text blob inside MultiContent is rewritten just like a +// top-level Content payload. A pure-text overflow can arrive as +// either, and the scrub must catch both shapes. +func TestScrubMessage_OversizedTextPart_Replaced(t *testing.T) { + t.Parallel() + + original := strings.Repeat("q", maxScrubbedTextBytes+1) + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "preserved preamble"}, + {Type: chat.MessagePartTypeText, Text: original}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Equal(t, 1, report.partsReplaced) + require.Len(t, out.MultiContent, 2) + assert.Equal(t, "preserved preamble", out.MultiContent[0].Text, + "small text parts must pass through untouched") + assert.NotEqual(t, original, out.MultiContent[1].Text, + "oversized text part must be rewritten") + assert.Contains(t, out.MultiContent[1].Text, "too large") +} + +func TestScrubMessage_MultipleMediaParts_AllReplaced(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "https://a/1.png"}}, + {Type: chat.MessagePartTypeFile, File: &chat.MessageFile{Path: "/tmp/2.log"}}, + {Type: chat.MessagePartTypeDocument, Document: &chat.Document{Name: "3.pdf"}}, + }, + } + out, report := scrubMessage(msg) + assert.Equal(t, 3, report.partsReplaced) + for _, part := range out.MultiContent { + assert.Equal(t, chat.MessagePartTypeText, part.Type) + } +} + +func TestScrubMessage_TextAndMediaTogether(t *testing.T) { + t.Parallel() + + oversized := strings.Repeat("x", maxScrubbedTextBytes+512) + msg := chat.Message{ + Role: chat.MessageRoleUser, + Content: oversized, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "https://a/x.png"}}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.textReplaced) + assert.Equal(t, 1, report.partsReplaced) + assert.NotEqual(t, oversized, out.Content) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[0].Type) +} + +// --- recoverFromOversizedTurn integration tests --- + +func TestRecoverFromOversizedTurn_NoUserMessage_NoOp(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + events := make(chan Event, 4) + + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + assert.Empty(t, drainEvents(events), "empty session should produce no events") +} + +func TestRecoverFromOversizedTurn_SmallMessage_NoOp(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + sess.AddMessage(session.UserMessage("hello")) + events := make(chan Event, 4) + + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + assert.Empty(t, drainEvents(events)) + assert.Equal(t, "hello", sess.GetLastUserMessageContent(), "small message stays verbatim") +} + +func TestRecoverFromOversizedTurn_OversizedText_Rewrites(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + original := strings.Repeat("Y", maxScrubbedTextBytes+1) + sess.AddMessage(session.UserMessage(original)) + events := make(chan Event, 4) + + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + rewritten := sess.GetLastUserMessageContent() + assert.NotEqual(t, original, rewritten, "oversized text should have been rewritten") + assert.Contains(t, rewritten, "too large") + + var sawWarning bool + for _, ev := range drainEvents(events) { + if _, ok := ev.(*WarningEvent); ok { + sawWarning = true + } + } + assert.True(t, sawWarning, "scrub should emit a Warning so the UI can inform the user") +} + +func TestRecoverFromOversizedTurn_OnlyLatestUserMessage(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + + old := strings.Repeat("o", maxScrubbedTextBytes+1) + sess.AddMessage(session.UserMessage(old)) + // Subsequent assistant + user messages — the scrub must only + // touch the most recent user turn. + sess.AddMessage(&session.Message{Message: chat.Message{Role: chat.MessageRoleAssistant, Content: "ok"}}) + sess.AddMessage(session.UserMessage("short")) + + events := make(chan Event, 4) + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + + // The latest user message ("short") is small — nothing to scrub. + assert.Equal(t, "short", sess.GetLastUserMessageContent()) + // The OLDER oversized message must NOT have been touched. + all := sess.GetAllMessages() + require.GreaterOrEqual(t, len(all), 1) + assert.Equal(t, old, all[0].Message.Content, + "older user messages must not be scrubbed — only the latest is suspect") + assert.Empty(t, drainEvents(events)) +} + +// --- handleStreamError integration --- + +func TestHandleStreamError_WireOverflowScrubsSession(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model"} + root := agent.New("root", "test", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New() + original := strings.Repeat("L", maxScrubbedTextBytes+10) + sess.AddMessage(session.UserMessage(original)) + + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("HTTP 413: Payload Too Large"), + Kind: modelerrors.OverflowKindWire, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, root, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + + assert.Equal(t, streamErrorFatal, outcome) + assert.NotEqual(t, original, sess.GetLastUserMessageContent(), + "wire overflow must scrub the offending user turn so future calls in this process don't re-fail") + + // The ErrorEvent for the rejection MUST still be emitted — + // scrubbing is in addition to the error, not instead of it. + var sawErrorEvent bool + for _, ev := range drainEvents(events) { + if e, ok := ev.(*ErrorEvent); ok { + sawErrorEvent = true + assert.Equal(t, ErrorCodeRequestTooLarge, e.Code, + "wire overflow still surfaces the request-too-large code") + } + } + assert.True(t, sawErrorEvent) +} + +// TestHandleStreamError_WireOverflowScrubsEvenWithCompactionDisabled +// pins that the hygiene scrub for wire/media overflow is independent +// of the session-compaction config. Compaction is irrelevant here — +// the scrub rewrites a single message and does not retry, and the +// in-process bug it fixes happens regardless of whether the user +// opted into auto-compaction. +func TestHandleStreamError_WireOverflowScrubsEvenWithCompactionDisabled(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model"} + root := agent.New("root", "test", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + ) + require.NoError(t, err) + + sess := session.New() + original := strings.Repeat("W", maxScrubbedTextBytes+10) + sess.AddMessage(session.UserMessage(original)) + + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("HTTP 413: Payload Too Large"), + Kind: modelerrors.OverflowKindWire, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, root, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + assert.Equal(t, streamErrorFatal, outcome) + assert.NotEqual(t, original, sess.GetLastUserMessageContent(), + "scrub must run even when session compaction is disabled") +} + +// --- Session.RewriteLatestUserMessage contract --- + +func TestSessionRewriteLatestUserMessage_FindsMostRecentUser(t *testing.T) { + t.Parallel() + + sess := session.New() + sess.AddMessage(session.UserMessage("first")) + sess.AddMessage(&session.Message{Message: chat.Message{Role: chat.MessageRoleAssistant, Content: "reply"}}) + sess.AddMessage(session.UserMessage("second")) + + var seen string + ok := sess.RewriteLatestUserMessage(func(m chat.Message) (chat.Message, bool) { + seen = m.Content + m.Content = "scrubbed" + return m, true + }) + assert.True(t, ok) + assert.Equal(t, "second", seen, "rewrite should target the latest user message") + assert.Equal(t, "scrubbed", sess.GetLastUserMessageContent()) +} + +func TestSessionRewriteLatestUserMessage_OptOut_DoesNotMutate(t *testing.T) { + t.Parallel() + + sess := session.New() + sess.AddMessage(session.UserMessage("keep")) + + ok := sess.RewriteLatestUserMessage(func(m chat.Message) (chat.Message, bool) { + return m, false + }) + assert.False(t, ok) + assert.Equal(t, "keep", sess.GetLastUserMessageContent()) +} + +func TestSessionRewriteLatestUserMessage_NoUserMessages(t *testing.T) { + t.Parallel() + + sess := session.New() + ok := sess.RewriteLatestUserMessage(func(m chat.Message) (chat.Message, bool) { + return m, true + }) + assert.False(t, ok) +} diff --git a/pkg/session/session.go b/pkg/session/session.go index e2f426827..49f09948e 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -431,6 +431,55 @@ func (s *Session) ApplyCompaction(inputTokens, outputTokens int64, item Item) { s.mu.Unlock() } +// RewriteLatestUserMessage atomically rewrites the most recent user +// message in s by passing its chat.Message to rewrite and replacing +// it with the returned value. The slot is found and updated under +// s.mu so concurrent readers (snapshotItems, persistence) cannot +// observe a torn state. +// +// rewrite is called at most once. If it returns false the message is +// left unchanged. The boolean return reports whether anything was +// rewritten — false when there is no user message in s or when +// rewrite opted out. +// +// This is the runtime's hook for in-place message hygiene after a +// failure that would otherwise poison the session — see the wire/ +// media overflow recovery in pkg/runtime. +// +// Scope: only items at the top level of s.Messages are considered. +// Sub-session items (Item.SubSession) are skipped; the function does +// not recurse into them. Callers that route user turns through a +// sub-session must rewrite the target message on the sub-session +// directly. The contract is intentional — sub-sessions own their +// own conversation transcript and the parent should not reach into +// them without going through their store. +// +// The rewrite is in-memory only. To mirror it to the session store +// (so it survives a docker-agent restart), the caller must follow up +// with a separate [Store.UpdateMessage] — which today requires a +// persistence ID that the runtime does not yet round-trip through +// [Store.AddMessage]. Closing that gap is a separate piece of work. +func (s *Session) RewriteLatestUserMessage(rewrite func(chat.Message) (chat.Message, bool)) bool { + s.mu.Lock() + defer s.mu.Unlock() + for i := range slices.Backward(s.Messages) { + item := &s.Messages[i] + if !item.IsMessage() { + continue + } + if item.Message.Message.Role != chat.MessageRoleUser { + continue + } + newMsg, ok := rewrite(item.Message.Message) + if !ok { + return false + } + item.Message.Message = newMsg + return true + } + return false +} + // AddSubSession adds a sub-session to the session func (s *Session) AddSubSession(subSession *Session) { s.mu.Lock()