Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,19 @@ public Rel visit(org.apache.calcite.rel.core.Values values) {
NamedStruct type = typeConverter.toNamedStruct(values.getRowType());

LiteralConverter literalConverter = new LiteralConverter(typeConverter);
List<Type> schemaFieldTypes = type.struct().fields();
List<Expression.NestedStruct> structs =
values.getTuples().stream()
.map(
list -> {
// Use schema nullability since Calcite infers non-nullable for all non-null
// values
List<Expression> fields =
list.stream()
.map(l -> literalConverter.convert(l))
IntStream.range(0, list.size())
.mapToObj(
i ->
literalConverter.convert(
list.get(i), schemaFieldTypes.get(i).nullable()))
.collect(Collectors.toUnmodifiableList());
return ExpressionCreator.nestedStruct(false, fields);
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.Duration;
Expand Down Expand Up @@ -69,72 +70,89 @@ private static BigDecimal bd(RexLiteral literal) {
public Expression.Literal convert(RexLiteral literal) {
// convert type first to guarantee we can handle the value.
final Type type = typeConverter.toSubstrait(literal.getType());
final boolean n = type.nullable();
return convert(literal, type.nullable());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now wondering if it's the case that type.nullable() will always be false. If nullability of literals is always lost in calcite values, won't this always just be the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, null literals have nullable types in Calcite, so it's not always non-nullable.
I have added the nullLiteralRoundTrip test to show this.

}

/**
* Converts a RexLiteral to a Substrait Literal with the specified nullability.
*
* <p>This overload is useful when the target nullability should come from the schema rather than
* the literal's own type. For example, Calcite's LogicalValues may have literals with
* non-nullable types even when the schema field is nullable.
*
* @param literal the RexLiteral to convert
* @param nullable the nullability to use for the resulting Substrait literal
* @return the converted Substrait Literal
*/
public Expression.Literal convert(RexLiteral literal, boolean nullable) {
if (literal.isNull()) {
return ExpressionCreator.typedNull(type);
final Type type = typeConverter.toSubstrait(literal.getType());
final Type typeWithNullability =
nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type);
return ExpressionCreator.typedNull(typeWithNullability);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that as of #686, it is the case that typedNull will actually throw an exception if it is passed a non-null type. Just commenting so you are aware that if the TypeCreator.asNotNullable case is hit here, an exception will be thrown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing, I gave it a bit more thinking and I think it's fine and the thrown error will be clear enough, no need to be defensive here nor adjust the error message (which is what you are suggesting, IIUC)

}

switch (literal.getType().getSqlTypeName()) {
case TINYINT:
return ExpressionCreator.i8(n, i(literal).intValue());
return ExpressionCreator.i8(nullable, i(literal).intValue());
case SMALLINT:
return ExpressionCreator.i16(n, i(literal).intValue());
return ExpressionCreator.i16(nullable, i(literal).intValue());
case INTEGER:
return ExpressionCreator.i32(n, i(literal).intValue());
return ExpressionCreator.i32(nullable, i(literal).intValue());
case BIGINT:
return ExpressionCreator.i64(n, i(literal).longValue());
return ExpressionCreator.i64(nullable, i(literal).longValue());
case BOOLEAN:
return ExpressionCreator.bool(n, literal.getValueAs(Boolean.class));
return ExpressionCreator.bool(nullable, literal.getValueAs(Boolean.class));
case CHAR:
{
Comparable<?> val = literal.getValue();
if (val instanceof NlsString) {
NlsString nls = (NlsString) val;
return ExpressionCreator.fixedChar(n, nls.getValue());
return ExpressionCreator.fixedChar(nullable, nls.getValue());
}
throw new UnsupportedOperationException("Unable to handle char type: " + val);
}
case FLOAT:
case DOUBLE:
return ExpressionCreator.fp64(n, literal.getValueAs(Double.class));
return ExpressionCreator.fp64(nullable, literal.getValueAs(Double.class));
case REAL:
return ExpressionCreator.fp32(n, literal.getValueAs(Float.class));
return ExpressionCreator.fp32(nullable, literal.getValueAs(Float.class));

case DECIMAL:
{
BigDecimal bd = bd(literal);
return ExpressionCreator.decimal(
n, bd, literal.getType().getPrecision(), literal.getType().getScale());
nullable, bd, literal.getType().getPrecision(), literal.getType().getScale());
}
case VARCHAR:
{
if (literal.getType().getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) {
return ExpressionCreator.string(n, s(literal));
return ExpressionCreator.string(nullable, s(literal));
}

return ExpressionCreator.varChar(n, s(literal), literal.getType().getPrecision());
return ExpressionCreator.varChar(nullable, s(literal), literal.getType().getPrecision());
}
case BINARY:
return ExpressionCreator.fixedBinary(
n,
nullable,
ByteString.copyFrom(
padRightIfNeeded(
literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class),
literal.getType().getPrecision())));
case VARBINARY:
return ExpressionCreator.binary(n, ByteString.copyFrom(literal.getValueAs(byte[].class)));
return ExpressionCreator.binary(
nullable, ByteString.copyFrom(literal.getValueAs(byte[].class)));
case SYMBOL:
{
Object value = literal.getValue();
if (value instanceof NlsString) {
return ExpressionCreator.string(n, ((NlsString) value).getValue());
return ExpressionCreator.string(nullable, ((NlsString) value).getValue());
} else if (value instanceof Enum) {
Enum<?> v = (Enum<?>) value;

Optional<Expression.Literal> r =
EnumConverter.canConvert(v)
? Optional.of(ExpressionCreator.string(n, v.name()))
? Optional.of(ExpressionCreator.string(nullable, v.name()))
: Optional.empty();
return r.orElseThrow(
() -> new UnsupportedOperationException("Unable to handle symbol: " + value));
Expand All @@ -146,21 +164,22 @@ public Expression.Literal convert(RexLiteral literal) {
{
DateString date = literal.getValueAs(DateString.class);
LocalDate localDate = LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER);
return ExpressionCreator.date(n, (int) localDate.toEpochDay());
return ExpressionCreator.date(nullable, (int) localDate.toEpochDay());
}
case TIME:
{
TimeString time = literal.getValueAs(TimeString.class);
LocalTime localTime = LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER);
return ExpressionCreator.time(n, TimeUnit.NANOSECONDS.toMicros(localTime.toNanoOfDay()));
return ExpressionCreator.time(
nullable, TimeUnit.NANOSECONDS.toMicros(localTime.toNanoOfDay()));
}
case TIMESTAMP:
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
{
TimestampString timestamp = literal.getValueAs(TimestampString.class);
LocalDateTime ldt =
LocalDateTime.parse(timestamp.toString(), CALCITE_LOCAL_DATETIME_FORMATTER);
return ExpressionCreator.timestamp(n, ldt);
return ExpressionCreator.timestamp(nullable, ldt);
}
case INTERVAL_YEAR:
case INTERVAL_YEAR_MONTH:
Expand All @@ -169,7 +188,7 @@ public Expression.Literal convert(RexLiteral literal) {
long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class));
long years = intervalLength / 12;
long months = intervalLength - years * 12;
return ExpressionCreator.intervalYear(n, (int) years, (int) months);
return ExpressionCreator.intervalYear(nullable, (int) years, (int) months);
}
case INTERVAL_DAY:
case INTERVAL_DAY_HOUR:
Expand All @@ -182,29 +201,29 @@ public Expression.Literal convert(RexLiteral literal) {
case INTERVAL_MINUTE_SECOND:
case INTERVAL_SECOND:
{
// Calcite represents day/time intervals in milliseconds, despite a default scale of 6.
// Calcite represents day/time intervals in milliseconds, despite a default scale of 6
Long totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class));
Duration interval = Duration.ofMillis(totalMillis);

long days = interval.toDays();
long seconds = interval.minusDays(days).toSeconds();
int micros = interval.toMillisPart() * 1000;

return ExpressionCreator.intervalDay(n, (int) days, (int) seconds, micros, 6);
return ExpressionCreator.intervalDay(nullable, (int) days, (int) seconds, micros, 6);
}

case ROW:
{
List<RexLiteral> literals = (List<RexLiteral>) literal.getValue();
return ExpressionCreator.struct(
n, literals.stream().map(this::convert).collect(Collectors.toList()));
nullable, literals.stream().map(this::convert).collect(Collectors.toList()));
}

case ARRAY:
{
List<RexLiteral> literals = (List<RexLiteral>) literal.getValue();
return ExpressionCreator.list(
n, literals.stream().map(this::convert).collect(Collectors.toList()));
nullable, literals.stream().map(this::convert).collect(Collectors.toList()));
}

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import java.io.PrintWriter;
Expand Down Expand Up @@ -81,6 +82,35 @@ void emptySchemaNonEmptyTable() {
AssertionError.class, () -> createVirtualTableScan(schema, List.of(sb.i32(3), sb.fp64(8))));
}

@Test
void nullableFieldRoundTrip() {
NamedStruct schema = NamedStruct.of(List.of("col1", "col2"), R.struct(N.I32, R.FP64));
Expression nullableI32 = ExpressionCreator.i32(true, 6);
VirtualTableScan virtualTableScan =
createVirtualTableScan(schema, List.of(nullableI32, sb.fp64(8)));
assertFullRoundTrip(virtualTableScan);
}

@Test
void nullLiteralRoundTrip() {
NamedStruct schema = NamedStruct.of(List.of("col1", "col2"), R.struct(N.I32, N.FP64));
Expression nullI32 = ExpressionCreator.typedNull(N.I32);
Expression nullFp64 = ExpressionCreator.typedNull(N.FP64);
VirtualTableScan virtualTableScan = createVirtualTableScan(schema, List.of(nullI32, nullFp64));
assertFullRoundTrip(virtualTableScan);
}

@Test
void mixedNullabilityRoundTrip() {
NamedStruct schema =
NamedStruct.of(List.of("col1", "col2", "col3"), R.struct(N.I32, R.FP64, N.STRING));
Expression nullI32 = ExpressionCreator.typedNull(N.I32);
Expression nullString = ExpressionCreator.typedNull(N.STRING);
VirtualTableScan virtualTableScan =
createVirtualTableScan(schema, List.of(nullI32, sb.fp64(8), nullString));
assertFullRoundTrip(virtualTableScan);
}

@SafeVarargs
private VirtualTableScan createVirtualTableScan(NamedStruct schema, List<Expression>... rows) {
List<Expression.NestedStruct> structs =
Expand Down