Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
govulncheck

docsPython

yazi
];
};
}
Expand Down
102 changes: 75 additions & 27 deletions internal/coderbootstrap/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"errors"
"net/http"
"net/url"
"time"

"golang.org/x/xerrors"
Expand All @@ -17,6 +16,65 @@ const (
coderSDKRequestTimeout = 30 * time.Second
)

type bypassRateLimitContextKey struct{}

func withRateLimitBypass(ctx context.Context) context.Context {
if ctx == nil {
return nil
}
return context.WithValue(ctx, bypassRateLimitContextKey{}, true)
}

func shouldBypassRateLimit(ctx context.Context) bool {
if ctx == nil {
return false
}
enabled, _ := ctx.Value(bypassRateLimitContextKey{}).(bool)
return enabled
}

func isRateLimitBypassRejected(err error) bool {
var apiErr *codersdk.Error
return errors.As(err, &apiErr) && apiErr.StatusCode() == http.StatusPreconditionRequired
}

// IsRateLimitError reports whether err (or any wrapped cause) is a codersdk
// API error with HTTP 429 Too Many Requests.
func IsRateLimitError(err error) bool {
var apiErr *codersdk.Error
return errors.As(err, &apiErr) && apiErr.StatusCode() == http.StatusTooManyRequests
}

func withOptionalRateLimitBypass[T any](ctx context.Context, operation func(context.Context) (T, error)) (T, error) {
result, err := operation(withRateLimitBypass(ctx))
if err == nil {
return result, nil
}
if !isRateLimitBypassRejected(err) {
return result, err
}
return operation(ctx)
}

type bypassRateLimitRoundTripper struct {
base http.RoundTripper
}

func (rt bypassRateLimitRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if req == nil {
return nil, xerrors.New("assertion failed: request must not be nil")
}
if !shouldBypassRateLimit(req.Context()) {
return rt.base.RoundTrip(req)
}

cloned := req.Clone(req.Context())
cloned.Header = req.Header.Clone()
cloned.Header.Set(codersdk.BypassRatelimitHeader, "true")

return rt.base.RoundTrip(cloned)
}

// RegisterWorkspaceProxyRequest describes how to register a workspace proxy in Coder.
type RegisterWorkspaceProxyRequest struct {
CoderURL string
Expand Down Expand Up @@ -60,24 +118,10 @@ func (c *SDKClient) EnsureWorkspaceProxy(ctx context.Context, req RegisterWorksp
return RegisterWorkspaceProxyResponse{}, xerrors.New("proxy name is required")
}

coderURL, err := url.Parse(req.CoderURL)
client, err := newAuthenticatedClient(req.CoderURL, req.SessionToken)
if err != nil {
return RegisterWorkspaceProxyResponse{}, xerrors.Errorf("parse coder URL: %w", err)
}

client := codersdk.New(coderURL)
client.SetSessionToken(req.SessionToken)
if client.HTTPClient == nil {
client.HTTPClient = &http.Client{}
}
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return RegisterWorkspaceProxyResponse{}, xerrors.New("assertion failed: http.DefaultTransport is not *http.Transport")
return RegisterWorkspaceProxyResponse{}, err
}
// Use a dedicated transport to avoid sharing http.DefaultTransport's
// connection pool across parallel test servers.
client.HTTPClient.Transport = defaultTransport.Clone()
client.HTTPClient.Timeout = coderSDKRequestTimeout

existing, err := client.WorkspaceProxyByName(ctx, req.ProxyName)
if err != nil {
Expand All @@ -86,10 +130,12 @@ func (c *SDKClient) EnsureWorkspaceProxy(ctx context.Context, req RegisterWorksp
return RegisterWorkspaceProxyResponse{}, xerrors.Errorf("query workspace proxy %q: %w", req.ProxyName, err)
}

created, createErr := client.CreateWorkspaceProxy(ctx, codersdk.CreateWorkspaceProxyRequest{
Name: req.ProxyName,
DisplayName: req.DisplayName,
Icon: req.Icon,
created, createErr := withOptionalRateLimitBypass(ctx, func(requestCtx context.Context) (codersdk.UpdateWorkspaceProxyResponse, error) {
return client.CreateWorkspaceProxy(requestCtx, codersdk.CreateWorkspaceProxyRequest{
Name: req.ProxyName,
DisplayName: req.DisplayName,
Icon: req.Icon,
})
})
if createErr != nil {
return RegisterWorkspaceProxyResponse{}, xerrors.Errorf("create workspace proxy %q: %w", req.ProxyName, createErr)
Expand All @@ -109,12 +155,14 @@ func (c *SDKClient) EnsureWorkspaceProxy(ctx context.Context, req RegisterWorksp
icon = existing.IconURL
}

updated, err := client.PatchWorkspaceProxy(ctx, codersdk.PatchWorkspaceProxy{
ID: existing.ID,
Name: existing.Name,
DisplayName: displayName,
Icon: icon,
RegenerateToken: true,
updated, err := withOptionalRateLimitBypass(ctx, func(requestCtx context.Context) (codersdk.UpdateWorkspaceProxyResponse, error) {
return client.PatchWorkspaceProxy(requestCtx, codersdk.PatchWorkspaceProxy{
ID: existing.ID,
Name: existing.Name,
DisplayName: displayName,
Icon: icon,
RegenerateToken: true,
})
})
if err != nil {
return RegisterWorkspaceProxyResponse{}, xerrors.Errorf("update workspace proxy %q: %w", req.ProxyName, err)
Expand Down
67 changes: 67 additions & 0 deletions internal/coderbootstrap/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/coder/coder/v2/codersdk"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -200,6 +201,72 @@ func TestEnsureWorkspaceProxyUpdatesExistingProxy(t *testing.T) {
require.Equal(t, "token-updated", result.ProxyToken)
}

func TestEnsureWorkspaceProxyCreateFallsBackWhenRateLimitBypassRejected(t *testing.T) {
t.Parallel()

const proxyName = "proxy-fallback"
now := time.Now().UTC().Format(time.RFC3339)
proxyID := uuid.NewString()
createCalls := 0
bypassRejectedCalls := 0

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/workspaceproxies/"+proxyName:
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(map[string]string{"message": "not found"})
return
case r.Method == http.MethodPost && r.URL.Path == "/api/v2/workspaceproxies":
createCalls++
if r.Header.Get(codersdk.BypassRatelimitHeader) == "true" {
bypassRejectedCalls++
w.WriteHeader(http.StatusPreconditionRequired)
_ = json.NewEncoder(w).Encode(map[string]string{"message": "bypass is not allowed"})
return
}

response := proxyResponse{}
response.Proxy.ID = proxyID
response.Proxy.Name = proxyName
response.Proxy.DisplayName = "Proxy Fallback"
response.Proxy.IconURL = "/emojis/1f5fa.png"
response.Proxy.Healthy = true
response.Proxy.PathAppURL = "https://proxy-fallback.example.com"
response.Proxy.WildcardHostname = "*.proxy-fallback.example.com"
response.Proxy.Status.Status = "unregistered"
response.Proxy.Status.CheckedAt = now
response.Proxy.CreatedAt = now
response.Proxy.UpdatedAt = now
response.Proxy.Version = "2.0.0"
response.ProxyToken = "token-created"

w.WriteHeader(http.StatusCreated)
err := json.NewEncoder(w).Encode(response)
require.NoError(t, err)
return
default:
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(map[string]string{"message": "unexpected route"})
return
}
}))
defer server.Close()

client := coderbootstrap.NewSDKClient()
result, err := client.EnsureWorkspaceProxy(context.Background(), coderbootstrap.RegisterWorkspaceProxyRequest{
CoderURL: server.URL,
SessionToken: "session-token",
ProxyName: proxyName,
DisplayName: "Proxy Fallback",
Icon: "/emojis/1f5fa.png",
})
require.NoError(t, err)
require.Equal(t, proxyName, result.ProxyName)
require.Equal(t, "token-created", result.ProxyToken)
require.Equal(t, 2, createCalls)
require.Equal(t, 1, bypassRejectedCalls)
}

func TestEnsureWorkspaceProxyValidatesInputs(t *testing.T) {
t.Parallel()

Expand Down
18 changes: 12 additions & 6 deletions internal/coderbootstrap/provisionerkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ func (c *SDKClient) EnsureProvisionerKey(ctx context.Context, req EnsureProvisio
}, nil
}

created, err := client.CreateProvisionerKey(ctx, organization.ID, codersdk.CreateProvisionerKeyRequest{
Name: req.KeyName,
Tags: req.Tags,
created, err := withOptionalRateLimitBypass(ctx, func(requestCtx context.Context) (codersdk.CreateProvisionerKeyResponse, error) {
return client.CreateProvisionerKey(requestCtx, organization.ID, codersdk.CreateProvisionerKeyRequest{
Name: req.KeyName,
Tags: req.Tags,
})
})
if err != nil {
return EnsureProvisionerKeyResponse{}, xerrors.Errorf("create provisioner key %q: %w", req.KeyName, err)
Expand All @@ -78,7 +80,9 @@ func (c *SDKClient) EnsureProvisionerKey(ctx context.Context, req EnsureProvisio
return EnsureProvisionerKeyResponse{}, xerrors.Errorf("assertion failed: created provisioner key %q returned an empty key", req.KeyName)
}

createdMetadata, err := findOrganizationProvisionerKey(ctx, client, organization.ID, req.KeyName)
createdMetadata, err := withOptionalRateLimitBypass(ctx, func(requestCtx context.Context) (*codersdk.ProvisionerKey, error) {
return findOrganizationProvisionerKey(requestCtx, client, organization.ID, req.KeyName)
})
if err != nil {
return EnsureProvisionerKeyResponse{}, xerrors.Errorf("query created provisioner key %q: %w", req.KeyName, err)
}
Expand Down Expand Up @@ -119,7 +123,9 @@ func (c *SDKClient) DeleteProvisionerKey(ctx context.Context, coderURL, sessionT
return err
}

err = client.DeleteProvisionerKey(ctx, organization.ID, keyName)
_, err = withOptionalRateLimitBypass(ctx, func(requestCtx context.Context) (struct{}, error) {
return struct{}{}, client.DeleteProvisionerKey(requestCtx, organization.ID, keyName)
})
if err == nil {
return nil
}
Expand Down Expand Up @@ -166,7 +172,7 @@ func newAuthenticatedClient(coderURL, sessionToken string) (*codersdk.Client, er
}
// Use a dedicated transport to avoid sharing http.DefaultTransport's
// connection pool across parallel test servers.
client.HTTPClient.Transport = defaultTransport.Clone()
client.HTTPClient.Transport = bypassRateLimitRoundTripper{base: defaultTransport.Clone()}
client.HTTPClient.Timeout = coderSDKRequestTimeout

return client, nil
Expand Down
Loading