diff --git a/cmd/bundle/sync.go b/cmd/bundle/sync.go index 462ca4c19a0..f65a5780bd1 100644 --- a/cmd/bundle/sync.go +++ b/cmd/bundle/sync.go @@ -17,11 +17,13 @@ import ( ) type syncFlags struct { - interval time.Duration - full bool - watch bool - output flags.Output - dryRun bool + interval time.Duration + full bool + watch bool + output flags.Output + dryRun bool + concurrency int + retryTimeout time.Duration } func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) (*sync.SyncOptions, error) { @@ -48,6 +50,8 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, b *bundle.Bundle) opts.Full = f.full opts.PollInterval = f.interval opts.DryRun = f.dryRun + opts.Concurrency = f.concurrency + opts.RetryTimeout = f.retryTimeout return opts, nil } @@ -74,8 +78,17 @@ Use 'databricks bundle deploy' for full resource deployment.`, cmd.Flags().BoolVar(&f.watch, "watch", false, "watch local file system for changes") cmd.Flags().Var(&f.output, "output", "type of the output format") cmd.Flags().BoolVar(&f.dryRun, "dry-run", false, "simulate sync execution without making actual changes") + cmd.Flags().IntVar(&f.concurrency, "concurrency", 5, "maximum number of concurrent in-flight requests during sync") + cmd.Flags().DurationVar(&f.retryTimeout, "retry-timeout", sync.DefaultRetryTimeout, "per-call deadline for retrying transient gateway errors (HTTP 502/503/504)") cmd.RunE = func(cmd *cobra.Command, args []string) error { + if f.concurrency < 1 { + return fmt.Errorf("--concurrency must be a positive integer, got %d", f.concurrency) + } + if f.retryTimeout < 0 { + return fmt.Errorf("--retry-timeout must be non-negative, got %s", f.retryTimeout) + } + b, err := utils.ProcessBundle(cmd, utils.ProcessOptions{}) if err != nil { return err diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index 5c110cecb2b..652d3ecc9db 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -26,15 +26,17 @@ import ( type syncFlags struct { // project files polling interval - interval time.Duration - full bool - watch bool - output flags.Output - exclude []string - include []string - dryRun bool - excludeFrom string - includeFrom string + interval time.Duration + full bool + watch bool + output flags.Output + exclude []string + include []string + dryRun bool + excludeFrom string + includeFrom string + concurrency int + retryTimeout time.Duration } func readPatternsFile(filePath string) ([]string, error) { @@ -89,6 +91,8 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, args []string, b * opts.Include = append(opts.Include, f.include...) opts.Include = append(opts.Include, includePatterns...) opts.DryRun = f.dryRun + opts.Concurrency = f.concurrency + opts.RetryTimeout = f.retryTimeout return opts, nil } @@ -163,6 +167,8 @@ func (f *syncFlags) syncOptionsFromArgs(cmd *cobra.Command, args []string) (*syn OutputHandler: outputHandler, DryRun: f.dryRun, + Concurrency: f.concurrency, + RetryTimeout: f.retryTimeout, } return &opts, nil } @@ -187,6 +193,8 @@ func New() *cobra.Command { cmd.Flags().StringVar(&f.excludeFrom, "exclude-from", "", "file containing patterns to exclude from sync (one pattern per line)") cmd.Flags().StringVar(&f.includeFrom, "include-from", "", "file containing patterns to include to sync (one pattern per line)") cmd.Flags().BoolVar(&f.dryRun, "dry-run", false, "simulate sync execution without making actual changes") + cmd.Flags().IntVar(&f.concurrency, "concurrency", 5, "maximum number of concurrent in-flight requests during sync") + cmd.Flags().DurationVar(&f.retryTimeout, "retry-timeout", sync.DefaultRetryTimeout, "per-call deadline for retrying transient gateway errors (HTTP 502/503/504)") // Wrapper for [root.MustWorkspaceClient] that disables loading authentication configuration from a bundle. mustWorkspaceClient := func(cmd *cobra.Command, args []string) error { @@ -196,6 +204,13 @@ func New() *cobra.Command { cmd.PreRunE = mustWorkspaceClient cmd.RunE = func(cmd *cobra.Command, args []string) error { + if f.concurrency < 1 { + return fmt.Errorf("--concurrency must be a positive integer, got %d", f.concurrency) + } + if f.retryTimeout < 0 { + return fmt.Errorf("--retry-timeout must be non-negative, got %s", f.retryTimeout) + } + var opts *sync.SyncOptions var err error diff --git a/libs/sync/retry.go b/libs/sync/retry.go new file mode 100644 index 00000000000..7342a48cc4f --- /dev/null +++ b/libs/sync/retry.go @@ -0,0 +1,50 @@ +package sync + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/retries" +) + +// DefaultRetryTimeout bounds how long the sync layer keeps retrying transient +// gateway errors per filer call. The SDK only retries 429 and 504 +// (httpclient/errors.go DefaultErrorRetriable); 502 and 503 land here. +const DefaultRetryTimeout = 30 * time.Second + +func isTransientGatewayError(err error) bool { + var aerr *apierr.APIError + if !errors.As(err, &aerr) { + return false + } + switch aerr.StatusCode { + case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return true + } + return false +} + +// retryOnTransient runs fn, retrying transient gateway errors +// (HTTP 502/503/504) until timeout elapses. Backoff and jitter are provided +// by retries.Poll. +func retryOnTransient(ctx context.Context, timeout time.Duration, label string, fn func() error) error { + if timeout <= 0 { + return fn() + } + _, err := retries.Poll(ctx, timeout, func() (*struct{}, *retries.Err) { + err := fn() + if err == nil { + return nil, nil + } + if !isTransientGatewayError(err) { + return nil, retries.Halt(err) + } + log.Warnf(ctx, "sync %s: retrying after transient error: %s", label, err) + return nil, retries.Continue(err) + }) + return err +} diff --git a/libs/sync/retry_test.go b/libs/sync/retry_test.go new file mode 100644 index 00000000000..25896bf3ed2 --- /dev/null +++ b/libs/sync/retry_test.go @@ -0,0 +1,86 @@ +package sync + +import ( + "context" + "errors" + "net/http" + "sync/atomic" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/stretchr/testify/require" +) + +func apiErr(status int) error { + return &apierr.APIError{StatusCode: status, Message: http.StatusText(status)} +} + +func TestIsTransientGatewayError(t *testing.T) { + cases := map[error]bool{ + nil: false, + apiErr(http.StatusBadGateway): true, + apiErr(http.StatusServiceUnavailable): true, + apiErr(http.StatusGatewayTimeout): true, + apiErr(http.StatusInternalServerError): false, + apiErr(http.StatusTooManyRequests): false, + apiErr(http.StatusNotFound): false, + errors.New("not an api error"): false, + } + for err, want := range cases { + require.Equal(t, want, isTransientGatewayError(err), "%v", err) + } +} + +func TestRetryOnTransient(t *testing.T) { + t.Run("succeeds after retries", func(t *testing.T) { + var calls atomic.Int32 + err := retryOnTransient(t.Context(), 30*time.Second, "test", func() error { + if calls.Add(1) < 3 { + return apiErr(http.StatusBadGateway) + } + return nil + }) + require.NoError(t, err) + require.Equal(t, int32(3), calls.Load()) + }) + + t.Run("does not retry non-transient", func(t *testing.T) { + var calls atomic.Int32 + err := retryOnTransient(t.Context(), 30*time.Second, "test", func() error { + calls.Add(1) + return apiErr(http.StatusNotFound) + }) + require.Error(t, err) + require.Equal(t, int32(1), calls.Load()) + }) + + t.Run("zero timeout disables retries", func(t *testing.T) { + var calls atomic.Int32 + err := retryOnTransient(t.Context(), 0, "test", func() error { + calls.Add(1) + return apiErr(http.StatusBadGateway) + }) + require.Error(t, err) + require.Equal(t, int32(1), calls.Load()) + }) + + t.Run("times out on persistent transient error", func(t *testing.T) { + var calls atomic.Int32 + err := retryOnTransient(t.Context(), 100*time.Millisecond, "test", func() error { + calls.Add(1) + return apiErr(http.StatusBadGateway) + }) + require.Error(t, err) + require.GreaterOrEqual(t, calls.Load(), int32(1)) + }) + + t.Run("honors context cancellation", func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + cancel() + err := retryOnTransient(ctx, 30*time.Second, "test", func() error { + return apiErr(http.StatusBadGateway) + }) + require.Error(t, err) + }) +} diff --git a/libs/sync/sync.go b/libs/sync/sync.go index c65b49eb775..3484335237e 100644 --- a/libs/sync/sync.go +++ b/libs/sync/sync.go @@ -43,6 +43,15 @@ type SyncOptions struct { OutputHandler OutputHandler DryRun bool + + // Concurrency is the maximum number of in-flight filer requests during sync. + // Defaults to MaxRequestsInFlight when zero. + Concurrency int + + // RetryTimeout bounds how long each filer call may keep retrying transient + // gateway errors (HTTP 502/503/504). Defaults to DefaultRetryTimeout when + // zero; a negative value disables sync-layer retries. + RetryTimeout time.Duration } type Sync struct { @@ -96,6 +105,13 @@ func New(ctx context.Context, opts SyncOptions) (*Sync, error) { return nil, errors.New("failed to resolve host for snapshot") } + if opts.Concurrency == 0 { + opts.Concurrency = MaxRequestsInFlight + } + if opts.RetryTimeout == 0 { + opts.RetryTimeout = DefaultRetryTimeout + } + // For full sync, we start with an empty snapshot. // For incremental sync, we try to load an existing snapshot to start from. var snapshot *Snapshot @@ -119,7 +135,7 @@ func New(ctx context.Context, opts SyncOptions) (*Sync, error) { var notifier EventNotifier outputWaitGroup := &stdsync.WaitGroup{} if opts.OutputHandler != nil { - ch := make(chan Event, MaxRequestsInFlight) + ch := make(chan Event, opts.Concurrency) notifier = &ChannelNotifier{ch} outputWaitGroup.Go(func() { opts.OutputHandler(ctx, ch) diff --git a/libs/sync/watchdog.go b/libs/sync/watchdog.go index 4a47acfb836..4078dd6eacf 100644 --- a/libs/sync/watchdog.go +++ b/libs/sync/watchdog.go @@ -10,7 +10,8 @@ import ( "golang.org/x/sync/errgroup" ) -// Maximum number of concurrent requests during sync. +// Default maximum number of concurrent requests during sync. Override with +// SyncOptions.Concurrency. const MaxRequestsInFlight = 20 // Delete the specified path. @@ -18,7 +19,9 @@ func (s *Sync) applyDelete(ctx context.Context, remoteName string) error { s.notifyProgress(ctx, EventActionDelete, remoteName, 0.0) if !s.DryRun { - err := s.filer.Delete(ctx, remoteName) + err := retryOnTransient(ctx, s.RetryTimeout, "delete "+remoteName, func() error { + return s.filer.Delete(ctx, remoteName) + }) if err != nil && !errors.Is(err, fs.ErrNotExist) { return err } @@ -33,7 +36,9 @@ func (s *Sync) applyRmdir(ctx context.Context, remoteName string) error { s.notifyProgress(ctx, EventActionDelete, remoteName, 0.0) if !s.DryRun { - err := s.filer.Delete(ctx, remoteName) + err := retryOnTransient(ctx, s.RetryTimeout, "rmdir "+remoteName, func() error { + return s.filer.Delete(ctx, remoteName) + }) if err != nil { // Directory deletion is opportunistic, so we ignore errors. log.Debugf(ctx, "error removing directory %s: %s", remoteName, err) @@ -49,7 +54,9 @@ func (s *Sync) applyMkdir(ctx context.Context, localName string) error { s.notifyProgress(ctx, EventActionPut, localName, 0.0) if !s.DryRun { - err := s.filer.Mkdir(ctx, localName) + err := retryOnTransient(ctx, s.RetryTimeout, "mkdir "+localName, func() error { + return s.filer.Mkdir(ctx, localName) + }) if err != nil { return err } @@ -63,16 +70,24 @@ func (s *Sync) applyMkdir(ctx context.Context, localName string) error { func (s *Sync) applyPut(ctx context.Context, localName string) error { s.notifyProgress(ctx, EventActionPut, localName, 0.0) - localFile, err := s.LocalRoot.Open(localName) + // Surface a missing/unreadable local file even on dry-run. + f, err := s.LocalRoot.Open(localName) if err != nil { return err } - - defer localFile.Close() + f.Close() if !s.DryRun { opts := []filer.WriteMode{filer.CreateParentDirectories, filer.OverwriteIfExists} - err = s.filer.Write(ctx, localName, localFile, opts...) + err := retryOnTransient(ctx, s.RetryTimeout, "put "+localName, func() error { + // Re-open per attempt: filer.Write consumes the reader. + f, err := s.LocalRoot.Open(localName) + if err != nil { + return err + } + defer f.Close() + return s.filer.Write(ctx, localName, f, opts...) + }) if err != nil { return err } @@ -96,9 +111,9 @@ func groupRunSingle(ctx context.Context, group *errgroup.Group, fn func(context. }) } -func groupRunParallel(ctx context.Context, paths []string, fn func(context.Context, string) error) error { +func groupRunParallel(ctx context.Context, paths []string, limit int, fn func(context.Context, string) error) error { group, ctx := errgroup.WithContext(ctx) - group.SetLimit(MaxRequestsInFlight) + group.SetLimit(limit) for _, path := range paths { groupRunSingle(ctx, group, fn, path) @@ -109,17 +124,17 @@ func groupRunParallel(ctx context.Context, paths []string, fn func(context.Conte } func (s *Sync) applyDiff(ctx context.Context, d diff) error { - var err error + limit := s.Concurrency // Delete files in parallel. - err = groupRunParallel(ctx, d.delete, s.applyDelete) + err := groupRunParallel(ctx, d.delete, limit, s.applyDelete) if err != nil { return err } // Delete directories ordered by depth from leaf to root. for _, group := range d.groupedRmdir() { - err = groupRunParallel(ctx, group, s.applyRmdir) + err = groupRunParallel(ctx, group, limit, s.applyRmdir) if err != nil { return err } @@ -127,14 +142,12 @@ func (s *Sync) applyDiff(ctx context.Context, d diff) error { // Create directories (leafs only because intermediates are created automatically). for _, group := range d.groupedMkdir() { - err = groupRunParallel(ctx, group, s.applyMkdir) + err = groupRunParallel(ctx, group, limit, s.applyMkdir) if err != nil { return err } } // Put files in parallel. - err = groupRunParallel(ctx, d.put, s.applyPut) - - return err + return groupRunParallel(ctx, d.put, limit, s.applyPut) }