diff --git a/libs/config/merge/merge.go b/libs/config/merge/merge.go new file mode 100644 index 0000000000..896e212923 --- /dev/null +++ b/libs/config/merge/merge.go @@ -0,0 +1,98 @@ +package merge + +import ( + "fmt" + + "github.com/databricks/cli/libs/config" +) + +// Merge recursively merges the specified values. +// +// Semantics are as follows: +// * Merging x with nil or nil with x always yields x. +// * Merging maps a and b means entries from map b take precedence. +// * Merging sequences a and b means concatenating them. +func Merge(a, b config.Value) (config.Value, error) { + return merge(a, b) +} + +func merge(a, b config.Value) (config.Value, error) { + ak := a.Kind() + bk := b.Kind() + + // If a is nil, return b. + if ak == config.KindNil { + return b, nil + } + + // If b is nil, return a. + if bk == config.KindNil { + return a, nil + } + + // Call the appropriate merge function based on the kind of a and b. + switch ak { + case config.KindMap: + if bk != config.KindMap { + return config.NilValue, fmt.Errorf("cannot merge map with %s", bk) + } + return mergeMap(a, b) + case config.KindSequence: + if bk != config.KindSequence { + return config.NilValue, fmt.Errorf("cannot merge sequence with %s", bk) + } + return mergeSequence(a, b) + default: + if ak != bk { + return config.NilValue, fmt.Errorf("cannot merge %s with %s", ak, bk) + } + return mergePrimitive(a, b) + } +} + +func mergeMap(a, b config.Value) (config.Value, error) { + out := make(map[string]config.Value) + am := a.MustMap() + bm := b.MustMap() + + // Add the values from a into the output map. + for k, v := range am { + out[k] = v + } + + // Merge the values from b into the output map. + for k, v := range bm { + if _, ok := out[k]; ok { + // If the key already exists, merge the values. + merged, err := merge(out[k], v) + if err != nil { + return config.NilValue, err + } + out[k] = merged + } else { + // Otherwise, just set the value. + out[k] = v + } + } + + // Preserve the location of the first value. + return config.NewValue(out, a.Location()), nil +} + +func mergeSequence(a, b config.Value) (config.Value, error) { + as := a.MustSequence() + bs := b.MustSequence() + + // Merging sequences means concatenating them. + out := make([]config.Value, len(as)+len(bs)) + copy(out[:], as) + copy(out[len(as):], bs) + + // Preserve the location of the first value. + return config.NewValue(out, a.Location()), nil +} + +func mergePrimitive(a, b config.Value) (config.Value, error) { + // Merging primitive values means using the incoming value. + return b, nil +} diff --git a/libs/config/merge/merge_test.go b/libs/config/merge/merge_test.go new file mode 100644 index 0000000000..c2e89f60a4 --- /dev/null +++ b/libs/config/merge/merge_test.go @@ -0,0 +1,207 @@ +package merge + +import ( + "testing" + + "github.com/databricks/cli/libs/config" + "github.com/stretchr/testify/assert" +) + +func TestMergeMaps(t *testing.T) { + v1 := config.V(map[string]config.Value{ + "foo": config.V("bar"), + "bar": config.V("baz"), + }) + + v2 := config.V(map[string]config.Value{ + "bar": config.V("qux"), + "qux": config.V("foo"), + }) + + // Merge v2 into v1. + { + out, err := Merge(v1, v2) + assert.NoError(t, err) + assert.Equal(t, map[string]any{ + "foo": "bar", + "bar": "qux", + "qux": "foo", + }, out.AsAny()) + } + + // Merge v1 into v2. + { + out, err := Merge(v2, v1) + assert.NoError(t, err) + assert.Equal(t, map[string]any{ + "foo": "bar", + "bar": "baz", + "qux": "foo", + }, out.AsAny()) + } +} + +func TestMergeMapsNil(t *testing.T) { + v := config.V(map[string]config.Value{ + "foo": config.V("bar"), + }) + + // Merge nil into v. + { + out, err := Merge(v, config.NilValue) + assert.NoError(t, err) + assert.Equal(t, map[string]any{ + "foo": "bar", + }, out.AsAny()) + } + + // Merge v into nil. + { + out, err := Merge(config.NilValue, v) + assert.NoError(t, err) + assert.Equal(t, map[string]any{ + "foo": "bar", + }, out.AsAny()) + } +} + +func TestMergeMapsError(t *testing.T) { + v := config.V(map[string]config.Value{ + "foo": config.V("bar"), + }) + + other := config.V("string") + + // Merge a string into v. + { + out, err := Merge(v, other) + assert.EqualError(t, err, "cannot merge map with string") + assert.Equal(t, config.NilValue, out) + } +} + +func TestMergeSequences(t *testing.T) { + v1 := config.V([]config.Value{ + config.V("bar"), + config.V("baz"), + }) + + v2 := config.V([]config.Value{ + config.V("qux"), + config.V("foo"), + }) + + // Merge v2 into v1. + { + out, err := Merge(v1, v2) + assert.NoError(t, err) + assert.Equal(t, []any{ + "bar", + "baz", + "qux", + "foo", + }, out.AsAny()) + } + + // Merge v1 into v2. + { + out, err := Merge(v2, v1) + assert.NoError(t, err) + assert.Equal(t, []any{ + "qux", + "foo", + "bar", + "baz", + }, out.AsAny()) + } +} + +func TestMergeSequencesNil(t *testing.T) { + v := config.V([]config.Value{ + config.V("bar"), + }) + + // Merge nil into v. + { + out, err := Merge(v, config.NilValue) + assert.NoError(t, err) + assert.Equal(t, []any{ + "bar", + }, out.AsAny()) + } + + // Merge v into nil. + { + out, err := Merge(config.NilValue, v) + assert.NoError(t, err) + assert.Equal(t, []any{ + "bar", + }, out.AsAny()) + } +} + +func TestMergeSequencesError(t *testing.T) { + v := config.V([]config.Value{ + config.V("bar"), + }) + + other := config.V("string") + + // Merge a string into v. + { + out, err := Merge(v, other) + assert.EqualError(t, err, "cannot merge sequence with string") + assert.Equal(t, config.NilValue, out) + } +} + +func TestMergePrimitives(t *testing.T) { + v1 := config.V("bar") + v2 := config.V("baz") + + // Merge v2 into v1. + { + out, err := Merge(v1, v2) + assert.NoError(t, err) + assert.Equal(t, "baz", out.AsAny()) + } + + // Merge v1 into v2. + { + out, err := Merge(v2, v1) + assert.NoError(t, err) + assert.Equal(t, "bar", out.AsAny()) + } +} + +func TestMergePrimitivesNil(t *testing.T) { + v := config.V("bar") + + // Merge nil into v. + { + out, err := Merge(v, config.NilValue) + assert.NoError(t, err) + assert.Equal(t, "bar", out.AsAny()) + } + + // Merge v into nil. + { + out, err := Merge(config.NilValue, v) + assert.NoError(t, err) + assert.Equal(t, "bar", out.AsAny()) + } +} + +func TestMergePrimitivesError(t *testing.T) { + v := config.V("bar") + other := config.V(map[string]config.Value{ + "foo": config.V("bar"), + }) + + // Merge a map into v. + { + out, err := Merge(v, other) + assert.EqualError(t, err, "cannot merge string with map") + assert.Equal(t, config.NilValue, out) + } +}