Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/evaluation/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
_ "embed"
"errors"
"fmt"
"io/fs"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -91,7 +92,7 @@ func (r *Runner) buildEvalImage(ctx context.Context, evals *session.EvalCriteria
data.CopyWorkingDir = false
} else {
buildContext = filepath.Join(r.EvalsDir, "working_dirs", evals.WorkingDir)
if _, err := os.Stat(buildContext); os.IsNotExist(err) {
if _, err := os.Stat(buildContext); errors.Is(err, fs.ErrNotExist) {
return "", fmt.Errorf("working directory not found: %s", buildContext)
}
data.CopyWorkingDir = true
Expand Down
3 changes: 2 additions & 1 deletion pkg/rag/strategy/bm25.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"io/fs"
"log/slog"
"math"
"os"
Expand Down Expand Up @@ -642,7 +643,7 @@ func (s *BM25Strategy) addPathToWatcher(ctx context.Context, path string) error

stat, err := os.Stat(absPath)
if err != nil {
if os.IsNotExist(err) {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return fmt.Errorf("failed to stat path: %w", err)
Expand Down
5 changes: 3 additions & 2 deletions pkg/rag/strategy/bm25_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
Expand Down Expand Up @@ -159,7 +160,7 @@ func (d *bm25DB) GetFileMetadata(ctx context.Context, sourcePath string) (*datab
fmt.Sprintf("SELECT source_path, file_hash, last_indexed, chunk_count FROM %s WHERE source_path = ?", d.metadataTable),
sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount)

if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
Expand Down Expand Up @@ -220,7 +221,7 @@ func ensureDir(filePath string) error {
return nil
}

if _, err := os.Stat(dir); os.IsNotExist(err) {
if _, err := os.Stat(dir); errors.Is(err, fs.ErrNotExist) {
return os.MkdirAll(dir, 0o700)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/rag/strategy/chunked_embeddings_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (d *chunkedVectorDB) GetFileMetadata(ctx context.Context, sourcePath string
GROUP BY f.source_path, f.file_hash, f.indexed_at`, d.filesTable, d.chunksTable),
sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount)

if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/rag/strategy/semantic_embeddings_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (d *semanticVectorDB) GetFileMetadata(ctx context.Context, sourcePath strin
GROUP BY f.source_path, f.file_hash, f.indexed_at`, d.filesTable, d.chunksTable),
sourcePath).Scan(&metadata.SourcePath, &metadata.FileHash, &metadata.LastIndexed, &metadata.ChunkCount)

if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
Expand Down
114 changes: 88 additions & 26 deletions pkg/skills/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
Expand All @@ -15,25 +17,32 @@ import (
"strings"
"time"

"github.com/docker/docker-agent/pkg/remote"
"github.com/docker/docker-agent/pkg/httpclient"
)

// remoteHTTPTimeout caps each HTTP request made to a remote skills source.
const remoteHTTPTimeout = 30 * time.Second

// httpGet performs a GET request using the standard remote transport so that
// Docker Desktop proxy/SSL settings are honoured. The returned response body
// must be closed by the caller.
// skillsHTTPClient is used for outbound calls to remote skill registries.
// The base URL is operator-supplied and the contents are fed to the model
// as instructions, so a hostile (or compromised) registry could otherwise
// be used to read internal endpoints (loopback, RFC1918, link-local incl.
// cloud metadata at 169.254.169.254) and exfiltrate them through prompt
// injection. The SSRF-safe client refuses such targets at dial time, after
// DNS resolution, defeating DNS rebinding.
//
// Tests in this package replace the var via TestMain (see main_test.go)
// because httptest.NewServer binds to 127.0.0.1.
var skillsHTTPClient = httpclient.NewSafeClient(remoteHTTPTimeout, false)

// httpGet performs a GET request using the SSRF-safe HTTP client. The
// returned response body must be closed by the caller.
func httpGet(ctx context.Context, url string) (*http.Response, error) {
client := &http.Client{
Timeout: remoteHTTPTimeout,
Transport: remote.NewTransport(ctx),
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, fmt.Errorf("creating request for %s: %w", url, err)
}
return client.Do(req)
return skillsHTTPClient.Do(req)
}

type diskCache struct {
Expand Down Expand Up @@ -61,13 +70,20 @@ func (c *diskCache) cacheDir(baseURL, skillName string) string {
}

// Get returns the cached content for a file if it exists and is not expired.
// Treats missing-file errors as a cache miss (returns false). Other I/O
// errors (e.g. EACCES, corrupt JSON) are surfaced through a debug log so
// they don't masquerade as a benign refetch trigger but still don't break
// the caller — a refetch is the right fallback.
func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) {
dir := c.cacheDir(baseURL, skillName)
contentPath := filepath.Join(dir, filePath)
metaPath := contentPath + ".meta"

meta, err := c.readMetadata(metaPath)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
slog.Debug("Skill cache metadata unreadable, treating as miss", "path", metaPath, "error", err)
}
return "", false
}

Expand All @@ -77,14 +93,29 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) {

data, err := os.ReadFile(contentPath)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
slog.Debug("Skill cache content unreadable, treating as miss", "path", contentPath, "error", err)
}
return "", false
}

return string(data), true
}

// FetchAndStore downloads a file from the given URL and stores it in the cache.
// It respects Cache-Control headers to determine expiry.
// It respects Cache-Control headers to determine expiry: no-cache forces
// immediate expiry, max-age=N sets a TTL of N seconds, and unknown headers
// fall back to defaultCacheTTL.
//
// no-store is treated as "do not retain across Load() cycles": the entry is
// written to disk but with an immediate-expiry marker so the next prefetch
// refetches it. We do not skip the disk write entirely because callers
// (notably the read_skill tool, see pkg/tools/builtin/skills) consume the
// content by re-reading skill.FilePath, not by going through diskCache.Get.
// Skipping the write would render the skill unreadable for the rest of the
// current process. A future improvement is to keep no-store content in an
// in-memory map shared with the reader; until then we trade strict RFC
// 9111 §5.2.2.5 compliance for a working tool.
func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, filePath, fileURL string) (string, error) {
slog.DebugContext(ctx, "Fetching remote skill file", "url", fileURL)

Expand All @@ -103,7 +134,7 @@ func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, fileP
return "", fmt.Errorf("reading %s: %w", fileURL, err)
}

expiresAt := parseCacheExpiry(resp.Header.Get("Cache-Control"))
directive := parseCacheControl(resp.Header.Get("Cache-Control"))

dir := c.cacheDir(baseURL, skillName)
contentPath := filepath.Join(dir, filePath)
Expand All @@ -120,7 +151,7 @@ func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, fileP
meta := cacheMetadata{
URL: fileURL,
CachedAt: time.Now(),
ExpiresAt: expiresAt,
ExpiresAt: directive.expiresAt(),
}
metaJSON, _ := json.Marshal(meta)
if err := os.WriteFile(metaPath, metaJSON, 0o600); err != nil {
Expand All @@ -145,28 +176,59 @@ func (c *diskCache) readMetadata(metaPath string) (cacheMetadata, error) {

const defaultCacheTTL = 1 * time.Hour

// parseCacheExpiry extracts the expiry time from a Cache-Control header value.
// Falls back to defaultCacheTTL if the header is missing or unparseable.
func parseCacheExpiry(cacheControl string) time.Time {
if cacheControl == "" {
return time.Now().Add(defaultCacheTTL)
// cacheDirective is the parsed subset of Cache-Control we care about.
type cacheDirective struct {
noStore bool
noCache bool
// hasMaxAge is true when the header explicitly carried max-age=N.
hasMaxAge bool
maxAge time.Duration
}

// expiresAt returns the absolute time after which the cached entry must
// not be reused. no-store and no-cache both force immediate expiry: the
// response may be on disk for the duration of the current process (so the
// in-process reader can consume it), but the next Load() cycle will
// refetch instead of reusing the stored copy. We currently approximate
// no-cache without conditional-GET support; the practical effect is the
// same as no-store with respect to whether the next read sees fresh
// content.
func (d cacheDirective) expiresAt() time.Time {
now := time.Now()
if d.noStore || d.noCache {
return now
}
if d.hasMaxAge {
return now.Add(d.maxAge)
}
return now.Add(defaultCacheTTL)
}

// parseCacheControl extracts the directives we honour from a Cache-Control
// header value. Unknown directives are ignored; an empty header yields the
// zero value, which falls back to defaultCacheTTL via expiresAt().
func parseCacheControl(header string) cacheDirective {
var d cacheDirective
if header == "" {
return d
}

for directive := range strings.SplitSeq(cacheControl, ",") {
for directive := range strings.SplitSeq(header, ",") {
directive = strings.TrimSpace(directive)

if strings.EqualFold(directive, "no-store") || strings.EqualFold(directive, "no-cache") {
// Still cache, but with zero TTL so it's refetched next time
return time.Now()
}

if strings.HasPrefix(strings.ToLower(directive), "max-age=") {
switch {
case strings.EqualFold(directive, "no-store"):
d.noStore = true
case strings.EqualFold(directive, "no-cache"):
d.noCache = true
case strings.HasPrefix(strings.ToLower(directive), "max-age="):
ageStr := directive[len("max-age="):]
if seconds, err := strconv.ParseInt(ageStr, 10, 64); err == nil && seconds >= 0 {
return time.Now().Add(time.Duration(seconds) * time.Second)
d.hasMaxAge = true
d.maxAge = time.Duration(seconds) * time.Second
}
}
}

return time.Now().Add(defaultCacheTTL)
return d
}
Loading
Loading