Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions experimental/ssh/internal/fileutil/backup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package fileutil

import (
"context"
"os"
"path/filepath"

"github.com/databricks/cli/libs/log"
)

const (
SuffixOriginalBak = ".original.bak"
SuffixLatestBak = ".latest.bak"
)

// BackupFile saves data to path+".original.bak" on the first call, and
// path+".latest.bak" on subsequent calls. Skips if data is empty.
func BackupFile(ctx context.Context, path string, data []byte) error {
if len(data) == 0 {
return nil
}
originalBak := path + SuffixOriginalBak
latestBak := path + SuffixLatestBak
var bakPath string
if _, err := os.Stat(originalBak); os.IsNotExist(err) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If os.Stat returns an error that is not IsNotExist (e.g. permission denied on the directory), the code silently falls through to latestBak. This could mask real filesystem errors. Consider checking err != nil && !os.IsNotExist(err) and returning the error in that case.

bakPath = originalBak
} else {
bakPath = latestBak
}
if err := os.WriteFile(bakPath, data, 0o600); err != nil {
return err
}
log.Infof(ctx, "Backed up %s to %s", filepath.ToSlash(path), filepath.ToSlash(bakPath))
return nil
}
66 changes: 66 additions & 0 deletions experimental/ssh/internal/fileutil/backup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package fileutil_test

import (
"os"
"path/filepath"
"testing"

"github.com/databricks/cli/experimental/ssh/internal/fileutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBackupFile_EmptyData(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "file.json")

err := fileutil.BackupFile(t.Context(), path, []byte{})
require.NoError(t, err)

_, err = os.Stat(path + fileutil.SuffixOriginalBak)
assert.True(t, os.IsNotExist(err))
}

func TestBackupFile_FirstBackup(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "file.json")
data := []byte(`{"key": "value"}`)

err := fileutil.BackupFile(t.Context(), path, data)
require.NoError(t, err)

content, err := os.ReadFile(path + fileutil.SuffixOriginalBak)
require.NoError(t, err)
assert.Equal(t, data, content)

_, err = os.Stat(path + fileutil.SuffixLatestBak)
assert.True(t, os.IsNotExist(err))
}

func TestBackupFile_SubsequentBackup(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "file.json")
original := []byte(`{"key": "value"}`)
updated := []byte(`{"key": "updated"}`)

err := fileutil.BackupFile(t.Context(), path, original)
require.NoError(t, err)

err = fileutil.BackupFile(t.Context(), path, updated)
require.NoError(t, err)

// .original.bak must remain unchanged
content, err := os.ReadFile(path + fileutil.SuffixOriginalBak)
require.NoError(t, err)
assert.Equal(t, original, content)

// .latest.bak should have the updated content
content, err = os.ReadFile(path + fileutil.SuffixLatestBak)
require.NoError(t, err)
assert.Equal(t, updated, content)
}

func TestBackupFile_WriteError(t *testing.T) {
err := fileutil.BackupFile(t.Context(), "/nonexistent/dir/file.json", []byte("data"))
assert.Error(t, err)
}
30 changes: 28 additions & 2 deletions experimental/ssh/internal/sshconfig/sshconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"strings"

"github.com/databricks/cli/experimental/ssh/internal/fileutil"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
)
Expand Down Expand Up @@ -80,11 +81,25 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error {
// Convert path to forward slashes for SSH config compatibility across platforms
configDirUnix := filepath.ToSlash(configDir)

includeLine := fmt.Sprintf("Include %s/*", configDirUnix)
if strings.Contains(string(content), includeLine) {
// Quoted to handle paths with spaces; OpenSSH still expands globs inside quotes.
includeLine := fmt.Sprintf(`Include "%s/*"`, configDirUnix)
if containsLine(content, includeLine) {
return nil
}

// Migrate unquoted Include written by older versions of the CLI.
oldIncludeLine := fmt.Sprintf("Include %s/*", configDirUnix)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a subtle mismatch: containsLine (line-aware) is used to detect the old directive, but strings.Replace on line 95 (substring-based) is used to perform the replacement. These have different matching semantics — containsLine could find a standalone line match while strings.Replace could substitute a different occurrence if the same text appears as a substring elsewhere in the file. Using the same line-aware approach for both detection and replacement would be safer.

if containsLine(content, oldIncludeLine) {
if err := fileutil.BackupFile(ctx, configPath, content); err != nil {
return fmt.Errorf("failed to backup SSH config before migration: %w", err)
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strings.Replace operates on raw substrings, not lines. Even though containsLine above does a line-aware check, strings.Replace could match the old include path as a substring of something else (e.g. inside a comment like # Include /path/*). Consider doing a line-aware replacement to stay consistent with the check — or at minimum replace the full line (with newline) rather than just the directive text.

migrated := strings.Replace(string(content), oldIncludeLine, includeLine, 1)
return os.WriteFile(configPath, []byte(migrated), 0o600)
}

if err := fileutil.BackupFile(ctx, configPath, content); err != nil {
return fmt.Errorf("failed to backup SSH config: %w", err)
}
newContent := includeLine + "\n"
if len(content) > 0 && !strings.HasPrefix(string(content), "\n") {
newContent += "\n"
Expand All @@ -99,6 +114,17 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error {
return nil
}

// containsLine reports whether data contains line as an exact line match,
// trimming \r to handle Windows line endings.
func containsLine(data []byte, line string) bool {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does an exact match with no left-side whitespace trimming. SSH config commonly has indented directives — an Include with a leading space or tab would be missed by this check, causing a duplicate Include to be prepended. Consider strings.TrimSpace or at least strings.TrimLeft(l, " \t") before comparing.

for l := range strings.SplitSeq(string(data), "\n") {
if strings.TrimRight(l, "\r") == line {
return true
}
}
return false
}

func GetHostConfigPath(ctx context.Context, hostName string) (string, error) {
configDir, err := GetConfigDir(ctx)
if err != nil {
Expand Down
78 changes: 76 additions & 2 deletions experimental/ssh/internal/sshconfig/sshconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) {
configDir, err := GetConfigDir(t.Context())
require.NoError(t, err)

// Use forward slashes as that's what SSH config uses
// Use forward slashes and quotes as that's what SSH config uses
configDirUnix := filepath.ToSlash(configDir)
existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n"
existingContent := `Include "` + configDirUnix + `/*"` + "\n\nHost example\n User test\n"
err = os.MkdirAll(filepath.Dir(configPath), 0o700)
require.NoError(t, err)
err = os.WriteFile(configPath, []byte(existingContent), 0o600)
Expand All @@ -96,6 +96,59 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) {
assert.Equal(t, existingContent, string(content))
}

func TestEnsureIncludeDirective_MigratesOldUnquotedFormat(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)

configPath := filepath.Join(tmpDir, ".ssh", "config")

configDir, err := GetConfigDir(t.Context())
require.NoError(t, err)

configDirUnix := filepath.ToSlash(configDir)
oldContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n"
err = os.MkdirAll(filepath.Dir(configPath), 0o700)
require.NoError(t, err)
err = os.WriteFile(configPath, []byte(oldContent), 0o600)
require.NoError(t, err)

err = EnsureIncludeDirective(t.Context(), configPath)
assert.NoError(t, err)

content, err := os.ReadFile(configPath)
assert.NoError(t, err)
configStr := string(content)

assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`)
assert.NotContains(t, configStr, "Include "+configDirUnix+"/*\n")
assert.Contains(t, configStr, "Host example")
}

func TestEnsureIncludeDirective_NotFooledBySubstring(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)

configPath := filepath.Join(tmpDir, ".ssh", "config")

configDir, err := GetConfigDir(t.Context())
require.NoError(t, err)

configDirUnix := filepath.ToSlash(configDir)
// The include path appears only inside a comment, not as a standalone directive.
existingContent := `# Include "` + configDirUnix + `/*"` + "\nHost example\n User test\n"
require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700))
require.NoError(t, os.WriteFile(configPath, []byte(existingContent), 0o600))

err = EnsureIncludeDirective(t.Context(), configPath)
require.NoError(t, err)

content, err := os.ReadFile(configPath)
require.NoError(t, err)
assert.Contains(t, string(content), `Include "`+configDirUnix+`/*"`)
}

func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, ".ssh", "config")
Expand Down Expand Up @@ -127,6 +180,27 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) {
assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content")
}

func TestContainsLine(t *testing.T) {
tests := []struct {
name string
data string
line string
found bool
}{
{"exact match", `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true},
{"not present", "Host example\n", `Include "/path/*"`, false},
{"substring only", `# Include "/path/*"`, `Include "/path/*"`, false},
{"commented line", `# Include "/path/*"` + "\n" + `Include "/path/*"` + "\n", `Include "/path/*"`, true},
{"windows line ending", `Include "/path/*"` + "\r\nHost example\r\n", `Include "/path/*"`, true},
{"empty data", "", `Include "/path/*"`, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.found, containsLine([]byte(tc.data), tc.line))
})
}
}

func TestGetHostConfigPath(t *testing.T) {
path, err := GetHostConfigPath(t.Context(), "test-host")
assert.NoError(t, err)
Expand Down
28 changes: 5 additions & 23 deletions experimental/ssh/internal/vscode/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"runtime"
"strings"

"github.com/databricks/cli/experimental/ssh/internal/fileutil"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/log"
Expand Down Expand Up @@ -96,8 +97,10 @@ func CheckAndUpdateSettings(ctx context.Context, ide, connectionName string) err
return nil
}

if err := backupSettings(ctx, settingsPath); err != nil {
log.Warnf(ctx, "Failed to backup settings: %v. Continuing with update.", err)
if data, err := os.ReadFile(settingsPath); err == nil {
if err := fileutil.BackupFile(ctx, settingsPath, data); err != nil {
return fmt.Errorf("failed to backup settings: %w", err)
}
}

if err := updateSettings(&settings, connectionName, missing); err != nil {
Expand Down Expand Up @@ -277,27 +280,6 @@ func handleMissingFile(ctx context.Context, ide, connectionName, settingsPath st
return nil
}

func backupSettings(ctx context.Context, path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if len(data) == 0 {
return nil
}

originalBak := path + ".original.bak"
latestBak := path + ".latest.bak"

if _, err := os.Stat(originalBak); os.IsNotExist(err) {
log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(originalBak))
return os.WriteFile(originalBak, data, 0o600)
}

log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(latestBak))
return os.WriteFile(latestBak, data, 0o600)
}

// subKeyOp returns a patch op that sets key/subKey=value, creating the parent object if absent.
func subKeyOp(v *hujson.Value, key, subKey, value string) patchOp {
if v.Find(jsonPtr(key)) == nil {
Expand Down
51 changes: 23 additions & 28 deletions experimental/ssh/internal/vscode/settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package vscode

import (
"encoding/json"
"io"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/databricks/cli/experimental/ssh/internal/fileutil"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -428,45 +430,38 @@ func TestUpdateSettings_PartialUpdate(t *testing.T) {
assert.Len(t, exts, 2)
}

func TestBackupSettings(t *testing.T) {
tmpDir := t.TempDir()
settingsPath := filepath.Join(tmpDir, "settings.json")
originalBak := settingsPath + ".original.bak"
latestBak := settingsPath + ".latest.bak"
func TestCheckAndUpdateSettings_CreatesBackup(t *testing.T) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only verifies that .original.bak is created, but does not cover the .latest.bak path (second backup round). The old TestBackupSettings tested both rounds. Since the backup logic moved to fileutil.BackupFile (which has its own unit tests), this is not critical — but having an integration-level test that exercises the full two-backup cycle through CheckAndUpdateSettings would provide more confidence that the wiring is correct.

if runtime.GOOS == "windows" {
t.Skip("path setup differs on windows")
}

originalContent := []byte(`{"key": "value"}`)
err := os.WriteFile(settingsPath, originalContent, 0o600)
require.NoError(t, err)
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)

ctx, _ := cmdio.NewTestContextWithStderr(t.Context())
ctx, tst := cmdio.SetupTest(t.Context(), cmdio.TestOptions{PromptSupported: true})
defer tst.Done()

// First backup: should create .original.bak
err = backupSettings(ctx, settingsPath)
settingsPath, err := getDefaultSettingsPath(ctx, "cursor")
require.NoError(t, err)
require.NoError(t, os.MkdirAll(filepath.Dir(settingsPath), 0o755))

content, err := os.ReadFile(originalBak)
require.NoError(t, err)
assert.Equal(t, originalContent, content)
_, err = os.Stat(latestBak)
assert.True(t, os.IsNotExist(err))
// Settings file with no Databricks-required keys → triggers an update prompt.
originalContent := []byte(`{}`)
require.NoError(t, os.WriteFile(settingsPath, originalContent, 0o600))

// Second backup: .original.bak exists, should create .latest.bak
updatedContent := []byte(`{"key": "updated"}`)
err = os.WriteFile(settingsPath, updatedContent, 0o600)
require.NoError(t, err)
// Drain stderr (where the prompt is written) and feed "y" to stdin.
go func() { _, _ = io.Copy(io.Discard, tst.Stderr) }()
go func() {
_, _ = tst.Stdin.WriteString("y\n")
_ = tst.Stdin.Flush()
}()

err = backupSettings(ctx, settingsPath)
err = CheckAndUpdateSettings(ctx, "cursor", "my-host")
require.NoError(t, err)

// .original.bak must remain unchanged
content, err = os.ReadFile(originalBak)
content, err := os.ReadFile(settingsPath + fileutil.SuffixOriginalBak)
require.NoError(t, err)
assert.Equal(t, originalContent, content)

// .latest.bak should have the updated content
content, err = os.ReadFile(latestBak)
require.NoError(t, err)
assert.Equal(t, updatedContent, content)
}

func TestSaveSettings_Formatting(t *testing.T) {
Expand Down
Loading