diff --git a/sdks/go/pkg/beam/core/graph/coder/map.go b/sdks/go/pkg/beam/core/graph/coder/map.go new file mode 100644 index 000000000000..4e5dc2ccddb2 --- /dev/null +++ b/sdks/go/pkg/beam/core/graph/coder/map.go @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package coder + +import ( + "io" + "reflect" +) + +// TODO(lostluck): 2020.08.04 export these for use for others? + +// mapDecoder produces a decoder for the beam schema map encoding. +func mapDecoder(rt reflect.Type, decodeToKey, decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error { + return func(ret reflect.Value, r io.Reader) error { + // (1) Read count prefixed encoded data + size, err := DecodeInt32(r) + if err != nil { + return err + } + n := int(size) + ret.Set(reflect.MakeMapWithSize(rt, n)) + for i := 0; i < n; i++ { + rvk := reflect.New(rt.Key()).Elem() + if err := decodeToKey(rvk, r); err != nil { + return err + } + rvv := reflect.New(rt.Elem()).Elem() + if err := decodeToElem(rvv, r); err != nil { + return err + } + ret.SetMapIndex(rvk, rvv) + } + return nil + } +} + +// containerNilDecoder handles when a value is nillable for map or iterable components. +// Nillable types have an extra byte prefixing them indicating nil status. +func containerNilDecoder(decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error { + return func(ret reflect.Value, r io.Reader) error { + hasValue, err := DecodeBool(r) + if err != nil { + return err + } + if !hasValue { + return nil + } + rv := reflect.New(ret.Type().Elem()) + if err := decodeToElem(rv.Elem(), r); err != nil { + return err + } + ret.Set(rv) + return nil + } +} + +// mapEncoder reflectively encodes a map or array type using the beam map encoding. +func mapEncoder(rt reflect.Type, encodeKey, encodeValue func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error { + return func(rv reflect.Value, w io.Writer) error { + size := rv.Len() + if err := EncodeInt32((int32)(size), w); err != nil { + return err + } + iter := rv.MapRange() + for iter.Next() { + if err := encodeKey(iter.Key(), w); err != nil { + return err + } + if err := encodeValue(iter.Value(), w); err != nil { + return err + } + } + return nil + } +} + +// containerNilEncoder handles when a value is nillable for map or iterable components. +// Nillable types have an extra byte prefixing them indicating nil status. +func containerNilEncoder(encodeElem func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error { + return func(rv reflect.Value, w io.Writer) error { + if rv.IsNil() { + return EncodeBool(false, w) + } + if err := EncodeBool(true, w); err != nil { + return err + } + return encodeElem(rv.Elem(), w) + } +} diff --git a/sdks/go/pkg/beam/core/graph/coder/map_test.go b/sdks/go/pkg/beam/core/graph/coder/map_test.go new file mode 100644 index 000000000000..0b825c2d5105 --- /dev/null +++ b/sdks/go/pkg/beam/core/graph/coder/map_test.go @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package coder + +import ( + "bytes" + "fmt" + "io" + "reflect" + "testing" + + "github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx" + "github.com/google/go-cmp/cmp" +) + +func TestEncodeDecodeMap(t *testing.T) { + byteEnc := containerEncoderForType(reflectx.Uint8) + byteDec := containerDecoderForType(reflectx.Uint8) + bytePtrEnc := containerEncoderForType(reflect.PtrTo(reflectx.Uint8)) + bytePtrDec := containerDecoderForType(reflect.PtrTo(reflectx.Uint8)) + + ptrByte := byte(42) + + tests := []struct { + v interface{} + encK, encV func(reflect.Value, io.Writer) error + decK, decV func(reflect.Value, io.Reader) error + encoded []byte + decodeOnly bool + }{ + { + v: map[byte]byte{10: 42}, + encK: byteEnc, + encV: byteEnc, + decK: byteDec, + decV: byteDec, + encoded: []byte{0, 0, 0, 1, 10, 42}, + }, { + v: map[byte]*byte{10: &ptrByte}, + encK: byteEnc, + encV: bytePtrEnc, + decK: byteDec, + decV: bytePtrDec, + encoded: []byte{0, 0, 0, 1, 10, 1, 42}, + }, { + v: map[byte]*byte{10: &ptrByte, 23: nil, 53: nil}, + encK: byteEnc, + encV: bytePtrEnc, + decK: byteDec, + decV: bytePtrDec, + encoded: []byte{0, 0, 0, 3, 10, 1, 42, 23, 0, 53, 0}, + decodeOnly: true, + }, + } + for _, test := range tests { + test := test + if !test.decodeOnly { + t.Run(fmt.Sprintf("encode %q", test.v), func(t *testing.T) { + var buf bytes.Buffer + err := mapEncoder(reflect.TypeOf(test.v), test.encK, test.encV)(reflect.ValueOf(test.v), &buf) + if err != nil { + t.Fatalf("mapEncoder(%q) = %v", test.v, err) + } + if d := cmp.Diff(test.encoded, buf.Bytes()); d != "" { + t.Errorf("mapEncoder(%q) = %v, want %v diff(-want,+got):\n %v", test.v, buf.Bytes(), test.encoded, d) + } + }) + } + t.Run(fmt.Sprintf("decode %v", test.v), func(t *testing.T) { + buf := bytes.NewBuffer(test.encoded) + rt := reflect.TypeOf(test.v) + var dec func(reflect.Value, io.Reader) error + dec = mapDecoder(rt, test.decK, test.decV) + rv := reflect.New(rt).Elem() + err := dec(rv, buf) + if err != nil { + t.Fatalf("mapDecoder(%q) = %v", test.encoded, err) + } + got := rv.Interface() + if d := cmp.Diff(test.v, got); d != "" { + t.Errorf("mapDecoder(%q) = %q, want %v diff(-want,+got):\n %v", test.encoded, got, test.v, d) + } + }) + } +} diff --git a/sdks/go/pkg/beam/core/graph/coder/row.go b/sdks/go/pkg/beam/core/graph/coder/row.go index aac34acf7e4c..00b4c265da00 100644 --- a/sdks/go/pkg/beam/core/graph/coder/row.go +++ b/sdks/go/pkg/beam/core/graph/coder/row.go @@ -152,21 +152,26 @@ func decoderForSingleTypeReflect(t reflect.Type) func(reflect.Value, io.Reader) return nil } } - decf := decoderForSingleTypeReflect(t.Elem()) - sdec := iterableDecoderForSlice(t, decf) - return func(rv reflect.Value, r io.Reader) error { - return sdec(rv, r) - } + decf := containerDecoderForType(t.Elem()) + return iterableDecoderForSlice(t, decf) case reflect.Array: - decf := decoderForSingleTypeReflect(t.Elem()) - sdec := iterableDecoderForArray(t, decf) - return func(rv reflect.Value, r io.Reader) error { - return sdec(rv, r) - } + decf := containerDecoderForType(t.Elem()) + return iterableDecoderForArray(t, decf) + case reflect.Map: + decK := containerDecoderForType(t.Key()) + decV := containerDecoderForType(t.Elem()) + return mapDecoder(t, decK, decV) } panic(fmt.Sprintf("unimplemented type to decode: %v", t)) } +func containerDecoderForType(t reflect.Type) func(reflect.Value, io.Reader) error { + if t.Kind() == reflect.Ptr { + return containerNilDecoder(decoderForSingleTypeReflect(t.Elem())) + } + return decoderForSingleTypeReflect(t) +} + type typeDecoderReflect struct { typ reflect.Type fields []func(reflect.Value, io.Reader) error @@ -270,15 +275,26 @@ func encoderForSingleTypeReflect(t reflect.Type) func(reflect.Value, io.Writer) return EncodeBytes(rv.Bytes(), w) } } - encf := encoderForSingleTypeReflect(t.Elem()) + encf := containerEncoderForType(t.Elem()) return iterableEncoder(t, encf) case reflect.Array: - encf := encoderForSingleTypeReflect(t.Elem()) + encf := containerEncoderForType(t.Elem()) return iterableEncoder(t, encf) + case reflect.Map: + encK := containerEncoderForType(t.Key()) + encV := containerEncoderForType(t.Elem()) + return mapEncoder(t, encK, encV) } panic(fmt.Sprintf("unimplemented type to encode: %v", t)) } +func containerEncoderForType(t reflect.Type) func(reflect.Value, io.Writer) error { + if t.Kind() == reflect.Ptr { + return containerNilEncoder(encoderForSingleTypeReflect(t.Elem())) + } + return encoderForSingleTypeReflect(t) +} + type typeEncoderReflect struct { fields []func(reflect.Value, io.Writer) error } diff --git a/sdks/go/pkg/beam/core/graph/coder/row_test.go b/sdks/go/pkg/beam/core/graph/coder/row_test.go index f1089b898171..38b7c5dfb1f9 100644 --- a/sdks/go/pkg/beam/core/graph/coder/row_test.go +++ b/sdks/go/pkg/beam/core/graph/coder/row_test.go @@ -78,16 +78,18 @@ func TestReflectionRowCoderGeneration(t *testing.T) { V12 [0]int V13 [2]int V14 []int - // V15 map[string]int // not yet a standard coder (BEAM-7996) + V15 map[string]int V16 float32 V17 float64 V18 []byte + V19 [2]*int + V20 map[*string]*int }{}, }, { want: struct { V00 bool - V01 byte - V02 uint8 + V01 byte // unsupported by spec (same as uint8) + V02 uint8 // unsupported by spec V03 int16 // V04 uint16 // unsupported by spec V05 int32 @@ -100,10 +102,13 @@ func TestReflectionRowCoderGeneration(t *testing.T) { V12 [0]int V13 [2]int V14 []int - // V15 map[string]int // not yet a standard coder (BEAM-7996) (encoding unspecified) + V15 map[string]int V16 float32 V17 float64 V18 []byte + V19 [2]*int + V20 map[string]*int + V21 []*int }{ V00: true, V01: 1, @@ -117,9 +122,16 @@ func TestReflectionRowCoderGeneration(t *testing.T) { V12: [0]int{}, V13: [2]int{72, 908}, V14: []int{12, 9326, 641346, 6}, + V15: map[string]int{"pants": 42}, V16: 3.14169, V17: 2.6e100, V18: []byte{21, 17, 65, 255, 0, 16}, + V19: [2]*int{nil, &num}, + V20: map[string]*int{ + "notnil": &num, + "nil": nil, + }, + V21: []*int{nil, &num, nil}, }, // TODO add custom types such as protocol buffers. },