Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
}

override def foldable: Boolean = child.foldable

override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable

override def toString: String = s"CAST($child, $dataType)"
Expand Down Expand Up @@ -426,10 +424,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)

override def eval(input: InternalRow): Any = {
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
}
protected override def nullSafeEval(input: Any): Any = cast(input)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
// TODO: Add support for more data types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable

/**
* Default behavior of evaluation according to the default nullability of UnaryExpression.
* If subclass of UnaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
nullSafeEval(value)
}
}

/**
* Called by default [[eval]] implementation. If subclass of UnaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input: Any): Any =
sys.error(s"UnaryExpressions must override either eval or nullSafeEval")

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
Expand All @@ -198,21 +219,24 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {
nullSafeCodeGen(ctx, ev, (result, eval) => {
s"$result = ${f(eval)};"
nullSafeCodeGen(ctx, ev, eval => {
s"${ev.primitive} = ${f(eval)};"
})
}

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
*
* @param f function that accepts the non-null evaluation result name of child and returns Java
* code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
f: String => String): String = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the javadoc above to include param doc for f, defining what the input / output is?

val eval = child.gen(ctx)
val resultCode = f(ev.primitive, eval.primitive)
val resultCode = f(eval.primitive)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
Expand All @@ -235,6 +259,32 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

override def nullable: Boolean = left.nullable || right.nullable

/**
* Default behavior of evaluation according to the default nullability of BinaryExpression.
* If subclass of BinaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val value1 = left.eval(input)
if (value1 == null) {
null
} else {
val value2 = right.eval(input)
if (value2 == null) {
null
} else {
nullSafeEval(value1, value2)
}
}
}

/**
* Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input1: Any, input2: Any): Any =
sys.error(s"BinaryExpressions must override either eval or nullSafeEval")

/**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
Expand All @@ -246,23 +296,26 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
s"$result = ${f(eval1, eval2)};"
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"${ev.primitive} = ${f(eval1, eval2)};"
})
}

/**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 2 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String, String) => String): String = {
f: (String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
val resultCode = f(eval1.primitive, eval2.primitive)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,16 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable

override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[InternalRow]
if (baseValue == null) null else baseValue(ordinal)
}
protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow](ordinal)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, eval) => {
nullSafeCodeGen(ctx, ev, eval => {
s"""
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
$result = ${ctx.getColumn(eval, dataType, ordinal)};
${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)};
}
"""
})
Expand All @@ -152,20 +150,17 @@ case class GetArrayStructFields(
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable || containsNull || field.nullable

override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
if (baseValue == null) null else {
baseValue.map { row =>
if (row == null) null else row(ordinal)
}
protected override def nullSafeEval(input: Any): Any = {
input.asInstanceOf[Seq[InternalRow]].map { row =>
if (row == null) null else row(ordinal)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val arraySeqClass = "scala.collection.mutable.ArraySeq"
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
nullSafeCodeGen(ctx, ev, (result, eval) => {
nullSafeCodeGen(ctx, ev, eval => {
s"""
final int n = $eval.size();
final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n);
Expand All @@ -175,7 +170,7 @@ case class GetArrayStructFields(
values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
}
}
$result = (${ctx.javaType(dataType)}) values;
${ev.primitive} = (${ctx.javaType(dataType)}) values;
"""
})
}
Expand All @@ -193,22 +188,6 @@ abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValu
/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
override def toString: String = s"$child[$ordinal]"

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
val o = ordinal.eval(input)
if (o == null) {
null
} else {
evalNotNull(value, o)
}
}
}

protected def evalNotNull(value: Any, ordinal: Any): Any
}

/**
Expand All @@ -219,7 +198,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

protected def evalNotNull(value: Any, ordinal: Any) = {
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
Expand All @@ -232,13 +211,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int index = (int)$eval2;
if (index >= $eval1.size() || index < 0) {
${ev.isNull} = true;
} else {
$result = (${ctx.boxedType(dataType)})$eval1.apply(index);
${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index);
}
"""
})
Expand All @@ -253,16 +232,16 @@ case class GetMapValue(child: Expression, ordinal: Expression)

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

protected def evalNotNull(value: Any, ordinal: Any) = {
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
if ($eval1.contains($eval2)) {
$result = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
} else {
${ev.isNull} = true;
}
Expand Down
Loading