diff --git a/spanner/google/cloud/spanner_v1/_helpers.py b/spanner/google/cloud/spanner_v1/_helpers.py index 6a1a9740ba06..e83ddb2732ab 100644 --- a/spanner/google/cloud/spanner_v1/_helpers.py +++ b/spanner/google/cloud/spanner_v1/_helpers.py @@ -58,7 +58,7 @@ def _make_value_pb(value): """ if value is None: return Value(null_value='NULL_VALUE') - if isinstance(value, list): + if isinstance(value, (list, tuple)): return Value(list_value=_make_list_value_pb(value)) if isinstance(value, bool): return Value(bool_value=value) @@ -84,6 +84,8 @@ def _make_value_pb(value): return Value(string_value=value) if isinstance(value, six.text_type): return Value(string_value=value) + if isinstance(value, ListValue): + return Value(list_value=value) raise ValueError("Unknown type: %s" % (value,)) # pylint: enable=too-many-return-statements,too-many-branches diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 5bcb5d54781e..f8621b310e55 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -23,6 +23,7 @@ from google.api_core import exceptions from google.api_core.datetime_helpers import DatetimeWithNanoseconds +from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1.proto.type_pb2 import ARRAY from google.cloud.spanner_v1.proto.type_pb2 import BOOL from google.cloud.spanner_v1.proto.type_pb2 import BYTES @@ -1557,6 +1558,196 @@ def test_execute_sql_w_query_param_transfinite(self): # NaNs cannot be searched for by equality. self.assertTrue(math.isnan(float_array[2])) + def test_execute_sql_w_query_param_struct(self): + NAME = 'Phred' + COUNT = 123 + SIZE = 23.456 + HEIGHT = 188.0 + WEIGHT = 97.6 + + record_type = param_types.Struct([ + param_types.StructField('name', param_types.STRING), + param_types.StructField('count', param_types.INT64), + param_types.StructField('size', param_types.FLOAT64), + param_types.StructField('nested', param_types.Struct([ + param_types.StructField('height', param_types.FLOAT64), + param_types.StructField('weight', param_types.FLOAT64), + ])), + ]) + + # Query with null struct, explicit type + self._check_sql_results( + self._db, + sql='SELECT @r.name, @r.count, @r.size, @r.nested.weight', + params={'r': None}, + param_types={'r': record_type}, + expected=[(None, None, None, None)], + order=False, + ) + + # Query with non-null struct, explicit type, NULL values + self._check_sql_results( + self._db, + sql='SELECT @r.name, @r.count, @r.size, @r.nested.weight', + params={'r': (None, None, None, None)}, + param_types={'r': record_type}, + expected=[(None, None, None, None)], + order=False, + ) + + # Query with non-null struct, explicit type, nested NULL values + self._check_sql_results( + self._db, + sql='SELECT @r.nested.weight', + params={'r': (None, None, None, (None, None))}, + param_types={'r': record_type}, + expected=[(None,)], + order=False, + ) + + # Query with non-null struct, explicit type + self._check_sql_results( + self._db, + sql='SELECT @r.name, @r.count, @r.size, @r.nested.weight', + params={'r': (NAME, COUNT, SIZE, (HEIGHT, WEIGHT))}, + param_types={'r': record_type}, + expected=[(NAME, COUNT, SIZE, WEIGHT)], + order=False, + ) + + # Query with empty struct, explicitly empty type + empty_type = param_types.Struct([]) + self._check_sql_results( + self._db, + sql='SELECT @r IS NULL', + params={'r': ()}, + param_types={'r': empty_type}, + expected=[(False,)], + order=False, + ) + + # Query with null struct, explicitly empty type + self._check_sql_results( + self._db, + sql='SELECT @r IS NULL', + params={'r': None}, + param_types={'r': empty_type}, + expected=[(True,)], + order=False, + ) + + # Query with equality check for struct value + struct_equality_query = ( + 'SELECT ' + '@struct_param=STRUCT(1,"bob")' + ) + struct_type = param_types.Struct([ + param_types.StructField('threadf', param_types.INT64), + param_types.StructField('userf', param_types.STRING), + ]) + self._check_sql_results( + self._db, + sql=struct_equality_query, + params={'struct_param': (1, 'bob')}, + param_types={'struct_param': struct_type}, + expected=[(True,)], + order=False, + ) + + # Query with nullness test for struct + self._check_sql_results( + self._db, + sql='SELECT @struct_param IS NULL', + params={'struct_param': None}, + param_types={'struct_param': struct_type}, + expected=[(True,)], + order=False, + ) + + # Query with null array-of-struct + array_elem_type = param_types.Struct([ + param_types.StructField('threadid', param_types.INT64), + ]) + array_type = param_types.Array(array_elem_type) + self._check_sql_results( + self._db, + sql='SELECT a.threadid FROM UNNEST(@struct_arr_param) a', + params={'struct_arr_param': None}, + param_types={'struct_arr_param': array_type}, + expected=[], + order=False, + ) + + # Query with non-null array-of-struct + self._check_sql_results( + self._db, + sql='SELECT a.threadid FROM UNNEST(@struct_arr_param) a', + params={'struct_arr_param': [(123,), (456,)]}, + param_types={'struct_arr_param': array_type}, + expected=[(123,), (456,)], + order=False, + ) + + # Query with null array-of-struct field + struct_type_with_array_field = param_types.Struct([ + param_types.StructField('intf', param_types.INT64), + param_types.StructField('arraysf', array_type), + ]) + self._check_sql_results( + self._db, + sql='SELECT a.threadid FROM UNNEST(@struct_param.arraysf) a', + params={'struct_param': (123, None)}, + param_types={'struct_param': struct_type_with_array_field}, + expected=[], + order=False, + ) + + # Query with non-null array-of-struct field + self._check_sql_results( + self._db, + sql='SELECT a.threadid FROM UNNEST(@struct_param.arraysf) a', + params={'struct_param': (123, ((456,), (789,)))}, + param_types={'struct_param': struct_type_with_array_field}, + expected=[(456,), (789,)], + order=False, + ) + + # Query with anonymous / repeated-name fields + anon_repeated_array_elem_type = param_types.Struct([ + param_types.StructField('', param_types.INT64), + param_types.StructField('', param_types.STRING), + ]) + anon_repeated_array_type = param_types.Array( + anon_repeated_array_elem_type) + self._check_sql_results( + self._db, + sql='SELECT CAST(t as STRUCT).* ' + 'FROM UNNEST(@struct_param) t', + params={'struct_param': [(123, 'abcdef')]}, + param_types={'struct_param': anon_repeated_array_type}, + expected=[(123, 'abcdef')], + order=False, + ) + + # Query and return a struct parameter + value_type = param_types.Struct([ + param_types.StructField('message', param_types.STRING), + param_types.StructField('repeat', param_types.INT64), + ]) + value_query = ( + 'SELECT ARRAY(SELECT AS STRUCT message, repeat ' + 'FROM (SELECT @value.message AS message, ' + '@value.repeat AS repeat)) AS value' + ) + self._check_sql_results( + self._db, + sql=value_query, + params={'value': ('hello', 1)}, + param_types={'value': value_type}, + expected=[([['hello', 1]],)], + order=False, + ) + def test_partition_query(self): row_count = 40 sql = 'SELECT * FROM {}'.format(self.TABLE) diff --git a/spanner/tests/unit/test__helpers.py b/spanner/tests/unit/test__helpers.py index 472affcfed93..5549e52ea131 100644 --- a/spanner/tests/unit/test__helpers.py +++ b/spanner/tests/unit/test__helpers.py @@ -60,6 +60,17 @@ def test_w_list(self): self.assertEqual([value.string_value for value in values], [u'a', u'b', u'c']) + def test_w_tuple(self): + from google.protobuf.struct_pb2 import Value + from google.protobuf.struct_pb2 import ListValue + + value_pb = self._callFUT((u'a', u'b', u'c')) + self.assertIsInstance(value_pb, Value) + self.assertIsInstance(value_pb.list_value, ListValue) + values = value_pb.list_value.values + self.assertEqual([value.string_value for value in values], + [u'a', u'b', u'c']) + def test_w_bool(self): from google.protobuf.struct_pb2 import Value @@ -124,6 +135,15 @@ def test_w_timestamp_w_nanos(self): self.assertIsInstance(value_pb, Value) self.assertEqual(value_pb.string_value, when.rfc3339()) + def test_w_listvalue(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner_v1._helpers import _make_list_value_pb + + list_value = _make_list_value_pb([1, 2, 3]) + value_pb = self._callFUT(list_value) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.list_value, list_value) + def test_w_datetime(self): import datetime import pytz