Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions pkg/state/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,35 @@ func NewLocalStore(appName string, storeName string) (*LocalStore, error) {
}, nil
}

// getFilePath returns the full file path for a configuration
func (s *LocalStore) getFilePath(name string) string {
// getFilePath returns the full file path for a configuration.
// It validates that the resolved path remains within basePath to prevent
// path traversal attacks via crafted names containing ".." or separators.
func (s *LocalStore) getFilePath(name string) (string, error) {
// Ensure the name has the correct extension
if !strings.HasSuffix(name, FileExtension) {
name = name + FileExtension
}
return filepath.Join(s.basePath, name)

resolved := filepath.Clean(filepath.Join(s.basePath, name))

// Verify the resolved path is contained within basePath.
// The trailing separator prevents prefix collisions (e.g. basePath
// "/state/toolhive" matching "/state/toolhive-evil/foo").
if !strings.HasPrefix(resolved, s.basePath+string(os.PathSeparator)) {
return "", fmt.Errorf("invalid state name %q: path traversal detected", name)
}

return resolved, nil
}

// GetReader returns a reader for the state data
func (s *LocalStore) GetReader(_ context.Context, name string) (io.ReadCloser, error) {
// Open the file
filePath := s.getFilePath(name)
// #nosec G304 - filePath is controlled by getFilePath which ensures it's within our designated directory
filePath, err := s.getFilePath(name)
if err != nil {
return nil, err
}
// #nosec G304 - filePath is validated by getFilePath to stay within our designated directory
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
Expand All @@ -80,8 +95,11 @@ func (s *LocalStore) GetReader(_ context.Context, name string) (io.ReadCloser, e
// GetWriter returns a writer for the state data
func (s *LocalStore) GetWriter(_ context.Context, name string) (io.WriteCloser, error) {
// Create the file
filePath := s.getFilePath(name)
// #nosec G304 - filePath is controlled by getFilePath which ensures it's within our designated directory
filePath, err := s.getFilePath(name)
if err != nil {
return nil, err
}
// #nosec G304 - filePath is validated by getFilePath to stay within our designated directory
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return nil, fmt.Errorf("failed to create file: %w", err)
Expand All @@ -93,9 +111,12 @@ func (s *LocalStore) GetWriter(_ context.Context, name string) (io.WriteCloser,
// CreateExclusive creates a new state entry exclusively, failing if it already exists.
// This provides atomic check-and-create semantics using O_EXCL to prevent race conditions.
func (s *LocalStore) CreateExclusive(_ context.Context, name string) (io.WriteCloser, error) {
filePath := s.getFilePath(name)
filePath, err := s.getFilePath(name)
if err != nil {
return nil, err
}
// O_EXCL with O_CREATE provides atomic check-and-create behavior
// #nosec G304 - filePath is controlled by getFilePath which ensures it's within our designated directory
// #nosec G304 - filePath is validated by getFilePath to stay within our designated directory
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
if err != nil {
if os.IsExist(err) {
Expand All @@ -112,8 +133,11 @@ func (s *LocalStore) CreateExclusive(_ context.Context, name string) (io.WriteCl

// Delete removes the data for the given name
func (s *LocalStore) Delete(_ context.Context, name string) error {
filePath := s.getFilePath(name)
// #nosec G304 - filePath is controlled by getFilePath which ensures it's within our designated directory
filePath, err := s.getFilePath(name)
if err != nil {
return err
}
// #nosec G304 - filePath is validated by getFilePath to stay within our designated directory
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("state '%s' not found", name)
Expand Down Expand Up @@ -154,8 +178,11 @@ func (s *LocalStore) List(_ context.Context) ([]string, error) {

// Exists checks if data exists for the given name
func (s *LocalStore) Exists(_ context.Context, name string) (bool, error) {
filePath := s.getFilePath(name)
_, err := os.Stat(filePath)
filePath, err := s.getFilePath(name)
if err != nil {
return false, err
}
_, err = os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
Expand Down
205 changes: 205 additions & 0 deletions pkg/state/local_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package state

import (
"context"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// newTestStore creates a LocalStore backed by a resolved temp directory.
func newTestStore(t *testing.T) *LocalStore {
t.Helper()
dir := t.TempDir()
resolved, err := filepath.EvalSymlinks(dir)
require.NoError(t, err)
return &LocalStore{basePath: resolved}
}

func TestGetFilePath(t *testing.T) {
t.Parallel()

store := newTestStore(t)

tests := []struct {
name string
input string
expectError bool
}{
// Valid names
{name: "simple name", input: "my-workload", expectError: false},
{name: "with dots", input: "workload.v2", expectError: false},
{name: "with underscores", input: "my_workload", expectError: false},
{name: "alphanumeric", input: "abc123", expectError: false},
{name: "already has extension", input: "config.json", expectError: false},

// Path traversal attacks
{name: "parent directory", input: "..", expectError: true},
{name: "relative escape", input: "../secret", expectError: true},
{name: "nested escape", input: "../../etc/passwd", expectError: true},
{name: "mid-path traversal", input: "foo/../../../etc/shadow", expectError: true},
{name: "absolute unix", input: "/etc/passwd", expectError: true},

// Path separators
{name: "forward slash subdirectory", input: "sub/file", expectError: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

result, err := store.getFilePath(tt.input)

if tt.expectError {
assert.Error(t, err)
assert.Empty(t, result)
assert.Contains(t, err.Error(), "path traversal detected")
} else {
require.NoError(t, err)
assert.True(t, filepath.IsAbs(result), "result should be an absolute path")
// Verify the result is inside basePath
dir := filepath.Dir(result)
assert.Equal(t, store.basePath, dir, "file should be inside basePath")
}
})
}
}

func TestGetFilePathSecurityCases(t *testing.T) {
t.Parallel()

store := newTestStore(t)

// Real-world attack patterns that must always be rejected.
attacks := []string{
"../../../etc/passwd",
"./../../../etc/shadow",
"../../../../../../etc/passwd",
"..\\..\\Windows\\System32",
"foo/../../bar",
}

for _, pattern := range attacks {
t.Run("reject_"+pattern, func(t *testing.T) {
t.Parallel()

result, err := store.getFilePath(pattern)
assert.Error(t, err, "should reject attack pattern: %q", pattern)
assert.Empty(t, result)
assert.Contains(t, err.Error(), "path traversal detected")
})
}
}

func TestLocalStoreOperationsRejectTraversal(t *testing.T) {
t.Parallel()

store := newTestStore(t)
ctx := context.Background()
malicious := "../../../etc/passwd"

t.Run("GetReader", func(t *testing.T) {
t.Parallel()
reader, err := store.GetReader(ctx, malicious)
assert.Error(t, err)
assert.Nil(t, reader)
assert.Contains(t, err.Error(), "path traversal detected")
})

t.Run("GetWriter", func(t *testing.T) {
t.Parallel()
writer, err := store.GetWriter(ctx, malicious)
assert.Error(t, err)
assert.Nil(t, writer)
assert.Contains(t, err.Error(), "path traversal detected")
})

t.Run("CreateExclusive", func(t *testing.T) {
t.Parallel()
writer, err := store.CreateExclusive(ctx, malicious)
assert.Error(t, err)
assert.Nil(t, writer)
assert.Contains(t, err.Error(), "path traversal detected")
})

t.Run("Delete", func(t *testing.T) {
t.Parallel()
err := store.Delete(ctx, malicious)
assert.Error(t, err)
assert.Contains(t, err.Error(), "path traversal detected")
})

t.Run("Exists", func(t *testing.T) {
t.Parallel()
exists, err := store.Exists(ctx, malicious)
assert.Error(t, err)
assert.False(t, exists)
assert.Contains(t, err.Error(), "path traversal detected")
})
}

func TestLocalStoreRoundTrip(t *testing.T) {
t.Parallel()

store := newTestStore(t)
ctx := context.Background()

// Write data
writer, err := store.GetWriter(ctx, "test-roundtrip")
require.NoError(t, err)
_, err = writer.Write([]byte(`{"key":"value"}`))
require.NoError(t, err)
require.NoError(t, writer.Close())

// Verify it exists
exists, err := store.Exists(ctx, "test-roundtrip")
require.NoError(t, err)
assert.True(t, exists)

// Read it back
reader, err := store.GetReader(ctx, "test-roundtrip")
require.NoError(t, err)
buf := make([]byte, 256)
n, err := reader.Read(buf)
require.NoError(t, err)
assert.Equal(t, `{"key":"value"}`, string(buf[:n]))
require.NoError(t, reader.Close())

// Verify it appears in list
names, err := store.List(ctx)
require.NoError(t, err)
assert.Contains(t, names, "test-roundtrip")

// Delete and verify
require.NoError(t, store.Delete(ctx, "test-roundtrip"))
exists, err = store.Exists(ctx, "test-roundtrip")
require.NoError(t, err)
assert.False(t, exists)
}

func TestLocalStoreCreateExclusiveConflict(t *testing.T) {
t.Parallel()

store := newTestStore(t)
ctx := context.Background()

// First create succeeds
writer, err := store.CreateExclusive(ctx, "exclusive-test")
require.NoError(t, err)
require.NoError(t, writer.Close())

// Second create fails with conflict
writer, err = store.CreateExclusive(ctx, "exclusive-test")
assert.Error(t, err)
assert.Nil(t, writer)
assert.Contains(t, err.Error(), "already exists")

// Cleanup
require.NoError(t, os.Remove(filepath.Join(store.basePath, "exclusive-test.json")))
}