diff --git a/experimental/ssh/internal/fileutil/backup.go b/experimental/ssh/internal/fileutil/backup.go new file mode 100644 index 0000000000..c9e07503ef --- /dev/null +++ b/experimental/ssh/internal/fileutil/backup.go @@ -0,0 +1,39 @@ +package fileutil + +import ( + "context" + "os" + "path/filepath" + + "github.com/databricks/cli/libs/log" +) + +const ( + SuffixOriginalBak = ".original.bak" + SuffixLatestBak = ".latest.bak" +) + +// BackupFile saves data to path+".original.bak" on the first call, and +// path+".latest.bak" on subsequent calls. Skips if data is empty. +func BackupFile(ctx context.Context, path string, data []byte) error { + if len(data) == 0 { + return nil + } + originalBak := path + SuffixOriginalBak + latestBak := path + SuffixLatestBak + var bakPath string + _, statErr := os.Stat(originalBak) + if statErr != nil && !os.IsNotExist(statErr) { + return statErr + } + if os.IsNotExist(statErr) { + bakPath = originalBak + } else { + bakPath = latestBak + } + if err := os.WriteFile(bakPath, data, 0o600); err != nil { + return err + } + log.Infof(ctx, "Backed up %s to %s", filepath.ToSlash(path), filepath.ToSlash(bakPath)) + return nil +} diff --git a/experimental/ssh/internal/fileutil/backup_test.go b/experimental/ssh/internal/fileutil/backup_test.go new file mode 100644 index 0000000000..f57e82367a --- /dev/null +++ b/experimental/ssh/internal/fileutil/backup_test.go @@ -0,0 +1,85 @@ +package fileutil_test + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/databricks/cli/experimental/ssh/internal/fileutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBackupFile_EmptyData(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + + err := fileutil.BackupFile(t.Context(), path, []byte{}) + require.NoError(t, err) + + _, err = os.Stat(path + fileutil.SuffixOriginalBak) + assert.True(t, os.IsNotExist(err)) +} + +func TestBackupFile_FirstBackup(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + data := []byte(`{"key": "value"}`) + + err := fileutil.BackupFile(t.Context(), path, data) + require.NoError(t, err) + + content, err := os.ReadFile(path + fileutil.SuffixOriginalBak) + require.NoError(t, err) + assert.Equal(t, data, content) + + _, err = os.Stat(path + fileutil.SuffixLatestBak) + assert.True(t, os.IsNotExist(err)) +} + +func TestBackupFile_SubsequentBackup(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + original := []byte(`{"key": "value"}`) + updated := []byte(`{"key": "updated"}`) + + err := fileutil.BackupFile(t.Context(), path, original) + require.NoError(t, err) + + err = fileutil.BackupFile(t.Context(), path, updated) + require.NoError(t, err) + + // .original.bak must remain unchanged + content, err := os.ReadFile(path + fileutil.SuffixOriginalBak) + require.NoError(t, err) + assert.Equal(t, original, content) + + // .latest.bak should have the updated content + content, err = os.ReadFile(path + fileutil.SuffixLatestBak) + require.NoError(t, err) + assert.Equal(t, updated, content) +} + +func TestBackupFile_WriteError(t *testing.T) { + err := fileutil.BackupFile(t.Context(), "/nonexistent/dir/file.json", []byte("data")) + assert.Error(t, err) +} + +func TestBackupFile_StatError(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("chmod is not supported on windows") + } + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + + // Create the .original.bak file so os.Stat would find it — but make the + // parent directory unreadable so Stat returns a permission error instead. + require.NoError(t, os.WriteFile(path+fileutil.SuffixOriginalBak, []byte("existing"), 0o600)) + require.NoError(t, os.Chmod(tmpDir, 0o000)) + t.Cleanup(func() { _ = os.Chmod(tmpDir, 0o700) }) + + err := fileutil.BackupFile(t.Context(), path, []byte("data")) + assert.Error(t, err) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index f6886a4be9..df7fbf1226 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" + "github.com/databricks/cli/experimental/ssh/internal/fileutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" ) @@ -80,11 +81,24 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error { // Convert path to forward slashes for SSH config compatibility across platforms configDirUnix := filepath.ToSlash(configDir) - includeLine := fmt.Sprintf("Include %s/*", configDirUnix) - if strings.Contains(string(content), includeLine) { + // Quoted to handle paths with spaces; OpenSSH still expands globs inside quotes. + includeLine := fmt.Sprintf(`Include "%s/*"`, configDirUnix) + if containsLine(content, includeLine) { return nil } + // Migrate unquoted Include written by older versions of the CLI. + oldIncludeLine := fmt.Sprintf("Include %s/*", configDirUnix) + if containsLine(content, oldIncludeLine) { + if err := fileutil.BackupFile(ctx, configPath, content); err != nil { + return fmt.Errorf("failed to backup SSH config before migration: %w", err) + } + return os.WriteFile(configPath, replaceLine(content, oldIncludeLine, includeLine), 0o600) + } + + if err := fileutil.BackupFile(ctx, configPath, content); err != nil { + return fmt.Errorf("failed to backup SSH config: %w", err) + } newContent := includeLine + "\n" if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { newContent += "\n" @@ -99,6 +113,31 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error { return nil } +// containsLine reports whether data contains line as a line match, +// trimming leading whitespace and \r (Windows line endings) before comparing. +func containsLine(data []byte, line string) bool { + for l := range strings.SplitSeq(string(data), "\n") { + if strings.TrimLeft(strings.TrimRight(l, "\r"), " \t") == line { + return true + } + } + return false +} + +// replaceLine replaces the first line in data whose trimmed content matches old +// with new. Uses the same trim logic as containsLine. Returns data unchanged if +// no match. +func replaceLine(data []byte, old, new string) []byte { + lines := strings.Split(string(data), "\n") + for i, l := range lines { + if strings.TrimLeft(strings.TrimRight(l, "\r"), " \t") == old { + lines[i] = new + break + } + } + return []byte(strings.Join(lines, "\n")) +} + func GetHostConfigPath(ctx context.Context, hostName string) (string, error) { configDir, err := GetConfigDir(ctx) if err != nil { diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go index b2abf22cf1..6c453910cd 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig_test.go +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -80,9 +80,9 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { configDir, err := GetConfigDir(t.Context()) require.NoError(t, err) - // Use forward slashes as that's what SSH config uses + // Use forward slashes and quotes as that's what SSH config uses configDirUnix := filepath.ToSlash(configDir) - existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + existingContent := `Include "` + configDirUnix + `/*"` + "\n\nHost example\n User test\n" err = os.MkdirAll(filepath.Dir(configPath), 0o700) require.NoError(t, err) err = os.WriteFile(configPath, []byte(existingContent), 0o600) @@ -96,6 +96,59 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { assert.Equal(t, existingContent, string(content)) } +func TestEnsureIncludeDirective_MigratesOldUnquotedFormat(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + oldContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + err = os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(oldContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(t.Context(), configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + configStr := string(content) + + assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`) + assert.NotContains(t, configStr, "Include "+configDirUnix+"/*\n") + assert.Contains(t, configStr, "Host example") +} + +func TestEnsureIncludeDirective_NotFooledBySubstring(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + // The include path appears only inside a comment, not as a standalone directive. + existingContent := `# Include "` + configDirUnix + `/*"` + "\nHost example\n User test\n" + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(existingContent), 0o600)) + + err = EnsureIncludeDirective(t.Context(), configPath) + require.NoError(t, err) + + content, err := os.ReadFile(configPath) + require.NoError(t, err) + assert.Contains(t, string(content), `Include "`+configDirUnix+`/*"`) +} + func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, ".ssh", "config") @@ -127,6 +180,127 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content") } +func TestContainsLine(t *testing.T) { + tests := []struct { + name string + data string + line string + found bool + }{ + {"exact match", `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true}, + {"not present", "Host example\n", `Include "/path/*"`, false}, + {"substring only", `# Include "/path/*"`, `Include "/path/*"`, false}, + {"commented line", `# Include "/path/*"` + "\n" + `Include "/path/*"` + "\n", `Include "/path/*"`, true}, + {"windows line ending", `Include "/path/*"` + "\r\nHost example\r\n", `Include "/path/*"`, true}, + {"empty data", "", `Include "/path/*"`, false}, + {"indented with spaces", " " + `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true}, + {"indented with tab", "\t" + `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.found, containsLine([]byte(tc.data), tc.line)) + }) + } +} + +func TestReplaceLine(t *testing.T) { + tests := []struct { + name string + data string + old string + new string + expected string + }{ + { + "exact match", + `Include "/p/*"` + "\nHost x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + `Include "/p/*" NEW` + "\nHost x\n", + }, + { + "indented match", + " " + `Include "/p/*"` + "\nHost x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + `Include "/p/*" NEW` + "\nHost x\n", + }, + { + "no match", + "Host x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + "Host x\n", + }, + { + "substring in comment — must not be replaced", + `# Include "/p/*"` + "\nHost x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + `# Include "/p/*"` + "\nHost x\n", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := replaceLine([]byte(tc.data), tc.old, tc.new) + assert.Equal(t, tc.expected, string(got)) + }) + } +} + +func TestEnsureIncludeDirective_MigratesIndentedOldFormat(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + // Old format with leading whitespace — should still be detected and migrated. + oldContent := " Include " + configDirUnix + "/*\n\nHost example\n User test\n" + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(oldContent), 0o600)) + + err = EnsureIncludeDirective(t.Context(), configPath) + require.NoError(t, err) + + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + + assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`) + assert.NotContains(t, configStr, " Include "+configDirUnix+"/*") + assert.Contains(t, configStr, "Host example") +} + +func TestEnsureIncludeDirective_NotFooledByOldFormatSubstring(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + // Old unquoted form appears only inside a comment — must not be migrated. + existingContent := "# Include " + configDirUnix + "/*\nHost example\n User test\n" + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(existingContent), 0o600)) + + err = EnsureIncludeDirective(t.Context(), configPath) + require.NoError(t, err) + + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + + // New quoted directive should have been prepended (not a migration). + assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`) + // Comment line must be preserved unchanged. + assert.Contains(t, configStr, "# Include "+configDirUnix+"/*") +} + func TestGetHostConfigPath(t *testing.T) { path, err := GetHostConfigPath(t.Context(), "test-host") assert.NoError(t, err) diff --git a/experimental/ssh/internal/vscode/settings.go b/experimental/ssh/internal/vscode/settings.go index 19b40858a2..c8b13e3d56 100644 --- a/experimental/ssh/internal/vscode/settings.go +++ b/experimental/ssh/internal/vscode/settings.go @@ -9,6 +9,7 @@ import ( "runtime" "strings" + "github.com/databricks/cli/experimental/ssh/internal/fileutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/log" @@ -96,8 +97,10 @@ func CheckAndUpdateSettings(ctx context.Context, ide, connectionName string) err return nil } - if err := backupSettings(ctx, settingsPath); err != nil { - log.Warnf(ctx, "Failed to backup settings: %v. Continuing with update.", err) + if data, err := os.ReadFile(settingsPath); err == nil { + if err := fileutil.BackupFile(ctx, settingsPath, data); err != nil { + return fmt.Errorf("failed to backup settings: %w", err) + } } if err := updateSettings(&settings, connectionName, missing); err != nil { @@ -277,27 +280,6 @@ func handleMissingFile(ctx context.Context, ide, connectionName, settingsPath st return nil } -func backupSettings(ctx context.Context, path string) error { - data, err := os.ReadFile(path) - if err != nil { - return err - } - if len(data) == 0 { - return nil - } - - originalBak := path + ".original.bak" - latestBak := path + ".latest.bak" - - if _, err := os.Stat(originalBak); os.IsNotExist(err) { - log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(originalBak)) - return os.WriteFile(originalBak, data, 0o600) - } - - log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(latestBak)) - return os.WriteFile(latestBak, data, 0o600) -} - // subKeyOp returns a patch op that sets key/subKey=value, creating the parent object if absent. func subKeyOp(v *hujson.Value, key, subKey, value string) patchOp { if v.Find(jsonPtr(key)) == nil { diff --git a/experimental/ssh/internal/vscode/settings_test.go b/experimental/ssh/internal/vscode/settings_test.go index c76f93c956..a6fcf77988 100644 --- a/experimental/ssh/internal/vscode/settings_test.go +++ b/experimental/ssh/internal/vscode/settings_test.go @@ -2,11 +2,13 @@ package vscode import ( "encoding/json" + "io" "os" "path/filepath" "runtime" "testing" + "github.com/databricks/cli/experimental/ssh/internal/fileutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" "github.com/stretchr/testify/assert" @@ -428,45 +430,59 @@ func TestUpdateSettings_PartialUpdate(t *testing.T) { assert.Len(t, exts, 2) } -func TestBackupSettings(t *testing.T) { +func TestCheckAndUpdateSettings_CreatesBackup(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("path setup differs on windows") + } + tmpDir := t.TempDir() - settingsPath := filepath.Join(tmpDir, "settings.json") - originalBak := settingsPath + ".original.bak" - latestBak := settingsPath + ".latest.bak" + t.Setenv("HOME", tmpDir) + + ctx, tst := cmdio.SetupTest(t.Context(), cmdio.TestOptions{PromptSupported: true}) + defer tst.Done() - originalContent := []byte(`{"key": "value"}`) - err := os.WriteFile(settingsPath, originalContent, 0o600) + settingsPath, err := getDefaultSettingsPath(ctx, "cursor") require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(settingsPath), 0o755)) - ctx, _ := cmdio.NewTestContextWithStderr(t.Context()) + // Settings file with no Databricks-required keys → triggers an update prompt. + originalContent := []byte(`{}`) + require.NoError(t, os.WriteFile(settingsPath, originalContent, 0o600)) - // First backup: should create .original.bak - err = backupSettings(ctx, settingsPath) + // Drain stderr (where the prompt is written) and feed "y" to stdin. + go func() { _, _ = io.Copy(io.Discard, tst.Stderr) }() + go func() { + _, _ = tst.Stdin.WriteString("y\n") + _ = tst.Stdin.Flush() + }() + + err = CheckAndUpdateSettings(ctx, "cursor", "my-host") require.NoError(t, err) - content, err := os.ReadFile(originalBak) + originalBakContent, err := os.ReadFile(settingsPath + fileutil.SuffixOriginalBak) require.NoError(t, err) - assert.Equal(t, originalContent, content) - _, err = os.Stat(latestBak) - assert.True(t, os.IsNotExist(err)) + assert.Equal(t, originalContent, originalBakContent) - // Second backup: .original.bak exists, should create .latest.bak - updatedContent := []byte(`{"key": "updated"}`) - err = os.WriteFile(settingsPath, updatedContent, 0o600) + // Second update for a new connection triggers another backup into .latest.bak. + postFirstContent, err := os.ReadFile(settingsPath) require.NoError(t, err) - err = backupSettings(ctx, settingsPath) + go func() { + _, _ = tst.Stdin.WriteString("y\n") + _ = tst.Stdin.Flush() + }() + + err = CheckAndUpdateSettings(ctx, "cursor", "my-host-2") require.NoError(t, err) - // .original.bak must remain unchanged - content, err = os.ReadFile(originalBak) + latestBakContent, err := os.ReadFile(settingsPath + fileutil.SuffixLatestBak) require.NoError(t, err) - assert.Equal(t, originalContent, content) + assert.Equal(t, postFirstContent, latestBakContent) - // .latest.bak should have the updated content - content, err = os.ReadFile(latestBak) + // .original.bak must still hold the very first snapshot. + originalBakContent2, err := os.ReadFile(settingsPath + fileutil.SuffixOriginalBak) require.NoError(t, err) - assert.Equal(t, updatedContent, content) + assert.Equal(t, originalContent, originalBakContent2) } func TestSaveSettings_Formatting(t *testing.T) {