Skip to content
100 changes: 100 additions & 0 deletions bundle/config/mutator/trampoline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package mutator

import (
"context"
"fmt"
"os"
"path"
"path/filepath"
"text/template"

"github.com/databricks/cli/bundle"
"github.com/databricks/databricks-sdk-go/service/jobs"
)

type TaskWithJobKey struct {
Task *jobs.Task
JobKey string
}

type TrampolineFunctions interface {
GetTemplateData(task *jobs.Task) (map[string]any, error)
GetTasks(b *bundle.Bundle) []TaskWithJobKey
CleanUp(task *jobs.Task) error
}
type trampoline struct {
name string
functions TrampolineFunctions
template string
}

func NewTrampoline(
name string,
functions TrampolineFunctions,
template string,
) *trampoline {
return &trampoline{name, functions, template}
}

func (m *trampoline) Name() string {
return fmt.Sprintf("trampoline(%s)", m.name)
}

func (m *trampoline) Apply(ctx context.Context, b *bundle.Bundle) error {
tasks := m.functions.GetTasks(b)
for _, task := range tasks {
err := m.generateNotebookWrapper(b, task)
if err != nil {
return err
}
}
return nil
}

func (m *trampoline) generateNotebookWrapper(b *bundle.Bundle, task TaskWithJobKey) error {
internalDir, err := b.InternalDir()
if err != nil {
return err
}

notebookName := fmt.Sprintf("notebook_%s_%s", task.JobKey, task.Task.TaskKey)
localNotebookPath := filepath.Join(internalDir, notebookName+".py")

err = os.MkdirAll(filepath.Dir(localNotebookPath), 0755)
if err != nil {
return err
}

f, err := os.Create(localNotebookPath)
if err != nil {
return err
}
defer f.Close()

data, err := m.functions.GetTemplateData(task.Task)
if err != nil {
return err
}

t, err := template.New(notebookName).Parse(m.template)
if err != nil {
return err
}

internalDirRel, err := filepath.Rel(b.Config.Path, internalDir)
if err != nil {
return err
}

err = m.functions.CleanUp(task.Task)
if err != nil {
return err
}
remotePath := path.Join(b.Config.Workspace.FilesPath, filepath.ToSlash(internalDirRel), notebookName)

task.Task.NotebookTask = &jobs.NotebookTask{
NotebookPath: remotePath,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on how we expect the wheel args, or task params, to be passed in the template?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a reference example, in Python trampoline we do it with sys.argv as you can see here
https://github.com/databricks/cli/pull/635/files#diff-8ddf2564fcd580399d61df5283d0c0c7b614c476e0803f63f9f785b62dc69a04R23

I'd expect it should work similarly for task params

}

return t.Execute(f, data)
}
97 changes: 97 additions & 0 deletions bundle/config/mutator/trampoline_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package mutator

import (
"context"
"fmt"
"os"
"path/filepath"
"testing"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)

type functions struct{}

func (f *functions) GetTasks(b *bundle.Bundle) []TaskWithJobKey {
tasks := make([]TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs["test"].Tasks {
tasks = append(tasks, TaskWithJobKey{
JobKey: "test",
Task: &b.Config.Resources.Jobs["test"].Tasks[k],
})
}

return tasks
}

func (f *functions) GetTemplateData(task *jobs.Task) (map[string]any, error) {
if task.PythonWheelTask == nil {
return nil, fmt.Errorf("PythonWheelTask cannot be nil")
}

data := make(map[string]any)
data["MyName"] = "Trampoline"
return data, nil
}

func (f *functions) CleanUp(task *jobs.Task) error {
task.PythonWheelTask = nil
return nil
}

func TestGenerateTrampoline(t *testing.T) {
tmpDir := t.TempDir()

tasks := []jobs.Task{
{
TaskKey: "to_trampoline",
PythonWheelTask: &jobs.PythonWheelTask{
PackageName: "test",
EntryPoint: "run",
}},
}

b := &bundle.Bundle{
Config: config.Root{
Path: tmpDir,
Bundle: config.Bundle{
Target: "development",
},
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test": {
Paths: resources.Paths{
ConfigFilePath: tmpDir,
},
JobSettings: &jobs.JobSettings{
Tasks: tasks,
},
},
},
},
},
}
ctx := context.Background()

funcs := functions{}
trampoline := NewTrampoline("test_trampoline", &funcs, "Hello from {{.MyName}}")
err := bundle.Apply(ctx, b, trampoline)
require.NoError(t, err)

dir, err := b.InternalDir()
require.NoError(t, err)
filename := filepath.Join(dir, "notebook_test_to_trampoline.py")

bytes, err := os.ReadFile(filename)
require.NoError(t, err)

require.Equal(t, "Hello from Trampoline", string(bytes))

task := b.Config.Resources.Jobs["test"].Tasks[0]
require.Equal(t, task.NotebookTask.NotebookPath, ".databricks/bundle/development/.internal/notebook_test_to_trampoline")
require.Nil(t, task.PythonWheelTask)
}
4 changes: 3 additions & 1 deletion bundle/phases/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/databricks/cli/bundle/deploy/lock"
"github.com/databricks/cli/bundle/deploy/terraform"
"github.com/databricks/cli/bundle/libraries"
"github.com/databricks/cli/bundle/python"
)

// The deploy phase deploys artifacts and resources.
Expand All @@ -17,10 +18,11 @@ func Deploy() bundle.Mutator {
bundle.Defer(
bundle.Seq(
mutator.ValidateGitDetails(),
files.Upload(),
libraries.MatchWithArtifacts(),
artifacts.CleanUp(),
artifacts.UploadAll(),
python.TransformWheelTask(),
files.Upload(),
terraform.Interpolate(),
terraform.Write(),
terraform.StatePull(),
Expand Down
109 changes: 109 additions & 0 deletions bundle/python/transform.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package python

import (
"fmt"
"strconv"
"strings"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/databricks-sdk-go/service/jobs"
)

const NOTEBOOK_TEMPLATE = `# Databricks notebook source
%python
{{range .Libraries}}
%pip install --force-reinstall {{.Whl}}
{{end}}

try:
from importlib import metadata
except ImportError: # for Python<3.8
import subprocess
import sys

subprocess.check_call([sys.executable, "-m", "pip", "install", "importlib-metadata"])
import importlib_metadata as metadata

from contextlib import redirect_stdout
import io
import sys
sys.argv = [{{.Params}}]

entry = [ep for ep in metadata.distribution("{{.Task.PackageName}}").entry_points if ep.name == "{{.Task.EntryPoint}}"]

f = io.StringIO()
with redirect_stdout(f):
if entry:
entry[0].load()()
else:
raise ImportError("Entry point '{{.Task.EntryPoint}}' not found")
s = f.getvalue()
dbutils.notebook.exit(s)
`

// This mutator takes the wheel task and transforms it into notebook
// which installs uploaded wheels using %pip and then calling corresponding
// entry point.
func TransformWheelTask() bundle.Mutator {
return mutator.NewTrampoline(
"python_wheel",
&pythonTrampoline{},
NOTEBOOK_TEMPLATE,
)
}

type pythonTrampoline struct{}

func (t *pythonTrampoline) CleanUp(task *jobs.Task) error {
task.PythonWheelTask = nil
task.Libraries = nil

return nil
}

func (t *pythonTrampoline) GetTasks(b *bundle.Bundle) []mutator.TaskWithJobKey {
r := b.Config.Resources
result := make([]mutator.TaskWithJobKey, 0)
for k := range b.Config.Resources.Jobs {
tasks := r.Jobs[k].JobSettings.Tasks
for i := range tasks {
task := &tasks[i]
result = append(result, mutator.TaskWithJobKey{
JobKey: k,
Task: task,
})
}
}
return result
}

func (t *pythonTrampoline) GetTemplateData(task *jobs.Task) (map[string]any, error) {
params, err := t.generateParameters(task.PythonWheelTask)
if err != nil {
return nil, err
}

data := map[string]any{
"Libraries": task.Libraries,
"Params": params,
"Task": task.PythonWheelTask,
}

return data, nil
}

func (t *pythonTrampoline) generateParameters(task *jobs.PythonWheelTask) (string, error) {
if task.Parameters != nil && task.NamedParameters != nil {
return "", fmt.Errorf("not allowed to pass both paramaters and named_parameters")
}
params := append([]string{"python"}, task.Parameters...)
for k, v := range task.NamedParameters {
params = append(params, fmt.Sprintf("%s=%s", k, v))
}

for i := range params {
params[i] = strconv.Quote(params[i])
}
return strings.Join(params, ", "), nil
}
66 changes: 66 additions & 0 deletions bundle/python/transform_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package python

import (
"strings"
"testing"

"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)

type testCase struct {
Actual []string
Expected string
}
type NamedParams map[string]string
type testCaseNamed struct {
Actual NamedParams
Expected string
}

var paramsTestCases []testCase = []testCase{
{[]string{}, `"python"`},
{[]string{"a"}, `"python", "a"`},
{[]string{"a", "b"}, `"python", "a", "b"`},
{[]string{"123!@#$%^&*()-="}, `"python", "123!@#$%^&*()-="`},
{[]string{`{"a": 1}`}, `"python", "{\"a\": 1}"`},
}

var paramsTestCasesNamed []testCaseNamed = []testCaseNamed{
{NamedParams{}, `"python"`},
{NamedParams{"a": "1"}, `"python", "a=1"`},
{NamedParams{"a": "'1'"}, `"python", "a='1'"`},
{NamedParams{"a": `"1"`}, `"python", "a=\"1\""`},
{NamedParams{"a": "1", "b": "2"}, `"python", "a=1", "b=2"`},
{NamedParams{"data": `{"a": 1}`}, `"python", "data={\"a\": 1}"`},
}

func TestGenerateParameters(t *testing.T) {
trampoline := pythonTrampoline{}
for _, c := range paramsTestCases {
task := &jobs.PythonWheelTask{Parameters: c.Actual}
result, err := trampoline.generateParameters(task)
require.NoError(t, err)
require.Equal(t, c.Expected, result)
}
}

func TestGenerateNamedParameters(t *testing.T) {
trampoline := pythonTrampoline{}
for _, c := range paramsTestCasesNamed {
task := &jobs.PythonWheelTask{NamedParameters: c.Actual}
result, err := trampoline.generateParameters(task)
require.NoError(t, err)

// parameters order can be undetermenistic, so just check that they exist as expected
require.ElementsMatch(t, strings.Split(c.Expected, ","), strings.Split(result, ","))
}
}

func TestGenerateBoth(t *testing.T) {
trampoline := pythonTrampoline{}
task := &jobs.PythonWheelTask{NamedParameters: map[string]string{"a": "1"}, Parameters: []string{"b"}}
_, err := trampoline.generateParameters(task)
require.Error(t, err)
require.ErrorContains(t, err, "not allowed to pass both paramaters and named_parameters")
}