From 27477f2d87f862554374b65c26787070fe936ca8 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 27 Mar 2026 16:48:47 +0100 Subject: [PATCH 1/2] experimental/ssh: refactor backup logic and fix Include directive detection - Extract BackupFile into fileutil package with exported suffix constants (SuffixOriginalBak, SuffixLatestBak) so callers don't hardcode strings - Replace strings.Contains with line-aware containsLine helper in sshconfig to avoid false positives from commented-out or mid-line occurrences of the Include path - Migrate unquoted Include directives written by older CLI versions to the quoted form (handles paths with spaces) Co-authored-by: Isaac --- experimental/ssh/internal/fileutil/backup.go | 35 +++++++++ .../ssh/internal/fileutil/backup_test.go | 66 ++++++++++++++++ .../ssh/internal/sshconfig/sshconfig.go | 30 ++++++- .../ssh/internal/sshconfig/sshconfig_test.go | 78 ++++++++++++++++++- experimental/ssh/internal/vscode/settings.go | 28 ++----- .../ssh/internal/vscode/settings_test.go | 51 ++++++------ 6 files changed, 233 insertions(+), 55 deletions(-) create mode 100644 experimental/ssh/internal/fileutil/backup.go create mode 100644 experimental/ssh/internal/fileutil/backup_test.go diff --git a/experimental/ssh/internal/fileutil/backup.go b/experimental/ssh/internal/fileutil/backup.go new file mode 100644 index 0000000000..4b931f4566 --- /dev/null +++ b/experimental/ssh/internal/fileutil/backup.go @@ -0,0 +1,35 @@ +package fileutil + +import ( + "context" + "os" + "path/filepath" + + "github.com/databricks/cli/libs/log" +) + +const ( + SuffixOriginalBak = ".original.bak" + SuffixLatestBak = ".latest.bak" +) + +// BackupFile saves data to path+".original.bak" on the first call, and +// path+".latest.bak" on subsequent calls. Skips if data is empty. +func BackupFile(ctx context.Context, path string, data []byte) error { + if len(data) == 0 { + return nil + } + originalBak := path + SuffixOriginalBak + latestBak := path + SuffixLatestBak + var bakPath string + if _, err := os.Stat(originalBak); os.IsNotExist(err) { + bakPath = originalBak + } else { + bakPath = latestBak + } + if err := os.WriteFile(bakPath, data, 0o600); err != nil { + return err + } + log.Infof(ctx, "Backed up %s to %s", filepath.ToSlash(path), filepath.ToSlash(bakPath)) + return nil +} diff --git a/experimental/ssh/internal/fileutil/backup_test.go b/experimental/ssh/internal/fileutil/backup_test.go new file mode 100644 index 0000000000..54c6adcf17 --- /dev/null +++ b/experimental/ssh/internal/fileutil/backup_test.go @@ -0,0 +1,66 @@ +package fileutil_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/experimental/ssh/internal/fileutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBackupFile_EmptyData(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + + err := fileutil.BackupFile(t.Context(), path, []byte{}) + require.NoError(t, err) + + _, err = os.Stat(path + fileutil.SuffixOriginalBak) + assert.True(t, os.IsNotExist(err)) +} + +func TestBackupFile_FirstBackup(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + data := []byte(`{"key": "value"}`) + + err := fileutil.BackupFile(t.Context(), path, data) + require.NoError(t, err) + + content, err := os.ReadFile(path + fileutil.SuffixOriginalBak) + require.NoError(t, err) + assert.Equal(t, data, content) + + _, err = os.Stat(path + fileutil.SuffixLatestBak) + assert.True(t, os.IsNotExist(err)) +} + +func TestBackupFile_SubsequentBackup(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + original := []byte(`{"key": "value"}`) + updated := []byte(`{"key": "updated"}`) + + err := fileutil.BackupFile(t.Context(), path, original) + require.NoError(t, err) + + err = fileutil.BackupFile(t.Context(), path, updated) + require.NoError(t, err) + + // .original.bak must remain unchanged + content, err := os.ReadFile(path + fileutil.SuffixOriginalBak) + require.NoError(t, err) + assert.Equal(t, original, content) + + // .latest.bak should have the updated content + content, err = os.ReadFile(path + fileutil.SuffixLatestBak) + require.NoError(t, err) + assert.Equal(t, updated, content) +} + +func TestBackupFile_WriteError(t *testing.T) { + err := fileutil.BackupFile(t.Context(), "/nonexistent/dir/file.json", []byte("data")) + assert.Error(t, err) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index f6886a4be9..6454a62745 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" + "github.com/databricks/cli/experimental/ssh/internal/fileutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" ) @@ -80,11 +81,25 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error { // Convert path to forward slashes for SSH config compatibility across platforms configDirUnix := filepath.ToSlash(configDir) - includeLine := fmt.Sprintf("Include %s/*", configDirUnix) - if strings.Contains(string(content), includeLine) { + // Quoted to handle paths with spaces; OpenSSH still expands globs inside quotes. + includeLine := fmt.Sprintf(`Include "%s/*"`, configDirUnix) + if containsLine(content, includeLine) { return nil } + // Migrate unquoted Include written by older versions of the CLI. + oldIncludeLine := fmt.Sprintf("Include %s/*", configDirUnix) + if containsLine(content, oldIncludeLine) { + if err := fileutil.BackupFile(ctx, configPath, content); err != nil { + return fmt.Errorf("failed to backup SSH config before migration: %w", err) + } + migrated := strings.Replace(string(content), oldIncludeLine, includeLine, 1) + return os.WriteFile(configPath, []byte(migrated), 0o600) + } + + if err := fileutil.BackupFile(ctx, configPath, content); err != nil { + return fmt.Errorf("failed to backup SSH config: %w", err) + } newContent := includeLine + "\n" if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { newContent += "\n" @@ -99,6 +114,17 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error { return nil } +// containsLine reports whether data contains line as an exact line match, +// trimming \r to handle Windows line endings. +func containsLine(data []byte, line string) bool { + for l := range strings.SplitSeq(string(data), "\n") { + if strings.TrimRight(l, "\r") == line { + return true + } + } + return false +} + func GetHostConfigPath(ctx context.Context, hostName string) (string, error) { configDir, err := GetConfigDir(ctx) if err != nil { diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go index b2abf22cf1..0578f8f6fc 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig_test.go +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -80,9 +80,9 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { configDir, err := GetConfigDir(t.Context()) require.NoError(t, err) - // Use forward slashes as that's what SSH config uses + // Use forward slashes and quotes as that's what SSH config uses configDirUnix := filepath.ToSlash(configDir) - existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + existingContent := `Include "` + configDirUnix + `/*"` + "\n\nHost example\n User test\n" err = os.MkdirAll(filepath.Dir(configPath), 0o700) require.NoError(t, err) err = os.WriteFile(configPath, []byte(existingContent), 0o600) @@ -96,6 +96,59 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { assert.Equal(t, existingContent, string(content)) } +func TestEnsureIncludeDirective_MigratesOldUnquotedFormat(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + oldContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + err = os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(oldContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(t.Context(), configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + configStr := string(content) + + assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`) + assert.NotContains(t, configStr, "Include "+configDirUnix+"/*\n") + assert.Contains(t, configStr, "Host example") +} + +func TestEnsureIncludeDirective_NotFooledBySubstring(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + // The include path appears only inside a comment, not as a standalone directive. + existingContent := `# Include "` + configDirUnix + `/*"` + "\nHost example\n User test\n" + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(existingContent), 0o600)) + + err = EnsureIncludeDirective(t.Context(), configPath) + require.NoError(t, err) + + content, err := os.ReadFile(configPath) + require.NoError(t, err) + assert.Contains(t, string(content), `Include "`+configDirUnix+`/*"`) +} + func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, ".ssh", "config") @@ -127,6 +180,27 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content") } +func TestContainsLine(t *testing.T) { + tests := []struct { + name string + data string + line string + found bool + }{ + {"exact match", `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true}, + {"not present", "Host example\n", `Include "/path/*"`, false}, + {"substring only", `# Include "/path/*"`, `Include "/path/*"`, false}, + {"commented line", `# Include "/path/*"` + "\n" + `Include "/path/*"` + "\n", `Include "/path/*"`, true}, + {"windows line ending", `Include "/path/*"` + "\r\nHost example\r\n", `Include "/path/*"`, true}, + {"empty data", "", `Include "/path/*"`, false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.found, containsLine([]byte(tc.data), tc.line)) + }) + } +} + func TestGetHostConfigPath(t *testing.T) { path, err := GetHostConfigPath(t.Context(), "test-host") assert.NoError(t, err) diff --git a/experimental/ssh/internal/vscode/settings.go b/experimental/ssh/internal/vscode/settings.go index 19b40858a2..c8b13e3d56 100644 --- a/experimental/ssh/internal/vscode/settings.go +++ b/experimental/ssh/internal/vscode/settings.go @@ -9,6 +9,7 @@ import ( "runtime" "strings" + "github.com/databricks/cli/experimental/ssh/internal/fileutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/log" @@ -96,8 +97,10 @@ func CheckAndUpdateSettings(ctx context.Context, ide, connectionName string) err return nil } - if err := backupSettings(ctx, settingsPath); err != nil { - log.Warnf(ctx, "Failed to backup settings: %v. Continuing with update.", err) + if data, err := os.ReadFile(settingsPath); err == nil { + if err := fileutil.BackupFile(ctx, settingsPath, data); err != nil { + return fmt.Errorf("failed to backup settings: %w", err) + } } if err := updateSettings(&settings, connectionName, missing); err != nil { @@ -277,27 +280,6 @@ func handleMissingFile(ctx context.Context, ide, connectionName, settingsPath st return nil } -func backupSettings(ctx context.Context, path string) error { - data, err := os.ReadFile(path) - if err != nil { - return err - } - if len(data) == 0 { - return nil - } - - originalBak := path + ".original.bak" - latestBak := path + ".latest.bak" - - if _, err := os.Stat(originalBak); os.IsNotExist(err) { - log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(originalBak)) - return os.WriteFile(originalBak, data, 0o600) - } - - log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(latestBak)) - return os.WriteFile(latestBak, data, 0o600) -} - // subKeyOp returns a patch op that sets key/subKey=value, creating the parent object if absent. func subKeyOp(v *hujson.Value, key, subKey, value string) patchOp { if v.Find(jsonPtr(key)) == nil { diff --git a/experimental/ssh/internal/vscode/settings_test.go b/experimental/ssh/internal/vscode/settings_test.go index c76f93c956..47f8dd5408 100644 --- a/experimental/ssh/internal/vscode/settings_test.go +++ b/experimental/ssh/internal/vscode/settings_test.go @@ -2,11 +2,13 @@ package vscode import ( "encoding/json" + "io" "os" "path/filepath" "runtime" "testing" + "github.com/databricks/cli/experimental/ssh/internal/fileutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" "github.com/stretchr/testify/assert" @@ -428,45 +430,38 @@ func TestUpdateSettings_PartialUpdate(t *testing.T) { assert.Len(t, exts, 2) } -func TestBackupSettings(t *testing.T) { - tmpDir := t.TempDir() - settingsPath := filepath.Join(tmpDir, "settings.json") - originalBak := settingsPath + ".original.bak" - latestBak := settingsPath + ".latest.bak" +func TestCheckAndUpdateSettings_CreatesBackup(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("path setup differs on windows") + } - originalContent := []byte(`{"key": "value"}`) - err := os.WriteFile(settingsPath, originalContent, 0o600) - require.NoError(t, err) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) - ctx, _ := cmdio.NewTestContextWithStderr(t.Context()) + ctx, tst := cmdio.SetupTest(t.Context(), cmdio.TestOptions{PromptSupported: true}) + defer tst.Done() - // First backup: should create .original.bak - err = backupSettings(ctx, settingsPath) + settingsPath, err := getDefaultSettingsPath(ctx, "cursor") require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(settingsPath), 0o755)) - content, err := os.ReadFile(originalBak) - require.NoError(t, err) - assert.Equal(t, originalContent, content) - _, err = os.Stat(latestBak) - assert.True(t, os.IsNotExist(err)) + // Settings file with no Databricks-required keys → triggers an update prompt. + originalContent := []byte(`{}`) + require.NoError(t, os.WriteFile(settingsPath, originalContent, 0o600)) - // Second backup: .original.bak exists, should create .latest.bak - updatedContent := []byte(`{"key": "updated"}`) - err = os.WriteFile(settingsPath, updatedContent, 0o600) - require.NoError(t, err) + // Drain stderr (where the prompt is written) and feed "y" to stdin. + go func() { _, _ = io.Copy(io.Discard, tst.Stderr) }() + go func() { + _, _ = tst.Stdin.WriteString("y\n") + _ = tst.Stdin.Flush() + }() - err = backupSettings(ctx, settingsPath) + err = CheckAndUpdateSettings(ctx, "cursor", "my-host") require.NoError(t, err) - // .original.bak must remain unchanged - content, err = os.ReadFile(originalBak) + content, err := os.ReadFile(settingsPath + fileutil.SuffixOriginalBak) require.NoError(t, err) assert.Equal(t, originalContent, content) - - // .latest.bak should have the updated content - content, err = os.ReadFile(latestBak) - require.NoError(t, err) - assert.Equal(t, updatedContent, content) } func TestSaveSettings_Formatting(t *testing.T) { From accb7dcf4c24510dec842a7f9f806cd868df5669 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 30 Mar 2026 10:40:51 +0200 Subject: [PATCH 2/2] experimental/ssh: address review comments on backup and sshconfig - BackupFile now returns an error if os.Stat on the original backup fails with anything other than IsNotExist, instead of silently falling through to latestBak - containsLine trims leading spaces/tabs so indented SSH config directives (common inside Host blocks) are matched correctly - Add replaceLine helper with the same trim logic as containsLine, replacing the substring-based strings.Replace in the migration path to prevent false matches inside comments - Extend tests: TestBackupFile_StatError, TestReplaceLine, indented/comment-substring cases for EnsureIncludeDirective, and .latest.bak coverage in TestCheckAndUpdateSettings_CreatesBackup Co-authored-by: Isaac --- experimental/ssh/internal/fileutil/backup.go | 6 +- .../ssh/internal/fileutil/backup_test.go | 19 ++++ .../ssh/internal/sshconfig/sshconfig.go | 23 +++- .../ssh/internal/sshconfig/sshconfig_test.go | 100 ++++++++++++++++++ .../ssh/internal/vscode/settings_test.go | 25 ++++- 5 files changed, 165 insertions(+), 8 deletions(-) diff --git a/experimental/ssh/internal/fileutil/backup.go b/experimental/ssh/internal/fileutil/backup.go index 4b931f4566..c9e07503ef 100644 --- a/experimental/ssh/internal/fileutil/backup.go +++ b/experimental/ssh/internal/fileutil/backup.go @@ -22,7 +22,11 @@ func BackupFile(ctx context.Context, path string, data []byte) error { originalBak := path + SuffixOriginalBak latestBak := path + SuffixLatestBak var bakPath string - if _, err := os.Stat(originalBak); os.IsNotExist(err) { + _, statErr := os.Stat(originalBak) + if statErr != nil && !os.IsNotExist(statErr) { + return statErr + } + if os.IsNotExist(statErr) { bakPath = originalBak } else { bakPath = latestBak diff --git a/experimental/ssh/internal/fileutil/backup_test.go b/experimental/ssh/internal/fileutil/backup_test.go index 54c6adcf17..f57e82367a 100644 --- a/experimental/ssh/internal/fileutil/backup_test.go +++ b/experimental/ssh/internal/fileutil/backup_test.go @@ -3,6 +3,7 @@ package fileutil_test import ( "os" "path/filepath" + "runtime" "testing" "github.com/databricks/cli/experimental/ssh/internal/fileutil" @@ -64,3 +65,21 @@ func TestBackupFile_WriteError(t *testing.T) { err := fileutil.BackupFile(t.Context(), "/nonexistent/dir/file.json", []byte("data")) assert.Error(t, err) } + +func TestBackupFile_StatError(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("chmod is not supported on windows") + } + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "file.json") + + // Create the .original.bak file so os.Stat would find it — but make the + // parent directory unreadable so Stat returns a permission error instead. + require.NoError(t, os.WriteFile(path+fileutil.SuffixOriginalBak, []byte("existing"), 0o600)) + require.NoError(t, os.Chmod(tmpDir, 0o000)) + t.Cleanup(func() { _ = os.Chmod(tmpDir, 0o700) }) + + err := fileutil.BackupFile(t.Context(), path, []byte("data")) + assert.Error(t, err) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index 6454a62745..df7fbf1226 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -93,8 +93,7 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error { if err := fileutil.BackupFile(ctx, configPath, content); err != nil { return fmt.Errorf("failed to backup SSH config before migration: %w", err) } - migrated := strings.Replace(string(content), oldIncludeLine, includeLine, 1) - return os.WriteFile(configPath, []byte(migrated), 0o600) + return os.WriteFile(configPath, replaceLine(content, oldIncludeLine, includeLine), 0o600) } if err := fileutil.BackupFile(ctx, configPath, content); err != nil { @@ -114,17 +113,31 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error { return nil } -// containsLine reports whether data contains line as an exact line match, -// trimming \r to handle Windows line endings. +// containsLine reports whether data contains line as a line match, +// trimming leading whitespace and \r (Windows line endings) before comparing. func containsLine(data []byte, line string) bool { for l := range strings.SplitSeq(string(data), "\n") { - if strings.TrimRight(l, "\r") == line { + if strings.TrimLeft(strings.TrimRight(l, "\r"), " \t") == line { return true } } return false } +// replaceLine replaces the first line in data whose trimmed content matches old +// with new. Uses the same trim logic as containsLine. Returns data unchanged if +// no match. +func replaceLine(data []byte, old, new string) []byte { + lines := strings.Split(string(data), "\n") + for i, l := range lines { + if strings.TrimLeft(strings.TrimRight(l, "\r"), " \t") == old { + lines[i] = new + break + } + } + return []byte(strings.Join(lines, "\n")) +} + func GetHostConfigPath(ctx context.Context, hostName string) (string, error) { configDir, err := GetConfigDir(ctx) if err != nil { diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go index 0578f8f6fc..6c453910cd 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig_test.go +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -193,6 +193,8 @@ func TestContainsLine(t *testing.T) { {"commented line", `# Include "/path/*"` + "\n" + `Include "/path/*"` + "\n", `Include "/path/*"`, true}, {"windows line ending", `Include "/path/*"` + "\r\nHost example\r\n", `Include "/path/*"`, true}, {"empty data", "", `Include "/path/*"`, false}, + {"indented with spaces", " " + `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true}, + {"indented with tab", "\t" + `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { @@ -201,6 +203,104 @@ func TestContainsLine(t *testing.T) { } } +func TestReplaceLine(t *testing.T) { + tests := []struct { + name string + data string + old string + new string + expected string + }{ + { + "exact match", + `Include "/p/*"` + "\nHost x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + `Include "/p/*" NEW` + "\nHost x\n", + }, + { + "indented match", + " " + `Include "/p/*"` + "\nHost x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + `Include "/p/*" NEW` + "\nHost x\n", + }, + { + "no match", + "Host x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + "Host x\n", + }, + { + "substring in comment — must not be replaced", + `# Include "/p/*"` + "\nHost x\n", + `Include "/p/*"`, `Include "/p/*" NEW`, + `# Include "/p/*"` + "\nHost x\n", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := replaceLine([]byte(tc.data), tc.old, tc.new) + assert.Equal(t, tc.expected, string(got)) + }) + } +} + +func TestEnsureIncludeDirective_MigratesIndentedOldFormat(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + // Old format with leading whitespace — should still be detected and migrated. + oldContent := " Include " + configDirUnix + "/*\n\nHost example\n User test\n" + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(oldContent), 0o600)) + + err = EnsureIncludeDirective(t.Context(), configPath) + require.NoError(t, err) + + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + + assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`) + assert.NotContains(t, configStr, " Include "+configDirUnix+"/*") + assert.Contains(t, configStr, "Host example") +} + +func TestEnsureIncludeDirective_NotFooledByOldFormatSubstring(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir(t.Context()) + require.NoError(t, err) + + configDirUnix := filepath.ToSlash(configDir) + // Old unquoted form appears only inside a comment — must not be migrated. + existingContent := "# Include " + configDirUnix + "/*\nHost example\n User test\n" + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(existingContent), 0o600)) + + err = EnsureIncludeDirective(t.Context(), configPath) + require.NoError(t, err) + + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + + // New quoted directive should have been prepended (not a migration). + assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`) + // Comment line must be preserved unchanged. + assert.Contains(t, configStr, "# Include "+configDirUnix+"/*") +} + func TestGetHostConfigPath(t *testing.T) { path, err := GetHostConfigPath(t.Context(), "test-host") assert.NoError(t, err) diff --git a/experimental/ssh/internal/vscode/settings_test.go b/experimental/ssh/internal/vscode/settings_test.go index 47f8dd5408..a6fcf77988 100644 --- a/experimental/ssh/internal/vscode/settings_test.go +++ b/experimental/ssh/internal/vscode/settings_test.go @@ -459,9 +459,30 @@ func TestCheckAndUpdateSettings_CreatesBackup(t *testing.T) { err = CheckAndUpdateSettings(ctx, "cursor", "my-host") require.NoError(t, err) - content, err := os.ReadFile(settingsPath + fileutil.SuffixOriginalBak) + originalBakContent, err := os.ReadFile(settingsPath + fileutil.SuffixOriginalBak) require.NoError(t, err) - assert.Equal(t, originalContent, content) + assert.Equal(t, originalContent, originalBakContent) + + // Second update for a new connection triggers another backup into .latest.bak. + postFirstContent, err := os.ReadFile(settingsPath) + require.NoError(t, err) + + go func() { + _, _ = tst.Stdin.WriteString("y\n") + _ = tst.Stdin.Flush() + }() + + err = CheckAndUpdateSettings(ctx, "cursor", "my-host-2") + require.NoError(t, err) + + latestBakContent, err := os.ReadFile(settingsPath + fileutil.SuffixLatestBak) + require.NoError(t, err) + assert.Equal(t, postFirstContent, latestBakContent) + + // .original.bak must still hold the very first snapshot. + originalBakContent2, err := os.ReadFile(settingsPath + fileutil.SuffixOriginalBak) + require.NoError(t, err) + assert.Equal(t, originalContent, originalBakContent2) } func TestSaveSettings_Formatting(t *testing.T) {