Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/cmd/auth/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func NewCmd(p *print.Printer) *cobra.Command {
"$ stackit auth login"),
),
RunE: func(cmd *cobra.Command, args []string) error {
err := auth.AuthorizeUser()
err := auth.AuthorizeUser(p, false)
if err != nil {
return fmt.Errorf("authorization failed: %w", err)
}
Expand Down
5 changes: 2 additions & 3 deletions internal/pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type tokenClaims struct {
// It returns the configuration option that can be used to create an authenticated SDK client.
//
// If the user was logged in and the user session expired, reauthorizeUserRoutine is called to reauthenticate the user again.
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func() error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, isReauthentication bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
flow, err := GetAuthFlow()
if err != nil {
return nil, fmt.Errorf("get authentication flow: %w", err)
Expand Down Expand Up @@ -57,8 +57,7 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func() error)
authCfgOption = sdkConfig.WithCustomAuth(keyFlow)
case AUTH_FLOW_USER_TOKEN:
if userSessionExpired {
p.Warn("Session expired, logging in again...\n")
err = reauthorizeUserRoutine()
err = reauthorizeUserRoutine(p, true)
if err != nil {
return nil, fmt.Errorf("user login: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func TestAuthenticationConfig(t *testing.T) {
}

reauthorizeUserCalled := false
reauthenticateUser := func() error {
reauthenticateUser := func(p *print.Printer, isReauthentication bool) error {
if reauthorizeUserCalled {
t.Errorf("user reauthorized more than once")
}
Expand Down
11 changes: 10 additions & 1 deletion internal/pkg/auth/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"time"

"golang.org/x/oauth2"

"github.com/stackitcloud/stackit-cli/internal/pkg/print"
)

const (
Expand All @@ -36,7 +38,14 @@ type User struct {
}

// AuthorizeUser implements the PKCE OAuth2 flow.
func AuthorizeUser() error {
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
if isReauthentication {
err := p.PromptForEnter("Your session has expired, press Enter to login again...")
if err != nil {
return err
}
}

listener, err := net.Listen("tcp", ":0")
if err != nil {
return fmt.Errorf("bind port for login redirect: %w", err)
Expand Down
5 changes: 2 additions & 3 deletions internal/pkg/auth/user_token_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

type userTokenFlow struct {
printer *print.Printer
reauthorizeUserRoutine func() error // Called if the user needs to login again
reauthorizeUserRoutine func(p *print.Printer, isReauthentication bool) error // Called if the user needs to login again
client *http.Client
authFlow AuthFlow
accessToken string
Expand Down Expand Up @@ -59,7 +59,6 @@ func (utf *userTokenFlow) RoundTrip(req *http.Request) (*http.Response, error) {
}

if !accessTokenValid {
utf.printer.Warn("Session expired, logging in again...")
err = reauthenticateUser(utf)
if err != nil {
return nil, fmt.Errorf("reauthenticate user: %w", err)
Expand Down Expand Up @@ -91,7 +90,7 @@ func loadVarsFromStorage(utf *userTokenFlow) error {
}

func reauthenticateUser(utf *userTokenFlow) error {
err := utf.reauthorizeUserRoutine()
err := utf.reauthorizeUserRoutine(utf.printer, true)
if err != nil {
return fmt.Errorf("authenticate user: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/auth/user_token_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func TestRoundTrip(t *testing.T) {
authorizeUserCalled: &authorizeUserCalled,
tokensRefreshed: &tokensRefreshed,
}
authorizeUserRoutine := func() error {
authorizeUserRoutine := func(p *print.Printer, isReauthentication bool) error {
return reauthorizeUser(authorizeUserContext)
}

Expand Down
21 changes: 21 additions & 0 deletions internal/pkg/print/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"errors"
"fmt"

"log/slog"
"os"
"os/exec"
Expand Down Expand Up @@ -130,6 +131,26 @@ func (p *Printer) PromptForConfirmation(prompt string) error {
return fmt.Errorf("max number of wrong inputs")
}

// Prompts the user for confirmation by pressing Enter.
//
// Returns nil only if the user (explicitly) press directly enter.
// Returns ErrAborted if the user press anything else before pressing enter.
func (p *Printer) PromptForEnter(prompt string) error {
reader := bufio.NewReaderSize(p.Cmd.InOrStdin(), 1)

p.Cmd.PrintErr(prompt)
answer, err := reader.ReadByte()
if err != nil {
return fmt.Errorf("read user response: %w", err)
}

// ASCII code for Enter (newline) is 10.
if answer == 10 {
return nil
}
return errAborted
}

// Shows the content in the command's stdout using the "less" command
// If output format is set to none, it does nothing
func (p *Printer) PagerDisplay(content string) error {
Expand Down