Skip to content

Commit c9852e6

Browse files
authored
Merge pull request #2294 from hajiler/wt-auth-error-fix
Add fail-close on auth error for GKE
2 parents 680e193 + bf43985 commit c9852e6

File tree

3 files changed

+73
-16
lines changed

3 files changed

+73
-16
lines changed

cmd/gce-pd-csi-driver/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ import (
5252

5353
var (
5454
cloudConfigFilePath = flag.String("cloud-config", "", "Path to GCE cloud provider config")
55+
failCloseOnAuthError = flag.Bool("fail-close-on-auth-error", false, "If set to true, the CSI driver will fail its controller service to start if it cannot fetch the initial authentication token during startup. Default is false (legacy behavior).")
5556
endpoint = flag.String("endpoint", "unix:/tmp/csi.sock", "CSI endpoint")
5657
runControllerService = flag.Bool("run-controller-service", true, "If set to false then the CSI driver does not activate its controller service (default: true)")
5758
runNodeService = flag.Bool("run-node-service", true, "If set to false then the CSI driver does not activate its node service (default: true)")
@@ -269,7 +270,7 @@ func handle() {
269270
// Initialize requirements for the controller service
270271
var controllerServer *driver.GCEControllerServer
271272
if *runControllerService {
272-
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment, waitForAttachConfig, listInstancesConfig, *enableMultitenancyFlag)
273+
cloudProvider, err := gce.CreateCloudProvider(ctx, version, *cloudConfigFilePath, computeEndpoint, computeEnvironment, waitForAttachConfig, listInstancesConfig, *enableMultitenancyFlag, *failCloseOnAuthError)
273274
if err != nil {
274275
klog.Fatalf("Failed to get cloud provider: %v", err.Error())
275276
}

pkg/gce-cloud-provider/compute/gce.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ const (
7676
// gcpTagsRequestTokenBucketSize is the burst/token bucket size used
7777
// for limiting API requests.
7878
gcpTagsRequestTokenBucketSize = 8
79+
80+
pollTimeout = 30 * time.Second
7981
)
8082

8183
var (
@@ -148,7 +150,7 @@ type ConfigGlobal struct {
148150
Zone string `gcfg:"zone"`
149151
}
150152

151-
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint *url.URL, computeEnvironment Environment, waitForAttachConfig WaitForAttachConfig, listInstancesConfig ListInstancesConfig, multiTenancyEnabled bool) (*CloudProvider, error) {
153+
func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath string, computeEndpoint *url.URL, computeEnvironment Environment, waitForAttachConfig WaitForAttachConfig, listInstancesConfig ListInstancesConfig, multiTenancyEnabled bool, failCloseOnAuthError bool) (*CloudProvider, error) {
152154
configFile, err := readConfig(configPath)
153155
if err != nil {
154156
return nil, err
@@ -163,13 +165,13 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s
163165
return nil, err
164166
}
165167

166-
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
168+
svc, err := createCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment, failCloseOnAuthError, pollTimeout)
167169
if err != nil {
168170
return nil, err
169171
}
170172
klog.Infof("Compute endpoint for V1 version: %s", svc.BasePath)
171173

172-
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment)
174+
betasvc, err := createBetaCloudService(ctx, vendorVersion, tokenSource, computeEndpoint, computeEnvironment, failCloseOnAuthError, pollTimeout)
173175
if err != nil {
174176
return nil, err
175177
}
@@ -217,7 +219,7 @@ func CreateCloudProvider(ctx context.Context, vendorVersion string, configPath s
217219
return nil, fmt.Errorf("error during tenant token source generation: %w", err)
218220
}
219221

220-
tenantComputeService, err := createCloudService(ctx, vendorVersion, tenantTokenSource, computeEndpoint, computeEnvironment)
222+
tenantComputeService, err := createCloudService(ctx, vendorVersion, tenantTokenSource, computeEndpoint, computeEnvironment, failCloseOnAuthError, pollTimeout)
221223
if err != nil {
222224
klog.Errorf("Error while creating compute service with tenant identity for %s: %v", tenantMeta.TenantName, err)
223225
return nil, fmt.Errorf("error while creating compute service with tenant identity: %w", err)
@@ -291,10 +293,13 @@ func readConfig(configPath string) (*ConfigFile, error) {
291293
return cfg, nil
292294
}
293295

294-
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*computebeta.Service, error) {
295-
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, GCEAPIVersionBeta)
296+
func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, failCloseOnAuthError bool, timeout time.Duration) (*computebeta.Service, error) {
297+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, GCEAPIVersionBeta, timeout)
296298
if err != nil {
297299
klog.Errorf("Failed to get compute endpoint: %s", err)
300+
if failCloseOnAuthError {
301+
return nil, err
302+
}
298303
}
299304
service, err := computebeta.NewService(ctx, computeOpts...)
300305
if err != nil {
@@ -304,10 +309,13 @@ func createBetaCloudService(ctx context.Context, vendorVersion string, tokenSour
304309
return service, nil
305310
}
306311

307-
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment) (*compute.Service, error) {
308-
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, GCEAPIVersionV1)
312+
func createCloudService(ctx context.Context, vendorVersion string, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, failCloseOnAuthError bool, timeout time.Duration) (*compute.Service, error) {
313+
computeOpts, err := getComputeVersion(ctx, tokenSource, computeEndpoint, computeEnvironment, GCEAPIVersionV1, timeout)
309314
if err != nil {
310315
klog.Errorf("Failed to get compute endpoint: %s", err)
316+
if failCloseOnAuthError {
317+
return nil, err
318+
}
311319
}
312320
service, err := compute.NewService(ctx, computeOpts...)
313321
if err != nil {
@@ -317,8 +325,8 @@ func createCloudService(ctx context.Context, vendorVersion string, tokenSource o
317325
return service, nil
318326
}
319327

320-
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, computeVersion GCEAPIVersion) ([]option.ClientOption, error) {
321-
client, err := newOauthClient(ctx, tokenSource)
328+
func getComputeVersion(ctx context.Context, tokenSource oauth2.TokenSource, computeEndpoint *url.URL, computeEnvironment Environment, computeVersion GCEAPIVersion, timeout time.Duration) ([]option.ClientOption, error) {
329+
client, err := newOauthClient(ctx, tokenSource, timeout)
322330
if err != nil {
323331
return nil, err
324332
}
@@ -342,7 +350,7 @@ func constructComputeEndpointPath(env Environment, version GCEAPIVersion) string
342350
}
343351

344352
func createTagValuesClient(ctx context.Context, tokenSource oauth2.TokenSource, resourceManagerHostSubPath string) (*rscmgr.TagValuesClient, error) {
345-
client, err := newOauthClient(ctx, tokenSource)
353+
client, err := newOauthClient(ctx, tokenSource, pollTimeout)
346354
if err != nil {
347355
return nil, err
348356
}
@@ -356,7 +364,7 @@ func createTagValuesClient(ctx context.Context, tokenSource oauth2.TokenSource,
356364
}
357365

358366
func createTagBindingsClient(ctx context.Context, tokenSource oauth2.TokenSource, location string, resourceManagerHostSubPath string) (*rscmgr.TagBindingsClient, error) {
359-
client, err := newOauthClient(ctx, tokenSource)
367+
client, err := newOauthClient(ctx, tokenSource, pollTimeout)
360368
if err != nil {
361369
return nil, err
362370
}
@@ -374,8 +382,8 @@ func createTagBindingsClient(ctx context.Context, tokenSource oauth2.TokenSource
374382
return rscmgr.NewTagBindingsRESTClient(ctx, opts...)
375383
}
376384

377-
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource) (*http.Client, error) {
378-
if err := wait.PollImmediate(5*time.Second, 30*time.Second, func() (bool, error) {
385+
func newOauthClient(ctx context.Context, tokenSource oauth2.TokenSource, timeout time.Duration) (*http.Client, error) {
386+
if err := wait.PollImmediate(5*time.Second, timeout, func() (bool, error) {
379387
if _, err := tokenSource.Token(); err != nil {
380388
klog.Errorf("error fetching initial token: %v", err.Error())
381389
return false, nil

pkg/gce-cloud-provider/compute/gce_test.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func TestGetComputeVersion(t *testing.T) {
202202
}
203203
for _, tc := range testCases {
204204
ctx := context.Background()
205-
computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion)
205+
computeOpts, err := getComputeVersion(ctx, &mockTokenSource{}, tc.computeEndpoint, tc.computeEnvironment, tc.computeVersion, 0)
206206
service, _ := compute.NewService(ctx, computeOpts...)
207207
gotEndpoint := service.BasePath
208208
if err != nil && !tc.expectError {
@@ -222,3 +222,51 @@ func convertStringToURL(urlString string) *url.URL {
222222
}
223223
return parsedURL
224224
}
225+
226+
type errorTokenSource struct{}
227+
228+
func (*errorTokenSource) Token() (*oauth2.Token, error) {
229+
return nil, errors.New("auth failed transient error")
230+
}
231+
232+
func TestCreateCloudService(t *testing.T) {
233+
testCases := []struct {
234+
name string
235+
tokenSource oauth2.TokenSource
236+
failCloseOnAuthError bool
237+
expectError bool
238+
}{
239+
{
240+
name: "failClose=false, auth fails -> Success (legacy behavior, proceed with nil options)",
241+
tokenSource: &errorTokenSource{},
242+
failCloseOnAuthError: false,
243+
expectError: false,
244+
},
245+
{
246+
name: "failClose=true, auth fails -> Error (fail-close behavior)",
247+
tokenSource: &errorTokenSource{},
248+
failCloseOnAuthError: true,
249+
expectError: true,
250+
},
251+
{
252+
name: "failClose=true, auth succeeds -> Success",
253+
tokenSource: &mockTokenSource{},
254+
failCloseOnAuthError: true,
255+
expectError: false,
256+
},
257+
}
258+
259+
for _, tc := range testCases {
260+
t.Run(tc.name, func(t *testing.T) {
261+
ctx := context.Background()
262+
// We use dummy values for other parameters as they are not critical for auth failure testing
263+
_, err := createCloudService(ctx, "test-version", tc.tokenSource, nil, EnvironmentProduction, tc.failCloseOnAuthError, 1*time.Nanosecond)
264+
if tc.expectError && err == nil {
265+
t.Fatalf("Expected error, but got nil")
266+
}
267+
if !tc.expectError && err != nil {
268+
t.Fatalf("Expected no error, but got: %v", err)
269+
}
270+
})
271+
}
272+
}

0 commit comments

Comments
 (0)