Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions builtin/credential/okta/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v5/okta"
"github.com/patrickmn/go-cache"
)

Expand Down Expand Up @@ -118,6 +118,7 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username, pas
StateToken string `json:"stateToken"`
}

// The okta-sdk-golang API says to construct your own requests for auth, and the Request Executor is gone, so
authReq, err := shim.NewRequest("POST", "authn", map[string]interface{}{
"username": username,
"password": password,
Expand All @@ -129,9 +130,6 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username, pas
var result authResult
rsp, err := shim.Do(authReq, &result)
if err != nil {
if oe, ok := err.(*okta.Error); ok {
return nil, logical.ErrorResponse("Okta auth failed: %v (code=%v)", err, oe.ErrorCode), nil, nil
}
return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil, nil
}
if rsp == nil {
Expand Down Expand Up @@ -370,23 +368,23 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username, pas
return policies, oktaResponse, allGroups, nil
}

func (b *backend) getOktaGroups(ctx context.Context, client *okta.Client, user *okta.User) ([]string, error) {
groups, resp, err := client.User.ListUserGroups(ctx, user.Id)
func (b *backend) getOktaGroups(ctx context.Context, client *okta.APIClient, user *okta.User) ([]string, error) {
groups, resp, err := client.UserAPI.ListUserGroups(ctx, user.GetId()).Execute()
if err != nil {
return nil, err
}
oktaGroups := make([]string, 0, len(groups))
for _, group := range groups {
oktaGroups = append(oktaGroups, group.Profile.Name)
oktaGroups = append(oktaGroups, group.Profile.GetName())
}
for resp.HasNextPage() {
var nextGroups []*okta.Group
resp, err = resp.Next(ctx, &nextGroups)
resp, err = resp.Next(&nextGroups)
if err != nil {
return nil, err
}
for _, group := range nextGroups {
oktaGroups = append(oktaGroups, group.Profile.Name)
oktaGroups = append(oktaGroups, group.Profile.GetName())
}
}
if b.Logger().IsDebug() {
Expand Down
34 changes: 16 additions & 18 deletions builtin/credential/okta/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/okta/okta-sdk-golang/v5/okta"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -115,15 +114,15 @@ func TestBackend_Config(t *testing.T) {

func createOktaGroups(t *testing.T, username string, token string, org string) []string {
orgURL := "https://" + org + "." + previewBaseURL
ctx, client, err := okta.NewClient(context.Background(), okta.WithOrgUrl(orgURL), okta.WithToken(token))
cfg, err := okta.NewConfiguration(okta.WithOrgUrl(orgURL), okta.WithToken(token))
require.Nil(t, err)
client := okta.NewAPIClient(cfg)
ctx := context.Background()

users, _, err := client.User.ListUsers(ctx, &query.Params{
Q: username,
})
users, _, err := client.UserAPI.ListUsers(ctx).Q(username).Execute()
require.Nil(t, err)
require.Len(t, users, 1)
userID := users[0].Id
userID := users[0].GetId()
var groupIDs []string

// Verify that login's call to list the groups of the user logging in will page
Expand All @@ -133,38 +132,37 @@ func createOktaGroups(t *testing.T, username string, token string, org string) [
// only 200 results are returned for most orgs."
for i := 0; i < 201; i++ {
name := fmt.Sprintf("TestGroup%d", i)
groups, _, err := client.Group.ListGroups(ctx, &query.Params{
Q: name,
})
groups, _, err := client.GroupAPI.ListGroups(ctx).Q(name).Execute()
require.Nil(t, err)

var groupID string
if len(groups) == 0 {
group, _, err := client.Group.CreateGroup(ctx, okta.Group{
group, _, err := client.GroupAPI.CreateGroup(ctx).Group(okta.Group{
Profile: &okta.GroupProfile{
Name: fmt.Sprintf("TestGroup%d", i),
Name: okta.PtrString(fmt.Sprintf("TestGroup%d", i)),
},
})
}).Execute()
require.Nil(t, err)
groupID = group.Id
groupID = group.GetId()
} else {
groupID = groups[0].Id
groupID = groups[0].GetId()
}
groupIDs = append(groupIDs, groupID)

_, err = client.Group.AddUserToGroup(ctx, groupID, userID)
_, err = client.GroupAPI.AssignUserToGroup(ctx, groupID, userID).Execute()
require.Nil(t, err)
}
return groupIDs
}

func deleteOktaGroups(t *testing.T, token string, org string, groupIDs []string) {
orgURL := "https://" + org + "." + previewBaseURL
ctx, client, err := okta.NewClient(context.Background(), okta.WithOrgUrl(orgURL), okta.WithToken(token))
cfg, err := okta.NewConfiguration(okta.WithOrgUrl(orgURL), okta.WithToken(token))
require.Nil(t, err)
client := okta.NewAPIClient(cfg)

for _, groupID := range groupIDs {
_, err := client.Group.DeleteGroup(ctx, groupID)
_, err := client.GroupAPI.DeleteGroup(context.Background(), groupID).Execute()
require.Nil(t, err)
}
}
Expand Down
133 changes: 123 additions & 10 deletions builtin/credential/okta/path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
package okta

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
Expand All @@ -16,7 +19,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/tokenutil"
"github.com/hashicorp/vault/sdk/logical"
oktanew "github.com/okta/okta-sdk-golang/v2/okta"
oktanew "github.com/okta/okta-sdk-golang/v5/okta"
)

const (
Expand Down Expand Up @@ -290,36 +293,128 @@ func (b *backend) pathConfigExistenceCheck(ctx context.Context, req *logical.Req
}

type oktaShim interface {
Client() (*oktanew.Client, context.Context)
Client() (*oktanew.APIClient, context.Context)
NewRequest(method string, url string, body interface{}) (*http.Request, error)
Do(req *http.Request, v interface{}) (interface{}, error)
}

type oktaShimNew struct {
client *oktanew.Client
cfg *oktanew.Configuration
client *oktanew.APIClient
ctx context.Context
}

func (new *oktaShimNew) Client() (*oktanew.Client, context.Context) {
func (new *oktaShimNew) Client() (*oktanew.APIClient, context.Context) {
return new.client, new.ctx
}

func (new *oktaShimNew) NewRequest(method string, url string, body interface{}) (*http.Request, error) {
if !strings.HasPrefix(url, "/") {
url = "/api/v1/" + url
}
return new.client.GetRequestExecutor().NewRequest(method, url, body)

// reimplementation of RequestExecutor.NewRequest() in v2 of okta-golang-sdk
var buff io.ReadWriter
if body != nil {
switch v := body.(type) {
case []byte:
buff = bytes.NewBuffer(v)
case *bytes.Buffer:
buff = v
default:
buff = &bytes.Buffer{}
// need to create an encoder specifically to disable html escaping
encoder := json.NewEncoder(buff)
encoder.SetEscapeHTML(false)
err := encoder.Encode(body)
if err != nil {
return nil, err
}
}
}

url = new.cfg.Okta.Client.OrgUrl + url
req, err := http.NewRequest(method, url, buff)
if err != nil {
return nil, err
}
//
var auth oktanew.Authorization
//
switch new.cfg.Okta.Client.AuthorizationMode {
case "SSWS":
auth = oktanew.NewSSWSAuth(new.cfg.Okta.Client.Token, req)
case "Bearer":
auth = oktanew.NewBearerAuth(new.cfg.Okta.Client.Token, req)
case "PrivateKey":
auth = oktanew.NewPrivateKeyAuth(oktanew.PrivateKeyAuthConfig{
// TokenCache: new.cfg., hmm
HttpClient: new.cfg.HTTPClient,
PrivateKeySigner: new.cfg.PrivateKeySigner,
PrivateKey: new.cfg.Okta.Client.PrivateKey,
PrivateKeyId: new.cfg.Okta.Client.PrivateKeyId,
ClientId: new.cfg.Okta.Client.ClientId,
OrgURL: new.cfg.Okta.Client.OrgUrl,
Scopes: new.cfg.Okta.Client.Scopes,
MaxRetries: new.cfg.Okta.Client.RateLimit.MaxRetries,
MaxBackoff: new.cfg.Okta.Client.RateLimit.MaxBackoff,
Req: req,
})
case "JWT":
auth = oktanew.NewJWTAuth(oktanew.JWTAuthConfig{
HttpClient: new.cfg.HTTPClient,
OrgURL: new.cfg.Okta.Client.OrgUrl,
Scopes: new.cfg.Okta.Client.Scopes,
ClientAssertion: new.cfg.Okta.Client.ClientAssertion,
MaxRetries: new.cfg.Okta.Client.RateLimit.MaxRetries,
MaxBackoff: new.cfg.Okta.Client.RateLimit.MaxBackoff,
Req: req,
})
default:
return nil, fmt.Errorf("unknown authorization mode %v", new.cfg.Okta.Client.AuthorizationMode)
}

err = auth.Authorize("POST", url)
if err != nil {
return nil, err
}

// req.Header.Add("User-Agent", NewUserAgent(re.config).String())
req.Header.Add("Accept", "application/json")

if body != nil {
req.Header.Set("Content-Type", "application/json")
}

return req, nil
}

func (new *oktaShimNew) Do(req *http.Request, v interface{}) (interface{}, error) {
return new.client.GetRequestExecutor().Do(new.ctx, req, v)
resp, err := new.cfg.HTTPClient.Do(req)
if err != nil {
return nil, err
}

if resp.Body == nil {
return nil, nil
}
defer resp.Body.Close()

bt, err := io.ReadAll(resp.Body)
err = json.Unmarshal(bt, v)
if err != nil {
return nil, err
}

// as far as i can tell, we only use the first return to check if it is nil, and assume that means an error happened.
return resp, nil
}

type oktaShimOld struct {
client *oktaold.Client
}

func (new *oktaShimOld) Client() (*oktanew.Client, context.Context) {
func (new *oktaShimOld) Client() (*oktanew.APIClient, context.Context) {
return nil, nil
}

Expand All @@ -331,7 +426,25 @@ func (new *oktaShimOld) Do(req *http.Request, v interface{}) (interface{}, error
return new.client.Do(req, v)
}

// OktaClient creates a basic okta client connection
func (c *ConfigEntry) OktaConfiguration(ctx context.Context) (*oktanew.Configuration, error) {
baseURL := defaultBaseURL
if c.Production != nil {
if !*c.Production {
baseURL = previewBaseURL
}
}
if c.BaseURL != "" {
baseURL = c.BaseURL
}

cfg, err := oktanew.NewConfiguration(oktanew.WithOrgUrl("https://"+c.Org+"."+baseURL), oktanew.WithToken(c.Token))
if err != nil {
return nil, err
}
return cfg, nil
}

// OktaClient returns an OktaShim, based on the presence of a token in the ConfigEntry.
func (c *ConfigEntry) OktaClient(ctx context.Context) (oktaShim, error) {
baseURL := defaultBaseURL
if c.Production != nil {
Expand All @@ -344,13 +457,13 @@ func (c *ConfigEntry) OktaClient(ctx context.Context) (oktaShim, error) {
}

if c.Token != "" {
ctx, client, err := oktanew.NewClient(ctx,
cfg, err := oktanew.NewConfiguration(
oktanew.WithOrgUrl("https://"+c.Org+"."+baseURL),
oktanew.WithToken(c.Token))
if err != nil {
return nil, err
}
return &oktaShimNew{client, ctx}, nil
return &oktaShimNew{cfg, oktanew.NewAPIClient(cfg), ctx}, nil
}
client, err := oktaold.NewClientWithDomain(cleanhttp.DefaultClient(), c.Org, baseURL, "")
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions changelog/28121.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
auth/okta: update to okta sdk v4
```
9 changes: 8 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ require (
github.com/mitchellh/reflectwalk v1.0.2
github.com/ncw/swift v1.0.47
github.com/oklog/run v1.1.0
github.com/okta/okta-sdk-golang/v2 v2.20.0
github.com/okta/okta-sdk-golang/v5 v5.0.2
github.com/oracle/oci-go-sdk v24.3.0+incompatible
github.com/ory/dockertest v3.3.5+incompatible
github.com/ory/dockertest/v3 v3.10.0
Expand Down Expand Up @@ -230,11 +230,18 @@ require (
require (
cel.dev/expr v0.15.0 // indirect
cloud.google.com/go/longrunning v0.6.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-viper/mapstructure/v2 v2.1.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
github.com/hashicorp/go-secure-stdlib/httputil v0.1.0 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/jwx v1.2.29 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
Expand Down
Loading