diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 79e1063b18..7d688832ff 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -2,9 +2,12 @@ package auth import ( "context" + "fmt" + "strings" "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" + "github.com/google/uuid" "github.com/spf13/cobra" ) @@ -34,25 +37,36 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, } func promptForHost(ctx context.Context) (string, error) { + if !cmdio.IsInTTY(ctx) { + return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a host using --host") + } + prompt := cmdio.Prompt(ctx) - prompt.Label = "Databricks Host (e.g. https://.cloud.databricks.com)" - // Validate? - host, err := prompt.Run() - if err != nil { - return "", err + prompt.Label = "Databricks host" + prompt.Validate = func(host string) error { + if !strings.HasPrefix(host, "https://") { + return fmt.Errorf("host URL must have a https:// prefix") + } + return nil } - return host, nil + return prompt.Run() } func promptForAccountID(ctx context.Context) (string, error) { + if !cmdio.IsInTTY(ctx) { + return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify an account ID using --account-id") + } + prompt := cmdio.Prompt(ctx) - prompt.Label = "Databricks Account ID" + prompt.Label = "Databricks account id" prompt.Default = "" prompt.AllowEdit = true - // Validate? - accountId, err := prompt.Run() - if err != nil { - return "", err + prompt.Validate = func(accountID string) error { + _, err := uuid.Parse(accountID) + if err != nil { + return fmt.Errorf("account ID must be a valid UUID: %w", err) + } + return nil } - return accountId, nil + return prompt.Run() } diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 11cba8e5f1..c1b2d924bd 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -17,18 +17,32 @@ import ( "github.com/spf13/cobra" ) -func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, args []string, argIndex int) error { - if len(args) > argIndex { - persistentAuth.Host = args[argIndex] - return nil +func promptForProfile(ctx context.Context, dv string) (string, error) { + if !cmdio.IsInTTY(ctx) { + return "", fmt.Errorf("the command is being run in a non-interactive environment, please specify a profile using --profile") } - host, err := promptForHost(ctx) - if err != nil { - return err + prompt := cmdio.Prompt(ctx) + prompt.Label = "Databricks profile name" + prompt.Default = dv + prompt.AllowEdit = true + return prompt.Run() +} + +func getHostFromProfile(ctx context.Context, profileName string) (string, error) { + profiler := profile.GetProfiler(ctx) + // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. + profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName)) + // Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow. + if err != nil && !errors.Is(err, profile.ErrNoConfiguration) { + return "", err } - persistentAuth.Host = host - return nil + + // Return host from profile + if len(profiles) > 0 && profiles[0].Host != "" { + return profiles[0].Host, nil + } + return "", nil } const minimalDbConnectVersion = "13.1" @@ -93,23 +107,18 @@ depends on the existing profiles you have set in your configuration file cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() + profileName := cmd.Flag("profile").Value.String() - var profileName string - profileFlag := cmd.Flag("profile") - if profileFlag != nil && profileFlag.Value.String() != "" { - profileName = profileFlag.Value.String() - } else if cmdio.IsInTTY(ctx) { - prompt := cmdio.Prompt(ctx) - prompt.Label = "Databricks Profile Name" - prompt.Default = persistentAuth.ProfileName() - prompt.AllowEdit = true - profile, err := prompt.Run() + // If the user has not specified a profile name, prompt for one. + if profileName == "" { + var err error + profileName, err = promptForProfile(ctx, persistentAuth.DefaultProfileName()) if err != nil { return err } - profileName = profile } + // Set the host and account-id based on the provided arguments and flags. err := setHostAndAccountId(ctx, profileName, persistentAuth, args) if err != nil { return err @@ -167,7 +176,23 @@ depends on the existing profiles you have set in your configuration file return cmd } +// Sets the host in the persistentAuth object based on the provided arguments and flags. +// Follows the following precedence: +// 1. [HOST] (first positional argument) or --host flag. Error if both are specified. +// 2. Profile host, if available. +// 3. Prompt the user for the host. +// +// Set the account in the persistentAuth object based on the flags. +// Follows the following precedence: +// 1. --account-id flag. +// 2. account-id from the specified profile, if available. +// 3. Prompt the user for the account-id. func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error { + // If both [HOST] and --host are provided, return an error. + if len(args) > 0 && persistentAuth.Host != "" { + return fmt.Errorf("please only provide a host as an argument or a flag, not both") + } + profiler := profile.GetProfiler(ctx) // If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile. profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName)) @@ -176,18 +201,34 @@ func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth return err } - if persistentAuth.Host == "" { + // If [HOST] is provided, set the host to the provided positional argument. + if len(args) > 0 && persistentAuth.Host == "" { + persistentAuth.Host = args[0] + } + + // If neither [HOST] nor --host are provided, and the profile has a host, use it. + // Otherwise, prompt the user for a host. + if len(args) == 0 && persistentAuth.Host == "" { if len(profiles) > 0 && profiles[0].Host != "" { persistentAuth.Host = profiles[0].Host } else { - configureHost(ctx, persistentAuth, args, 0) + hostName, err := promptForHost(ctx) + if err != nil { + return err + } + persistentAuth.Host = hostName } } + + // If the account-id was not provided as a cmd line flag, try to read it from + // the specified profile. isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient() if isAccountClient && persistentAuth.AccountID == "" { if len(profiles) > 0 && profiles[0].AccountID != "" { persistentAuth.AccountID = profiles[0].AccountID } else { + // Prompt user for the account-id if it we could not get it from a + // profile. accountId, err := promptForAccountID(ctx) if err != nil { return err diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index ce3ca5ae57..d0fa5a16b8 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -5,8 +5,10 @@ import ( "testing" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { @@ -15,3 +17,69 @@ func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) { err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{}) assert.NoError(t, err) } + +func TestSetHost(t *testing.T) { + var persistentAuth auth.PersistentAuth + t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") + ctx, _ := cmdio.SetupTest(context.Background()) + + // Test error when both flag and argument are provided + persistentAuth.Host = "val from --host" + err := setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both") + + // Test setting host from flag + persistentAuth.Host = "val from --host" + err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "val from --host", persistentAuth.Host) + + // Test setting host from argument + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"}) + assert.NoError(t, err) + assert.Equal(t, "val from [HOST]", persistentAuth.Host) + + // Test setting host from profile + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "https://www.host1.com", persistentAuth.Host) + + // Test setting host from profile + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "profile-2", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "https://www.host2.com", persistentAuth.Host) + + // Test host is not set. Should prompt. + persistentAuth.Host = "" + err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host") +} + +func TestSetAccountId(t *testing.T) { + var persistentAuth auth.PersistentAuth + t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") + ctx, _ := cmdio.SetupTest(context.Background()) + + // Test setting account-id from flag + persistentAuth.AccountID = "val from --account-id" + err := setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + assert.NoError(t, err) + assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) + assert.Equal(t, "val from --account-id", persistentAuth.AccountID) + + // Test setting account_id from profile + persistentAuth.AccountID = "" + err = setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{}) + require.NoError(t, err) + assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host) + assert.Equal(t, "id-from-profile", persistentAuth.AccountID) + + // Neither flag nor profile account-id is set, should prompt + persistentAuth.AccountID = "" + persistentAuth.Host = "https://accounts.cloud.databricks.com" + err = setHostAndAccountId(ctx, "", &persistentAuth, []string{}) + assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id") +} diff --git a/cmd/auth/testdata/.databrickscfg b/cmd/auth/testdata/.databrickscfg new file mode 100644 index 0000000000..06e55224a1 --- /dev/null +++ b/cmd/auth/testdata/.databrickscfg @@ -0,0 +1,9 @@ +[profile-1] +host = https://www.host1.com + +[profile-2] +host = https://www.host2.com + +[account-profile] +host = https://accounts.cloud.databricks.com +account_id = id-from-profile diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index 1f3e032de9..ba8bb60fef 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -104,11 +104,13 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { return refreshed, nil } -func (a *PersistentAuth) ProfileName() string { - // TODO: get profile name from interactive input +func (a *PersistentAuth) DefaultProfileName() string { if a.AccountID != "" { return fmt.Sprintf("ACCOUNT-%s", a.AccountID) } + if a.Host == "" { + return "DEFAULT" + } host := strings.TrimPrefix(a.Host, "https://") split := strings.Split(host, ".") return split[0]