Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 82 additions & 20 deletions libs/dyn/convert/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,51 @@ import (
"github.com/databricks/cli/libs/dyn"
)

func Normalize(dst any, src dyn.Value) (dyn.Value, diag.Diagnostics) {
return normalizeType(reflect.TypeOf(dst), src)
// NormalizeOption is the type for options that can be passed to Normalize.
type NormalizeOption int

const (
// IncludeMissingFields causes the normalization to include fields that defined on the given
// type but are missing in the source value. They are included with their zero values.
IncludeMissingFields NormalizeOption = iota
)

type normalizeOptions struct {
includeMissingFields bool
}

func Normalize(dst any, src dyn.Value, opts ...NormalizeOption) (dyn.Value, diag.Diagnostics) {
var n normalizeOptions
for _, opt := range opts {
switch opt {
case IncludeMissingFields:
n.includeMissingFields = true
}
}

return n.normalizeType(reflect.TypeOf(dst), src)
}

func normalizeType(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}

switch typ.Kind() {
case reflect.Struct:
return normalizeStruct(typ, src)
return n.normalizeStruct(typ, src)
case reflect.Map:
return normalizeMap(typ, src)
return n.normalizeMap(typ, src)
case reflect.Slice:
return normalizeSlice(typ, src)
return n.normalizeSlice(typ, src)
case reflect.String:
return normalizeString(typ, src)
return n.normalizeString(typ, src)
case reflect.Bool:
return normalizeBool(typ, src)
return n.normalizeBool(typ, src)
case reflect.Int, reflect.Int32, reflect.Int64:
return normalizeInt(typ, src)
return n.normalizeInt(typ, src)
case reflect.Float32, reflect.Float64:
return normalizeFloat(typ, src)
return n.normalizeFloat(typ, src)
}

return dyn.InvalidValue, diag.Errorf("unsupported type: %s", typ.Kind())
Expand All @@ -46,7 +67,7 @@ func typeMismatch(expected dyn.Kind, src dyn.Value) diag.Diagnostic {
}
}

func normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics

switch src.Kind() {
Expand All @@ -65,7 +86,7 @@ func normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnosti
}

// Normalize the value according to the field type.
v, err := normalizeType(typ.FieldByIndex(index).Type, v)
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
Expand All @@ -77,6 +98,47 @@ func normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnosti
out[k] = v
}

// Return the normalized value if missing fields are not included.
if !n.includeMissingFields {
return dyn.NewValue(out, src.Location()), diags
}

// Populate missing fields with their zero values.
for k, index := range info.Fields {
if _, ok := out[k]; ok {
continue
}

// Optionally dereference pointers to get the underlying field type.
ftyp := typ.FieldByIndex(index).Type
for ftyp.Kind() == reflect.Pointer {
ftyp = ftyp.Elem()
}

var v dyn.Value
switch ftyp.Kind() {
case reflect.Struct, reflect.Map:
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}))
case reflect.Slice:
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}))
case reflect.String:
v, _ = n.normalizeType(ftyp, dyn.V(""))
case reflect.Bool:
v, _ = n.normalizeType(ftyp, dyn.V(false))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, _ = n.normalizeType(ftyp, dyn.V(int64(0)))
case reflect.Float32, reflect.Float64:
v, _ = n.normalizeType(ftyp, dyn.V(float64(0)))
default:
// Skip fields for which we do not have a natural [dyn.Value] equivalent.
// For example, we don't handle reflect.Complex* and reflect.Uint* types.
continue
}
if v.IsValid() {
out[k] = v
}
}

return dyn.NewValue(out, src.Location()), diags
case dyn.KindNil:
return src, diags
Expand All @@ -85,15 +147,15 @@ func normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnosti
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
}

func normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics

switch src.Kind() {
case dyn.KindMap:
out := make(map[string]dyn.Value)
for k, v := range src.MustMap() {
// Normalize the value according to the map element type.
v, err := normalizeType(typ.Elem(), v)
v, err := n.normalizeType(typ.Elem(), v)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
Expand All @@ -113,15 +175,15 @@ func normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics)
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
}

func normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics

switch src.Kind() {
case dyn.KindSequence:
out := make([]dyn.Value, 0, len(src.MustSequence()))
for _, v := range src.MustSequence() {
// Normalize the value according to the slice element type.
v, err := normalizeType(typ.Elem(), v)
v, err := n.normalizeType(typ.Elem(), v)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
Expand All @@ -141,7 +203,7 @@ func normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostic
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindSequence, src))
}

func normalizeString(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeString(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
var out string

Expand All @@ -161,7 +223,7 @@ func normalizeString(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnosti
return dyn.NewValue(out, src.Location()), diags
}

func normalizeBool(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeBool(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
var out bool

Expand All @@ -186,7 +248,7 @@ func normalizeBool(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics
return dyn.NewValue(out, src.Location()), diags
}

func normalizeInt(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeInt(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
var out int64

Expand All @@ -210,7 +272,7 @@ func normalizeInt(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics)
return dyn.NewValue(out, src.Location()), diags
}

func normalizeFloat(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeFloat(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics
var out float64

Expand Down
47 changes: 47 additions & 0 deletions libs/dyn/convert/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,53 @@ func TestNormalizeStructNestedError(t *testing.T) {
)
}

func TestNormalizeStructIncludeMissingFields(t *testing.T) {
type Nested struct {
String string `json:"string"`
}

type Tmp struct {
// Verify that fields that are already set in the dynamic value are not overridden.
Existing string `json:"existing"`

// Verify that structs are recursively normalized if not set.
Nested Nested `json:"nested"`
Ptr *Nested `json:"ptr"`

// Verify that containers are also zero-initialized if not set.
Map map[string]string `json:"map"`
Slice []string `json:"slice"`

// Verify that primitive types are zero-initialized if not set.
String string `json:"string"`
Bool bool `json:"bool"`
Int int `json:"int"`
Float float64 `json:"float"`
}

var typ Tmp
vin := dyn.V(map[string]dyn.Value{
"existing": dyn.V("already set"),
})
vout, err := Normalize(typ, vin, IncludeMissingFields)
assert.Empty(t, err)
assert.Equal(t, dyn.V(map[string]dyn.Value{
"existing": dyn.V("already set"),
"nested": dyn.V(map[string]dyn.Value{
"string": dyn.V(""),
}),
"ptr": dyn.V(map[string]dyn.Value{
"string": dyn.V(""),
}),
"map": dyn.V(map[string]dyn.Value{}),
"slice": dyn.V([]dyn.Value{}),
"string": dyn.V(""),
"bool": dyn.V(false),
"int": dyn.V(int64(0)),
"float": dyn.V(float64(0)),
}), vout)
}

func TestNormalizeMap(t *testing.T) {
var typ map[string]string
vin := dyn.V(map[string]dyn.Value{
Expand Down