Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 99 additions & 43 deletions cmd/root/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,57 @@ func initProfileFlag(cmd *cobra.Command) {
cmd.RegisterFlagCompletionFunc("profile", databrickscfg.ProfileCompletion)
}

func profileFlagValue(cmd *cobra.Command) (string, bool) {
profileFlag := cmd.Flag("profile")
if profileFlag == nil {
return "", false
}
value := profileFlag.Value.String()
return value, value != ""
}

// Helper function to create an account client or prompt once if the given configuration is not valid.
func accountClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.AccountClient, error) {
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
if err == nil {
err = a.Config.Authenticate(emptyHttpRequest(ctx))
}

prompt := false
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
// Prompt to select a profile if the current configuration is not an account client.
prompt = prompt || errors.Is(err, databricks.ErrNotAccountClient)
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
}

if !prompt {
// If we are not prompting, we can return early.
return a, err
}

// Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForAccountProfile(ctx)
if err != nil {
return nil, err
}
a, err = databricks.NewAccountClient(&databricks.Config{Profile: profile})
if err == nil {
err = a.Config.Authenticate(emptyHttpRequest(ctx))
if err != nil {
return nil, err
}
}
return a, nil
}

func MustAccountClient(cmd *cobra.Command, args []string) error {
cfg := &config.Config{}

// command-line flag can specify the profile in use
profileFlag := cmd.Flag("profile")
if profileFlag != nil {
cfg.Profile = profileFlag.Value.String()
// The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
profile, hasProfileFlag := profileFlagValue(cmd)
if hasProfileFlag {
cfg.Profile = profile
}

if cfg.Profile == "" {
Expand All @@ -48,16 +92,8 @@ func MustAccountClient(cmd *cobra.Command, args []string) error {
}
}

TRY_AUTH: // or try picking a config profile dynamically
a, err := databricks.NewAccountClient((*databricks.Config)(cfg))
if cmdio.IsInteractive(cmd.Context()) && errors.Is(err, databricks.ErrNotAccountClient) {
profile, err := askForAccountProfile()
if err != nil {
return err
}
cfg = &config.Config{Profile: profile}
goto TRY_AUTH
}
allowPrompt := !hasProfileFlag
a, err := accountClientOrPrompt(cmd.Context(), cfg, allowPrompt)
if err != nil {
return err
}
Expand All @@ -66,13 +102,48 @@ TRY_AUTH: // or try picking a config profile dynamically
return nil
}

// Helper function to create a workspace client or prompt once if the given configuration is not valid.
func workspaceClientOrPrompt(ctx context.Context, cfg *config.Config, allowPrompt bool) (*databricks.WorkspaceClient, error) {
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
if err == nil {
err = w.Config.Authenticate(emptyHttpRequest(ctx))
}

prompt := false
if allowPrompt && err != nil && cmdio.IsInteractive(ctx) {
// Prompt to select a profile if the current configuration is not a workspace client.
prompt = prompt || errors.Is(err, databricks.ErrNotWorkspaceClient)
// Prompt to select a profile if the current configuration doesn't resolve to a credential provider.
prompt = prompt || errors.Is(err, config.ErrCannotConfigureAuth)
}

if !prompt {
// If we are not prompting, we can return early.
return w, err
}

// Try picking a profile dynamically if the current configuration is not valid.
profile, err := askForWorkspaceProfile(ctx)
if err != nil {
return nil, err
}
w, err = databricks.NewWorkspaceClient(&databricks.Config{Profile: profile})
if err == nil {
err = w.Config.Authenticate(emptyHttpRequest(ctx))
if err != nil {
return nil, err
}
}
return w, nil
}

func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
cfg := &config.Config{}

// command-line flag takes precedence over environment variable
profileFlag := cmd.Flag("profile")
if profileFlag != nil {
cfg.Profile = profileFlag.Value.String()
// The command-line profile flag takes precedence over DATABRICKS_CONFIG_PROFILE.
profile, hasProfileFlag := profileFlagValue(cmd)
if hasProfileFlag {
cfg.Profile = profile
}

// try configuring a bundle
Expand All @@ -87,24 +158,13 @@ func MustWorkspaceClient(cmd *cobra.Command, args []string) error {
cfg = currentBundle.WorkspaceClient().Config
}

TRY_AUTH: // or try picking a config profile dynamically
ctx := cmd.Context()
w, err := databricks.NewWorkspaceClient((*databricks.Config)(cfg))
if err != nil {
return err
}
err = w.Config.Authenticate(emptyHttpRequest(ctx))
if cmdio.IsInteractive(ctx) && errors.Is(err, config.ErrCannotConfigureAuth) {
profile, err := askForWorkspaceProfile()
if err != nil {
return err
}
cfg = &config.Config{Profile: profile}
goto TRY_AUTH
}
allowPrompt := !hasProfileFlag
w, err := workspaceClientOrPrompt(cmd.Context(), cfg, allowPrompt)
if err != nil {
return err
}

ctx := cmd.Context()
ctx = context.WithValue(ctx, &workspaceClient, w)
cmd.SetContext(ctx)
return nil
Expand All @@ -121,7 +181,7 @@ func transformLoadError(path string, err error) error {
return err
}

func askForWorkspaceProfile() (string, error) {
func askForWorkspaceProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath()
if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
Expand All @@ -136,7 +196,7 @@ func askForWorkspaceProfile() (string, error) {
case 1:
return profiles[0].Name, nil
}
i, _, err := (&promptui.Select{
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Workspace profiles defined in %s", file),
Items: profiles,
Searcher: profiles.SearchCaseInsensitive,
Expand All @@ -147,16 +207,14 @@ func askForWorkspaceProfile() (string, error) {
Inactive: `{{.Name}}`,
Selected: `{{ "Using workspace profile" | faint }}: {{ .Name | bold }}`,
},
Stdin: os.Stdin,
Stdout: os.Stderr,
}).Run()
})
if err != nil {
return "", err
}
return profiles[i].Name, nil
}

func askForAccountProfile() (string, error) {
func askForAccountProfile(ctx context.Context) (string, error) {
path, err := databrickscfg.GetPath()
if err != nil {
return "", fmt.Errorf("cannot determine Databricks config file path: %w", err)
Expand All @@ -171,7 +229,7 @@ func askForAccountProfile() (string, error) {
case 1:
return profiles[0].Name, nil
}
i, _, err := (&promptui.Select{
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: fmt.Sprintf("Account profiles defined in %s", file),
Items: profiles,
Searcher: profiles.SearchCaseInsensitive,
Expand All @@ -182,9 +240,7 @@ func askForAccountProfile() (string, error) {
Inactive: `{{.Name}}`,
Selected: `{{ "Using account profile" | faint }}: {{ .Name | bold }}`,
},
Stdin: os.Stdin,
Stdout: os.Stderr,
}).Run()
})
if err != nil {
return "", err
}
Expand Down
164 changes: 164 additions & 0 deletions cmd/root/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,177 @@ package root

import (
"context"
"os"
"path/filepath"
"testing"
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestEmptyHttpRequest(t *testing.T) {
ctx, _ := context.WithCancel(context.Background())
req := emptyHttpRequest(ctx)
assert.Equal(t, req.Context(), ctx)
}

type promptFn func(ctx context.Context, cfg *config.Config, retry bool) (any, error)

var accountPromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
return accountClientOrPrompt(ctx, cfg, retry)
}

var workspacePromptFn = func(ctx context.Context, cfg *config.Config, retry bool) (any, error) {
return workspaceClientOrPrompt(ctx, cfg, retry)
}

func expectPrompts(t *testing.T, fn promptFn, config *config.Config) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

// Channel to pass errors from the prompting function back to the test.
errch := make(chan error, 1)

ctx, io := cmdio.SetupTest(ctx)
go func() {
defer close(errch)
defer cancel()
_, err := fn(ctx, config, true)
errch <- err
}()

// Expect a prompt
line, _, err := io.Stderr.ReadLine()
if assert.NoError(t, err, "Expected to read a line from stderr") {
assert.Contains(t, string(line), "Search:")
} else {
// If there was an error reading from stderr, the prompting function must have terminated early.
assert.NoError(t, <-errch)
}
}

func expectReturns(t *testing.T, fn promptFn, config *config.Config) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

ctx, _ = cmdio.SetupTest(ctx)
client, err := fn(ctx, config, true)
require.NoError(t, err)
require.NotNil(t, client)
}

func TestAccountClientOrPrompt(t *testing.T) {
dir := t.TempDir()
configFile := filepath.Join(dir, ".databrickscfg")
err := os.WriteFile(
configFile,
[]byte(`
[account-1111]
host = https://accounts.azuredatabricks.net/
account_id = 1111
token = foobar

[account-1112]
host = https://accounts.azuredatabricks.net/
account_id = 1112
token = foobar
`),
0755)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
t.Setenv("PATH", "/nothing")

t.Run("Prompt if nothing is specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{})
})

t.Run("Prompt if a workspace host is specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://adb-1234567.89.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})

t.Run("Prompt if account ID is not specified", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
Token: "foobar",
})
})

t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
expectPrompts(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
})
})

t.Run("Returns if configuration is valid", func(t *testing.T) {
expectReturns(t, accountPromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})

t.Run("Returns if a valid profile is specified", func(t *testing.T) {
expectReturns(t, accountPromptFn, &config.Config{
Profile: "account-1111",
})
})
}

func TestWorkspaceClientOrPrompt(t *testing.T) {
dir := t.TempDir()
configFile := filepath.Join(dir, ".databrickscfg")
err := os.WriteFile(
configFile,
[]byte(`
[workspace-1111]
host = https://adb-1111.11.azuredatabricks.net/
token = foobar

[workspace-1112]
host = https://adb-1112.12.azuredatabricks.net/
token = foobar
`),
0755)
require.NoError(t, err)
t.Setenv("DATABRICKS_CONFIG_FILE", configFile)
t.Setenv("PATH", "/nothing")

t.Run("Prompt if nothing is specified", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{})
})

t.Run("Prompt if an account host is specified", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{
Host: "https://accounts.azuredatabricks.net/",
AccountID: "1234",
Token: "foobar",
})
})

t.Run("Prompt if no credential provider can be configured", func(t *testing.T) {
expectPrompts(t, workspacePromptFn, &config.Config{
Host: "https://adb-1111.11.azuredatabricks.net/",
})
})

t.Run("Returns if configuration is valid", func(t *testing.T) {
expectReturns(t, workspacePromptFn, &config.Config{
Host: "https://adb-1111.11.azuredatabricks.net/",
Token: "foobar",
})
})

t.Run("Returns if a valid profile is specified", func(t *testing.T) {
expectReturns(t, workspacePromptFn, &config.Config{
Profile: "workspace-1111",
})
})
}
Loading