Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/enclave/Enclave/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ void non_oblivious_aggregate(
count += 1;
}

// Skip outputting the final row if the number of input rows is 0 AND
// 1. It's a grouping aggregation, OR
// Skip outputting the final row if:
// 1. The number of input rows is 0 AND it's a grouping aggregation, OR
Comment on lines -33 to +34
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this comment changed? I don't think the new meaning is equivalent to what the code says?

// 2. It's a global aggregation, the mode is final
if (!(count == 0 && (agg_op_eval.get_num_grouping_keys() > 0 || (agg_op_eval.get_num_grouping_keys() == 0 && !is_partial)))) {
w.append(agg_op_eval.evaluate());
Expand Down
30 changes: 30 additions & 0 deletions src/enclave/Enclave/ExpressionEvaluation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,9 @@ class AggregateExpressionEvaluator {
std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(eval_expr)));
}
is_distinct = expr->is_distinct();
value_selector = std::unique_ptr<FlatbuffersExpressionEvaluator>(
new FlatbuffersExpressionEvaluator(expr->value_selector()));
}

std::vector<const tuix::Field *> initial_values(const tuix::Row *unused) {
Expand All @@ -1824,6 +1827,15 @@ class AggregateExpressionEvaluator {
std::vector<const tuix::Field *> update(const tuix::Row *concat) {
std::vector<const tuix::Field *> result;
for (auto&& e : update_evaluators) {
if (is_distinct) {
std::string value = to_string(value_selector->eval(concat));
/* Check to see if this distinct value has already been counted */
if (observed_values.count(value)) {
std::vector<const tuix::Field *> vect(1, nullptr);
return vect;
}
observed_values.insert(value);
}
result.push_back(e->eval(concat));
}
return result;
Expand All @@ -1837,11 +1849,18 @@ class AggregateExpressionEvaluator {
return result;
}

void clear_observed_values() {
observed_values.clear();
}

private:
flatbuffers::FlatBufferBuilder builder;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> initial_value_evaluators;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> update_evaluators;
std::vector<std::unique_ptr<FlatbuffersExpressionEvaluator>> evaluate_evaluators;
bool is_distinct;
std::unique_ptr<FlatbuffersExpressionEvaluator> value_selector;
std::set<std::string> observed_values;
};

class FlatbuffersAggOpEvaluator {
Expand Down Expand Up @@ -1880,6 +1899,7 @@ class FlatbuffersAggOpEvaluator {
// Write initial values to a
std::vector<flatbuffers::Offset<tuix::Field>> init_fields;
for (auto&& e : aggregate_evaluators) {
e->clear_observed_values();
for (auto f : e->initial_values(nullptr)) {
init_fields.push_back(flatbuffers_copy<tuix::Field>(f, builder2));
}
Expand All @@ -1901,6 +1921,7 @@ class FlatbuffersAggOpEvaluator {
void aggregate(const tuix::Row *row) {
builder.Clear();
flatbuffers::Offset<tuix::Row> concat;
int a_length = a->field_values()->size();

std::vector<flatbuffers::Offset<tuix::Field>> concat_fields;
// concat row to a
Expand All @@ -1918,9 +1939,18 @@ class FlatbuffersAggOpEvaluator {
std::vector<flatbuffers::Offset<tuix::Field>> output_fields;
for (auto&& e : aggregate_evaluators) {
for (auto f : e->update(concat_ptr)) {
if (f == nullptr) { // Only triggered on EXPR(distinct expr ...)
output_fields.clear();
for (int i = 0; i < a_length; i++) {
auto f = concat_ptr->field_values()->Get(i);
output_fields.push_back(flatbuffers_copy<tuix::Field>(f, builder2));
}
goto save_a;
}
output_fields.push_back(flatbuffers_copy<tuix::Field>(f, builder2));
}
}
save_a:
a = flatbuffers::GetTemporaryPointer<tuix::Row>(
builder2, tuix::CreateRowDirect(builder2, &output_fields));
}
Expand Down
3 changes: 3 additions & 0 deletions src/flatbuffers/operators.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ table AggregateExpr {
initial_values: [Expr];
update_exprs: [Expr];
evaluate_exprs: [Expr];
// Items below are used for EXPR(distinct col_name ...)
is_distinct: bool;
value_selector: Expr;
}
// Supported: Average, Count, First, Last, Max, Min, Sum

Expand Down
50 changes: 41 additions & 9 deletions src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1371,13 +1371,25 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
false,
0
)

case c @ Count(children) =>
val count = c.aggBufferAttributes(0)
// COUNT(*) should count NULL values
// COUNT(expr) should return the number or rows for which the supplied expressions are non-NULL
// COUNT(distinct expr ...) should return the number of rows that contain UNIQUE values of expr

val ar = e.aggregateFunction.children(0)
val colNum = concatSchema.indexWhere(_.semanticEquals(ar))
val (isDistinct, valueSelector) = (e.isDistinct, colNum) match {
case (true, x) if x >= 0 => // If colNum < 0, then the given schema does not contain the attribute
(true, flatbuffersSerializeExpression(builder, ar, concatSchema))
case _ =>
(false, 0)
}

val (updateExprs: Seq[Expression], evaluateExprs: Seq[Expression]) = e.mode match {
case Partial => {
Expand All @@ -1396,7 +1408,7 @@ object Utils extends Logging {
val countUpdateExpr = Add(count, Literal(1L))
(Seq(countUpdateExpr), Seq(count))
}
case _ =>
case _ =>
}

tuix.AggregateExpr.createAggregateExpr(
Expand All @@ -1410,7 +1422,9 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
isDistinct,
valueSelector
)

case f @ First(child, false) =>
Expand Down Expand Up @@ -1449,7 +1463,10 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
false,
0
)

case l @ Last(child, false) =>
val last = l.aggBufferAttributes(0)
Expand Down Expand Up @@ -1487,7 +1504,10 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
false,
0
)

case m @ Max(child) =>
val max = m.aggBufferAttributes(0)
Expand Down Expand Up @@ -1520,7 +1540,10 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
false,
0
)

case m @ Min(child) =>
val min = m.aggBufferAttributes(0)
Expand Down Expand Up @@ -1553,7 +1576,10 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
false,
0
)

case s @ Sum(child) =>
val sum = s.aggBufferAttributes(0)
Expand Down Expand Up @@ -1591,7 +1617,10 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
false,
0
)

case vs @ ScalaUDAF(Seq(child), _: VectorSum, _, _) =>
val sum = vs.aggBufferAttributes(0)
Expand Down Expand Up @@ -1626,7 +1655,10 @@ object Utils extends Logging {
updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray),
tuix.AggregateExpr.createEvaluateExprsVector(
builder,
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray))
evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray),
false,
0
)
}
}

Expand Down
60 changes: 41 additions & 19 deletions src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,47 @@ object OpaqueOperators extends Strategy {
if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) =>

val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression])

if (groupingExpressions.size == 0) {
// Global aggregation
val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child))
val partialOutput = partialAggregate.output
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)

EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
EncryptedProjectExec(partialOutput,
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
} else {
// Grouping aggregation
EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial,
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil
val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct)

functionsWithDistinct.size match {
case size if size == 0 => // No distinct aggregate operations
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case 0 since you're matching on functionsWithDistinct.size?

if (groupingExpressions.size == 0) {
// Global aggregation
val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child))
val partialOutput = partialAggregate.output
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)

EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
EncryptedProjectExec(partialOutput,
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
} else {
// Grouping aggregation
EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial,
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil
}
case size if size == 1 => // One distinct aggregate operation
if (groupingExpressions.size == 0) {
// Global aggregation
val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child))
val partialOutput = partialAggregate.output
val (projSchema, tag) = tagForGlobalAggregate(partialOutput)

EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final,
EncryptedProjectExec(partialOutput,
EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true,
EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil
} else {
// Grouping aggregation
EncryptedProjectExec(resultExpressions,
EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Complete,
EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), true, planLater(child)))) :: Nil
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you see what happens when there are multiple distincts? We should catch it here or somewhere else and say that we do not support it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting org.apache.spark.sql.execution.WholeStageCodegenExec cannot be cast to edu.berkeley.cs.rise.opaque.execution.OpaqueOperatorExec.

}

case p @ Union(Seq(left, right)) if isEncrypted(p) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,13 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self =>
.collect.sortBy { case Row(category: String, _) => category }
}

testAgainstSpark("aggregate count - distinct") { securityLevel =>
val data = (0 until 32).map{ i => (abc(i), i % 8)}.toSeq
val words = makeDF(data, securityLevel, "category", "price")
words.groupBy("category").agg(countDistinct("price").as("distinctPrices"))
.collect.sortBy { case Row(category: String, _) => category }
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another test for global distinct aggregation, as well as tests for when the number of distinct items is 0?

testAgainstSpark("aggregate first") { securityLevel =>
val data = for (i <- 0 until 256) yield (i, abc(i), 1)
val words = makeDF(data, securityLevel, "id", "category", "price")
Expand Down