Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions pkg/fileutil/fileutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,34 @@ func IsDirEmpty(path string) bool {
return len(files) == 0
}

type syncWriteCloser interface {
io.Writer
Sync() error
Close() error
}

func copyFileContents(in io.Reader, out syncWriteCloser, dst string) (err error) {
removePartial := false

defer func() {
if closeErr := out.Close(); closeErr != nil && err == nil {
err = closeErr
}
if removePartial {
if removeErr := os.Remove(dst); removeErr != nil {
log.Printf("Failed to remove partial destination file during cleanup: %s", removeErr)
}
}
}()

if _, err = io.Copy(out, in); err != nil {
removePartial = true
return err
}
Comment on lines +128 to +142

return out.Sync()
}

// CopyFile copies a file from src to dst using buffered IO.
func CopyFile(src, dst string) error {
log.Printf("Copying file: src=%s, dst=%s", src, dst)
Expand All @@ -131,17 +159,11 @@ func CopyFile(src, dst string) error {
log.Printf("Failed to create destination file: %s", err)
return err
}
defer func() { _ = out.Close() }()

if _, err = io.Copy(out, in); err != nil {
if closeErr := out.Close(); closeErr != nil {
log.Printf("Failed to close destination file during cleanup: %s", closeErr)
}
if removeErr := os.Remove(dst); removeErr != nil {
log.Printf("Failed to remove partial destination file during cleanup: %s", removeErr)
}
err = copyFileContents(in, out, dst)
if err != nil {
return err
}

log.Printf("File copied successfully: src=%s, dst=%s", src, dst)
return out.Sync()
return nil
}
60 changes: 60 additions & 0 deletions pkg/fileutil/fileutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package fileutil
import (
"archive/tar"
"bytes"
"errors"
"os"
"path/filepath"
"runtime"
Expand All @@ -15,6 +16,34 @@ import (
"github.com/stretchr/testify/require"
)

type stubSyncWriteCloser struct {
buf bytes.Buffer
writeErr error
syncErr error
closeErr error
closeCalls int
}

func (s *stubSyncWriteCloser) Write(p []byte) (int, error) {
if s.writeErr != nil {
return 0, s.writeErr
}
return s.buf.Write(p)
}

func (s *stubSyncWriteCloser) Sync() error {
return s.syncErr
}

func (s *stubSyncWriteCloser) Close() error {
s.closeCalls++
return s.closeErr
}

func (s *stubSyncWriteCloser) String() string {
return s.buf.String()
}

func TestValidateAbsolutePath(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -344,6 +373,37 @@ func TestCopyFile(t *testing.T) {
})
}

func TestCopyFileContents(t *testing.T) {
t.Run("returns close error after successful sync", func(t *testing.T) {
closeErr := errors.New("close failed")
out := &stubSyncWriteCloser{closeErr: closeErr}

err := copyFileContents(strings.NewReader("hello"), out, filepath.Join(t.TempDir(), "dst.txt"))

require.ErrorIs(t, err, closeErr)
assert.Equal(t, 1, out.closeCalls, "destination should be closed once")
assert.Equal(t, "hello", out.String(), "content should be copied before close")
})

t.Run("preserves copy error and closes destination once", func(t *testing.T) {
writeErr := errors.New("write failed")
closeErr := errors.New("close failed")
out := &stubSyncWriteCloser{
writeErr: writeErr,
closeErr: closeErr,
}

dst := filepath.Join(t.TempDir(), "dst.txt")
require.NoError(t, os.WriteFile(dst, []byte("partial"), 0600), "Should create destination placeholder")

err := copyFileContents(strings.NewReader("hello"), out, dst)

require.ErrorIs(t, err, writeErr)
assert.Equal(t, 1, out.closeCalls, "destination should be closed once during cleanup")
assert.NoFileExists(t, dst, "partial destination should be removed after copy failure")
})
}

func TestValidatePathWithinBase(t *testing.T) {
base := t.TempDir()

Expand Down