diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 025ab8791..26cf37c0c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -165,13 +165,19 @@ public Rel visit(org.apache.calcite.rel.core.Values values) { NamedStruct type = typeConverter.toNamedStruct(values.getRowType()); LiteralConverter literalConverter = new LiteralConverter(typeConverter); + List schemaFieldTypes = type.struct().fields(); List structs = values.getTuples().stream() .map( list -> { + // Use schema nullability since Calcite infers non-nullable for all non-null + // values List fields = - list.stream() - .map(l -> literalConverter.convert(l)) + IntStream.range(0, list.size()) + .mapToObj( + i -> + literalConverter.convert( + list.get(i), schemaFieldTypes.get(i).nullable())) .collect(Collectors.toUnmodifiableList()); return ExpressionCreator.nestedStruct(false, fields); }) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 5d91e6a9a..2e1cad6ca 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -5,6 +5,7 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.isthmus.TypeConverter; import io.substrait.type.Type; +import io.substrait.type.TypeCreator; import java.math.BigDecimal; import java.math.RoundingMode; import java.time.Duration; @@ -69,72 +70,89 @@ private static BigDecimal bd(RexLiteral literal) { public Expression.Literal convert(RexLiteral literal) { // convert type first to guarantee we can handle the value. final Type type = typeConverter.toSubstrait(literal.getType()); - final boolean n = type.nullable(); + return convert(literal, type.nullable()); + } + /** + * Converts a RexLiteral to a Substrait Literal with the specified nullability. + * + *

This overload is useful when the target nullability should come from the schema rather than + * the literal's own type. For example, Calcite's LogicalValues may have literals with + * non-nullable types even when the schema field is nullable. + * + * @param literal the RexLiteral to convert + * @param nullable the nullability to use for the resulting Substrait literal + * @return the converted Substrait Literal + */ + public Expression.Literal convert(RexLiteral literal, boolean nullable) { if (literal.isNull()) { - return ExpressionCreator.typedNull(type); + final Type type = typeConverter.toSubstrait(literal.getType()); + final Type typeWithNullability = + nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type); + return ExpressionCreator.typedNull(typeWithNullability); } switch (literal.getType().getSqlTypeName()) { case TINYINT: - return ExpressionCreator.i8(n, i(literal).intValue()); + return ExpressionCreator.i8(nullable, i(literal).intValue()); case SMALLINT: - return ExpressionCreator.i16(n, i(literal).intValue()); + return ExpressionCreator.i16(nullable, i(literal).intValue()); case INTEGER: - return ExpressionCreator.i32(n, i(literal).intValue()); + return ExpressionCreator.i32(nullable, i(literal).intValue()); case BIGINT: - return ExpressionCreator.i64(n, i(literal).longValue()); + return ExpressionCreator.i64(nullable, i(literal).longValue()); case BOOLEAN: - return ExpressionCreator.bool(n, literal.getValueAs(Boolean.class)); + return ExpressionCreator.bool(nullable, literal.getValueAs(Boolean.class)); case CHAR: { Comparable val = literal.getValue(); if (val instanceof NlsString) { NlsString nls = (NlsString) val; - return ExpressionCreator.fixedChar(n, nls.getValue()); + return ExpressionCreator.fixedChar(nullable, nls.getValue()); } throw new UnsupportedOperationException("Unable to handle char type: " + val); } case FLOAT: case DOUBLE: - return ExpressionCreator.fp64(n, literal.getValueAs(Double.class)); + return ExpressionCreator.fp64(nullable, literal.getValueAs(Double.class)); case REAL: - return ExpressionCreator.fp32(n, literal.getValueAs(Float.class)); + return ExpressionCreator.fp32(nullable, literal.getValueAs(Float.class)); case DECIMAL: { BigDecimal bd = bd(literal); return ExpressionCreator.decimal( - n, bd, literal.getType().getPrecision(), literal.getType().getScale()); + nullable, bd, literal.getType().getPrecision(), literal.getType().getScale()); } case VARCHAR: { if (literal.getType().getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { - return ExpressionCreator.string(n, s(literal)); + return ExpressionCreator.string(nullable, s(literal)); } - return ExpressionCreator.varChar(n, s(literal), literal.getType().getPrecision()); + return ExpressionCreator.varChar(nullable, s(literal), literal.getType().getPrecision()); } case BINARY: return ExpressionCreator.fixedBinary( - n, + nullable, ByteString.copyFrom( padRightIfNeeded( literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class), literal.getType().getPrecision()))); case VARBINARY: - return ExpressionCreator.binary(n, ByteString.copyFrom(literal.getValueAs(byte[].class))); + return ExpressionCreator.binary( + nullable, ByteString.copyFrom(literal.getValueAs(byte[].class))); case SYMBOL: { Object value = literal.getValue(); if (value instanceof NlsString) { - return ExpressionCreator.string(n, ((NlsString) value).getValue()); + return ExpressionCreator.string(nullable, ((NlsString) value).getValue()); } else if (value instanceof Enum) { Enum v = (Enum) value; Optional r = EnumConverter.canConvert(v) - ? Optional.of(ExpressionCreator.string(n, v.name())) + ? Optional.of(ExpressionCreator.string(nullable, v.name())) : Optional.empty(); return r.orElseThrow( () -> new UnsupportedOperationException("Unable to handle symbol: " + value)); @@ -146,13 +164,14 @@ public Expression.Literal convert(RexLiteral literal) { { DateString date = literal.getValueAs(DateString.class); LocalDate localDate = LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER); - return ExpressionCreator.date(n, (int) localDate.toEpochDay()); + return ExpressionCreator.date(nullable, (int) localDate.toEpochDay()); } case TIME: { TimeString time = literal.getValueAs(TimeString.class); LocalTime localTime = LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER); - return ExpressionCreator.time(n, TimeUnit.NANOSECONDS.toMicros(localTime.toNanoOfDay())); + return ExpressionCreator.time( + nullable, TimeUnit.NANOSECONDS.toMicros(localTime.toNanoOfDay())); } case TIMESTAMP: case TIMESTAMP_WITH_LOCAL_TIME_ZONE: @@ -160,7 +179,7 @@ public Expression.Literal convert(RexLiteral literal) { TimestampString timestamp = literal.getValueAs(TimestampString.class); LocalDateTime ldt = LocalDateTime.parse(timestamp.toString(), CALCITE_LOCAL_DATETIME_FORMATTER); - return ExpressionCreator.timestamp(n, ldt); + return ExpressionCreator.timestamp(nullable, ldt); } case INTERVAL_YEAR: case INTERVAL_YEAR_MONTH: @@ -169,7 +188,7 @@ public Expression.Literal convert(RexLiteral literal) { long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class)); long years = intervalLength / 12; long months = intervalLength - years * 12; - return ExpressionCreator.intervalYear(n, (int) years, (int) months); + return ExpressionCreator.intervalYear(nullable, (int) years, (int) months); } case INTERVAL_DAY: case INTERVAL_DAY_HOUR: @@ -182,7 +201,7 @@ public Expression.Literal convert(RexLiteral literal) { case INTERVAL_MINUTE_SECOND: case INTERVAL_SECOND: { - // Calcite represents day/time intervals in milliseconds, despite a default scale of 6. + // Calcite represents day/time intervals in milliseconds, despite a default scale of 6 Long totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class)); Duration interval = Duration.ofMillis(totalMillis); @@ -190,21 +209,21 @@ public Expression.Literal convert(RexLiteral literal) { long seconds = interval.minusDays(days).toSeconds(); int micros = interval.toMillisPart() * 1000; - return ExpressionCreator.intervalDay(n, (int) days, (int) seconds, micros, 6); + return ExpressionCreator.intervalDay(nullable, (int) days, (int) seconds, micros, 6); } case ROW: { List literals = (List) literal.getValue(); return ExpressionCreator.struct( - n, literals.stream().map(this::convert).collect(Collectors.toList())); + nullable, literals.stream().map(this::convert).collect(Collectors.toList())); } case ARRAY: { List literals = (List) literal.getValue(); return ExpressionCreator.list( - n, literals.stream().map(this::convert).collect(Collectors.toList())); + nullable, literals.stream().map(this::convert).collect(Collectors.toList())); } default: diff --git a/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java b/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java index 71f4e1262..e038cedd6 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java @@ -5,6 +5,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; import java.io.PrintWriter; @@ -81,6 +82,35 @@ void emptySchemaNonEmptyTable() { AssertionError.class, () -> createVirtualTableScan(schema, List.of(sb.i32(3), sb.fp64(8)))); } + @Test + void nullableFieldRoundTrip() { + NamedStruct schema = NamedStruct.of(List.of("col1", "col2"), R.struct(N.I32, R.FP64)); + Expression nullableI32 = ExpressionCreator.i32(true, 6); + VirtualTableScan virtualTableScan = + createVirtualTableScan(schema, List.of(nullableI32, sb.fp64(8))); + assertFullRoundTrip(virtualTableScan); + } + + @Test + void nullLiteralRoundTrip() { + NamedStruct schema = NamedStruct.of(List.of("col1", "col2"), R.struct(N.I32, N.FP64)); + Expression nullI32 = ExpressionCreator.typedNull(N.I32); + Expression nullFp64 = ExpressionCreator.typedNull(N.FP64); + VirtualTableScan virtualTableScan = createVirtualTableScan(schema, List.of(nullI32, nullFp64)); + assertFullRoundTrip(virtualTableScan); + } + + @Test + void mixedNullabilityRoundTrip() { + NamedStruct schema = + NamedStruct.of(List.of("col1", "col2", "col3"), R.struct(N.I32, R.FP64, N.STRING)); + Expression nullI32 = ExpressionCreator.typedNull(N.I32); + Expression nullString = ExpressionCreator.typedNull(N.STRING); + VirtualTableScan virtualTableScan = + createVirtualTableScan(schema, List.of(nullI32, sb.fp64(8), nullString)); + assertFullRoundTrip(virtualTableScan); + } + @SafeVarargs private VirtualTableScan createVirtualTableScan(NamedStruct schema, List... rows) { List structs =