Skip to content
Merged
2 changes: 1 addition & 1 deletion bundle/artifacts/whl/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (m *infer) Apply(ctx context.Context, b *bundle.Bundle) error {
// version=datetime.datetime.utcnow().strftime("%Y%m%d.%H%M%S"),
// ...
//)
artifact.BuildCommand = fmt.Sprintf("%s setup.py bdist_wheel", py)
artifact.BuildCommand = fmt.Sprintf(`"%s" setup.py bdist_wheel`, py)

return nil
}
Expand Down
21 changes: 5 additions & 16 deletions bundle/config/artifact.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package config

import (
"bytes"
"context"
"fmt"
"path"
"strings"

"github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/cli/libs/process"
"github.com/databricks/cli/libs/exec"
"github.com/databricks/databricks-sdk-go/service/compute"
)

Expand Down Expand Up @@ -52,20 +50,11 @@ func (a *Artifact) Build(ctx context.Context) ([]byte, error) {
return nil, fmt.Errorf("no build property defined")
}

out := make([][]byte, 0)
commands := strings.Split(a.BuildCommand, " && ")
for _, command := range commands {
buildParts := strings.Split(command, " ")
var buf bytes.Buffer
_, err := process.Background(ctx, buildParts,
process.WithCombinedOutput(&buf),
process.WithDir(a.Path))
if err != nil {
return buf.Bytes(), err
}
out = append(out, buf.Bytes())
e, err := exec.NewCommandExecutor(a.Path)
if err != nil {
return nil, err
}
return bytes.Join(out, []byte{}), nil
return e.Exec(ctx, a.BuildCommand)
}

func (a *Artifact) NormalisePaths() {
Expand Down
18 changes: 18 additions & 0 deletions bundle/config/artifacts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package config

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestArtifactBuild(t *testing.T) {
artifact := Artifact{
BuildCommand: "echo 'Hello from build command'",
}
res, err := artifact.Build(context.Background())
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Hello from build command\n", string(res))
}
34 changes: 10 additions & 24 deletions bundle/scripts/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"context"
"fmt"
"io"
"os/exec"
"strings"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/exec"
"github.com/databricks/cli/libs/log"
)

Expand All @@ -29,7 +29,12 @@ func (m *script) Name() string {
}

func (m *script) Apply(ctx context.Context, b *bundle.Bundle) error {
cmd, out, err := executeHook(ctx, b, m.scriptHook)
executor, err := exec.NewCommandExecutor(b.Config.Path)
if err != nil {
return err
}

cmd, out, err := executeHook(ctx, executor, b, m.scriptHook)
if err != nil {
return err
}
Expand All @@ -50,32 +55,18 @@ func (m *script) Apply(ctx context.Context, b *bundle.Bundle) error {
return cmd.Wait()
}

func executeHook(ctx context.Context, b *bundle.Bundle, hook config.ScriptHook) (*exec.Cmd, io.Reader, error) {
func executeHook(ctx context.Context, executor *exec.Executor, b *bundle.Bundle, hook config.ScriptHook) (exec.Command, io.Reader, error) {
command := getCommmand(b, hook)
if command == "" {
return nil, nil, nil
}

interpreter, err := findInterpreter()
cmd, err := executor.StartCommand(ctx, string(command))
if err != nil {
return nil, nil, err
}

// TODO: switch to process.Background(...)
cmd := exec.CommandContext(ctx, interpreter, "-c", string(command))
cmd.Dir = b.Config.Path

outPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, nil, err
}

errPipe, err := cmd.StderrPipe()
if err != nil {
return nil, nil, err
}

return cmd, io.MultiReader(outPipe, errPipe), cmd.Start()
return cmd, io.MultiReader(cmd.Stdout(), cmd.Stderr()), nil
}

func getCommmand(b *bundle.Bundle, hook config.ScriptHook) config.Command {
Expand All @@ -85,8 +76,3 @@ func getCommmand(b *bundle.Bundle, hook config.ScriptHook) config.Command {

return b.Config.Experimental.Scripts[hook]
}

func findInterpreter() (string, error) {
// At the moment we just return 'sh' on all platforms and use it to execute scripts
return "sh", nil
}
6 changes: 5 additions & 1 deletion bundle/scripts/scripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/exec"
"github.com/stretchr/testify/require"
)

Expand All @@ -21,7 +22,10 @@ func TestExecutesHook(t *testing.T) {
},
},
}
_, out, err := executeHook(context.Background(), b, config.ScriptPreBuild)

executor, err := exec.NewCommandExecutor(b.Config.Path)
require.NoError(t, err)
_, out, err := executeHook(context.Background(), executor, b, config.ScriptPreBuild)
require.NoError(t, err)

reader := bufio.NewReader(out)
Expand Down
101 changes: 101 additions & 0 deletions libs/exec/exec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package exec

import (
"context"
"io"
"os"
osexec "os/exec"
)

type Command interface {
// Wait for command to terminate. It must have been previously started.
Wait() error

// StdinPipe returns a pipe that will be connected to the command's standard input when the command starts.
Stdout() io.ReadCloser

// StderrPipe returns a pipe that will be connected to the command's standard error when the command starts.
Stderr() io.ReadCloser
}

type command struct {
cmd *osexec.Cmd
execContext *execContext
stdout io.ReadCloser
stderr io.ReadCloser
}

func (c *command) Wait() error {
// After the command has finished (cmd.Wait call), remove the temporary script file
defer os.Remove(c.execContext.scriptFile)

err := c.cmd.Wait()
if err != nil {
return err
}

return nil
}

func (c *command) Stdout() io.ReadCloser {
return c.stdout
}

func (c *command) Stderr() io.ReadCloser {
return c.stderr
}

type Executor struct {
interpreter interpreter
dir string
}

func NewCommandExecutor(dir string) (*Executor, error) {
interpreter, err := findInterpreter()
if err != nil {
return nil, err
}
return &Executor{
interpreter: interpreter,
dir: dir,
}, nil
}

func (e *Executor) StartCommand(ctx context.Context, command string) (Command, error) {
ec, err := e.interpreter.prepare(command)
if err != nil {
return nil, err
}
return e.start(ctx, ec)
}

func (e *Executor) start(ctx context.Context, ec *execContext) (Command, error) {
cmd := osexec.CommandContext(ctx, ec.executable, ec.args...)
cmd.Dir = e.dir

stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}

stderr, err := cmd.StderrPipe()
if err != nil {
return nil, err
}

return &command{cmd, ec, stdout, stderr}, cmd.Start()
}

func (e *Executor) Exec(ctx context.Context, command string) ([]byte, error) {
cmd, err := e.StartCommand(ctx, command)
if err != nil {
return nil, err
}

res, err := io.ReadAll(io.MultiReader(cmd.Stdout(), cmd.Stderr()))
if err != nil {
return nil, err
}

return res, cmd.Wait()
}
Loading