From 9be4e48fc500a2e1aa0ca85ee2969b0f3b13a08b Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Thu, 8 Feb 2024 17:14:44 +0100 Subject: [PATCH 1/3] Added --restart flag for bundle run command --- bundle/run/job.go | 39 +++++++++++++ bundle/run/pipeline.go | 14 +++++ bundle/run/runner.go | 3 + cmd/bundle/run.go | 10 ++++ internal/bundle/bind_resource_test.go | 81 +++++++++++++++++++++++++++ 5 files changed, 147 insertions(+) create mode 100644 internal/bundle/bind_resource_test.go diff --git a/bundle/run/job.go b/bundle/run/job.go index a54279c11f..2ce326327a 100644 --- a/bundle/run/job.go +++ b/bundle/run/job.go @@ -15,6 +15,7 @@ import ( "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/fatih/color" + "golang.org/x/sync/errgroup" ) // Default timeout for waiting for a job run to complete. @@ -275,3 +276,41 @@ func (r *jobRunner) convertPythonParams(opts *Options) error { return nil } + +func (r *jobRunner) Cancel(ctx context.Context) error { + w := r.bundle.WorkspaceClient() + jobID, err := strconv.ParseInt(r.job.ID, 10, 64) + if err != nil { + return fmt.Errorf("job ID is not an integer: %s", r.job.ID) + } + + runs, err := w.Jobs.ListRunsAll(ctx, jobs.ListRunsRequest{ + ActiveOnly: true, + JobId: jobID, + }) + + if err != nil { + return err + } + + if len(runs) == 0 { + return nil + } + + errGroup, errCtx := errgroup.WithContext(ctx) + for _, run := range runs { + runId := run.RunId + errGroup.Go(func() error { + wait, err := w.Jobs.CancelRun(errCtx, jobs.CancelRun{ + RunId: runId, + }) + if err != nil { + return err + } + _, err = wait.GetWithTimeout(jobRunTimeout) + return err + }) + } + + return errGroup.Wait() +} diff --git a/bundle/run/pipeline.go b/bundle/run/pipeline.go index 342a771b13..b5e289da14 100644 --- a/bundle/run/pipeline.go +++ b/bundle/run/pipeline.go @@ -166,3 +166,17 @@ func (r *pipelineRunner) Run(ctx context.Context, opts *Options) (output.RunOutp time.Sleep(time.Second) } } + +func (r *pipelineRunner) Cancel(ctx context.Context) error { + w := r.bundle.WorkspaceClient() + wait, err := w.Pipelines.Stop(ctx, pipelines.StopRequest{ + PipelineId: r.pipeline.ID, + }) + + if err != nil { + return err + } + + _, err = wait.GetWithTimeout(jobRunTimeout) + return err +} diff --git a/bundle/run/runner.go b/bundle/run/runner.go index 7d3c2c297a..de2a1ae7a6 100644 --- a/bundle/run/runner.go +++ b/bundle/run/runner.go @@ -26,6 +26,9 @@ type Runner interface { // Run the underlying worklow. Run(ctx context.Context, opts *Options) (output.RunOutput, error) + + // Cancel the underlying workflow. + Cancel(ctx context.Context) error } // Find locates a runner matching the specified argument. diff --git a/cmd/bundle/run.go b/cmd/bundle/run.go index a4b1065882..e8ecf4b7d0 100644 --- a/cmd/bundle/run.go +++ b/cmd/bundle/run.go @@ -27,7 +27,9 @@ func newRunCommand() *cobra.Command { runOptions.Define(cmd) var noWait bool + var restart bool cmd.Flags().BoolVar(&noWait, "no-wait", false, "Don't wait for the run to complete.") + cmd.Flags().BoolVar(&restart, "restart", false, "Restart the run if it is already running.") cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -68,6 +70,14 @@ func newRunCommand() *cobra.Command { } runOptions.NoWait = noWait + if restart { + cmdio.LogString(ctx, "Cancelling the run...") + err := runner.Cancel(ctx) + if err != nil { + return err + } + cmdio.LogString(ctx, "All runs have been cancelled, starting a new run") + } output, err := runner.Run(ctx, &runOptions) if err != nil { return err diff --git a/internal/bundle/bind_resource_test.go b/internal/bundle/bind_resource_test.go new file mode 100644 index 0000000000..6337e2dfd4 --- /dev/null +++ b/internal/bundle/bind_resource_test.go @@ -0,0 +1,81 @@ +package bundle + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/internal" + "github.com/databricks/cli/internal/acc" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestAccBindJobToExistingJob(t *testing.T) { + env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(env) + + ctx, wt := acc.WorkspaceTest(t) + gt := &generateJobTest{T: t, w: wt.W} + + nodeTypeId := internal.GetNodeTypeId(env) + uniqueId := uuid.New().String() + bundleRoot, err := initTestTemplate(t, ctx, "basic", map[string]any{ + "unique_id": uniqueId, + "spark_version": "13.3.x-scala2.12", + "node_type_id": nodeTypeId, + }) + require.NoError(t, err) + + jobId := gt.createTestJob(ctx) + t.Cleanup(func() { + gt.destroyJob(ctx, jobId) + require.NoError(t, err) + }) + + t.Setenv("BUNDLE_ROOT", bundleRoot) + c := internal.NewCobraTestRunner(t, "bundle", "deployment", "bind", "foo", fmt.Sprint(jobId), "--auto-approve") + _, _, err = c.Run() + require.NoError(t, err) + + // Remove .databricks directory to simulate a fresh deployment + err = os.RemoveAll(filepath.Join(bundleRoot, ".databricks")) + require.NoError(t, err) + + err = deployBundle(t, ctx, bundleRoot) + require.NoError(t, err) + + w, err := databricks.NewWorkspaceClient() + require.NoError(t, err) + + // Check that job is bound and updated with config from bundle + job, err := w.Jobs.Get(ctx, jobs.GetJobRequest{ + JobId: jobId, + }) + require.NoError(t, err) + require.Equal(t, job.Settings.Name, fmt.Sprintf("test-job-basic-%s", uniqueId)) + require.Contains(t, job.Settings.Tasks[0].SparkPythonTask.PythonFile, "hello_world.py") + + c = internal.NewCobraTestRunner(t, "bundle", "deployment", "unbind", "foo") + _, _, err = c.Run() + require.NoError(t, err) + + // Remove .databricks directory to simulate a fresh deployment + err = os.RemoveAll(filepath.Join(bundleRoot, ".databricks")) + require.NoError(t, err) + + err = destroyBundle(t, ctx, bundleRoot) + require.NoError(t, err) + + // Check that job is unbound and exists after bundle is destroyed + job, err = w.Jobs.Get(ctx, jobs.GetJobRequest{ + JobId: jobId, + }) + require.NoError(t, err) + require.Equal(t, job.Settings.Name, fmt.Sprintf("test-job-basic-%s", uniqueId)) + require.Contains(t, job.Settings.Tasks[0].SparkPythonTask.PythonFile, "hello_world.py") + +} From 3d40b04ee4fe2e811daf7c803a3a1dd65cabcbe7 Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Thu, 8 Feb 2024 17:39:21 +0100 Subject: [PATCH 2/3] fix --- bundle/run/job.go | 1 + bundle/run/pipeline.go | 1 + cmd/bundle/run.go | 5 +- internal/bundle/bind_resource_test.go | 81 --------------------------- 4 files changed, 5 insertions(+), 83 deletions(-) delete mode 100644 internal/bundle/bind_resource_test.go diff --git a/bundle/run/job.go b/bundle/run/job.go index 2ce326327a..043ea846a2 100644 --- a/bundle/run/job.go +++ b/bundle/run/job.go @@ -307,6 +307,7 @@ func (r *jobRunner) Cancel(ctx context.Context) error { if err != nil { return err } + // Waits for the Terminated or Skipped state _, err = wait.GetWithTimeout(jobRunTimeout) return err }) diff --git a/bundle/run/pipeline.go b/bundle/run/pipeline.go index b5e289da14..e1f5bfe5f2 100644 --- a/bundle/run/pipeline.go +++ b/bundle/run/pipeline.go @@ -177,6 +177,7 @@ func (r *pipelineRunner) Cancel(ctx context.Context) error { return err } + // Waits for the Idle state of the pipeline _, err = wait.GetWithTimeout(jobRunTimeout) return err } diff --git a/cmd/bundle/run.go b/cmd/bundle/run.go index e8ecf4b7d0..c1a8d4ea92 100644 --- a/cmd/bundle/run.go +++ b/cmd/bundle/run.go @@ -71,12 +71,13 @@ func newRunCommand() *cobra.Command { runOptions.NoWait = noWait if restart { - cmdio.LogString(ctx, "Cancelling the run...") + s := cmdio.Spinner(ctx) + s <- "Cancelling all runs" err := runner.Cancel(ctx) + close(s) if err != nil { return err } - cmdio.LogString(ctx, "All runs have been cancelled, starting a new run") } output, err := runner.Run(ctx, &runOptions) if err != nil { diff --git a/internal/bundle/bind_resource_test.go b/internal/bundle/bind_resource_test.go deleted file mode 100644 index 6337e2dfd4..0000000000 --- a/internal/bundle/bind_resource_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package bundle - -import ( - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/databricks/cli/internal" - "github.com/databricks/cli/internal/acc" - "github.com/databricks/databricks-sdk-go" - "github.com/databricks/databricks-sdk-go/service/jobs" - "github.com/google/uuid" - "github.com/stretchr/testify/require" -) - -func TestAccBindJobToExistingJob(t *testing.T) { - env := internal.GetEnvOrSkipTest(t, "CLOUD_ENV") - t.Log(env) - - ctx, wt := acc.WorkspaceTest(t) - gt := &generateJobTest{T: t, w: wt.W} - - nodeTypeId := internal.GetNodeTypeId(env) - uniqueId := uuid.New().String() - bundleRoot, err := initTestTemplate(t, ctx, "basic", map[string]any{ - "unique_id": uniqueId, - "spark_version": "13.3.x-scala2.12", - "node_type_id": nodeTypeId, - }) - require.NoError(t, err) - - jobId := gt.createTestJob(ctx) - t.Cleanup(func() { - gt.destroyJob(ctx, jobId) - require.NoError(t, err) - }) - - t.Setenv("BUNDLE_ROOT", bundleRoot) - c := internal.NewCobraTestRunner(t, "bundle", "deployment", "bind", "foo", fmt.Sprint(jobId), "--auto-approve") - _, _, err = c.Run() - require.NoError(t, err) - - // Remove .databricks directory to simulate a fresh deployment - err = os.RemoveAll(filepath.Join(bundleRoot, ".databricks")) - require.NoError(t, err) - - err = deployBundle(t, ctx, bundleRoot) - require.NoError(t, err) - - w, err := databricks.NewWorkspaceClient() - require.NoError(t, err) - - // Check that job is bound and updated with config from bundle - job, err := w.Jobs.Get(ctx, jobs.GetJobRequest{ - JobId: jobId, - }) - require.NoError(t, err) - require.Equal(t, job.Settings.Name, fmt.Sprintf("test-job-basic-%s", uniqueId)) - require.Contains(t, job.Settings.Tasks[0].SparkPythonTask.PythonFile, "hello_world.py") - - c = internal.NewCobraTestRunner(t, "bundle", "deployment", "unbind", "foo") - _, _, err = c.Run() - require.NoError(t, err) - - // Remove .databricks directory to simulate a fresh deployment - err = os.RemoveAll(filepath.Join(bundleRoot, ".databricks")) - require.NoError(t, err) - - err = destroyBundle(t, ctx, bundleRoot) - require.NoError(t, err) - - // Check that job is unbound and exists after bundle is destroyed - job, err = w.Jobs.Get(ctx, jobs.GetJobRequest{ - JobId: jobId, - }) - require.NoError(t, err) - require.Equal(t, job.Settings.Name, fmt.Sprintf("test-job-basic-%s", uniqueId)) - require.Contains(t, job.Settings.Tasks[0].SparkPythonTask.PythonFile, "hello_world.py") - -} From 51ef679f966ebc9e022cc9e2878315628990ea8d Mon Sep 17 00:00:00 2001 From: Andrew Nester Date: Fri, 9 Feb 2024 11:36:51 +0100 Subject: [PATCH 3/3] Added unit tests --- bundle/run/job_test.go | 79 +++++++++++++++++++++++++++++++++++++ bundle/run/pipeline_test.go | 49 +++++++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 bundle/run/pipeline_test.go diff --git a/bundle/run/job_test.go b/bundle/run/job_test.go index e4cb4e7e82..be189306b2 100644 --- a/bundle/run/job_test.go +++ b/bundle/run/job_test.go @@ -1,12 +1,16 @@ package run import ( + "context" "testing" + "time" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -47,3 +51,78 @@ func TestConvertPythonParams(t *testing.T) { require.Contains(t, opts.Job.notebookParams, "__python_params") require.Equal(t, opts.Job.notebookParams["__python_params"], `["param1","param2","param3"]`) } + +func TestJobRunnerCancel(t *testing.T) { + job := &resources.Job{ + ID: "123", + } + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "test_job": job, + }, + }, + }, + } + + runner := jobRunner{key: "test", bundle: b, job: job} + + m := mocks.NewMockWorkspaceClient(t) + b.SetWorkpaceClient(m.WorkspaceClient) + + jobApi := m.GetMockJobsAPI() + jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{ + ActiveOnly: true, + JobId: 123, + }).Return([]jobs.BaseRun{ + {RunId: 1}, + {RunId: 2}, + }, nil) + + mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{ + Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) { + return nil, nil + }, + } + jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{ + RunId: 1, + }).Return(mockWait, nil) + jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{ + RunId: 2, + }).Return(mockWait, nil) + + err := runner.Cancel(context.Background()) + require.NoError(t, err) +} + +func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) { + job := &resources.Job{ + ID: "123", + } + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "test_job": job, + }, + }, + }, + } + + runner := jobRunner{key: "test", bundle: b, job: job} + + m := mocks.NewMockWorkspaceClient(t) + b.SetWorkpaceClient(m.WorkspaceClient) + + jobApi := m.GetMockJobsAPI() + jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{ + ActiveOnly: true, + JobId: 123, + }).Return([]jobs.BaseRun{}, nil) + + jobApi.AssertNotCalled(t, "CancelRun") + + err := runner.Cancel(context.Background()) + require.NoError(t, err) +} diff --git a/bundle/run/pipeline_test.go b/bundle/run/pipeline_test.go new file mode 100644 index 0000000000..29b57ffdb2 --- /dev/null +++ b/bundle/run/pipeline_test.go @@ -0,0 +1,49 @@ +package run + +import ( + "context" + "testing" + "time" + + "github.com/databricks/cli/bundle" + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/stretchr/testify/require" +) + +func TestPipelineRunnerCancel(t *testing.T) { + pipeline := &resources.Pipeline{ + ID: "123", + } + + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Pipelines: map[string]*resources.Pipeline{ + "test_pipeline": pipeline, + }, + }, + }, + } + + runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline} + + m := mocks.NewMockWorkspaceClient(t) + b.SetWorkpaceClient(m.WorkspaceClient) + + mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{ + Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) { + return nil, nil + }, + } + + pipelineApi := m.GetMockPipelinesAPI() + pipelineApi.EXPECT().Stop(context.Background(), pipelines.StopRequest{ + PipelineId: "123", + }).Return(mockWait, nil) + + err := runner.Cancel(context.Background()) + require.NoError(t, err) +}