Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions libs/config/merge/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package merge

import (
"fmt"

"github.com/databricks/cli/libs/config"
)

// Merge recursively merges the specified values.
//
// Semantics are as follows:
// * Merging x with nil or nil with x always yields x.
// * Merging maps a and b means entries from map b take precedence.
// * Merging sequences a and b means concatenating them.
func Merge(a, b config.Value) (config.Value, error) {
return merge(a, b)
}

func merge(a, b config.Value) (config.Value, error) {
ak := a.Kind()
bk := b.Kind()

// If a is nil, return b.
if ak == config.KindNil {
return b, nil
}

// If b is nil, return a.
if bk == config.KindNil {
return a, nil
}

// Call the appropriate merge function based on the kind of a and b.
switch ak {
case config.KindMap:
if bk != config.KindMap {
return config.NilValue, fmt.Errorf("cannot merge map with %s", bk)
}
return mergeMap(a, b)
case config.KindSequence:
if bk != config.KindSequence {
return config.NilValue, fmt.Errorf("cannot merge sequence with %s", bk)
}
return mergeSequence(a, b)
default:
if ak != bk {
return config.NilValue, fmt.Errorf("cannot merge %s with %s", ak, bk)
}
return mergePrimitive(a, b)
}
}

func mergeMap(a, b config.Value) (config.Value, error) {
out := make(map[string]config.Value)
am := a.MustMap()
bm := b.MustMap()

// Add the values from a into the output map.
for k, v := range am {
out[k] = v
}

// Merge the values from b into the output map.
for k, v := range bm {
if _, ok := out[k]; ok {
// If the key already exists, merge the values.
merged, err := merge(out[k], v)
if err != nil {
return config.NilValue, err
}
out[k] = merged
} else {
// Otherwise, just set the value.
out[k] = v
}
}

// Preserve the location of the first value.
return config.NewValue(out, a.Location()), nil
}

func mergeSequence(a, b config.Value) (config.Value, error) {
as := a.MustSequence()
bs := b.MustSequence()

// Merging sequences means concatenating them.
out := make([]config.Value, len(as)+len(bs))
copy(out[:], as)
copy(out[len(as):], bs)

// Preserve the location of the first value.
return config.NewValue(out, a.Location()), nil
}

func mergePrimitive(a, b config.Value) (config.Value, error) {
// Merging primitive values means using the incoming value.
return b, nil
}
207 changes: 207 additions & 0 deletions libs/config/merge/merge_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package merge

import (
"testing"

"github.com/databricks/cli/libs/config"
"github.com/stretchr/testify/assert"
)

func TestMergeMaps(t *testing.T) {
v1 := config.V(map[string]config.Value{
"foo": config.V("bar"),
"bar": config.V("baz"),
})

v2 := config.V(map[string]config.Value{
"bar": config.V("qux"),
"qux": config.V("foo"),
})

// Merge v2 into v1.
{
out, err := Merge(v1, v2)
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"foo": "bar",
"bar": "qux",
"qux": "foo",
}, out.AsAny())
}

// Merge v1 into v2.
{
out, err := Merge(v2, v1)
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"foo": "bar",
"bar": "baz",
"qux": "foo",
}, out.AsAny())
}
}

func TestMergeMapsNil(t *testing.T) {
v := config.V(map[string]config.Value{
"foo": config.V("bar"),
})

// Merge nil into v.
{
out, err := Merge(v, config.NilValue)
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"foo": "bar",
}, out.AsAny())
}

// Merge v into nil.
{
out, err := Merge(config.NilValue, v)
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"foo": "bar",
}, out.AsAny())
}
}

func TestMergeMapsError(t *testing.T) {
v := config.V(map[string]config.Value{
"foo": config.V("bar"),
})

other := config.V("string")

// Merge a string into v.
{
out, err := Merge(v, other)
assert.EqualError(t, err, "cannot merge map with string")
assert.Equal(t, config.NilValue, out)
}
}

func TestMergeSequences(t *testing.T) {
v1 := config.V([]config.Value{
config.V("bar"),
config.V("baz"),
})

v2 := config.V([]config.Value{
config.V("qux"),
config.V("foo"),
})

// Merge v2 into v1.
{
out, err := Merge(v1, v2)
assert.NoError(t, err)
assert.Equal(t, []any{
"bar",
"baz",
"qux",
"foo",
}, out.AsAny())
}

// Merge v1 into v2.
{
out, err := Merge(v2, v1)
assert.NoError(t, err)
assert.Equal(t, []any{
"qux",
"foo",
"bar",
"baz",
}, out.AsAny())
}
}

func TestMergeSequencesNil(t *testing.T) {
v := config.V([]config.Value{
config.V("bar"),
})

// Merge nil into v.
{
out, err := Merge(v, config.NilValue)
assert.NoError(t, err)
assert.Equal(t, []any{
"bar",
}, out.AsAny())
}

// Merge v into nil.
{
out, err := Merge(config.NilValue, v)
assert.NoError(t, err)
assert.Equal(t, []any{
"bar",
}, out.AsAny())
}
}

func TestMergeSequencesError(t *testing.T) {
v := config.V([]config.Value{
config.V("bar"),
})

other := config.V("string")

// Merge a string into v.
{
out, err := Merge(v, other)
assert.EqualError(t, err, "cannot merge sequence with string")
assert.Equal(t, config.NilValue, out)
}
}

func TestMergePrimitives(t *testing.T) {
v1 := config.V("bar")
v2 := config.V("baz")

// Merge v2 into v1.
{
out, err := Merge(v1, v2)
assert.NoError(t, err)
assert.Equal(t, "baz", out.AsAny())
}

// Merge v1 into v2.
{
out, err := Merge(v2, v1)
assert.NoError(t, err)
assert.Equal(t, "bar", out.AsAny())
}
}

func TestMergePrimitivesNil(t *testing.T) {
v := config.V("bar")

// Merge nil into v.
{
out, err := Merge(v, config.NilValue)
assert.NoError(t, err)
assert.Equal(t, "bar", out.AsAny())
}

// Merge v into nil.
{
out, err := Merge(config.NilValue, v)
assert.NoError(t, err)
assert.Equal(t, "bar", out.AsAny())
}
}

func TestMergePrimitivesError(t *testing.T) {
v := config.V("bar")
other := config.V(map[string]config.Value{
"foo": config.V("bar"),
})

// Merge a map into v.
{
out, err := Merge(v, other)
assert.EqualError(t, err, "cannot merge string with map")
assert.Equal(t, config.NilValue, out)
}
}