Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions internal/acp/client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package acp

import (
"github.com/pedronauck/agh/internal/testutil"
"os"
"path/filepath"
"testing"
Expand All @@ -16,7 +17,7 @@ func TestACPIntegrationRoundTrip(t *testing.T) {
proc := startHelperProcess(t, driver, "stream_updates", "", StartOpts{})
defer stopProcess(t, driver, proc)

eventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
eventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-integration-roundtrip",
Message: "run roundtrip",
})
Expand Down Expand Up @@ -48,7 +49,7 @@ func TestACPIntegrationReadTextFileRequest(t *testing.T) {
})
defer stopProcess(t, driver, proc)

eventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
eventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-integration-fs",
Message: "read file",
})
Expand All @@ -73,7 +74,7 @@ func TestACPIntegrationRequestPermissionPolicy(t *testing.T) {
})
defer stopProcess(t, driver, proc)

eventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
eventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-integration-permission",
Message: "request permission",
})
Expand All @@ -98,7 +99,7 @@ func TestACPIntegrationRequestPermissionPolicy(t *testing.T) {
if pendingRequestID == "" {
t.Fatal("permission request_id = empty, want non-empty")
}
if err := driver.ApprovePermission(testContext(t), proc, ApproveRequest{
if err := driver.ApprovePermission(testutil.Context(t), proc, ApproveRequest{
RequestID: pendingRequestID,
Decision: string(decisionAllowAlways),
}); err != nil {
Expand Down Expand Up @@ -142,7 +143,7 @@ func TestACPIntegrationRequestPermissionTimeout(t *testing.T) {
})
defer stopProcess(t, driver, proc)

eventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
eventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-integration-timeout",
Message: "request permission",
})
Expand Down
22 changes: 8 additions & 14 deletions internal/acp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/pedronauck/agh/internal/testutil"
"io"
"os"
"os/exec"
Expand Down Expand Up @@ -241,7 +242,7 @@ func TestPromptPrependsSystemPromptOnce(t *testing.T) {
})
defer stopProcess(t, driver, proc)

firstEventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
firstEventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-1",
Message: "first request",
})
Expand All @@ -262,7 +263,7 @@ func TestPromptPrependsSystemPromptOnce(t *testing.T) {
t.Fatalf("first prompt text = %q, want user request content", firstEvents[0].Text)
}

secondEventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
secondEventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-2",
Message: "second request",
})
Expand All @@ -285,7 +286,7 @@ func TestPromptStreamsSessionUpdates(t *testing.T) {
proc := startHelperProcess(t, driver, "stream_updates", "", StartOpts{})
defer stopProcess(t, driver, proc)

eventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
eventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-stream",
Message: "hello",
})
Expand Down Expand Up @@ -465,7 +466,7 @@ func TestStartResumeReturnsSentinelErrors(t *testing.T) {
t.Parallel()

driver := New()
_, err := driver.Start(testContext(t), StartOpts{
_, err := driver.Start(testutil.Context(t), StartOpts{
AgentName: "helper",
Command: helperCommand(t),
Cwd: t.TempDir(),
Expand Down Expand Up @@ -517,7 +518,7 @@ func TestProcessCrashDetected(t *testing.T) {
driver := New()
proc := startHelperProcess(t, driver, "crash_on_prompt", "", StartOpts{})

eventsCh, err := driver.Prompt(testContext(t), proc, PromptRequest{
eventsCh, err := driver.Prompt(testutil.Context(t), proc, PromptRequest{
TurnID: "turn-crash",
Message: "trigger crash",
})
Expand Down Expand Up @@ -618,7 +619,7 @@ func startHelperProcess(t *testing.T, driver *Driver, scenario string, filePath
}
opts.ResumeSessionID = overrides.ResumeSessionID

proc, err := driver.Start(testContext(t), opts)
proc, err := driver.Start(testutil.Context(t), opts)
if err != nil {
t.Fatalf("Start() error = %v", err)
}
Expand All @@ -630,7 +631,7 @@ func stopProcess(t *testing.T, driver *Driver, proc *AgentProcess) {
if proc == nil {
return
}
if err := driver.Stop(testContext(t), proc); err != nil {
if err := driver.Stop(testutil.Context(t), proc); err != nil {
t.Fatalf("Stop() error = %v", err)
}
}
Expand Down Expand Up @@ -797,13 +798,6 @@ func assertPermissionResult(t *testing.T, err error, wantOK bool) {
}
}

func testContext(t *testing.T) context.Context {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
return ctx
}

type helperACPAgent struct {
conn *acpsdk.AgentSideConnection
scenario string
Expand Down
87 changes: 24 additions & 63 deletions internal/acp/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,43 +240,21 @@ func (p *AgentProcess) handleRequestPermission(ctx context.Context, request acps
}

decision, interactive := p.permissions.permissionDecision(request)
sessionID := string(request.SessionId)
toolCallID := strings.TrimSpace(string(request.ToolCall.ToolCallId))

if !interactive {
requestID := p.nextPermissionRequestID(turnID, request)
outcome, appliedDecision := selectPermissionOutcome(request.Options, decision)
raw := buildPermissionEventRaw(requestID, appliedDecision, request)
p.emitPromptEvent(AgentEvent{
Type: EventTypePermission,
SessionID: string(request.SessionId),
TurnID: turnID,
RequestID: requestID,
Timestamp: timeNowUTC(),
Title: title,
ToolCallID: strings.TrimSpace(string(request.ToolCall.ToolCallId)),
Action: string(permissionRequestToolGrant),
Resource: resource,
Decision: string(appliedDecision),
Raw: cloneRawJSON(raw),
})
p.emitPermissionEvent(sessionID, turnID, requestID, title, toolCallID, resource, appliedDecision, raw)
return acpsdk.RequestPermissionResponse{Outcome: outcome}, nil
}

requestID, pending := p.registerPendingPermission(turnID, request)
defer p.clearPendingPermission(requestID)
raw := buildPermissionEventRaw(requestID, decisionPending, request)

p.emitPromptEvent(AgentEvent{
Type: EventTypePermission,
SessionID: string(request.SessionId),
TurnID: turnID,
RequestID: requestID,
Timestamp: timeNowUTC(),
Title: title,
ToolCallID: strings.TrimSpace(string(request.ToolCall.ToolCallId)),
Action: string(permissionRequestToolGrant),
Resource: resource,
Raw: cloneRawJSON(raw),
})
p.emitPermissionEvent(sessionID, turnID, requestID, title, toolCallID, resource, "", raw)

timer := time.NewTimer(p.permissionTimeoutOrDefault())
defer timer.Stop()
Expand All @@ -285,36 +263,12 @@ func (p *AgentProcess) handleRequestPermission(ctx context.Context, request acps
case resolvedDecision := <-pending.response:
outcome, appliedDecision := selectPermissionOutcome(request.Options, resolvedDecision)
raw = buildPermissionEventRaw(requestID, appliedDecision, request)
p.emitPromptEvent(AgentEvent{
Type: EventTypePermission,
SessionID: string(request.SessionId),
TurnID: turnID,
RequestID: requestID,
Timestamp: timeNowUTC(),
Title: title,
ToolCallID: strings.TrimSpace(string(request.ToolCall.ToolCallId)),
Action: string(permissionRequestToolGrant),
Resource: resource,
Decision: string(appliedDecision),
Raw: cloneRawJSON(raw),
})
p.emitPermissionEvent(sessionID, turnID, requestID, title, toolCallID, resource, appliedDecision, raw)
return acpsdk.RequestPermissionResponse{Outcome: outcome}, nil
case <-timer.C:
outcome, appliedDecision := selectPermissionOutcome(request.Options, decisionRejectOnce)
raw = buildPermissionEventRaw(requestID, appliedDecision, request)
p.emitPromptEvent(AgentEvent{
Type: EventTypePermission,
SessionID: string(request.SessionId),
TurnID: turnID,
RequestID: requestID,
Timestamp: timeNowUTC(),
Title: title,
ToolCallID: strings.TrimSpace(string(request.ToolCall.ToolCallId)),
Action: string(permissionRequestToolGrant),
Resource: resource,
Decision: string(appliedDecision),
Raw: cloneRawJSON(raw),
})
p.emitPermissionEvent(sessionID, turnID, requestID, title, toolCallID, resource, appliedDecision, raw)
return acpsdk.RequestPermissionResponse{Outcome: outcome}, nil
case <-ctx.Done():
return acpsdk.RequestPermissionResponse{
Expand Down Expand Up @@ -347,7 +301,7 @@ func (p *AgentProcess) handleSessionUpdate(params json.RawMessage) error {
TurnID: merged.TurnID,
Timestamp: usage.Timestamp,
Usage: &merged,
Raw: cloneRawJSON(raw.Update),
Raw: CloneRawMessage(raw.Update),
})
}
return nil
Expand All @@ -363,6 +317,22 @@ func (p *AgentProcess) handleSessionUpdate(params json.RawMessage) error {
return nil
}

func (p *AgentProcess) emitPermissionEvent(sessionID string, turnID string, requestID string, title string, toolCallID string, resource string, decision permissionDecision, raw json.RawMessage) {
p.emitPromptEvent(AgentEvent{
Type: EventTypePermission,
SessionID: sessionID,
TurnID: turnID,
RequestID: requestID,
Timestamp: timeNowUTC(),
Title: title,
ToolCallID: toolCallID,
Action: string(permissionRequestToolGrant),
Resource: resource,
Decision: string(decision),
Raw: CloneRawMessage(raw),
})
}

func (p *AgentProcess) handleCreateTerminal(request acpsdk.CreateTerminalRequest) (acpsdk.CreateTerminalResponse, error) {
if err := p.permissions.authorize(permissionCreateTerminal); err != nil {
return acpsdk.CreateTerminalResponse{}, err
Expand Down Expand Up @@ -577,7 +547,7 @@ func translateSessionUpdate(notification acpsdk.SessionNotification, rawUpdate j
SessionID: string(notification.SessionId),
TurnID: turnID,
Timestamp: timeNowUTC(),
Raw: cloneRawJSON(rawUpdate),
Raw: CloneRawMessage(rawUpdate),
}

switch {
Expand Down Expand Up @@ -750,15 +720,6 @@ func mustMarshalJSON(value any) json.RawMessage {
return encoded
}

func cloneRawJSON(value json.RawMessage) json.RawMessage {
if len(value) == 0 {
return nil
}
cloned := make([]byte, len(value))
copy(cloned, value)
return cloned
}

func timeNowUTC() time.Time {
return time.Now().UTC()
}
Expand Down
60 changes: 58 additions & 2 deletions internal/acp/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,62 @@ func TestHandleInboundPermissionRequestTimeout(t *testing.T) {
}
}

func TestEmitPermissionEvent(t *testing.T) {
t.Parallel()

tests := []struct {
name string
decision permissionDecision
}{
{name: "ShouldHandleInteractivePending", decision: ""},
{name: "ShouldAllowOnceAutomatically", decision: decisionAllowOnce},
{name: "ShouldRejectOnceOnTimeout", decision: decisionRejectOnce},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

proc := newDirectProcess(t, aghconfig.PermissionModeDenyAll)
active, err := proc.beginPrompt("turn-permission-event", 4)
if err != nil {
t.Fatalf("beginPrompt() error = %v", err)
}
defer proc.endPrompt(active)

raw := mustMarshalJSON(map[string]any{"decision": string(tt.decision), "value": "original"})
wantRaw := append(json.RawMessage(nil), raw...)

proc.emitPermissionEvent("sess-emit", "turn-permission-event", "req-1", "permission request", "tool-1", "/tmp/demo.txt", tt.decision, raw)
event := collectEventsUntilCount(t, active.events, 1)[0]

raw[0] = '['

if event.Type != EventTypePermission {
t.Fatalf("event.Type = %q, want %q", event.Type, EventTypePermission)
}
if event.SessionID != "sess-emit" || event.TurnID != "turn-permission-event" || event.RequestID != "req-1" {
t.Fatalf("event ids = %#v, want session/turn/request populated", event)
}
if event.Title != "permission request" || event.ToolCallID != "tool-1" {
t.Fatalf("event title/tool = %#v, want copied fields", event)
}
if event.Action != string(permissionRequestToolGrant) || event.Resource != "/tmp/demo.txt" {
t.Fatalf("event action/resource = %#v, want permission action/resource", event)
}
if event.Decision != string(tt.decision) {
t.Fatalf("event.Decision = %q, want %q", event.Decision, tt.decision)
}
if event.Timestamp.IsZero() {
t.Fatal("event.Timestamp = zero, want populated")
}
if string(event.Raw) != string(wantRaw) {
t.Fatalf("event.Raw = %s, want %s", string(event.Raw), string(wantRaw))
}
})
}
}

func TestResolvePermissionByTurnIDConflictsWhenMultipleRequestsPending(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -542,8 +598,8 @@ func TestHelperUtilities(t *testing.T) {
}

raw := mustMarshalJSON(map[string]string{"hello": "world"})
if string(cloneRawJSON(raw)) != string(raw) {
t.Fatalf("cloneRawJSON() = %q, want %q", string(cloneRawJSON(raw)), string(raw))
if string(CloneRawMessage(raw)) != string(raw) {
t.Fatalf("CloneRawMessage() = %q, want %q", string(CloneRawMessage(raw)), string(raw))
}

if requestError(ErrPermissionDenied) == nil {
Expand Down
13 changes: 13 additions & 0 deletions internal/acp/rawjson.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package acp

import "encoding/json"

// CloneRawMessage returns an independent copy of one raw JSON payload.
func CloneRawMessage(value json.RawMessage) json.RawMessage {
if len(value) == 0 {
return nil
}
cloned := make([]byte, len(value))
copy(cloned, value)
return cloned
}
Loading