From eabb65b6375967b954c93224dc9ea7b9c96568b3 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 5 Sep 2018 08:10:47 +0900 Subject: [PATCH 01/60] WIP nothing worked, just recording the progress --- .../streaming/UpdatingSessionIterator.scala | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala new file mode 100644 index 0000000000000..76258ef75d400 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +class UpdateSessionIterator( + iter: Iterator[InternalRow], + groupWithoutSessionExpressions: Seq[Expression], + sessionExpression: Expression, + aggregateExpressions: Seq[Expression], + inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + +} From c3076d2e62631b0a7a3a5e930b67ab2082d31e83 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 6 Sep 2018 13:36:11 +0900 Subject: [PATCH 02/60] WIP not working yet... lots of implementations needed --- .../streaming/UpdatingSessionIterator.scala | 128 +++++++++++++++++- 1 file changed, 122 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala index 76258ef75d400..6a98dc2f34ec1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala @@ -17,14 +17,130 @@ package org.apache.spark.sql.streaming +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, Literal, PreciseTimestampConversion, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.types.{LongType, TimestampType} class UpdateSessionIterator( - iter: Iterator[InternalRow], - groupWithoutSessionExpressions: Seq[Expression], - sessionExpression: Expression, - aggregateExpressions: Seq[Expression], - inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + iter: Iterator[InternalRow], + groupWithoutSessionExpressions: Seq[Expression], + sessionExpression: Expression, + aggregateExpressions: Seq[Expression], + inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + + var currentKeys: InternalRow = _ + var currentSessionStart: Long = null + var currentSessionEnd: Long = null + + var currentRows: mutable.MutableList[InternalRow] = _ + + var returnRowsIter: Iterator[InternalRow] = _ + + val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression)) + val aggregateProjections = GenerateUnsafeProjection.generate(Seq(aggregateProjections)) + + override def hasNext: Boolean = { + if (returnRowsIter != null && returnRowsIter.hasNext) { + return true + } + + if (returnRowsIter != null) { + returnRowsIter = null + } + + iter.hasNext + } + + override def next(): InternalRow = { + if (returnRowsIter != null && returnRowsIter.hasNext) { + return returnRowsIter.next() + } + + while (iter.hasNext) { + val row = iter.next() + + val keys = keysProjection(row) + val session = sessionProjection(row) + val sessionStart = session.getLong(0) + val sessionEnd = session.getLong(1) + + if (keys != currentKeys) { + closeCurrentSession() + startNewSession(row, keys, sessionStart, sessionEnd) + } else { + if (sessionStart < currentSessionStart) { + throw new IllegalStateException("The iterator must be sorted by key and session start!") + } else if (sessionStart <= currentSessionEnd) { + // expanding session length if needed + expandEndOfCurrentSession(sessionEnd) + currentRows += row + } else { + closeCurrentSession() + startNewSession(row, keys, sessionStart, sessionEnd) + } + } + } + + if (!iter.hasNext) { + // no further row: closing session + closeCurrentSession() + } + + // here returnRowsIter should be at least one row + require(returnRowsIter != null && returnRowsIter.hasNext) + + returnRowsIter.next() + } + + private def expandEndOfCurrentSession(sessionEnd: Long): Unit = { + if (sessionEnd > currentSessionEnd) { + currentSessionEnd = sessionEnd + } + } + + private def startNewSession(row: InternalRow, keys: UnsafeRow, sessionStart: Long, + sessionEnd: Long): Unit = { + currentKeys = keys + currentSessionStart = sessionStart + currentSessionEnd = sessionEnd + currentRows = new mutable.MutableList[InternalRow]() + currentRows += row + } + + private def closeCurrentSession(): Unit = { + val convertedGroupWithoutSessionExpressions = groupWithoutSessionExpressions.map { x => + BindReferences.bindReference[Expression](x, inputSchema) + } + val convertedAggregateExpressions = aggregateExpressions.map { + x => BindReferences.bindReference[Expression](x, inputSchema) + } + + val returnRows = currentRows.map { internalRow => + val sessionStruct = CreateNamedStruct( + Literal("start") :: + PreciseTimestampConversion( + Literal(currentSessionStart, LongType), LongType, TimestampType) :: + Literal("end") :: + PreciseTimestampConversion( + Literal(currentSessionEnd, LongType), LongType, TimestampType) :: + Nil) + + val valueExpressions = convertedGroupWithoutSessionExpressions ++ Seq(sessionStruct) ++ + convertedAggregateExpressions + + val proj = GenerateUnsafeProjection.generate(valueExpressions, inputSchema) + proj(internalRow) + }.toList + + returnRowsIter = returnRows.iterator + currentKeys = null + currentSessionStart = null + currentSessionEnd = null + currentRows = null + } } From 9d59c7af301e1dc34f4f02cb1cc298ba0a71b9b5 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 6 Sep 2018 22:31:08 +0900 Subject: [PATCH 03/60] WIP Finished implementing UpdatingSessionIterator --- .../streaming/UpdatingSessionIterator.scala | 47 ++- .../UpdatingSessionIteratorSuite.scala | 301 ++++++++++++++++++ 2 files changed, 334 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala index 6a98dc2f34ec1..5afd964076392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Cre import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{LongType, TimestampType} -class UpdateSessionIterator( +class UpdatingSessionIterator( iter: Iterator[InternalRow], groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, @@ -32,18 +32,17 @@ class UpdateSessionIterator( inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { var currentKeys: InternalRow = _ - var currentSessionStart: Long = null - var currentSessionEnd: Long = null + var currentSessionStart: Long = Long.MaxValue + var currentSessionEnd: Long = Long.MinValue var currentRows: mutable.MutableList[InternalRow] = _ var returnRowsIter: Iterator[InternalRow] = _ - - val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions) - val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression)) - val aggregateProjections = GenerateUnsafeProjection.generate(Seq(aggregateProjections)) + var errorOnIterator: Boolean = false override def hasNext: Boolean = { + assertIteratorNotCorrupted() + if (returnRowsIter != null && returnRowsIter.hasNext) { return true } @@ -56,23 +55,35 @@ class UpdateSessionIterator( } override def next(): InternalRow = { + assertIteratorNotCorrupted() + if (returnRowsIter != null && returnRowsIter.hasNext) { return returnRowsIter.next() } - while (iter.hasNext) { + var exitCondition = false + while (iter.hasNext && !exitCondition) { val row = iter.next() + val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, + inputSchema) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) + val keys = keysProjection(row) val session = sessionProjection(row) - val sessionStart = session.getLong(0) - val sessionEnd = session.getLong(1) + val sessionRow = session.getStruct(0, 2) + val sessionStart = sessionRow.getLong(0) + val sessionEnd = sessionRow.getLong(1) - if (keys != currentKeys) { + if (currentKeys == null) { + startNewSession(row, keys, sessionStart, sessionEnd) + } else if (keys != currentKeys) { closeCurrentSession() startNewSession(row, keys, sessionStart, sessionEnd) + exitCondition = true } else { if (sessionStart < currentSessionStart) { + errorOnIterator = true throw new IllegalStateException("The iterator must be sorted by key and session start!") } else if (sessionStart <= currentSessionEnd) { // expanding session length if needed @@ -81,6 +92,7 @@ class UpdateSessionIterator( } else { closeCurrentSession() startNewSession(row, keys, sessionStart, sessionEnd) + exitCondition = true } } } @@ -90,7 +102,7 @@ class UpdateSessionIterator( closeCurrentSession() } - // here returnRowsIter should be at least one row + // here returnRowsIter should be able to provide at least one row require(returnRowsIter != null && returnRowsIter.hasNext) returnRowsIter.next() @@ -139,8 +151,15 @@ class UpdateSessionIterator( returnRowsIter = returnRows.iterator currentKeys = null - currentSessionStart = null - currentSessionEnd = null + currentSessionStart = Long.MaxValue + currentSessionEnd = Long.MinValue currentRows = null } + + private def assertIteratorNotCorrupted(): Unit = { + if (errorOnIterator) { + throw new IllegalStateException("The iterator is already corrupted.") + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala new file mode 100644 index 0000000000000..b1c88c70e0214 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class UpdatingSessionIteratorSuite extends SharedSQLContext { + + val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val rowAttributes = rowSchema.toAttributes + + val keysWithoutSessionSchema = rowSchema.filter(st => List("key1", "key2").contains(st.name)) + val keysWithoutSessionAttributes = rowAttributes.filter { + attr => List("key1", "key2").contains(attr.name) + } + + val sessionSchema = rowSchema.filter(st => st.name == "session").head + val sessionAttribute = rowAttributes.filter(attr => attr.name == "session").head + + val valuesSchema = rowSchema.filter(st => List("aggVal1", "aggVal2").contains(st.name)) + val valuesAttributes = rowAttributes.filter { + attr => List("aggVal1", "aggVal2").contains(attr.name) + } + + test("only one row") { + val rows = List(createRow("a", 1, 100, 110, 10, 1.1)) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + assert(iterator.hasNext) + + val retRow = iterator.next() + assertRowsEquals(retRow, rows.head) + + assert(!iterator.hasNext) + } + + test("one session per key, one key") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 105, 115, 30, 1.3) + val row4 = createRow("a", 1, 113, 123, 40, 1.4) + val rows = List(row1, row2, row3, row4) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + val retRows = rows.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows.zip(rows).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 123) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 123) + } + + assert(iterator.hasNext === false) + } + + test("one session per key, multi keys") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 105, 115, 30, 1.3) + val row4 = createRow("a", 1, 113, 123, 40, 1.4) + val rows1 = List(row1, row2, row3, row4) + + val row5 = createRow("a", 2, 110, 120, 10, 1.1) + val row6 = createRow("a", 2, 115, 125, 20, 1.2) + val row7 = createRow("a", 2, 117, 127, 30, 1.3) + val row8 = createRow("a", 2, 125, 135, 40, 1.4) + val rows2 = List(row5, row6, row7, row8) + + val rowsAll = rows1 ++ rows2 + + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + val retRows1 = rows1.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + val retRows2 = rows2.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows1.zip(rows1).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 123) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 123) + } + + retRows2.zip(rows2).foreach { case (retRow, expectedRow) => + // session being expanded to (110 ~ 135) + assertRowsEqualsWithNewSession(expectedRow, retRow, 110, 135) + } + + assert(iterator.hasNext === false) + } + + test("multiple sessions per key, single key") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 105, 115, 20, 1.2) + val rows1 = List(row1, row2) + + val row3 = createRow("a", 1, 125, 135, 30, 1.3) + val row4 = createRow("a", 1, 127, 137, 40, 1.4) + val rows2 = List(row3, row4) + + val rowsAll = rows1 ++ rows2 + + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + val retRows1 = rows1.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows2 = rows2.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows1.zip(rows1).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 115) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 115) + } + + retRows2.zip(rows2).foreach { case (retRow, expectedRow) => + // session being expanded to (125 ~ 137) + assertRowsEqualsWithNewSession(expectedRow, retRow, 125, 137) + } + + assert(iterator.hasNext === false) + } + + test("multiple sessions per key, multi keys") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val rows1 = List(row1, row2) + + val row3 = createRow("a", 1, 115, 125, 30, 1.3) + val row4 = createRow("a", 1, 119, 129, 40, 1.4) + val rows2 = List(row3, row4) + + val row5 = createRow("a", 2, 110, 120, 10, 1.1) + val row6 = createRow("a", 2, 115, 125, 20, 1.2) + val rows3 = List(row5, row6) + + val row7 = createRow("a", 2, 127, 137, 30, 1.3) + val row8 = createRow("a", 2, 135, 145, 40, 1.4) + val rows4 = List(row7, row8) + + val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 + + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + val retRows1 = rows1.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows2 = rows2.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows3 = rows3.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + val retRows4 = rows4.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows1.zip(rows1).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 110) + assertRowsEqualsWithNewSession(expectedRow, retRow, 100, 110) + } + + retRows2.zip(rows2).foreach { case (retRow, expectedRow) => + // session being expanded to (115 ~ 129) + assertRowsEqualsWithNewSession(expectedRow, retRow, 115, 129) + } + + retRows3.zip(rows3).foreach { case (retRow, expectedRow) => + // session being expanded to (110 ~ 125) + assertRowsEqualsWithNewSession(expectedRow, retRow, 110, 125) + } + + retRows4.zip(rows4).foreach { case (retRow, expectedRow) => + // session being expanded to (127 ~ 145) + assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 145) + } + + assert(iterator.hasNext === false) + } + + test("throws exception if data is not sorted by session start") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 95, 105, 30, 1.3) + val row4 = createRow("a", 1, 113, 123, 40, 1.4) + val rows = List(row1, row2, row3, row4) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + // UpdatingSessionIterator can't detect error on hasNext + iterator.hasNext + + // when calling next() it can detect error and throws IllegalStateException + intercept[IllegalStateException] { + iterator.next() + } + + // afterwards, calling either hasNext() or next() will throw IllegalStateException + intercept[IllegalStateException] { + iterator.hasNext + } + + intercept[IllegalStateException] { + iterator.next() + } + } + + private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(6) + if (key1 != null) { + genericRow.update(0, UTF8String.fromString(key1)) + } else { + genericRow.setNullAt(0) + } + genericRow.setInt(1, key2) + + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(2, sessionRow) + + genericRow.setLong(3, aggVal1) + genericRow.setDouble(4, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(rowAttributes, rowAttributes) + rowProjection(genericRow) + } + + private def doubleEquals(value1: Double, value2: Double): Boolean = { + value1 > value2 - 0.000001 && value1 < value2 + 0.000001 + } + + private def assertRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + assert(retRow.getString(0) === expectedRow.getString(0)) + assert(retRow.getInt(1) === expectedRow.getInt(1)) + assert(retRow.getStruct(2, 2).getLong(0) == expectedRow.getStruct(2, 2).getLong(0)) + assert(retRow.getStruct(2, 2).getLong(1) == expectedRow.getStruct(2, 2).getLong(1)) + assert(retRow.getLong(3) === expectedRow.getLong(3)) + assert(doubleEquals(retRow.getDouble(3), expectedRow.getDouble(3))) + } + + private def assertRowsEqualsWithNewSession(expectedRow: InternalRow, retRow: InternalRow, + newSessionStart: Long, newSessionEnd: Long): Unit = { + assert(retRow.getString(0) === expectedRow.getString(0)) + assert(retRow.getInt(1) === expectedRow.getInt(1)) + assert(retRow.getStruct(2, 2).getLong(0) == newSessionStart) + assert(retRow.getStruct(2, 2).getLong(1) == newSessionEnd) + assert(retRow.getLong(3) === expectedRow.getLong(3)) + assert(doubleEquals(retRow.getDouble(3), expectedRow.getDouble(3))) + } + +} From b38f2b980e815a56889354911c677759bbb8988d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 7 Sep 2018 17:35:32 +0900 Subject: [PATCH 04/60] WIP add verification on precondition "rows in iterator are sorted by key" --- .../streaming/UpdatingSessionIterator.scala | 31 +++++++++++----- .../UpdatingSessionIteratorSuite.scala | 35 ++++++++++++++++++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala index 5afd964076392..7995b9b5e8070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala @@ -35,11 +35,13 @@ class UpdatingSessionIterator( var currentSessionStart: Long = Long.MaxValue var currentSessionEnd: Long = Long.MinValue - var currentRows: mutable.MutableList[InternalRow] = _ + val currentRows: mutable.MutableList[InternalRow] = new mutable.MutableList[InternalRow]() var returnRowsIter: Iterator[InternalRow] = _ var errorOnIterator: Boolean = false + val processedKeys: mutable.HashSet[InternalRow] = new mutable.HashSet[InternalRow]() + override def hasNext: Boolean = { assertIteratorNotCorrupted() @@ -78,19 +80,19 @@ class UpdatingSessionIterator( if (currentKeys == null) { startNewSession(row, keys, sessionStart, sessionEnd) } else if (keys != currentKeys) { - closeCurrentSession() + closeCurrentSession(keyChanged = true) + processedKeys.add(currentKeys) startNewSession(row, keys, sessionStart, sessionEnd) exitCondition = true } else { if (sessionStart < currentSessionStart) { - errorOnIterator = true - throw new IllegalStateException("The iterator must be sorted by key and session start!") + handleBrokenPreconditionForSort() } else if (sessionStart <= currentSessionEnd) { // expanding session length if needed expandEndOfCurrentSession(sessionEnd) currentRows += row } else { - closeCurrentSession() + closeCurrentSession(keyChanged = false) startNewSession(row, keys, sessionStart, sessionEnd) exitCondition = true } @@ -99,7 +101,7 @@ class UpdatingSessionIterator( if (!iter.hasNext) { // no further row: closing session - closeCurrentSession() + closeCurrentSession(keyChanged = false) } // here returnRowsIter should be able to provide at least one row @@ -116,14 +118,23 @@ class UpdatingSessionIterator( private def startNewSession(row: InternalRow, keys: UnsafeRow, sessionStart: Long, sessionEnd: Long): Unit = { + if (processedKeys.contains(keys)) { + handleBrokenPreconditionForSort() + } + currentKeys = keys currentSessionStart = sessionStart currentSessionEnd = sessionEnd - currentRows = new mutable.MutableList[InternalRow]() + currentRows.clear() currentRows += row } - private def closeCurrentSession(): Unit = { + private def handleBrokenPreconditionForSort(): Unit = { + errorOnIterator = true + throw new IllegalStateException("The iterator must be sorted by key and session start!") + } + + private def closeCurrentSession(keyChanged: Boolean): Unit = { val convertedGroupWithoutSessionExpressions = groupWithoutSessionExpressions.map { x => BindReferences.bindReference[Expression](x, inputSchema) } @@ -150,10 +161,12 @@ class UpdatingSessionIterator( returnRowsIter = returnRows.iterator + if (keyChanged) processedKeys.add(currentKeys) + currentKeys = null currentSessionStart = Long.MaxValue currentSessionEnd = Long.MinValue - currentRows = null + currentRows.clear() } private def assertIteratorNotCorrupted(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala index b1c88c70e0214..7df1fb39d1110 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala @@ -234,7 +234,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { sessionAttribute, valuesAttributes, rowAttributes) // UpdatingSessionIterator can't detect error on hasNext - iterator.hasNext + assert(iterator.hasNext) // when calling next() it can detect error and throws IllegalStateException intercept[IllegalStateException] { @@ -251,6 +251,39 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { } } + test("throws exception if data is not sorted by key") { + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 2, 100, 110, 20, 1.2) + val row3 = createRow("a", 1, 113, 123, 40, 1.4) + val rows = List(row1, row2, row3) + + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + // UpdatingSessionIterator can't detect error on hasNext + assert(iterator.hasNext) + + assertRowsEquals(row1, iterator.next()) + + assert(iterator.hasNext) + + // second row itself is OK but while finding end of session it reads third row, and finds + // its key is already finished processing, hence precondition for sorting is broken, and + // it throws IllegalStateException + intercept[IllegalStateException] { + iterator.next() + } + + // afterwards, calling either hasNext() or next() will throw IllegalStateException + intercept[IllegalStateException] { + iterator.hasNext + } + + intercept[IllegalStateException] { + iterator.next() + } + } + private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, aggVal1: Long, aggVal2: Double): UnsafeRow = { val genericRow = new GenericInternalRow(6) From 668c1f5ad913ef4d6e96786d3d47b55a1fe0d2b7 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 8 Sep 2018 13:36:46 +0900 Subject: [PATCH 05/60] Rename SymmetricHashJoinStateManager to MultiValuesStateManager * This will be also used from session window state as well --- .../StreamingSymmetricHashJoinExec.scala | 14 ++++-- ...er.scala => MultiValuesStateManager.scala} | 49 ++++++++----------- ...ala => MultiValuesStateManagerSuite.scala} | 33 ++++++------- 3 files changed, 47 insertions(+), 49 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/{SymmetricHashJoinStateManager.scala => MultiValuesStateManager.scala} (91%) rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/{SymmetricHashJoinStateManagerSuite.scala => MultiValuesStateManagerSuite.scala} (81%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 50cf971e4ec3c..494e5579646fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.MultiValuesStateManager.{getStateStoreName, KeyToNumValuesType, KeyWithIndexToValueType, StateStoreType} import org.apache.spark.sql.internal.SessionState import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -200,7 +201,7 @@ case class StreamingSymmetricHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator - val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + val stateStoreNames = allStateStoreNames(LeftSide, RightSide) left.execute().stateStoreAwareZipPartitions( right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions) } @@ -394,8 +395,8 @@ case class StreamingSymmetricHashJoinExec( val preJoinFilter = newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ - private val joinStateManager = new SymmetricHashJoinStateManager( - joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) + private val joinStateManager = new MultiValuesStateManager(joinSide.toString, inputAttributes, + joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { @@ -504,4 +505,11 @@ case class StreamingSymmetricHashJoinExec( def numUpdatedStateRows: Long = updatedStateRowsCount } + + private def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) + for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { + getStateStoreName(joinSide.toString, stateStoreType) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala index 43f22803e7685..80d24227a6895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala @@ -24,21 +24,20 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec} -import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.util.NextIterator /** - * Helper class to manage state required by a single side of [[StreamingSymmetricHashJoinExec]]. + * Helper class to manage state which is useful for operations which require to store multiple + * values for a key. * The interface of this class is basically that of a multi-map: * - Get: Returns an iterator of multiple values for given key * - Append: Append a new value to the given key * - Remove Data by predicate: Drop any state using a predicate condition on keys or values * - * @param joinSide Defines the join side * @param inputValueAttributes Attributes of the input row which will be stored as value - * @param joinKeys Expressions to generate rows that will be used to key the value rows + * @param keys Expressions to generate rows that will be used to key the value rows * @param stateInfo Information about how to retrieve the correct version of state * @param storeConf Configuration for the state store. * @param hadoopConf Hadoop configuration for reading state data from storage @@ -59,15 +58,15 @@ import org.apache.spark.util.NextIterator * by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex), * decrement corresponding num values in KeyToNumValuesStore */ -class SymmetricHashJoinStateManager( - val joinSide: JoinSide, +class MultiValuesStateManager( + storeNamePrefix: String, inputValueAttributes: Seq[Attribute], - joinKeys: Seq[Expression], + keys: Seq[Expression], stateInfo: Option[StatefulOperatorStateInfo], storeConf: StateStoreConf, hadoopConf: Configuration) extends Logging { - import SymmetricHashJoinStateManager._ + import MultiValuesStateManager._ /* ===================================================== @@ -265,7 +264,7 @@ class SymmetricHashJoinStateManager( def metrics: StateStoreMetrics = { val keyToNumValuesMetrics = keyToNumValues.metrics val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics - def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc" + def newDesc(desc: String): String = s"${storeNamePrefix.toUpperCase(Locale.ROOT)}: $desc" StateStoreMetrics( keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once @@ -291,7 +290,7 @@ class SymmetricHashJoinStateManager( */ private val keySchema = StructType( - joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) + keys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) private val keyAttributes = keySchema.toAttributes private val keyToNumValues = new KeyToNumValuesStore() private val keyWithIndexToValue = new KeyWithIndexToValueStore() @@ -321,8 +320,8 @@ class SymmetricHashJoinStateManager( /** Get the StateStore with the given schema */ protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { - val storeProviderId = StateStoreProviderId( - stateInfo.get, TaskContext.getPartitionId(), getStateStoreName(joinSide, stateStoreType)) + val storeProviderId = StateStoreProviderId(stateInfo.get, TaskContext.getPartitionId(), + getStateStoreName(storeNamePrefix, stateStoreType)) val store = StateStore.get( storeProviderId, keySchema, valueSchema, None, stateInfo.get.storeVersion, storeConf, hadoopConf) @@ -380,8 +379,8 @@ class SymmetricHashJoinStateManager( * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. * Designed for object reuse. */ - private case class KeyWithIndexAndValue( - var key: UnsafeRow = null, var valueIndex: Long = -1, var value: UnsafeRow = null) { + private case class KeyWithIndexAndValue(var key: UnsafeRow = null, var valueIndex: Long = -1, + var value: UnsafeRow = null) { def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow): this.type = { this.key = newKey this.valueIndex = newIndex @@ -475,26 +474,18 @@ class SymmetricHashJoinStateManager( } } -object SymmetricHashJoinStateManager { +object MultiValuesStateManager { + sealed trait StateStoreType - def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) - for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { - getStateStoreName(joinSide, stateStoreType) - } - } - - private sealed trait StateStoreType - - private case object KeyToNumValuesType extends StateStoreType { + case object KeyToNumValuesType extends StateStoreType { override def toString(): String = "keyToNumValues" } - private case object KeyWithIndexToValueType extends StateStoreType { + case object KeyWithIndexToValueType extends StateStoreType { override def toString(): String = "keyWithIndexToValue" } - private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { - s"$joinSide-$storeType" + def getStateStoreName(storeNamePrefix: String, storeType: StateStoreType): String = { + s"$storeNamePrefix-$storeType" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManagerSuite.scala similarity index 81% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManagerSuite.scala index c0216a2ef3e61..7843dddee84e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManagerSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types._ -class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter { +class MultiValuesStateManagerSuite extends StreamTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -39,8 +39,8 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } - test("SymmetricHashJoinStateManager - all operations") { - withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager => + test("MultiValuesStateManager - all operations") { + withStateManager(inputValueAttribs, keyExprs) { manager => implicit val mgr = manager assert(get(20) === Seq.empty) // initially empty @@ -106,10 +106,10 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter .add(StructField("value", BooleanType)) val inputValueAttribs = inputValueSchema.toAttributes val inputValueAttribWithWatermark = inputValueAttribs(0) - val joinKeyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) + val keyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray) - val joinKeyGen = UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray) + val keyGen = UnsafeProjection.create(keyExprs.map(_.dataType).toArray) def toInputValue(i: Int): UnsafeRow = { @@ -117,21 +117,21 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } def toJoinKeyRow(i: Int): UnsafeRow = { - joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) + keyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) } def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) - def append(key: Int, value: Int)(implicit manager: SymmetricHashJoinStateManager): Unit = { + def append(key: Int, value: Int)(implicit manager: MultiValuesStateManager): Unit = { manager.append(toJoinKeyRow(key), toInputValue(value)) } - def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { + def get(key: Int)(implicit manager: MultiValuesStateManager): Seq[Int] = { manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted } /** Remove keys (and corresponding values) where `time <= threshold` */ - def removeByKey(threshold: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { + def removeByKey(threshold: Long)(implicit manager: MultiValuesStateManager): Unit = { val expr = LessThanOrEqual( BoundReference( @@ -142,27 +142,26 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } /** Remove values where `time <= threshold` */ - def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { + def removeByValue(watermark: Long)(implicit manager: MultiValuesStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) while (iter.hasNext) iter.next() } - def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { + def numRows(implicit manager: MultiValuesStateManager): Long = { manager.metrics.numKeys } - - def withJoinStateManager( - inputValueAttribs: Seq[Attribute], - joinKeyExprs: Seq[Expression])(f: SymmetricHashJoinStateManager => Unit): Unit = { + def withStateManager( + inputValueAttribs: Seq[Attribute], + keyExprs: Seq[Expression])(f: MultiValuesStateManager => Unit): Unit = { withTempDir { file => val storeConf = new StateStoreConf() val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) - val manager = new SymmetricHashJoinStateManager( - LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) + val manager = new MultiValuesStateManager(LeftSide.toString(), inputValueAttribs, keyExprs, + Some(stateInfo), storeConf, new Configuration) try { f(manager) } finally { From 9f63a3c1bea67337cfb87e35e888d65a4b348967 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 8 Sep 2018 13:41:37 +0900 Subject: [PATCH 06/60] Move package of UpdatingSessionIterator --- .../{ => execution}/streaming/UpdatingSessionIterator.scala | 3 ++- .../streaming/UpdatingSessionIteratorSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/streaming/UpdatingSessionIterator.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/streaming/UpdatingSessionIteratorSuite.scala (99%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala index 7995b9b5e8070..598206beb56da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.execution.streaming import scala.collection.mutable @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Cre import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{LongType, TimestampType} +// FIXME: javadoc!! class UpdatingSessionIterator( iter: Iterator[InternalRow], groupWithoutSessionExpressions: Seq[Expression], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala index 7df1fb39d1110..3670d27eea8e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} From 5d17ac8e76f41326c0bd79391f9af9a4ff190e28 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 10 Sep 2018 14:52:28 +0900 Subject: [PATCH 07/60] WIP add MergingSortWithMultiValuesStateIterator, now integrating with stateful operators (WIP...) --- ...gingSortWithMultiValuesStateIterator.scala | 132 +++++++++++ .../state/MultiValuesStateManager.scala | 6 + .../state/MultiValuesStateStoreRDD.scala | 78 +++++++ .../execution/streaming/state/package.scala | 43 ++++ .../streaming/statefulOperators.scala | 184 +++++++++++++++ ...ortWithMultiValuesStateIteratorSuite.scala | 214 ++++++++++++++++++ .../UpdatingSessionIteratorSuite.scala | 7 + 7 files changed, 664 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala new file mode 100644 index 0000000000000..3c646e3d278b4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.streaming.state.MultiValuesStateManager + +class MergingSortWithMultiValuesStateIterator( + iter: Iterator[InternalRow], + stateManager: MultiValuesStateManager, + groupWithoutSessionExpressions: Seq[Expression], + sessionExpression: Expression, + inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + + // FIXME: handle watermark in input rows, not from state + + private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, + row: InternalRow) + + private object SessionRowInformation { + def of(row: InternalRow): SessionRowInformation = { + val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, + inputSchema) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) + + val keys = keysProjection(row) + val session = sessionProjection(row) + val sessionRow = session.getStruct(0, 2) + val sessionStart = sessionRow.getLong(0) + val sessionEnd = sessionRow.getLong(1) + + SessionRowInformation(keys, sessionStart, sessionEnd, row) + } + } + + private var currentRow: SessionRowInformation = _ + private var currentStateRow: SessionRowInformation = _ + private var currentStateIter: Iterator[InternalRow] = _ + private var currentStateFetchedKey: UnsafeRow = _ + + override def hasNext: Boolean = { + currentRow != null || currentStateRow != null || + (currentStateIter != null && currentStateIter.hasNext) || iter.hasNext + } + + override def next(): InternalRow = { + if (currentRow == null) { + mayFillCurrentRow() + } + + if (currentStateRow == null) { + mayFillCurrentStateRow() + } + + if (currentRow == null && currentStateRow == null) { + throw new IllegalStateException("No Row to provide in next() which should not happen!") + } + + // return current row vs current state row, should return smaller key, earlier session start + val returnCurrentRow: Boolean = { + if (currentRow == null) { + false + } else if (currentStateRow == null) { + true + } else { + // compare + if (currentRow.keys != currentStateRow.keys) { + // state row cannot advance to row in input, so state row should be lower + false + } else { + currentRow.sessionStart < currentStateRow.sessionStart + } + } + } + + val ret: SessionRowInformation = { + if (returnCurrentRow) { + val toRet = currentRow + currentRow = null + mayFillCurrentRow() + toRet + } else { + val toRet = currentStateRow + currentStateRow = null + mayFillCurrentStateRow() + toRet + } + } + + ret.row + } + + private def mayFillCurrentRow(): Unit = { + if (iter.hasNext) { + currentRow = SessionRowInformation.of(iter.next()) + } + } + + private def mayFillCurrentStateRow(): Unit = { + if (currentStateIter != null && currentStateIter.hasNext) { + currentStateRow = SessionRowInformation.of(currentStateIter.next()) + } else { + currentStateIter = null + + if (currentRow != null && currentRow.keys != currentStateFetchedKey) { + currentStateIter = stateManager.get(currentRow.keys) + currentStateFetchedKey = currentRow.keys + if (currentStateIter.hasNext) { + currentStateRow = SessionRowInformation.of(currentStateIter.next()) + } + } + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala index 80d24227a6895..19c2a677f1066 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala @@ -87,6 +87,12 @@ class MultiValuesStateManager( keyToNumValues.put(key, numExistingValues + 1) } + def removeKey(key: UnsafeRow): Unit = { + val numExistingValues = keyToNumValues.get(key) + keyToNumValues.remove(key) + (0 until numExistingValues).foreach(keyWithIndexToValue.remove(key, _)) + } + /** * Remove using a predicate on keys. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala new file mode 100644 index 0000000000000..7d03e3662a6d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker +import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +class MultiValuesStateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U], + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + sessionState: SessionState, + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + extends RDD[U](dataRDD) { + + private val storeConf = new StateStoreConf(sessionState.conf) + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val hadoopConfBroadcast = dataRDD.context.broadcast( + new SerializableConfiguration(sessionState.newHadoopConf())) + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ + override def getPreferredLocations(partition: Partition): Seq[String] = { + val stateStoreProviderId = StateStoreProviderId( + StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, partition.index), + stateInfo.queryRunId) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => stateInfo.storeVersion + case Some(value) => value + } + + val modifiedStateInfo = stateInfo.copy(storeVersion = currentVersion) + + val stateManager: MultiValuesStateManager = new MultiValuesStateManager("session-", + valueSchema.toAttributes, keySchema.toAttributes, Some(modifiedStateInfo), storeConf, + hadoopConfBroadcast.value.value) + + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(stateManager, inputIter) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b6021438e902b..00b51c0bdcedc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -79,6 +79,49 @@ package object state { indexOrdinal, sessionState, storeCoordinator) + } + + /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ + def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( + sqlContext: SQLContext, + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int])( + storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) + : MultiValuesStateStoreRDD[T, U] = { + + mapPartitionsWithMultiValuesStateManager( + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) + } + + /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ + private[streaming] def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + sessionState: SessionState, + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) + : MultiValuesStateStoreRDD[T, U] = { + + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new MultiValuesStateStoreRDD( + dataRDD, + cleanedF, + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sessionState, + storeCoordinator) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c11af345b0248..e79b617c27b7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -21,6 +21,7 @@ import java.util.UUID import java.util.concurrent.TimeUnit._ import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -424,6 +425,189 @@ case class StateStoreSaveExec( } } +// FIXME: javadoc! +// FIXME: keyExpressions shouldn't have 'session': otherwise we should exclude it... +case class SessionWindowStateStoreRestoreExec( + keyExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + child: SparkPlan) + extends UnaryExecNode with StateStoreReader { + + // FIXME: does we really need to have global aggregation from here? + require(keyExpressions.nonEmpty, "Key expressions should be presented.") + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + child.execute().mapPartitionsWithMultiValuesStateManager( + getStateInfo, + keyExpressions.toStructType, + child.output.toStructType, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { case (stateManager, iter) => + + new MergingSortWithMultiValuesStateIterator(iter, stateManager, keyExpressions, + sessionExpression, child.output).map { row => + numOutputRows += 1 + row + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + } +} + +/** + * For each input tuple, the key is calculated and sessions are being `put` into + * the [[MultiValuesStateManager]]. + */ +case class SessionWindowStateStoreSaveExec( + keyExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, + child: SparkPlan) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + // FIXME: does we really need to have global aggregation from here? + require(keyExpressions.nonEmpty, "Key expressions should be presented.") + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + + child.execute().mapPartitionsWithMultiValuesStateManager( + getStateInfo, + keyExpressions.toStructType, + child.output.toStructType, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (stateManager, iter) => + + def evictSessionsByWatermark(manager: MultiValuesStateManager): Iterator[UnsafeRowPair] = { + manager.removeByValueCondition { row => watermarkPredicateForData match { + case Some(predicate) => predicate.eval(row) + case None => false + } + } + } + + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + + val keyProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + val alreadyRemovedKeys = new mutable.HashSet[UnsafeRow]() + + // assuming late events were dropped from MergingSortWithMultiValuesStateIterator + outputMode match { + // Update and output only sessions being evicted from the MultiValuesStateManager + // Assumption: watermark predicates must be non-empty if append mode is allowed + case Some(Append) => + allUpdatesTimeMs += timeTakenMs { + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val keys = keyProjection(row) + if (!alreadyRemovedKeys.contains(keys)) { + stateManager.removeKey(keys) + alreadyRemovedKeys.add(keys) + } + stateManager.append(keys, row) + numUpdatedStateRows += 1 + } + } + + val removalStartTimeNs = System.nanoTime + + val retIter = evictSessionsByWatermark(stateManager).map(_.value) + + CompletionIterator(retIter, { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { stateManager.commit() } + setStoreMetrics(stateManager) + }) + + // Update and output modified rows from the MultiValuesStateManager. + case Some(Update) => + + // FIXME: it doesn't compare all output rows with current state rows, so all sessions + // including previous sessions will be provided + + new NextIterator[InternalRow] { + private val updatesStartTimeNs = System.nanoTime + + override protected def getNext(): InternalRow = { + if (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val keys = keyProjection(row) + if (!alreadyRemovedKeys.contains(keys)) { + stateManager.removeKey(keys) + alreadyRemovedKeys.add(keys) + } + stateManager.append(keys, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row + } else { + finished = true + null + } + } + + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { + evictSessionsByWatermark(stateManager) + } + commitTimeMs += timeTakenMs { stateManager.commit() } + setStoreMetrics(stateManager) + } + } + + case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } + + protected def setStoreMetrics(manager: MultiValuesStateManager): Unit = { + val storeMetrics = manager.metrics + longMetric("numTotalStateRows") += storeMetrics.numKeys + longMetric("stateMemory") += storeMetrics.memoryUsedBytes + storeMetrics.customMetrics.foreach { case (metric, value) => + longMetric(metric.name) += value + } + } +} + /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala new file mode 100644 index 0000000000000..7054a9a896e77 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{MultiValuesStateManager, StateStore, StateStoreConf} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { + + val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val rowAttributes = rowSchema.toAttributes + + val keysWithoutSessionSchema = rowSchema.filter(st => List("key1", "key2").contains(st.name)) + val keysWithoutSessionAttributes = rowAttributes.filter { + attr => List("key1", "key2").contains(attr.name) + } + + val sessionSchema = rowSchema.filter(st => st.name == "session").head + val sessionAttribute = rowAttributes.filter(attr => attr.name == "session").head + + val valuesSchema = rowSchema.filter(st => List("aggVal1", "aggVal2").contains(st.name)) + val valuesAttributes = rowAttributes.filter { + attr => List("aggVal1", "aggVal2").contains(attr.name) + } + + test("no row in input data") { + withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + assert(!iterator.hasNext) + } + } + + test("no row in input data but having state") { + withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + val srow11 = createRow("a", 1, 55, 85, 50, 2.5) + val srow12 = createRow("a", 1, 105, 140, 30, 2.0) + appendRowToStateManager(manager, srow11, srow12) + + val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + assert(!iterator.hasNext) + } + } + + test("no previous state") { + withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 2, 110, 120, 10, 1.1) + val row4 = createRow("a", 2, 115, 125, 20, 1.2) + val rows = List(row1, row2, row3, row4) + + val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + rows.foreach { row => + assert(iterator.hasNext) + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("multiple keys in input data and state") { + withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + // key 1 - placing sessions in state to start and end + val row11 = createRow("a", 1, 100, 110, 10, 1.1) + val row12 = createRow("a", 1, 100, 110, 20, 1.2) + + val srow11 = createRow("a", 1, 55, 85, 50, 2.5) + val srow12 = createRow("a", 1, 105, 140, 30, 2.0) + appendRowToStateManager(manager, srow11, srow12) + + // key 2 - no state + val row21 = createRow("a", 2, 110, 120, 10, 1.1) + val row22 = createRow("a", 2, 115, 125, 20, 1.2) + + // key 3 - placing sessions in state to only start + val row31 = createRow("a", 3, 130, 140, 10, 1.1) + val row32 = createRow("a", 3, 135, 145, 20, 1.2) + + val srow31 = createRow("a", 3, 105, 140, 30, 2.0) + val srow32 = createRow("a", 3, 120, 150, 30, 2.0) + appendRowToStateManager(manager, srow31, srow32) + + // key 4 - placing sessions in state to only end + val row41 = createRow("a", 4, 100, 110, 10, 1.1) + val row42 = createRow("a", 4, 100, 115, 20, 1.2) + + val srow41 = createRow("a", 4, 105, 140, 30, 2.0) + val srow42 = createRow("a", 4, 120, 150, 30, 2.0) + appendRowToStateManager(manager, srow41, srow42) + + // key 5 - placing sessions in state like one row and state session and another + val row51 = createRow("a", 5, 100, 110, 10, 1.1) + val row52 = createRow("a", 5, 120, 130, 20, 1.2) + + val srow51 = createRow("a", 5, 90, 120, 30, 2.0) + val srow52 = createRow("a", 5, 110, 125, 30, 2.0) + val srow53 = createRow("a", 5, 130, 150, 30, 2.0) + appendRowToStateManager(manager, srow51, srow52, srow53) + + val rows = List(row11, row12, row21, row22, row31, row32, row41, row42, row51, row52) + + val expectedRowSequence = List(srow11, row11, row12, srow12, row21, row22, srow31, srow32, + row31, row32, row41, row42, srow41, srow42, srow51, row51, srow52, row52, srow53) + + val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext) + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + private def getKeyRow(row: UnsafeRow): UnsafeRow = { + val keyProjection = GenerateUnsafeProjection.generate(keysWithoutSessionAttributes, + rowAttributes) + keyProjection(row) + } + + private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(6) + if (key1 != null) { + genericRow.update(0, UTF8String.fromString(key1)) + } else { + genericRow.setNullAt(0) + } + genericRow.setInt(1, key2) + + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(2, sessionRow) + + genericRow.setLong(3, aggVal1) + genericRow.setDouble(4, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(rowAttributes, rowAttributes) + rowProjection(genericRow) + } + + private def appendRowToStateManager(manager: MultiValuesStateManager, rows: UnsafeRow*): Unit = { + rows.foreach(row => manager.append(getKeyRow(row), row)) + } + + private def doubleEquals(value1: Double, value2: Double): Boolean = { + value1 > value2 - 0.000001 && value1 < value2 + 0.000001 + } + + private def assertRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + assert(retRow.getString(0) === expectedRow.getString(0)) + assert(retRow.getInt(1) === expectedRow.getInt(1)) + assert(retRow.getStruct(2, 2).getLong(0) == expectedRow.getStruct(2, 2).getLong(0)) + assert(retRow.getStruct(2, 2).getLong(1) == expectedRow.getStruct(2, 2).getLong(1)) + assert(retRow.getLong(3) === expectedRow.getLong(3)) + assert(doubleEquals(retRow.getDouble(3), expectedRow.getDouble(3))) + } + + private def withStateManager( + inputValueAttribs: Seq[Attribute], + keyExprs: Seq[Expression])(f: MultiValuesStateManager => Unit): Unit = { + + withTempDir { file => + val storeConf = new StateStoreConf() + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + val manager = new MultiValuesStateManager("session-", inputValueAttribs, keyExprs, + Some(stateInfo), storeConf, new Configuration) + try { + f(manager) + } finally { + manager.abortIfNeeded() + } + } + StateStore.stop() + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala index 3670d27eea8e0..9e703721e8ea8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -44,6 +44,13 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { attr => List("aggVal1", "aggVal2").contains(attr.name) } + test("no row") { + val iterator = new UpdatingSessionIterator(None.iterator, keysWithoutSessionAttributes, + sessionAttribute, valuesAttributes, rowAttributes) + + assert(!iterator.hasNext) + } + test("only one row") { val rows = List(createRow("a", 1, 100, 110, 10, 1.1)) From ec33265a71ba3d8896c595b4be271c171d4f96c6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 15:54:37 +0900 Subject: [PATCH 08/60] WIP the first version of working one! Still have lots of TODOs and FIXMEs to go --- .../sql/catalyst/analysis/Analyzer.scala | 65 ++++++++++ .../catalyst/expressions/SessionWindow.scala | 112 ++++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 31 +++-- .../sql/execution/aggregate/AggUtils.scala | 91 +++++++++++++- .../streaming/IncrementalExecution.scala | 39 ++++++ ...gingSortWithMultiValuesStateIterator.scala | 39 ++++-- .../streaming/UpdatingSessionIterator.scala | 60 ++++++++-- .../state/MultiValuesStateManager.scala | 7 +- .../streaming/statefulOperators.scala | 98 ++++++++++++++- .../org/apache/spark/sql/functions.scala | 7 ++ ...ortWithMultiValuesStateIteratorSuite.scala | 10 +- .../UpdatingSessionIteratorSuite.scala | 16 +-- .../streaming/EventTimeWatermarkSuite.scala | 88 +++++++++++++- 13 files changed, 614 insertions(+), 49 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d72e512e0df56..80bd8230ee2df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -181,6 +181,7 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: + SessionWindowing :: ResolveInlineTables(conf) :: ResolveHigherOrderFunctions(catalog) :: ResolveLambdaVariables(conf) :: @@ -2721,6 +2722,70 @@ object TimeWindowing extends Rule[LogicalPlan] { } } +// FIXME: javadoc +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + // FIXME: javadoc + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + // FIXME: we would want to also couple window and session on restriction + val numSessionExpr = sessionExpressions.size + // Only support a single session expression for now + if (numSessionExpr == 1 && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = metadata)() + + val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) + val sessionEnd = sessionStart + session.gapDuration + + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, TimestampType) :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, TimestampType) :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(metadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(session.timeColumn) + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) + } else if (numSessionExpr > 1) { + p.failAnalysis("Multiple time session expressions would result in a cartesian product " + + "of rows, therefore they are currently not supported.") + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + /** * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala new file mode 100644 index 0000000000000..75c9e663d5db7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this(timeColumn: Expression, gapDuration: Expression) = { + this(timeColumn, SessionWindow.parseExpression(gapDuration)) + } + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** Validate the inputs for the gap duration in addition to the input data type. */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (gapDuration <= 0) { + return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.") + } + } + dataTypeCheck + } +} + +object SessionWindow { + /** + * Parses the interval string for a valid time duration. CalendarInterval expects interval + * strings to start with the string `interval`. For usability, we prepend `interval` to the string + * if the user omitted it. + * + * @param interval The interval string + * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond + * precision. + */ + private def getIntervalInMicroSeconds(interval: String): Long = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "The window duration, slide duration and start time cannot be null or blank.") + } + val intervalString = if (interval.startsWith("interval")) { + interval + } else { + "interval " + interval + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"The provided interval ($interval) did not correspond to a valid interval string.") + } + if (cal.months > 0) { + throw new IllegalArgumentException( + s"Intervals greater than a month is not supported ($interval).") + } + cal.microseconds + } + + /** + * Parses the duration expression to generate the long value for the original constructor so + * that we can use `window` in SQL. + */ + private def parseExpression(expr: Expression): Long = expr match { + case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case IntegerLiteral(i) => i.toLong + case NonNullLiteral(l, LongType) => l.toString.toLong + case _ => throw new AnalysisException("The duration and time inputs to window must be " + + "an integer, long or string literal.") + } + + def apply( + timeColumn: Expression, + gapDuration: String): SessionWindow = { + SessionWindow(timeColumn, + getIntervalInMicroSeconds(gapDuration)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd820..0dc84ed640a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -329,14 +329,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Streaming aggregation doesn't support group aggregate pandas UDF") } - val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) - - aggregate.AggUtils.planStreamingAggregation( - namedGroupingExpressions, - aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), - rewrittenResultExpressions, - stateVersion, - planLater(child)) + val sessionWindowOption = namedGroupingExpressions.find { p => + p.name == "session" && p.dataType.isInstanceOf[StructType] + } + + sessionWindowOption match { + case Some(sessionWindow) => + aggregate.AggUtils.planStreamingAggregationForSession( + namedGroupingExpressions, + sessionWindow, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + planLater(child)) + + case None => + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + + aggregate.AggUtils.planStreamingAggregation( + namedGroupingExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + planLater(child)) + } case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 6be88c463dbd9..a751324a80b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.streaming.{SessionWindowStateStoreRestoreExec, SessionWindowStateStoreSaveExec, StateStoreRestoreExec, StateStoreSaveExec} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -338,4 +337,92 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + // FIXME: change! + /** + * Plans a streaming aggregation using the following progression: + * + * - Partial Aggregation + * - Shuffle + * - Partial Merge (now there is at most 1 tuple per group) + * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) + * - PartialMerge (now there is at most 1 tuple per group) + * - StateStoreSave (saves the tuple for the next batch) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregationForSession( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + // shuffle & sort happens here + val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialAggregate) + + val partialMerged: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored) + } + + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = + SessionWindowStateStoreSaveExec( + groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + stateInfo = None, + outputMode = None, + eventTimeWatermark = None, + partialMerged) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index fad287e28877d..c86ac09644f5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -118,6 +118,45 @@ class IncrementalExecution( Some(aggStateInfo), stateFormatVersion, child) :: Nil)) +/* + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, + UnaryExecNode(agg, + UnaryExecNode(agg2, + SessionWindowStateStoreRestoreExec(_, _, None, None, child)))) => + val aggStateInfo = nextStatefulOperationStateInfo + SessionWindowStateStoreSaveExec( + keys, + session, + Some(aggStateInfo), + Some(outputMode), + Some(offsetSeqMetadata.batchWatermarkMs), + agg.withNewChildren( + agg2.withNewChildren( + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + Some(offsetSeqMetadata.batchWatermarkMs), + child) :: Nil) :: Nil)) + */ + + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, + UnaryExecNode(agg, + SessionWindowStateStoreRestoreExec(_, _, None, None, child))) => + val aggStateInfo = nextStatefulOperationStateInfo + SessionWindowStateStoreSaveExec( + keys, + session, + Some(aggStateInfo), + Some(outputMode), + Some(offsetSeqMetadata.batchWatermarkMs), + agg.withNewChildren( + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + Some(offsetSeqMetadata.batchWatermarkMs), + child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => StreamingDeduplicateExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index 3c646e3d278b4..09061d7ceb203 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} import org.apache.spark.sql.execution.streaming.state.MultiValuesStateManager class MergingSortWithMultiValuesStateIterator( @@ -27,6 +27,7 @@ class MergingSortWithMultiValuesStateIterator( stateManager: MultiValuesStateManager, groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, + watermarkPredicateForData: Option[Predicate], inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { // FIXME: handle watermark in input rows, not from state @@ -55,9 +56,20 @@ class MergingSortWithMultiValuesStateIterator( private var currentStateIter: Iterator[InternalRow] = _ private var currentStateFetchedKey: UnsafeRow = _ + private val baseIterator = watermarkPredicateForData match { + case Some(predicate) => iter.filter((row: InternalRow) => { + val pr = !predicate.eval(row) + if (!pr) { + System.err.println(s"DEBUG - evicting input due to watermark... $row") + } + pr + }) + case None => iter + } + override def hasNext: Boolean = { currentRow != null || currentStateRow != null || - (currentStateIter != null && currentStateIter.hasNext) || iter.hasNext + (currentStateIter != null && currentStateIter.hasNext) || baseIterator.hasNext } override def next(): InternalRow = { @@ -85,6 +97,7 @@ class MergingSortWithMultiValuesStateIterator( // state row cannot advance to row in input, so state row should be lower false } else { + System.err.println(s"DEBUG: WARN - comparing row ${currentRow} and state row ${currentStateRow}") currentRow.sessionStart < currentStateRow.sessionStart } } @@ -94,36 +107,48 @@ class MergingSortWithMultiValuesStateIterator( if (returnCurrentRow) { val toRet = currentRow currentRow = null - mayFillCurrentRow() toRet } else { val toRet = currentStateRow currentStateRow = null - mayFillCurrentStateRow() toRet } } + System.err.println(s"DEBUG: WARN - returning row ${ret.row} for iterator") + ret.row } private def mayFillCurrentRow(): Unit = { - if (iter.hasNext) { - currentRow = SessionRowInformation.of(iter.next()) + if (baseIterator.hasNext) { + currentRow = SessionRowInformation.of(baseIterator.next()) + System.err.println(s"DEBUG - filling current row... current row: $currentRow") } } private def mayFillCurrentStateRow(): Unit = { if (currentStateIter != null && currentStateIter.hasNext) { currentStateRow = SessionRowInformation.of(currentStateIter.next()) + System.err.println(s"DEBUG - filling state row... current state row: $currentStateRow") } else { currentStateIter = null if (currentRow != null && currentRow.keys != currentStateFetchedKey) { - currentStateIter = stateManager.get(currentRow.keys) + val unsortedIter = stateManager.get(currentRow.keys) + currentStateIter = unsortedIter.toList.sortWith((row1, row2) => { + val rowInfo1 = SessionRowInformation.of(row1) + val rowInfo2 = SessionRowInformation.of(row2) + // here sorting is based on the fact that keys are same + rowInfo1.sessionStart.compareTo(rowInfo2.sessionStart) < 0 + }).iterator + currentStateFetchedKey = currentRow.keys if (currentStateIter.hasNext) { currentStateRow = SessionRowInformation.of(currentStateIter.next()) + System.err.println(s"DEBUG: WARN - read data ${currentStateRow.row} from state for key ${currentRow.keys}") + } else { + System.err.println(s"DEBUG: WARN - no state data for key ${currentRow.keys}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala index 598206beb56da..60b992f164c30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, Literal, PreciseTimestampConversion, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BindReferences, CreateNamedStruct, Expression, Literal, PreciseTimestampConversion, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{LongType, TimestampType} @@ -29,9 +29,13 @@ class UpdatingSessionIterator( iter: Iterator[InternalRow], groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, - aggregateExpressions: Seq[Expression], inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + val sessionIndex = inputSchema.indexOf(sessionExpression) + + val valuesExpressions: Seq[Attribute] = inputSchema.diff(groupWithoutSessionExpressions) + .diff(Seq(sessionExpression)) + var currentKeys: InternalRow = _ var currentSessionStart: Long = Long.MaxValue var currentSessionEnd: Long = Long.MinValue @@ -43,6 +47,8 @@ class UpdatingSessionIterator( val processedKeys: mutable.HashSet[InternalRow] = new mutable.HashSet[InternalRow]() + // FIXME: data loss seen... one data from input and one data from state + override def hasNext: Boolean = { assertIteratorNotCorrupted() @@ -61,6 +67,11 @@ class UpdatingSessionIterator( assertIteratorNotCorrupted() if (returnRowsIter != null && returnRowsIter.hasNext) { + System.err.println(s"DEBUG: has remaining returnRowsIter - not going into loop - " + + s"current session - currentKeys: $currentKeys / " + + s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + + s"currentSessionEnd: $currentSessionEnd") + return returnRowsIter.next() } @@ -105,6 +116,10 @@ class UpdatingSessionIterator( closeCurrentSession(keyChanged = false) } + System.err.println(s"DEBUG: end of loop - current session - currentKeys: $currentKeys / " + + s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + + s"currentSessionEnd: $currentSessionEnd") + // here returnRowsIter should be able to provide at least one row require(returnRowsIter != null && returnRowsIter.hasNext) @@ -128,6 +143,10 @@ class UpdatingSessionIterator( currentSessionEnd = sessionEnd currentRows.clear() currentRows += row + + System.err.println(s"DEBUG: started new session - currentKeys: $currentKeys / " + + s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + + s"currentSessionEnd: $currentSessionEnd") } private def handleBrokenPreconditionForSort(): Unit = { @@ -136,12 +155,9 @@ class UpdatingSessionIterator( } private def closeCurrentSession(keyChanged: Boolean): Unit = { - val convertedGroupWithoutSessionExpressions = groupWithoutSessionExpressions.map { x => - BindReferences.bindReference[Expression](x, inputSchema) - } - val convertedAggregateExpressions = aggregateExpressions.map { - x => BindReferences.bindReference[Expression](x, inputSchema) - } + System.err.println(s"DEBUG: closing current session - currentKeys: $currentKeys / " + + s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + + s"currentSessionEnd: $currentSessionEnd") val returnRows = currentRows.map { internalRow => val sessionStruct = CreateNamedStruct( @@ -153,14 +169,34 @@ class UpdatingSessionIterator( Literal(currentSessionEnd, LongType), LongType, TimestampType) :: Nil) - val valueExpressions = convertedGroupWithoutSessionExpressions ++ Seq(sessionStruct) ++ - convertedAggregateExpressions + val convertedAllExpressions = inputSchema.map { x => + BindReferences.bindReference[Expression](x, inputSchema) + } - val proj = GenerateUnsafeProjection.generate(valueExpressions, inputSchema) + val newSchemaExpressions = convertedAllExpressions.indices.map { idx => + if (idx == sessionIndex) { + sessionStruct + } else { + convertedAllExpressions(idx) + } + } + + val proj = GenerateUnsafeProjection.generate(newSchemaExpressions, inputSchema) proj(internalRow) }.toList - returnRowsIter = returnRows.iterator + if (returnRowsIter != null && returnRowsIter.hasNext) { + returnRowsIter = returnRowsIter ++ returnRows.iterator + } else { + returnRowsIter = returnRows.iterator + } + + //returnRowsIter = returnRows.iterator + + // FIXME: DEBUG + val (rIter, tmpReturnRowsIter) = returnRowsIter.duplicate + returnRowsIter = rIter + System.err.println(s"DEBUG: closing current session - return rows iter will return: ${tmpReturnRowsIter.toList}") if (keyChanged) processedKeys.add(currentKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala index 19c2a677f1066..a94cbd2968e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala @@ -90,7 +90,7 @@ class MultiValuesStateManager( def removeKey(key: UnsafeRow): Unit = { val numExistingValues = keyToNumValues.get(key) keyToNumValues.remove(key) - (0 until numExistingValues).foreach(keyWithIndexToValue.remove(key, _)) + (0L until numExistingValues).foreach(keyWithIndexToValue.remove(key, _)) } /** @@ -159,6 +159,9 @@ class MultiValuesStateManager( * * This implies the iterator must be consumed fully without any other operations on this manager * or the underlying store being interleaved. + * + * NOTE: if any value is remove for the key, the order of values will be non-deterministic. + * It doesn't keep the order stable when removing value for gaining performance. */ def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { new NextIterator[UnsafeRowPair] { @@ -312,6 +315,8 @@ class MultiValuesStateManager( def commit(): Unit = { stateStore.commit() + // FIXME: DEBUG + logInfo("Committed, metrics = " + stateStore.metrics) logDebug("Committed, metrics = " + stateStore.metrics) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e79b617c27b7e..b50824273e6f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -23,6 +23,8 @@ import java.util.concurrent.TimeUnit._ import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.TaskContext + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -431,10 +433,12 @@ case class SessionWindowStateStoreRestoreExec( keyExpressions: Seq[Attribute], sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo], + eventTimeWatermark: Option[Long], child: SparkPlan) - extends UnaryExecNode with StateStoreReader { + extends UnaryExecNode with StateStoreReader with WatermarkSupport { // FIXME: does we really need to have global aggregation from here? + // yes... it even has a case which session is only key require(keyExpressions.nonEmpty, "Key expressions should be presented.") override protected def doExecute(): RDD[InternalRow] = { @@ -448,9 +452,61 @@ case class SessionWindowStateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (stateManager, iter) => - new MergingSortWithMultiValuesStateIterator(iter, stateManager, keyExpressions, - sessionExpression, child.output).map { row => + // FIXME: would we want to break down into two multiple physical plans? + // + // new MergingSortWithMultiValuesStateIterator(iter, stateManager, keyExpressions, + // sessionExpression, child.output).map { row => + // numOutputRows += 1 + // row + // } + + val debugPartitionId = TaskContext.get().partitionId() + + val debugIter = iter.map { row => + val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + + logWarning(s"DEBUG: partitionId $debugPartitionId - input row - keys ${keysProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - input row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - input row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - input row - row $row") + + row + } + + val mergedIter = new MergingSortWithMultiValuesStateIterator(debugIter, stateManager, + keyExpressions, sessionExpression, watermarkPredicateForData, child.output) + + val debugMergedIter = mergedIter.map { row => + val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + + logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - keys ${keysProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - row ${row}") + + row + } + + new UpdatingSessionIterator(debugMergedIter, keyExpressions, sessionExpression, + child.output).map { row => numOutputRows += 1 + + val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - keys ${keysProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row ${row}") + row } } @@ -463,6 +519,10 @@ case class SessionWindowStateStoreRestoreExec( override def requiredChildDistribution: Seq[Distribution] = { ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } } /** @@ -508,10 +568,14 @@ case class SessionWindowStateStoreSaveExec( val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") val commitTimeMs = longMetric("commitTimeMs") - val keyProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val keyProjection = GenerateUnsafeProjection.generate(keyExpressions, + child.output) val alreadyRemovedKeys = new mutable.HashSet[UnsafeRow]() + // FIXME: DEBUG + val debugPartitionId = TaskContext.get().partitionId() + // assuming late events were dropped from MergingSortWithMultiValuesStateIterator outputMode match { // Update and output only sessions being evicted from the MultiValuesStateManager @@ -522,20 +586,38 @@ case class SessionWindowStateStoreSaveExec( val row = iter.next().asInstanceOf[UnsafeRow] val keys = keyProjection(row) if (!alreadyRemovedKeys.contains(keys)) { + logWarning(s"DEBUG: partitionId $debugPartitionId - removing key ${keys} ...") stateManager.removeKey(keys) alreadyRemovedKeys.add(keys) } + + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - keys ${keyProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row ${row}") + + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row for key ${keys} row ${row} ...") stateManager.append(keys, row) numUpdatedStateRows += 1 } + + logWarning(s"DEBUG: partitionId $debugPartitionId - finished iterating...") } val removalStartTimeNs = System.nanoTime - val retIter = evictSessionsByWatermark(stateManager).map(_.value) + val retIter = evictSessionsByWatermark(stateManager).map(_.value).map { row => + logWarning(s"DEBUG: partitionId $debugPartitionId - evicting row ${row} ...") + row + } - CompletionIterator(retIter, { + CompletionIterator[InternalRow, Iterator[InternalRow]](retIter, { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + logWarning(s"DEBUG: partitionId $debugPartitionId - committing...") commitTimeMs += timeTakenMs { stateManager.commit() } setStoreMetrics(stateManager) }) @@ -592,6 +674,10 @@ case class SessionWindowStateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && eventTimeWatermark.isDefined && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4247d3110f1e1..55bd85f459ddf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3262,6 +3262,13 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + // FIXME: javadoc! + def session(timeColumn: Column, gapDuration: String): Column = { + withExpr { + SessionWindow(timeColumn.expr, gapDuration) + }.as("session") + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala index 7054a9a896e77..c00e3ffcbeaea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala @@ -49,10 +49,12 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { attr => List("aggVal1", "aggVal2").contains(attr.name) } + // FIXME: add test for watermark + test("no row in input data") { withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) assert(!iterator.hasNext) } @@ -65,7 +67,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { appendRowToStateManager(manager, srow11, srow12) val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) assert(!iterator.hasNext) } @@ -80,7 +82,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3, row4) val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) rows.foreach { row => assert(iterator.hasNext) @@ -136,7 +138,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { row31, row32, row41, row42, srow41, srow42, srow51, row51, srow52, row52, srow53) val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) expectedRowSequence.foreach { row => assert(iterator.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala index 9e703721e8ea8..c6a30d0f0d8e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -46,7 +46,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { test("no row") { val iterator = new UpdatingSessionIterator(None.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) assert(!iterator.hasNext) } @@ -55,7 +55,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(createRow("a", 1, 100, 110, 10, 1.1)) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) assert(iterator.hasNext) @@ -73,7 +73,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) val retRows = rows.indices.map { _ => assert(iterator.hasNext) @@ -104,7 +104,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -140,7 +140,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -185,7 +185,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -238,7 +238,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) // UpdatingSessionIterator can't detect error on hasNext assert(iterator.hasNext) @@ -265,7 +265,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, valuesAttributes, rowAttributes) + sessionAttribute, rowAttributes) // UpdatingSessionIterator can't detect error on hasNext assert(iterator.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c696204cecc2c..a8d52400f59d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.functions.{count, window, session} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.util.Utils @@ -276,6 +276,92 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("append mode - session") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*", "CAST(value / 10 AS INT) AS valuegroup") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) + .agg(count("*") as 'count) + .select($"valuegroup", $"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11), // sessions: key 1 => (10,16) + AssertOnQuery(execution => { + execution.explain(true) + true + }), + CheckNewAnswer(), + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: key 1 => (10,16), (17,23) + CheckNewAnswer(), + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: key 1 => (10,16), (17,23) / key 2 => (25,30) + CheckNewAnswer(), + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) + // evicts: key 1 => (10,16), (17,22) + CheckNewAnswer((1, 10, 16, 2), (1, 17, 22, 1)), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: key 2 => (25,30) / key 3 => (35,45) + // evicts: key 2 => (25,30) + CheckNewAnswer((2, 25, 30, 1)) + ) + } + + test("update mode - session") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*", "CAST(value / 10 AS INT) AS valuegroup") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) + .agg(count("*") as 'count) + .select($"valuegroup", $"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation, OutputMode.Update())( + + AddData(inputData, 10, 11), + // Advance watermark to 1 seconds + // sessions: key 1 => (10,16) + CheckNewAnswer((1, 10, 16, 2)), + AssertOnQuery(execution => { + execution.explain(true) + true + }), + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: key 1 => (10,16), (17,23) + // FIXME: subtract with previous state? or leave it as it is? + CheckNewAnswer((1, 10, 16, 2), (1, 17, 22, 1)), + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: key 1 => (10,20) / key 2 => (25,30) + CheckNewAnswer((2, 25, 30, 1)), + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: key 1 => (10,20) / key 2 => (25,30) / key 3 => (35,40) + CheckNewAnswer((3, 35, 40, 1)), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: key 1 => (10,20) / key 2 => (25,30) / key 3 => (35,45) + CheckNewAnswer((4, 40, 45, 1)) + ) + } + test("update mode") { val inputData = MemoryStream[Int] spark.conf.set("spark.sql.shuffle.partitions", "10") From 8b210d5431576f832fa1df408645870b2393f3c5 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 17:13:45 +0900 Subject: [PATCH 09/60] Add more explanations --- .../sql/catalyst/analysis/Analyzer.scala | 23 ++++++++++------ .../sql/execution/aggregate/AggUtils.scala | 26 +++++++++++------- ...gingSortWithMultiValuesStateIterator.scala | 27 +++++++------------ .../state/MultiValuesStateManager.scala | 4 +-- .../streaming/statefulOperators.scala | 27 +++++++++++-------- 5 files changed, 59 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 80bd8230ee2df..e52c62bdf07a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2644,7 +2644,11 @@ object TimeWindowing extends Rule[LogicalPlan] { val windowExpressions = p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet - val numWindowExpr = windowExpressions.size + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + // Only support a single window expression for now if (numWindowExpr == 1 && windowExpressions.head.timeColumn.resolved && @@ -2714,7 +2718,7 @@ object TimeWindowing extends Rule[LogicalPlan] { renamedPlan.withNewChildren(substitutedPlan :: Nil) } } else if (numWindowExpr > 1) { - p.failAnalysis("Multiple time window expressions would result in a cartesian product " + + p.failAnalysis("Multiple time/session window expressions would result in a cartesian product " + "of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later @@ -2737,10 +2741,13 @@ object SessionWindowing extends Rule[LogicalPlan] { val sessionExpressions = p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet - // FIXME: we would want to also couple window and session on restriction - val numSessionExpr = sessionExpressions.size + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + // Only support a single session expression for now - if (numSessionExpr == 1 && + if (numWindowExpr == 1 && sessionExpressions.head.timeColumn.resolved && sessionExpressions.head.checkInputDataTypes().isSuccess) { @@ -2777,9 +2784,9 @@ object SessionWindowing extends Rule[LogicalPlan] { replacedPlan.withNewChildren( Filter(filterExpr, Project(sessionStruct +: child.output, child)) :: Nil) - } else if (numSessionExpr > 1) { - p.failAnalysis("Multiple time session expressions would result in a cartesian product " + - "of rows, therefore they are currently not supported.") + } else if (numWindowExpr > 1) { + p.failAnalysis("Multiple time/session window expressions would result in a " + + "cartesian product of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index a751324a80b3f..3b2c9eb5f19bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -115,6 +115,8 @@ object AggUtils { finalAggregate :: Nil } + // FIXME: distinct in session makes sense? + def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], @@ -338,16 +340,19 @@ object AggUtils { finalAndCompleteAggregate :: Nil } - // FIXME: change! /** - * Plans a streaming aggregation using the following progression: + * Plans a streaming session aggregation using the following progression: * - * - Partial Aggregation - * - Shuffle - * - Partial Merge (now there is at most 1 tuple per group) - * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - * - PartialMerge (now there is at most 1 tuple per group) - * - StateStoreSave (saves the tuple for the next batch) + * - Partial Merge (group: all keys) + * - all tuples will have aggregated columns with initial value + * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) + * - SessionWindowStateStoreRestore (group: keys "without" session) + * - merge input tuples with stored tuples (sessions) respecting sort order + * - calculate session among tuples, and update all tuples to get correct session range + * - PartialMerge (group: all keys) + * - now there is at most 1 tuple per group + * - SessionWindowStateStoreSave (group: keys "without" session) + * - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time) * - Complete (output the current result of the aggregation) */ def planStreamingAggregationForSession( @@ -365,6 +370,9 @@ object AggUtils { val groupingAttributes = groupingExpressions.map(_.toAttribute) + // we don't do partial aggregate here, because it requires additional shuffle + // and there will be less rows which have same session start + // here doing partial merge is to have aggregated columns with default value for each row val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) @@ -377,7 +385,7 @@ object AggUtils { child = child) } - // shuffle & sort happens here + // shuffle & sort happens here: most of details are also handled in this physical plan val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialAggregate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index 09061d7ceb203..e5a157b1009cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -30,8 +30,6 @@ class MergingSortWithMultiValuesStateIterator( watermarkPredicateForData: Option[Predicate], inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { - // FIXME: handle watermark in input rows, not from state - private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, row: InternalRow) @@ -56,20 +54,9 @@ class MergingSortWithMultiValuesStateIterator( private var currentStateIter: Iterator[InternalRow] = _ private var currentStateFetchedKey: UnsafeRow = _ - private val baseIterator = watermarkPredicateForData match { - case Some(predicate) => iter.filter((row: InternalRow) => { - val pr = !predicate.eval(row) - if (!pr) { - System.err.println(s"DEBUG - evicting input due to watermark... $row") - } - pr - }) - case None => iter - } - override def hasNext: Boolean = { currentRow != null || currentStateRow != null || - (currentStateIter != null && currentStateIter.hasNext) || baseIterator.hasNext + (currentStateIter != null && currentStateIter.hasNext) || iter.hasNext } override def next(): InternalRow = { @@ -121,8 +108,8 @@ class MergingSortWithMultiValuesStateIterator( } private def mayFillCurrentRow(): Unit = { - if (baseIterator.hasNext) { - currentRow = SessionRowInformation.of(baseIterator.next()) + if (iter.hasNext) { + currentRow = SessionRowInformation.of(iter.next()) System.err.println(s"DEBUG - filling current row... current row: $currentRow") } } @@ -135,6 +122,10 @@ class MergingSortWithMultiValuesStateIterator( currentStateIter = null if (currentRow != null && currentRow.keys != currentStateFetchedKey) { + + // This is necessary because MultiValuesStateManager doesn't guarantee stable ordering + // The number of values for the given key is expected to be likely small, + // so sorting it here doesn't hurt. val unsortedIter = stateManager.get(currentRow.keys) currentStateIter = unsortedIter.toList.sortWith((row1, row2) => { val rowInfo1 = SessionRowInformation.of(row1) @@ -146,9 +137,9 @@ class MergingSortWithMultiValuesStateIterator( currentStateFetchedKey = currentRow.keys if (currentStateIter.hasNext) { currentStateRow = SessionRowInformation.of(currentStateIter.next()) - System.err.println(s"DEBUG: WARN - read data ${currentStateRow.row} from state for key ${currentRow.keys}") + System.err.println(s"DEBUG: read data ${currentStateRow.row} from state for key ${currentRow.keys}") } else { - System.err.println(s"DEBUG: WARN - no state data for key ${currentRow.keys}") + System.err.println(s"DEBUG: no state data for key ${currentRow.keys}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala index a94cbd2968e2e..648cadbafb1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala @@ -57,6 +57,7 @@ import org.apache.spark.util.NextIterator * the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore * by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex), * decrement corresponding num values in KeyToNumValuesStore + * (the operation doesn't guarantee stable ordering once value is removed) */ class MultiValuesStateManager( storeNamePrefix: String, @@ -160,8 +161,7 @@ class MultiValuesStateManager( * This implies the iterator must be consumed fully without any other operations on this manager * or the underlying store being interleaved. * - * NOTE: if any value is remove for the key, the order of values will be non-deterministic. - * It doesn't keep the order stable when removing value for gaining performance. + * NOTE: It doesn't keep order of values being stable when removing one for performance gain. */ def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { new NextIterator[UnsafeRowPair] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b50824273e6f9..9687ab97959f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -438,7 +438,7 @@ case class SessionWindowStateStoreRestoreExec( extends UnaryExecNode with StateStoreReader with WatermarkSupport { // FIXME: does we really need to have global aggregation from here? - // yes... it even has a case which session is only key + // FIXME: yes... it even has a case which session is only key require(keyExpressions.nonEmpty, "Key expressions should be presented.") override protected def doExecute(): RDD[InternalRow] = { @@ -452,14 +452,7 @@ case class SessionWindowStateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (stateManager, iter) => - // FIXME: would we want to break down into two multiple physical plans? - // - // new MergingSortWithMultiValuesStateIterator(iter, stateManager, keyExpressions, - // sessionExpression, child.output).map { row => - // numOutputRows += 1 - // row - // } - + // FIXME: remove val debugPartitionId = TaskContext.get().partitionId() val debugIter = iter.map { row => @@ -476,7 +469,19 @@ case class SessionWindowStateStoreRestoreExec( row } - val mergedIter = new MergingSortWithMultiValuesStateIterator(debugIter, stateManager, + // We need to filter out outdated inputs + val filteredIterator = watermarkPredicateForData match { + case Some(predicate) => debugIter.filter((row: InternalRow) => { + val pr = !predicate.eval(row) + if (!pr) { + logWarning(s"DEBUG - evicting input due to watermark... $row") + } + pr + }) + case None => debugIter + } + + val mergedIter = new MergingSortWithMultiValuesStateIterator(filteredIterator, stateManager, keyExpressions, sessionExpression, watermarkPredicateForData, child.output) val debugMergedIter = mergedIter.map { row => @@ -585,6 +590,7 @@ case class SessionWindowStateStoreSaveExec( while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] val keys = keyProjection(row) + if (!alreadyRemovedKeys.contains(keys)) { logWarning(s"DEBUG: partitionId $debugPartitionId - removing key ${keys} ...") stateManager.removeKey(keys) @@ -594,7 +600,6 @@ case class SessionWindowStateStoreSaveExec( val sessionProjection = GenerateUnsafeProjection.generate( Seq(sessionExpression), child.output) val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - keys ${keyProjection(row)}") logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - session ${sessionProjection(row)}") logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row (proj) ${rowProjection(row)}") From 7255bcace4b8d94f5b5139d84fc0c3240d26410f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 17:49:01 +0900 Subject: [PATCH 10/60] Silly bugfix & block session window for batch query as of now We can enable it but there're lots of approaches on aggregations in batch side... * AggUtils.planAggregateWithoutDistinct * AggUtils.planAggregateWithOneDistinct * RewriteDistinctAggregates * AggregateInPandasExec So unless we are sure which things to support, just block them for now... --- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e52c62bdf07a1..64ceda601abbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2650,7 +2650,7 @@ object TimeWindowing extends Rule[LogicalPlan] { }).toSet.size // Only support a single window expression for now - if (numWindowExpr == 1 && + if (numWindowExpr == 1 && windowExpressions.nonEmpty && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { @@ -2718,8 +2718,8 @@ object TimeWindowing extends Rule[LogicalPlan] { renamedPlan.withNewChildren(substitutedPlan :: Nil) } } else if (numWindowExpr > 1) { - p.failAnalysis("Multiple time/session window expressions would result in a cartesian product " + - "of rows, therefore they are currently not supported.") + p.failAnalysis("Multiple time/session window expressions would result in a cartesian " + + "product of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } @@ -2747,10 +2747,15 @@ object SessionWindowing extends Rule[LogicalPlan] { }).toSet.size // Only support a single session expression for now - if (numWindowExpr == 1 && + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && sessionExpressions.head.timeColumn.resolved && sessionExpressions.head.checkInputDataTypes().isSuccess) { + // FIXME: where it needs to place the check? In UnsupportedOperationsSuite? + if (!p.isStreaming) { + p.failAnalysis("Session window is not supported for batch query as of now.") + } + val session = sessionExpressions.head val metadata = session.timeColumn match { From 7b57fe5d4fb1ca4a62ae9ed8d7851ffee6633fb9 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 18:28:34 +0900 Subject: [PATCH 11/60] More works: majorly split out updating session to individual physical node * we will leverage such node for batch case if we want --- .../sql/execution/aggregate/AggUtils.scala | 12 ++- .../aggregate/UpdatingSessionExec.scala | 79 +++++++++++++++++++ .../streaming/IncrementalExecution.scala | 21 +---- .../streaming/UpdatingSessionIterator.scala | 6 +- .../streaming/statefulOperators.scala | 22 ++---- .../streaming/EventTimeWatermarkSuite.scala | 2 +- 6 files changed, 96 insertions(+), 46 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 3b2c9eb5f19bb..7de71c0b47aac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{SessionWindowStateStoreRestoreExec, SessionWindowStateStoreSaveExec, StateStoreRestoreExec, StateStoreSaveExec} +import org.apache.spark.sql.execution.streaming._ /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -115,8 +115,6 @@ object AggUtils { finalAggregate :: Nil } - // FIXME: distinct in session makes sense? - def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], @@ -348,7 +346,9 @@ object AggUtils { * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) * - SessionWindowStateStoreRestore (group: keys "without" session) * - merge input tuples with stored tuples (sessions) respecting sort order + * - UpdatingSessionExec * - calculate session among tuples, and update all tuples to get correct session range + * - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted * - PartialMerge (group: all keys) * - now there is at most 1 tuple per group * - SessionWindowStateStoreSave (group: keys "without" session) @@ -389,6 +389,10 @@ object AggUtils { val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialAggregate) + val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, optRequiredChildDistribution = None, + optRequiredChildOrdering = None, restored) + val partialMerged: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) @@ -401,7 +405,7 @@ object AggUtils { initialInputBufferOffset = groupingAttributes.length, resultExpressions = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = restored) + child = updatedSession) } // Note: stateId and returnAllStates are filled in later with preparation rules diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala new file mode 100644 index 0000000000000..1167e302150c3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.TaskContext + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.UpdatingSessionIterator + +// FIXME: javadoc should provide precondition that input must be sorted +// or both required child distribution as well as required child ordering should be presented +// to guarantee input will be sorted +case class UpdatingSessionExec( + keyExpressions: Seq[Attribute], + sessionExpression: Attribute, + optRequiredChildDistribution: Option[Seq[Distribution]], + optRequiredChildOrdering: Option[Seq[Seq[SortOrder]]], + child: SparkPlan) extends UnaryExecNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val newIter = new UpdatingSessionIterator(iter, keyExpressions, sessionExpression, + child.output) + + val debugIter = newIter.map { row => + val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + + // FIXME: remove + val debugPartitionId = TaskContext.get().partitionId() + + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - keys ${keysProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row ${row}") + + row + } + + debugIter + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = optRequiredChildDistribution match { + case Some(distribution) => distribution + case None => super.requiredChildDistribution + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = optRequiredChildOrdering match { + case Some(ordering) => ordering + case None => super.requiredChildOrdering + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index c86ac09644f5f..0ee41ea8fc077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -118,7 +118,7 @@ class IncrementalExecution( Some(aggStateInfo), stateFormatVersion, child) :: Nil)) -/* + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, UnaryExecNode(agg, UnaryExecNode(agg2, @@ -138,25 +138,6 @@ class IncrementalExecution( Some(aggStateInfo), Some(offsetSeqMetadata.batchWatermarkMs), child) :: Nil) :: Nil)) - */ - - case SessionWindowStateStoreSaveExec(keys, session, None, None, None, - UnaryExecNode(agg, - SessionWindowStateStoreRestoreExec(_, _, None, None, child))) => - val aggStateInfo = nextStatefulOperationStateInfo - SessionWindowStateStoreSaveExec( - keys, - session, - Some(aggStateInfo), - Some(outputMode), - Some(offsetSeqMetadata.batchWatermarkMs), - agg.withNewChildren( - SessionWindowStateStoreRestoreExec( - keys, - session, - Some(aggStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs), - child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => StreamingDeduplicateExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala index 60b992f164c30..2bb5b3dd2d857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BindReferences, CreateNamedStruct, Expression, Literal, PreciseTimestampConversion, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{LongType, TimestampType} @@ -47,7 +47,7 @@ class UpdatingSessionIterator( val processedKeys: mutable.HashSet[InternalRow] = new mutable.HashSet[InternalRow]() - // FIXME: data loss seen... one data from input and one data from state + // FIXME: check whether it can be run with such situation: empty groupWithoutSessionExpressions override def hasNext: Boolean = { assertIteratorNotCorrupted() @@ -191,8 +191,6 @@ class UpdatingSessionIterator( returnRowsIter = returnRows.iterator } - //returnRowsIter = returnRows.iterator - // FIXME: DEBUG val (rIter, tmpReturnRowsIter) = returnRowsIter.duplicate returnRowsIter = rIter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 9687ab97959f7..430251541d4b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -482,7 +482,10 @@ case class SessionWindowStateStoreRestoreExec( } val mergedIter = new MergingSortWithMultiValuesStateIterator(filteredIterator, stateManager, - keyExpressions, sessionExpression, watermarkPredicateForData, child.output) + keyExpressions, sessionExpression, watermarkPredicateForData, child.output).map { row => + numOutputRows += 1 + row + } val debugMergedIter = mergedIter.map { row => val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) @@ -498,22 +501,7 @@ case class SessionWindowStateStoreRestoreExec( row } - new UpdatingSessionIterator(debugMergedIter, keyExpressions, sessionExpression, - child.output).map { row => - numOutputRows += 1 - - val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - keys ${keysProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row ${row}") - - row - } + debugMergedIter } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index a8d52400f59d8..33a674f87d712 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.functions.{count, window, session} +import org.apache.spark.sql.functions.{count, session, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.util.Utils From 969859b8ea9d416f91762d88f2c223b520491937 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 18:38:00 +0900 Subject: [PATCH 12/60] Fix a silly bug and also add check for session window against batch query --- .../sql/DataFrameTimeWindowingSuite.scala | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 2953425b1db49..12f8a730a9591 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -241,7 +241,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B df.select(window($"time", "10 second"), window($"time", "15 second")).collect() } assert(e.getMessage.contains( - "Multiple time window expressions would result in a cartesian product")) + "Multiple time/session window expressions would result in a cartesian product")) } test("aliased windows") { @@ -354,4 +354,21 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } } + + // TODO: we can add session window tests when session window is enabled for batch + + test("Session window in batch query throws nice exception") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + val e = intercept[AnalysisException] { + df.groupBy(session($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session.start".asc) + .select($"session.start".cast("string"), $"session.end".cast("string"), $"counts") + } + + assert(e.getMessage.contains("Session window is not supported for batch query")) + } + } From f5ecbdd62d1ec4fb405100f04794a4cd2a85802f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 20:30:15 +0900 Subject: [PATCH 13/60] WIP Fixed eviction on update mode --- ...gingSortWithMultiValuesStateIterator.scala | 1 - .../streaming/statefulOperators.scala | 24 +++-- ...ortWithMultiValuesStateIteratorSuite.scala | 63 +++++++++++- .../UpdatingSessionIteratorSuite.scala | 57 ++++++++++- .../streaming/EventTimeWatermarkSuite.scala | 96 ++++++++++++++++++- 5 files changed, 223 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index e5a157b1009cf..32d1d9ba3e2f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -144,5 +144,4 @@ class MergingSortWithMultiValuesStateIterator( } } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 430251541d4b9..4be6aeb5ade86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -437,10 +437,6 @@ case class SessionWindowStateStoreRestoreExec( child: SparkPlan) extends UnaryExecNode with StateStoreReader with WatermarkSupport { - // FIXME: does we really need to have global aggregation from here? - // FIXME: yes... it even has a case which session is only key - require(keyExpressions.nonEmpty, "Key expressions should be presented.") - override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -510,7 +506,11 @@ case class SessionWindowStateStoreRestoreExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + } } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { @@ -531,9 +531,6 @@ case class SessionWindowStateStoreSaveExec( child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { - // FIXME: does we really need to have global aggregation from here? - require(keyExpressions.nonEmpty, "Key expressions should be presented.") - override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -647,7 +644,10 @@ case class SessionWindowStateStoreSaveExec( // Remove old aggregates if watermark specified allRemovalsTimeMs += timeTakenMs { - evictSessionsByWatermark(stateManager) + // fully consume iterator to let removal take effect + evictSessionsByWatermark(stateManager).map { rowPair => + System.err.println(s"DEBUG: evicting row ${rowPair.value}") + }.toList } commitTimeMs += timeTakenMs { stateManager.commit() } setStoreMetrics(stateManager) @@ -664,7 +664,11 @@ case class SessionWindowStateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + } } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala index c00e3ffcbeaea..ffb66326c0031 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala @@ -49,8 +49,6 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { attr => List("aggVal1", "aggVal2").contains(attr.name) } - // FIXME: add test for watermark - test("no row in input data") { withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, @@ -149,6 +147,67 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { } } + test("no keys in input data and state") { + val noKeyRowSchema = new StructType() + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val noKeyRowAttributes = noKeyRowSchema.toAttributes + + val noKeySessionAttribute = noKeyRowAttributes.filter(attr => attr.name == "session").head + + def createNoKeyRow(sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(4) + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(0, sessionRow) + + genericRow.setLong(1, aggVal1) + genericRow.setDouble(2, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(noKeyRowAttributes, noKeyRowAttributes) + rowProjection(genericRow) + } + + def assertNoKeyRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + assert(retRow.getStruct(0, 2).getLong(0) == expectedRow.getStruct(0, 2).getLong(0)) + assert(retRow.getStruct(0, 2).getLong(1) == expectedRow.getStruct(0, 2).getLong(1)) + assert(retRow.getLong(1) === expectedRow.getLong(1)) + assert(doubleEquals(retRow.getDouble(2), expectedRow.getDouble(2))) + } + + def appendNoKeyRowToStateManager(manager: MultiValuesStateManager, rows: UnsafeRow*): Unit = { + rows.foreach(row => manager.append(new UnsafeRow(0), row)) + } + + withStateManager(noKeyRowAttributes, Seq.empty[Attribute]) { manager => + // only input data + val row1 = createNoKeyRow(100, 110, 10, 1.1) + val row2 = createNoKeyRow(100, 110, 20, 1.2) + + val srow1 = createNoKeyRow(55, 85, 50, 2.5) + val srow2 = createNoKeyRow(105, 140, 30, 2.0) + appendNoKeyRowToStateManager(manager, srow1, srow2) + + val rows = List(row1, row2) + + val expectedRowSequence = List(srow1, row1, row2, srow2) + + val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, + manager, Seq.empty[Attribute], noKeySessionAttribute, None, noKeyRowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext) + assertNoKeyRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + private def getKeyRow(row: UnsafeRow): UnsafeRow = { val keyProjection = GenerateUnsafeProjection.generate(keysWithoutSessionAttributes, rowAttributes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala index c6a30d0f0d8e7..2471d809a183b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -291,6 +291,61 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { } } + test("no key") { + val noKeyRowSchema = new StructType() + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val noKeyRowAttributes = noKeyRowSchema.toAttributes + + val noKeySessionAttribute = noKeyRowAttributes.filter(attr => attr.name == "session").head + + def createNoKeyRow(sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(4) + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(0, sessionRow) + + genericRow.setLong(1, aggVal1) + genericRow.setDouble(2, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(noKeyRowAttributes, noKeyRowAttributes) + rowProjection(genericRow) + } + + def assertNoKeyRowsEqualsWithNewSession(expectedRow: InternalRow, retRow: InternalRow, + newSessionStart: Long, newSessionEnd: Long): Unit = { + assert(retRow.getStruct(0, 2).getLong(0) == newSessionStart) + assert(retRow.getStruct(0, 2).getLong(1) == newSessionEnd) + assert(retRow.getLong(1) === expectedRow.getLong(1)) + assert(doubleEquals(retRow.getDouble(2), expectedRow.getDouble(2))) + } + + val row1 = createNoKeyRow(100, 110, 10, 1.1) + val row2 = createNoKeyRow(100, 110, 20, 1.2) + val row3 = createNoKeyRow(105, 115, 30, 1.3) + val row4 = createNoKeyRow(113, 123, 40, 1.4) + val rows = List(row1, row2, row3, row4) + + val iterator = new UpdatingSessionIterator(rows.iterator, Seq.empty[Attribute], + noKeySessionAttribute, noKeyRowAttributes) + + val retRows = rows.indices.map { _ => + assert(iterator.hasNext) + iterator.next() + } + + retRows.zip(rows).foreach { case (retRow, expectedRow) => + // session being expanded to (100 ~ 123) + assertNoKeyRowsEqualsWithNewSession(expectedRow, retRow, 100, 123) + } + + assert(iterator.hasNext === false) + } + private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, aggVal1: Long, aggVal2: Double): UnsafeRow = { val genericRow = new GenericInternalRow(6) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 33a674f87d712..a2746556322b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -318,6 +318,48 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("append mode - session - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11), // sessions: (10,16) + AssertOnQuery(execution => { + execution.explain(true) + true + }), + CheckNewAnswer(), + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,23) + CheckNewAnswer(), + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,23), (25,30) + CheckNewAnswer(), + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // evicts: (10,16), (17,22) + CheckNewAnswer((10, 16, 2), (17, 22, 1)), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30) / (35,45) + // evicts: (25,30) + CheckNewAnswer((25, 30, 1)) + ) + } + test("update mode - session") { val inputData = MemoryStream[Int] @@ -342,26 +384,72 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche }), AddData(inputData, 17), // Advance watermark to 7 seconds - // sessions: key 1 => (10,16), (17,23) + // sessions: key 1 => (10,16), (17,22) // FIXME: subtract with previous state? or leave it as it is? CheckNewAnswer((1, 10, 16, 2), (1, 17, 22, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - // sessions: key 1 => (10,20) / key 2 => (25,30) + // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) CheckNewAnswer((2, 25, 30, 1)), AddData(inputData, 35), // Advance watermark to 25 seconds - // sessions: key 1 => (10,20) / key 2 => (25,30) / key 3 => (35,40) + // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) + // evicts: key 1 => (10,16), (17,22) CheckNewAnswer((3, 35, 40, 1)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), AddData(inputData, 40), // Advance watermark to 30 seconds - // sessions: key 1 => (10,20) / key 2 => (25,30) / key 3 => (35,45) + // sessions: key 2 => (25,30) / key 3 => (35,40) / key 4 => (40, 45) CheckNewAnswer((4, 40, 45, 1)) ) } + test("update mode - session - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation, OutputMode.Update())( + + AddData(inputData, 10, 11), + // Advance watermark to 1 seconds + // sessions: (10,16) + CheckNewAnswer((10, 16, 2)), + AssertOnQuery(execution => { + execution.explain(true) + true + }), + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,22) + // FIXME: subtract with previous state? or leave it as it is? + CheckNewAnswer((10, 16, 2), (17, 22, 1)), + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,22), (25,30) + CheckNewAnswer((10, 16, 2), (17, 22, 1), (25, 30, 1)), + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // evicts: (10,16), (17,22) + CheckNewAnswer((10, 16, 2), (17, 22, 1), (25, 30, 1), (35, 40, 1)), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30), (35,45) + CheckNewAnswer((25, 30, 1), (35, 45, 2)) + ) + } + test("update mode") { val inputData = MemoryStream[Int] spark.conf.set("spark.sql.shuffle.partitions", "10") From b180772ad70481b0bd9736bed6182561aa82344b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 21:48:07 +0900 Subject: [PATCH 14/60] WIP found root reason of broken UT... fixed it --- .../streaming/UpdatingSessionIterator.scala | 48 +++++++++++-------- .../streaming/EventTimeWatermarkSuite.scala | 46 +++++++++--------- 2 files changed, 50 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala index 2bb5b3dd2d857..b9fbd606b5639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala @@ -77,7 +77,10 @@ class UpdatingSessionIterator( var exitCondition = false while (iter.hasNext && !exitCondition) { - val row = iter.next() + // we are going to modify the row, so we should make sure multiple objects are not + // referencing same memory, which could be possible when optimizing iterator + // without this, multiple rows in same key will be returned with same content + val row = iter.next().copy() val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, inputSchema) @@ -103,6 +106,9 @@ class UpdatingSessionIterator( // expanding session length if needed expandEndOfCurrentSession(sessionEnd) currentRows += row + + System.err.println(s"DEBUG: - adding row: $row / currentRows: $currentRows") + } else { closeCurrentSession(keyChanged = false) startNewSession(row, keys, sessionStart, sessionEnd) @@ -159,29 +165,29 @@ class UpdatingSessionIterator( s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + s"currentSessionEnd: $currentSessionEnd") - val returnRows = currentRows.map { internalRow => - val sessionStruct = CreateNamedStruct( - Literal("start") :: - PreciseTimestampConversion( - Literal(currentSessionStart, LongType), LongType, TimestampType) :: - Literal("end") :: - PreciseTimestampConversion( - Literal(currentSessionEnd, LongType), LongType, TimestampType) :: - Nil) - - val convertedAllExpressions = inputSchema.map { x => - BindReferences.bindReference[Expression](x, inputSchema) - } + val sessionStruct = CreateNamedStruct( + Literal("start") :: + PreciseTimestampConversion( + Literal(currentSessionStart, LongType), LongType, TimestampType) :: + Literal("end") :: + PreciseTimestampConversion( + Literal(currentSessionEnd, LongType), LongType, TimestampType) :: + Nil) + + val convertedAllExpressions = inputSchema.map { x => + BindReferences.bindReference[Expression](x, inputSchema) + } - val newSchemaExpressions = convertedAllExpressions.indices.map { idx => - if (idx == sessionIndex) { - sessionStruct - } else { - convertedAllExpressions(idx) - } + val newSchemaExpressions = convertedAllExpressions.indices.map { idx => + if (idx == sessionIndex) { + sessionStruct + } else { + convertedAllExpressions(idx) } + } - val proj = GenerateUnsafeProjection.generate(newSchemaExpressions, inputSchema) + val returnRows = currentRows.map { internalRow => + val proj = UnsafeProjection.create(newSchemaExpressions, inputSchema) proj(internalRow) }.toList diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index a2746556322b6..cda3c5de6fe2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.functions.{count, session, window} +import org.apache.spark.sql.functions.{count, session, sum, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.util.Utils @@ -284,9 +284,9 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) - .agg(count("*") as 'count) + .agg(count("*") as 'count, sum("value") as 'sum) .select($"valuegroup", $"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) testStream(windowedAggregation)( AddData(inputData, 10, 11), // sessions: key 1 => (10,16) @@ -307,14 +307,14 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 25 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) // evicts: key 1 => (10,16), (17,22) - CheckNewAnswer((1, 10, 16, 2), (1, 17, 22, 1)), + CheckNewAnswer((1, 10, 16, 2, 21), (1, 17, 22, 1, 17)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: key 2 => (25,30) / key 3 => (35,45) // evicts: key 2 => (25,30) - CheckNewAnswer((2, 25, 30, 1)) + CheckNewAnswer((2, 25, 30, 1, 25)) ) } @@ -326,9 +326,9 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") .groupBy(session($"eventTime", "5 seconds") as 'session) - .agg(count("*") as 'count) + .agg(count("*") as 'count, sum("value") as 'sum) .select($"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) testStream(windowedAggregation)( AddData(inputData, 10, 11), // sessions: (10,16) @@ -349,14 +349,14 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 25 seconds // sessions: (10,16), (17,22), (25,30), (35,40) // evicts: (10,16), (17,22) - CheckNewAnswer((10, 16, 2), (17, 22, 1)), + CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: (25,30) / (35,45) // evicts: (25,30) - CheckNewAnswer((25, 30, 1)) + CheckNewAnswer((25, 30, 1, 25)) ) } @@ -368,16 +368,16 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) - .agg(count("*") as 'count) + .agg(count("*") as 'count, sum("value") as 'sum) .select($"valuegroup", $"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) testStream(windowedAggregation, OutputMode.Update())( AddData(inputData, 10, 11), // Advance watermark to 1 seconds // sessions: key 1 => (10,16) - CheckNewAnswer((1, 10, 16, 2)), + CheckNewAnswer((1, 10, 16, 2, 21)), AssertOnQuery(execution => { execution.explain(true) true @@ -386,22 +386,22 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 7 seconds // sessions: key 1 => (10,16), (17,22) // FIXME: subtract with previous state? or leave it as it is? - CheckNewAnswer((1, 10, 16, 2), (1, 17, 22, 1)), + CheckNewAnswer((1, 10, 16, 2, 21), (1, 17, 22, 1, 17)), AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) - CheckNewAnswer((2, 25, 30, 1)), + CheckNewAnswer((2, 25, 30, 1, 25)), AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) // evicts: key 1 => (10,16), (17,22) - CheckNewAnswer((3, 35, 40, 1)), + CheckNewAnswer((3, 35, 40, 1, 35)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: key 2 => (25,30) / key 3 => (35,40) / key 4 => (40, 45) - CheckNewAnswer((4, 40, 45, 1)) + CheckNewAnswer((4, 40, 45, 1, 40)) ) } @@ -413,16 +413,16 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") .groupBy(session($"eventTime", "5 seconds") as 'session) - .agg(count("*") as 'count) + .agg(count("*") as 'count, sum("value") as 'sum) .select($"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long]) + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) testStream(windowedAggregation, OutputMode.Update())( AddData(inputData, 10, 11), // Advance watermark to 1 seconds // sessions: (10,16) - CheckNewAnswer((10, 16, 2)), + CheckNewAnswer((10, 16, 2, 21)), AssertOnQuery(execution => { execution.explain(true) true @@ -431,22 +431,22 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 7 seconds // sessions: (10,16), (17,22) // FIXME: subtract with previous state? or leave it as it is? - CheckNewAnswer((10, 16, 2), (17, 22, 1)), + CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: (10,16), (17,22), (25,30) - CheckNewAnswer((10, 16, 2), (17, 22, 1), (25, 30, 1)), + CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17), (25, 30, 1, 25)), AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: (10,16), (17,22), (25,30), (35,40) // evicts: (10,16), (17,22) - CheckNewAnswer((10, 16, 2), (17, 22, 1), (25, 30, 1), (35, 40, 1)), + CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17), (25, 30, 1, 25), (35, 40, 1, 35)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: (25,30), (35,45) - CheckNewAnswer((25, 30, 1), (35, 45, 2)) + CheckNewAnswer((25, 30, 1, 25), (35, 45, 2, 75)) ) } From 0a2d7312ca98d6b562d7046b3c8751e5387de088 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 21:50:31 +0900 Subject: [PATCH 15/60] WIP remove printing "explain" on UTs --- .../sql/streaming/EventTimeWatermarkSuite.scala | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index cda3c5de6fe2b..fb45a39c39ed0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -290,10 +290,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11), // sessions: key 1 => (10,16) - AssertOnQuery(execution => { - execution.explain(true) - true - }), CheckNewAnswer(), AddData(inputData, 17), // Advance watermark to 7 seconds @@ -332,10 +328,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11), // sessions: (10,16) - AssertOnQuery(execution => { - execution.explain(true) - true - }), CheckNewAnswer(), AddData(inputData, 17), // Advance watermark to 7 seconds @@ -373,15 +365,10 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) testStream(windowedAggregation, OutputMode.Update())( - AddData(inputData, 10, 11), // Advance watermark to 1 seconds // sessions: key 1 => (10,16) CheckNewAnswer((1, 10, 16, 2, 21)), - AssertOnQuery(execution => { - execution.explain(true) - true - }), AddData(inputData, 17), // Advance watermark to 7 seconds // sessions: key 1 => (10,16), (17,22) @@ -423,10 +410,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 1 seconds // sessions: (10,16) CheckNewAnswer((10, 16, 2, 21)), - AssertOnQuery(execution => { - execution.explain(true) - true - }), AddData(inputData, 17), // Advance watermark to 7 seconds // sessions: (10,16), (17,22) From 6ee901ec88e6b68f3c3070f73d5dd1475b03eaf8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 23:43:32 +0900 Subject: [PATCH 16/60] WIP address session to batch query (+ python) as well... not having tests for some aggregations * distinct * two distincts * pandas --- python/pyspark/sql/functions.py | 14 ++ .../sql/catalyst/analysis/Analyzer.scala | 14 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/SessionWindow.scala | 2 + .../spark/sql/execution/SparkStrategies.scala | 67 ++++-- .../sql/execution/aggregate/AggUtils.scala | 218 ++++++++++++++++++ .../python/AggregateInPandasExec.scala | 36 ++- .../sql/DataFrameSessionWindowingSuite.scala | 169 ++++++++++++++ .../sql/DataFrameTimeWindowingSuite.scala | 17 -- 9 files changed, 490 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5425d311f8c7f..fd89b4834cbc6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1396,6 +1396,20 @@ def check_string_field(field, fieldName): res = sc._jvm.functions.window(time_col, windowDuration) return Column(res) +@since(3.0) +@ignore_unicode_prefix +def session(timeColumn, gapDuration): + # FIXME: python doc!! + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(gapDuration, "gapDuration") + res = sc._jvm.functions.session(time_col, gapDuration) + return Column(res) + # ---------------------------- misc functions ---------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 64ceda601abbd..e3c3765fcd334 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2751,11 +2751,6 @@ object SessionWindowing extends Rule[LogicalPlan] { sessionExpressions.head.timeColumn.resolved && sessionExpressions.head.checkInputDataTypes().isSuccess) { - // FIXME: where it needs to place the check? In UnsupportedOperationsSuite? - if (!p.isStreaming) { - p.failAnalysis("Session window is not supported for batch query as of now.") - } - val session = sessionExpressions.head val metadata = session.timeColumn match { @@ -2763,8 +2758,13 @@ object SessionWindowing extends Rule[LogicalPlan] { case _ => Metadata.empty } + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + val sessionAttr = AttributeReference( - SESSION_COL_NAME, session.dataType, metadata = metadata)() + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) val sessionEnd = sessionStart + session.gapDuration @@ -2777,7 +2777,7 @@ object SessionWindowing extends Rule[LogicalPlan] { Nil) val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( - exprId = sessionAttr.exprId, explicitMetadata = Some(metadata)) + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) val replacedPlan = p transformExpressions { case s: SessionWindow => sessionAttr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7dafebff79874..36d1fff4d834b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -398,6 +398,7 @@ object FunctionRegistry { expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), + expression[SessionWindow]("session"), // collection functions expression[CreateArray]("array"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index 75c9e663d5db7..54c516a93c18d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -60,6 +60,8 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends Unar } object SessionWindow { + val marker = "spark.sessionWindow" + /** * Parses the interval string for a valid time duration. CalendarInterval expects interval * strings to start with the string `interval`. For usability, we prepend `interval` to the string diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0dc84ed640a68..72132f197ac0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -330,7 +330,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } val sessionWindowOption = namedGroupingExpressions.find { p => - p.name == "session" && p.dataType.isInstanceOf[StructType] + p.metadata.contains(SessionWindow.marker) } sessionWindowOption match { @@ -429,30 +429,63 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Spark user mailing list.") } - val aggregateOperator = - if (functionsWithDistinct.isEmpty) { - aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, - aggregateExpressions, - resultExpressions, - planLater(child)) - } else { - aggregate.AggUtils.planAggregateWithOneDistinct( - groupingExpressions, - functionsWithDistinct, - functionsWithoutDistinct, - resultExpressions, - planLater(child)) - } + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + sessionWindowOption match { + case Some(sessionWindow) => + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.AggUtils.planSessionAggregateWithoutDistinct( + groupingExpressions, + sessionWindow, + aggregateExpressions, + resultExpressions, + planLater(child)) + } else { + aggregate.AggUtils.planSessionAggregateWithOneDistinct( + groupingExpressions, + sessionWindow, + functionsWithDistinct, + functionsWithoutDistinct, + resultExpressions, + planLater(child)) + } + + aggregateOperator - aggregateOperator + case None => + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.AggUtils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + resultExpressions, + planLater(child)) + } else { + aggregate.AggUtils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + resultExpressions, + planLater(child)) + } + + aggregateOperator + } case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + Seq(execution.python.AggregateInPandasExec( groupingExpressions, + sessionWindowOption, udfExpressions, resultExpressions, planLater(child))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 7de71c0b47aac..6c6bba29c72e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ @@ -245,6 +246,223 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + def planSessionAggregateWithoutDistinct( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + // Check if we can use HashAggregate. + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + // 1. Create an Aggregate Operator for partial aggregations. + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialResultExpressions = + groupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = groupingExpressions, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = partialResultExpressions, + child = child) + + val childDistribution = if (groupWithoutSessionExpression.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupWithoutSessionExpression) :: Nil + } + + val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + optRequiredChildDistribution = Some(childDistribution), + optRequiredChildOrdering = Some(childOrdering), + partialAggregate) + + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + val finalAggregate = createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, + child = updatedSession) + + finalAggregate :: Nil + } + + def planSessionAggregateWithOneDistinct( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + functionsWithDistinct: Seq[AggregateExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one + // DISTINCT aggregate function, all of those functions will have the same column expressions. + // For example, it would be valid for functionsWithDistinct to be + // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is + // disallowed because those two distinct aggregates have different column expressions. + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + // 1. Create an Aggregate Operator for partial aggregations. + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + // 2. Create an Aggregate Operator for partial merge aggregations. + val partialMergeAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression(aggregateFunction, mode, true, _) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] + case agg => + throw new IllegalArgumentException( + "Non-distinct aggregate is found in functionsWithDistinct " + + s"at planAggregateWithOneDistinct: $agg") + } + + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = functionsWithDistinct(i).resultAttribute + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, + child = partialMergeAggregate) + } + + val childDistribution = if (groupWithoutSessionExpression.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupWithoutSessionExpression) :: Nil + } + + val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + optRequiredChildDistribution = Some(childDistribution), + optRequiredChildOrdering = Some(childOrdering), + partialAggregate) + + // 4. Create an Aggregate Operator for the final aggregation. + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = functionsWithDistinct(i).resultAttribute + (expr, attr) + }.unzip + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = updatedSession) + } + + finalAndCompleteAggregate :: Nil + } + /** * Plans a streaming aggregation using the following progression: * - Partial Aggregation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 2ab7240556aaa..4a23a37e02152 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.arrow.ArrowUtils +import org.apache.spark.sql.execution.streaming.UpdatingSessionIterator import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -42,6 +43,7 @@ import org.apache.spark.util.Utils */ case class AggregateInPandasExec( groupingExpressions: Seq[NamedExpression], + optSessionExpression: Option[NamedExpression], udfExpressions: Seq[PythonUDF], resultExpressions: Seq[NamedExpression], child: SparkPlan) @@ -53,14 +55,29 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) + val groupingWithoutSessionExpressions = optSessionExpression match { + case Some(sessionExpression) => + groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) } + + case None => groupingExpressions + } + override def requiredChildDistribution: Seq[Distribution] = { - if (groupingExpressions.isEmpty) { + if (groupingWithoutSessionExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(groupingExpressions) :: Nil + ClusteredDistribution(groupingWithoutSessionExpressions) :: Nil } } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = optSessionExpression match { + case Some(sessionExpression) => + Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + case None => Seq(groupingExpressions.map(SortOrder(_, Ascending))) + } + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => @@ -73,9 +90,6 @@ case class AggregateInPandasExec( } } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingExpressions.map(SortOrder(_, Ascending))) - override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() @@ -106,13 +120,21 @@ case class AggregateInPandasExec( }) inputRDD.mapPartitionsInternal { iter => + val newIter = optSessionExpression match { + case Some(sessionExpression) => + new UpdatingSessionIterator(iter, groupingWithoutSessionExpressions, sessionExpression, + child.output) + + case None => iter + } + val prunedProj = UnsafeProjection.create(allInputs, child.output) val grouped = if (groupingExpressions.isEmpty) { // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) + Iterator((new UnsafeRow(), newIter)) } else { - GroupedIterator(iter, groupingExpressions, child.output) + GroupedIterator(newIter, groupingExpressions, child.output) }.map { case (key, rows) => (key, rows.map(prunedProj)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala new file mode 100644 index 0000000000000..37fa28d0905d3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType + +class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext + with BeforeAndAfterEach { + + import testImplicits._ + + // FIXME: session window + + test("simple session window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(session($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session.start".asc) + .select($"session.start".cast("string"), $"session.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + + test("session window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sorting while applying group by + + checkAnswer( + df.groupBy(session($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session.start".asc) + .select("counts"), + Seq(Row(2), Row(1)) + ) + } + + test("session window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(session($"time", "10 seconds"), $"value") + .orderBy($"session.start".asc) + .select($"session.start".cast("string"), $"session.end".cast("string"), $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Session windows shouldn't require expand") + + checkAnswer( + df, + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + + test("session window combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"session.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(session($"time", "10 seconds"), $"value") + .orderBy($"session.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + // NOTE: unlike time window, joining session windows without grouping + // doesn't arrange session, so two rows will be joined only if session range is exactly same + + test("multiple session windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(session($"time", "10 second"), session($"time", "15 second")).collect() + } + assert(e.getMessage.contains( + "Multiple time/session window expressions would result in a cartesian product")) + } + + test("aliased session windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session($"time", "10 seconds").as("session_window"), $"value") + .orderBy($"session_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName) + try { + f(tableName) + } finally { + spark.catalog.dropTempView(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + spark.sql(s"""select session(time, "10 seconds"), value from $table""") + .select($"session.start".cast(StringType), $"session.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 12f8a730a9591..f1be46bf759c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -354,21 +354,4 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } } - - // TODO: we can add session window tests when session window is enabled for batch - - test("Session window in batch query throws nice exception") { - val df = Seq( - ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") - - val e = intercept[AnalysisException] { - df.groupBy(session($"time", "10 seconds")) - .agg(count("*").as("counts")) - .orderBy($"session.start".asc) - .select($"session.start".cast("string"), $"session.end".cast("string"), $"counts") - } - - assert(e.getMessage.contains("Session window is not supported for batch query")) - } - } From 395606bb08ab2dfc6c20e2d2f2b9708b73974aad Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 13 Sep 2018 23:59:39 +0900 Subject: [PATCH 17/60] WIP add more test on session batch query --- .../sql/DataFrameSessionWindowingSuite.scala | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 37fa28d0905d3..ef816aff212fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -29,8 +29,6 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext import testImplicits._ - // FIXME: session window - test("simple session window with record at window start") { val df = Seq( ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") @@ -52,7 +50,8 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") - // session window handles sorting while applying group by + // session window handles sort while applying group by + // whereas time window doesn't checkAnswer( df.groupBy(session($"time", "10 seconds")) @@ -63,6 +62,36 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ) } + test("session window groupBy with multiple keys statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + test("session window with multi-column projection") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), From f6bb34d919780aa72e1bffce4dd41ba98cbb4a89 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 14 Sep 2018 14:13:32 +0900 Subject: [PATCH 18/60] WIP add UT for sessions with keys overlapped --- .../sql/DataFrameSessionWindowingSuite.scala | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index ef816aff212fb..11930e67a6e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -92,6 +92,36 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ) } + test("session window groupBy with multiple keys statement - keys overlapped with sessions") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "b"), + ("2016-03-27 19:39:40", 2, "a"), + ("2016-03-27 19:39:45", 2, "b"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // a => (19:39:34 ~ 19:39:50) + // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) + + checkAnswer( + df.groupBy(session($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:50", "a", 2, 3), + Row("2016-03-27 19:39:39", "2016-03-27 19:39:55", "b", 2, 3) + ) + ) + } + test("session window with multi-column projection") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), From 22fffd2ae078a799ee82271e55210ff41026170a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 14 Sep 2018 14:49:47 +0900 Subject: [PATCH 19/60] WIP refactor a bit --- .../spark/sql/execution/SparkStrategies.scala | 65 ++--- .../sql/execution/aggregate/AggUtils.scala | 258 +++--------------- .../python/AggregateInPandasExec.scala | 30 +- 3 files changed, 76 insertions(+), 277 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 72132f197ac0b..7268d2bcc07f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -429,63 +429,30 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Spark user mailing list.") } - val sessionWindowOption = groupingExpressions.find { p => - p.metadata.contains(SessionWindow.marker) - } - - sessionWindowOption match { - case Some(sessionWindow) => - val aggregateOperator = - if (functionsWithDistinct.isEmpty) { - aggregate.AggUtils.planSessionAggregateWithoutDistinct( - groupingExpressions, - sessionWindow, - aggregateExpressions, - resultExpressions, - planLater(child)) - } else { - aggregate.AggUtils.planSessionAggregateWithOneDistinct( - groupingExpressions, - sessionWindow, - functionsWithDistinct, - functionsWithoutDistinct, - resultExpressions, - planLater(child)) - } - - aggregateOperator + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.AggUtils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + resultExpressions, + planLater(child)) + } else { + aggregate.AggUtils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + resultExpressions, + planLater(child)) + } - case None => - val aggregateOperator = - if (functionsWithDistinct.isEmpty) { - aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, - aggregateExpressions, - resultExpressions, - planLater(child)) - } else { - aggregate.AggUtils.planAggregateWithOneDistinct( - groupingExpressions, - functionsWithDistinct, - functionsWithoutDistinct, - resultExpressions, - planLater(child)) - } - - aggregateOperator - } + aggregateOperator case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) - val sessionWindowOption = groupingExpressions.find { p => - p.metadata.contains(SessionWindow.marker) - } - Seq(execution.python.AggregateInPandasExec( groupingExpressions, - sessionWindowOption, udfExpressions, resultExpressions, planLater(child))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 6c6bba29c72e9..abd81c46f477f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -98,200 +98,7 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) - // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - - val finalAggregate = createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - aggregateExpressions = finalAggregateExpressions, - aggregateAttributes = finalAggregateAttributes, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) - - finalAggregate :: Nil - } - - def planAggregateWithOneDistinct( - groupingExpressions: Seq[NamedExpression], - functionsWithDistinct: Seq[AggregateExpression], - functionsWithoutDistinct: Seq[AggregateExpression], - resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { - - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one - // DISTINCT aggregate function, all of those functions will have the same column expressions. - // For example, it would be valid for functionsWithDistinct to be - // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is - // disallowed because those two distinct aggregates have different column expressions. - val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children - val namedDistinctExpressions = distinctExpressions.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() - } - val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) - val groupingAttributes = groupingExpressions.map(_.toAttribute) - - // 1. Create an Aggregate Operator for partial aggregations. - val partialAggregate: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. - createAggregate( - groupingExpressions = groupingExpressions ++ namedDistinctExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - resultExpressions = groupingAttributes ++ distinctAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) - } - - // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregate: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( - requiredChildDistributionExpressions = - Some(groupingAttributes ++ distinctAttributes), - groupingExpressions = groupingAttributes ++ distinctAttributes, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, - resultExpressions = groupingAttributes ++ distinctAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = partialAggregate) - } - - // 3. Create an Aggregate operator for partial aggregation (for distinct) - val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap - val rewrittenDistinctFunctions = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true, _) => - aggregateFunction.transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] - case agg => - throw new IllegalArgumentException( - "Non-distinct aggregate is found in functionsWithDistinct " + - s"at planAggregateWithOneDistinct: $agg") - } - - val partialDistinctAggregate: SparkPlan = { - val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) - val (distinctAggregateExpressions, distinctAggregateAttributes) = - rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => - // We rewrite the aggregate function to a non-distinct aggregation because - // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. - val expr = AggregateExpression(func, Partial, isDistinct = true) - // Use original AggregationFunction to lookup attributes, which is used to build - // aggregateFunctionToAttribute - val attr = functionsWithDistinct(i).resultAttribute - (expr, attr) - }.unzip - - val partialAggregateResult = groupingAttributes ++ - mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ - distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - createAggregate( - groupingExpressions = groupingAttributes, - aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, - aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, - resultExpressions = partialAggregateResult, - child = partialMergeAggregate) - } - - // 4. Create an Aggregate Operator for the final aggregation. - val finalAndCompleteAggregate: SparkPlan = { - val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - - val (distinctAggregateExpressions, distinctAggregateAttributes) = - rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => - // We rewrite the aggregate function to a non-distinct aggregation because - // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. - val expr = AggregateExpression(func, Final, isDistinct = true) - // Use original AggregationFunction to lookup attributes, which is used to build - // aggregateFunctionToAttribute - val attr = functionsWithDistinct(i).resultAttribute - (expr, attr) - }.unzip - - createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, - aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, - initialInputBufferOffset = groupingAttributes.length, - resultExpressions = resultExpressions, - child = partialDistinctAggregate) - } - - finalAndCompleteAggregate :: Nil - } - - def planSessionAggregateWithoutDistinct( - groupingExpressions: Seq[NamedExpression], - sessionExpression: NamedExpression, - aggregateExpressions: Seq[AggregateExpression], - resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { - // Check if we can use HashAggregate. - - val groupWithoutSessionExpression = groupingExpressions.filterNot { p => - p.semanticEquals(sessionExpression) - } - - val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) - - // 1. Create an Aggregate Operator for partial aggregations. - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialResultExpressions = - groupingAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - - val partialAggregate = createAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = groupingExpressions, - aggregateExpressions = partialAggregateExpressions, - aggregateAttributes = partialAggregateAttributes, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - - val childDistribution = if (groupWithoutSessionExpression.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupWithoutSessionExpression) :: Nil - } - - val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) - .map(SortOrder(_, Ascending))) - - val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, - optRequiredChildDistribution = Some(childDistribution), - optRequiredChildOrdering = Some(childOrdering), - partialAggregate) + val interExec: SparkPlan = mayAppendUpdatingSessionExec(groupingExpressions, partialAggregate) // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -306,25 +113,18 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, - child = updatedSession) + child = interExec) finalAggregate :: Nil } - def planSessionAggregateWithOneDistinct( + def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], - sessionExpression: NamedExpression, functionsWithDistinct: Seq[AggregateExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val groupWithoutSessionExpression = groupingExpressions.filterNot { p => - p.semanticEquals(sessionExpression) - } - - val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be @@ -415,20 +215,8 @@ object AggUtils { child = partialMergeAggregate) } - val childDistribution = if (groupWithoutSessionExpression.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupWithoutSessionExpression) :: Nil - } - - val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) - .map(SortOrder(_, Ascending))) - - val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, - optRequiredChildDistribution = Some(childDistribution), - optRequiredChildOrdering = Some(childOrdering), - partialAggregate) + val interExec: SparkPlan = mayAppendUpdatingSessionExec(groupingExpressions, + partialDistinctAggregate) // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { @@ -457,7 +245,7 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = updatedSession) + child = interExec) } finalAndCompleteAggregate :: Nil @@ -655,4 +443,38 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + private def mayAppendUpdatingSessionExec(groupingExpressions: Seq[NamedExpression], + maybeChildPlan: SparkPlan): SparkPlan = { + val interExec = groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val childDistribution = if (groupWithoutSessionExpression.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupWithoutSessionExpression) :: Nil + } + + val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + optRequiredChildDistribution = Some(childDistribution), + optRequiredChildOrdering = Some(childOrdering), + maybeChildPlan) + + updatedSession + + case None => maybeChildPlan + } + + interExec + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 4a23a37e02152..38a39e2252685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -43,7 +43,6 @@ import org.apache.spark.util.Utils */ case class AggregateInPandasExec( groupingExpressions: Seq[NamedExpression], - optSessionExpression: Option[NamedExpression], udfExpressions: Seq[PythonUDF], resultExpressions: Seq[NamedExpression], child: SparkPlan) @@ -55,7 +54,11 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) - val groupingWithoutSessionExpressions = optSessionExpression match { + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + val groupingWithoutSessionExpressions = sessionWindowOption match { case Some(sessionExpression) => groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) } @@ -70,7 +73,7 @@ case class AggregateInPandasExec( } } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = optSessionExpression match { + override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match { case Some(sessionExpression) => Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) .map(SortOrder(_, Ascending))) @@ -120,13 +123,7 @@ case class AggregateInPandasExec( }) inputRDD.mapPartitionsInternal { iter => - val newIter = optSessionExpression match { - case Some(sessionExpression) => - new UpdatingSessionIterator(iter, groupingWithoutSessionExpressions, sessionExpression, - child.output) - - case None => iter - } + val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter) val prunedProj = UnsafeProjection.create(allInputs, child.output) @@ -175,4 +172,17 @@ case class AggregateInPandasExec( } } } + + private def mayAppendUpdatingSessionIterator(iter: Iterator[InternalRow]) + : Iterator[InternalRow] = { + val newIter = sessionWindowOption match { + case Some(sessionExpression) => + new UpdatingSessionIterator(iter, groupingWithoutSessionExpressions, sessionExpression, + child.output) + + case None => iter + } + + newIter + } } From 847f69efe7b29ab5f318e3992c06494e02129f36 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 14 Sep 2018 15:22:45 +0900 Subject: [PATCH 20/60] WIP add more FIXMEs for javadoc, and remove invalid FIXMEs --- ...gingSortWithMultiValuesStateIterator.scala | 1 + .../streaming/UpdatingSessionIterator.scala | 2 - .../state/MultiValuesStateManager.scala | 2 - .../state/MultiValuesStateStoreRDD.scala | 1 + .../streaming/statefulOperators.scala | 39 ++++++++++++------- 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index 32d1d9ba3e2f9..3a0cc32cca034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeR import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} import org.apache.spark.sql.execution.streaming.state.MultiValuesStateManager +// FIXME: javadoc!! class MergingSortWithMultiValuesStateIterator( iter: Iterator[InternalRow], stateManager: MultiValuesStateManager, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala index b9fbd606b5639..51a5b329fdeb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala @@ -47,8 +47,6 @@ class UpdatingSessionIterator( val processedKeys: mutable.HashSet[InternalRow] = new mutable.HashSet[InternalRow]() - // FIXME: check whether it can be run with such situation: empty groupWithoutSessionExpressions - override def hasNext: Boolean = { assertIteratorNotCorrupted() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala index 648cadbafb1c6..1abf330b3b1cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala @@ -315,8 +315,6 @@ class MultiValuesStateManager( def commit(): Unit = { stateStore.commit() - // FIXME: DEBUG - logInfo("Committed, metrics = " + stateStore.metrics) logDebug("Committed, metrics = " + stateStore.metrics) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala index 7d03e3662a6d0..652b463d85b72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration +// FIXME: javadoc!! class MultiValuesStateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 4be6aeb5ade86..8ab72c1a9da86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -430,13 +430,15 @@ case class StateStoreSaveExec( // FIXME: javadoc! // FIXME: keyExpressions shouldn't have 'session': otherwise we should exclude it... case class SessionWindowStateStoreRestoreExec( - keyExpressions: Seq[Attribute], + keyWithoutSessionExpressions: Seq[Attribute], sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], child: SparkPlan) extends UnaryExecNode with StateStoreReader with WatermarkSupport { + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -452,7 +454,8 @@ case class SessionWindowStateStoreRestoreExec( val debugPartitionId = TaskContext.get().partitionId() val debugIter = iter.map { row => - val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val keysProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, + child.output) val sessionProjection = GenerateUnsafeProjection.generate( Seq(sessionExpression), child.output) val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) @@ -477,14 +480,20 @@ case class SessionWindowStateStoreRestoreExec( case None => debugIter } - val mergedIter = new MergingSortWithMultiValuesStateIterator(filteredIterator, stateManager, - keyExpressions, sessionExpression, watermarkPredicateForData, child.output).map { row => + val mergedIter = new MergingSortWithMultiValuesStateIterator( + filteredIterator, + stateManager, + keyWithoutSessionExpressions, + sessionExpression, + watermarkPredicateForData, + child.output).map { row => numOutputRows += 1 row } val debugMergedIter = mergedIter.map { row => - val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val keysProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, + child.output) val sessionProjection = GenerateUnsafeProjection.generate( Seq(sessionExpression), child.output) val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) @@ -506,15 +515,15 @@ case class SessionWindowStateStoreRestoreExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - if (keyExpressions.isEmpty) { + if (keyWithoutSessionExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil } } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - Seq((keyExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) } } @@ -523,7 +532,7 @@ case class SessionWindowStateStoreRestoreExec( * the [[MultiValuesStateManager]]. */ case class SessionWindowStateStoreSaveExec( - keyExpressions: Seq[Attribute], + keyWithoutSessionExpressions: Seq[Attribute], sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, @@ -531,6 +540,8 @@ case class SessionWindowStateStoreSaveExec( child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -538,7 +549,7 @@ case class SessionWindowStateStoreSaveExec( child.execute().mapPartitionsWithMultiValuesStateManager( getStateInfo, - keyExpressions.toStructType, + keyWithoutSessionExpressions.toStructType, child.output.toStructType, indexOrdinal = None, sqlContext.sessionState, @@ -558,7 +569,7 @@ case class SessionWindowStateStoreSaveExec( val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") val commitTimeMs = longMetric("commitTimeMs") - val keyProjection = GenerateUnsafeProjection.generate(keyExpressions, + val keyProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, child.output) val alreadyRemovedKeys = new mutable.HashSet[UnsafeRow]() @@ -664,15 +675,15 @@ case class SessionWindowStateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - if (keyExpressions.isEmpty) { + if (keyWithoutSessionExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil } } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - Seq((keyExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { From 104df13bc90191d12a5300de70e6fdb5eef86293 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 17 Sep 2018 13:41:56 +0900 Subject: [PATCH 21/60] WIP Repackage & remove unnecessary field --- .../sql/execution/aggregate/UpdatingSessionExec.scala | 1 - .../UpdatingSessionIterator.scala | 2 +- .../sql/execution/python/AggregateInPandasExec.scala | 3 ++- .../MergingSortWithMultiValuesStateIterator.scala | 1 - .../sql/execution/streaming/statefulOperators.scala | 1 - .../MergingSortWithMultiValuesStateIteratorSuite.scala | 10 +++++----- .../streaming/UpdatingSessionIteratorSuite.scala | 1 + 7 files changed, 9 insertions(+), 10 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{streaming => aggregate}/UpdatingSessionIterator.scala (99%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala index 1167e302150c3..7375cc77e4e4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.streaming.UpdatingSessionIterator // FIXME: javadoc should provide precondition that input must be sorted // or both required child distribution as well as required child ordering should be presented diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala index 51a5b329fdeb9..f45b5c0783d35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.aggregate import scala.collection.mutable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 38a39e2252685..9286635a17c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -22,14 +22,15 @@ import java.io.File import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkEnv, TaskContext} + import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.aggregate.UpdatingSessionIterator import org.apache.spark.sql.execution.arrow.ArrowUtils -import org.apache.spark.sql.execution.streaming.UpdatingSessionIterator import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index 3a0cc32cca034..c8b53ba85f68c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -28,7 +28,6 @@ class MergingSortWithMultiValuesStateIterator( stateManager: MultiValuesStateManager, groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, - watermarkPredicateForData: Option[Predicate], inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 8ab72c1a9da86..6bfa1f5fdda80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -485,7 +485,6 @@ case class SessionWindowStateStoreRestoreExec( stateManager, keyWithoutSessionExpressions, sessionExpression, - watermarkPredicateForData, child.output).map { row => numOutputRows += 1 row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala index ffb66326c0031..929385322f977 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala @@ -52,7 +52,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { test("no row in input data") { withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) assert(!iterator.hasNext) } @@ -65,7 +65,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { appendRowToStateManager(manager, srow11, srow12) val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) assert(!iterator.hasNext) } @@ -80,7 +80,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3, row4) val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) rows.foreach { row => assert(iterator.hasNext) @@ -136,7 +136,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { row31, row32, row41, row42, srow41, srow42, srow51, row51, srow52, row52, srow53) val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, None, rowAttributes) + manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) expectedRowSequence.foreach { row => assert(iterator.hasNext) @@ -197,7 +197,7 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { val expectedRowSequence = List(srow1, row1, row2, srow2) val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, Seq.empty[Attribute], noKeySessionAttribute, None, noKeyRowAttributes) + manager, Seq.empty[Attribute], noKeySessionAttribute, noKeyRowAttributes) expectedRowSequence.foreach { row => assert(iterator.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala index 2471d809a183b..ea0fa651d2f40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.aggregate.UpdatingSessionIterator import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String From 8108fc5100db1e7f998ce782aaa3784d2b9b3faf Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 17 Sep 2018 15:12:50 +0900 Subject: [PATCH 22/60] WIP addressed UPDATE mode, but doesn't look like performant --- .../streaming/statefulOperators.scala | 86 +++++++++++++++---- .../streaming/EventTimeWatermarkSuite.scala | 76 ++++++++++++++-- 2 files changed, 136 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6bfa1f5fdda80..80687be7fef49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ @@ -572,6 +573,9 @@ case class SessionWindowStateStoreSaveExec( child.output) val alreadyRemovedKeys = new mutable.HashSet[UnsafeRow]() + var previousSessions: Seq[UnsafeRow] = null + + val ordering = TypeUtils.getInterpretedOrdering(child.output.toStructType) // FIXME: DEBUG val debugPartitionId = TaskContext.get().partitionId() @@ -587,22 +591,35 @@ case class SessionWindowStateStoreSaveExec( val keys = keyProjection(row) if (!alreadyRemovedKeys.contains(keys)) { - logWarning(s"DEBUG: partitionId $debugPartitionId - removing key ${keys} ...") + // This is necessary because MultiValuesStateManager doesn't guarantee + // stable ordering. + // The number of values for the given key is expected to be likely small, + // so sorting it here doesn't hurt. + previousSessions = stateManager.get(keys).toList + stateManager.removeKey(keys) alreadyRemovedKeys.add(keys) } - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - keys ${keyProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row ${row}") - - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row for key ${keys} row ${row} ...") - stateManager.append(keys, row) - numUpdatedStateRows += 1 + if (!previousSessions.exists(p => ordering.equiv(row, p))) { + // such session is not in previous session + + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - keys ${keyProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row ${row}") + + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row for key ${keys} row ${row} ...") + + stateManager.append(keys, row) + numUpdatedStateRows += 1 + } else { + logWarning(s"DEBUG: partitionId $debugPartitionId - add row for key ${keys} row ${row} but don't update count since it is already existed...") + stateManager.append(keys, row) + } } logWarning(s"DEBUG: partitionId $debugPartitionId - finished iterating...") @@ -632,20 +649,53 @@ case class SessionWindowStateStoreSaveExec( private val updatesStartTimeNs = System.nanoTime override protected def getNext(): InternalRow = { - if (iter.hasNext) { + var ret: InternalRow = null + + while (ret == null && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] val keys = keyProjection(row) if (!alreadyRemovedKeys.contains(keys)) { + // This is necessary because MultiValuesStateManager doesn't guarantee + // stable ordering. + // The number of values for the given key is expected to be likely small, + // so sorting it here doesn't hurt. + previousSessions = stateManager.get(keys).toList + stateManager.removeKey(keys) alreadyRemovedKeys.add(keys) } - stateManager.append(keys, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row - } else { + + if (!previousSessions.exists(p => ordering.equiv(row, p))) { + // such session is not in previous session + + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - keys ${keyProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row ${row}") + + logWarning(s"DEBUG: partitionId $debugPartitionId - adding row for key ${keys} row ${row} ...") + + stateManager.append(keys, row) + numUpdatedStateRows += 1 + + ret = row + } else { + logWarning(s"DEBUG: partitionId $debugPartitionId - add row for key ${keys} row ${row} but don't update count since it is already existed...") + stateManager.append(keys, row) + } + } + + if (ret == null && !iter.hasNext) { finished = true null + } else { + // !iter.hasNext && ret != null => can return ret, and next getNext() call will + // set finished = true + // iter.hasNext && (ret != null || ret == null) => not possible + ret } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index fb45a39c39ed0..cb68299b9b98b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -371,16 +371,18 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche CheckNewAnswer((1, 10, 16, 2, 21)), AddData(inputData, 17), // Advance watermark to 7 seconds - // sessions: key 1 => (10,16), (17,22) - // FIXME: subtract with previous state? or leave it as it is? - CheckNewAnswer((1, 10, 16, 2, 21), (1, 17, 22, 1, 17)), + // sessions: key 1 => (10,16), (17,22) <- updated + // updated: key 1 => (17,22) + CheckNewAnswer((1, 17, 22, 1, 17)), AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) + // updated: key 2 => (25,30) CheckNewAnswer((2, 25, 30, 1, 25)), AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) + // updated: key 3 => (35,40) // evicts: key 1 => (10,16), (17,22) CheckNewAnswer((3, 35, 40, 1, 35)), AddData(inputData, 10), // Should not emit anything as data less than watermark @@ -388,10 +390,65 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: key 2 => (25,30) / key 3 => (35,40) / key 4 => (40, 45) + // updated: key 4 => (40,45) CheckNewAnswer((4, 40, 45, 1, 40)) ) } + test("update mode - session - keys overlapped with sessions") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"valuegroup", $"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Update())( + AddData(inputData, 10, 11, 12, 13), + // Advance watermark to 3 seconds + // sessions: key 0 => (10,17) / key 1 => (11, 18) + CheckNewAnswer((0, 10, 17, 2, 22), (1, 11, 18, 2, 24)), + + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: key 0 => (10,17) / key 1 => (11, 22) + // updated: key 1 => (11,22) + CheckNewAnswer((1, 11, 22, 3, 41)), + + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30) + // updated: key 1 => (25,30) + CheckNewAnswer((1, 25, 30, 1, 25)), + + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30), (35,40) + // updated: key 1 => (35,40) + // evicts: key 1 => (10,17), (11,22) + CheckNewAnswer((1, 35, 40, 1, 35)), + + AddData(inputData, 27), + // don't advance watermark + // sessions: key 1 => (25,32), (35,40) + // updated: key 1 => (25,32) + CheckNewAnswer((1, 25, 32, 2, 52)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: key 0 => (40,45) / key 1 => (25,32), (35,40) + // updated: key 0 => (40,45) + CheckNewAnswer((0, 40, 45, 1, 40)) + ) + } + test("update mode - session - no key") { val inputData = MemoryStream[Int] @@ -413,23 +470,26 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 17), // Advance watermark to 7 seconds // sessions: (10,16), (17,22) - // FIXME: subtract with previous state? or leave it as it is? - CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), + // updated: (17,22) + CheckNewAnswer((17, 22, 1, 17)), AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: (10,16), (17,22), (25,30) - CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17), (25, 30, 1, 25)), + // updated: (25,30) + CheckNewAnswer((25, 30, 1, 25)), AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: (10,16), (17,22), (25,30), (35,40) + // updated: (35, 40) // evicts: (10,16), (17,22) - CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17), (25, 30, 1, 25), (35, 40, 1, 35)), + CheckNewAnswer((35, 40, 1, 35)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: (25,30), (35,45) - CheckNewAnswer((25, 30, 1, 25), (35, 45, 2, 75)) + // updated: (35, 45) + CheckNewAnswer((35, 45, 2, 75)) ) } From 35c8fef49fc2f3e75ae0fcad83d4e88907c684b0 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 17 Sep 2018 15:57:19 +0900 Subject: [PATCH 23/60] WIP remove FIXME since it is not relevant --- .../spark/sql/execution/streaming/statefulOperators.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 80687be7fef49..d0f9eced2640f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -642,9 +642,6 @@ case class SessionWindowStateStoreSaveExec( // Update and output modified rows from the MultiValuesStateManager. case Some(Update) => - // FIXME: it doesn't compare all output rows with current state rows, so all sessions - // including previous sessions will be provided - new NextIterator[InternalRow] { private val updatesStartTimeNs = System.nanoTime From 6b1d1e0407f52b63c19ece05555df3f008c1499a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 17 Sep 2018 17:50:29 +0900 Subject: [PATCH 24/60] WIP update numOutputRows for Append mode --- .../apache/spark/sql/execution/streaming/statefulOperators.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d0f9eced2640f..d54ce9f010aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -629,6 +629,7 @@ case class SessionWindowStateStoreSaveExec( val retIter = evictSessionsByWatermark(stateManager).map(_.value).map { row => logWarning(s"DEBUG: partitionId $debugPartitionId - evicting row ${row} ...") + numOutputRows += 1 row } From 86b3060ef30e7af59a77d522613db7d063c4649b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 18 Sep 2018 12:35:18 +0900 Subject: [PATCH 25/60] WIP apply aggregations when merging sessions --- .../sql/execution/aggregate/AggUtils.scala | 103 ++++---- .../aggregate/MergingSessionsExec.scala | 114 ++++++++ .../aggregate/MergingSessionsIterator.scala | 249 ++++++++++++++++++ .../aggregate/UpdatingSessionExec.scala | 78 ------ .../streaming/IncrementalExecution.scala | 16 +- .../streaming/statefulOperators.scala | 4 + .../sql/DataFrameSessionWindowingSuite.scala | 14 +- .../streaming/EventTimeWatermarkSuite.scala | 20 ++ 8 files changed, 454 insertions(+), 144 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index abd81c46f477f..299e1841bb8b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ @@ -98,7 +97,8 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) - val interExec: SparkPlan = mayAppendUpdatingSessionExec(groupingExpressions, partialAggregate) + val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions, + aggregateExpressions, partialAggregate) // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -118,6 +118,8 @@ object AggUtils { finalAggregate :: Nil } + // FIXME: now I'm not sure sessionization works with distinct (don't have imagination) + // FIXME: maybe it is not easy to add sessionization with distinct? def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], @@ -206,17 +208,19 @@ object AggUtils { val partialAggregateResult = groupingAttributes ++ mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - createAggregate( + + val partialDistinctAggregate = createAggregate( groupingExpressions = groupingAttributes, aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = partialAggregateResult, child = partialMergeAggregate) - } - val interExec: SparkPlan = mayAppendUpdatingSessionExec(groupingExpressions, - partialDistinctAggregate) + mayAppendMergingSessionExec(groupingExpressions, + mergeAggregateExpressions ++ distinctAggregateExpressions, + partialDistinctAggregate) + } // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { @@ -245,7 +249,7 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = interExec) + child = partialDistinctAggregate) } finalAndCompleteAggregate :: Nil @@ -352,11 +356,10 @@ object AggUtils { * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) * - SessionWindowStateStoreRestore (group: keys "without" session) * - merge input tuples with stored tuples (sessions) respecting sort order - * - UpdatingSessionExec - * - calculate session among tuples, and update all tuples to get correct session range + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge * - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted - * - PartialMerge (group: all keys) - * - now there is at most 1 tuple per group + * - now there is at most 1 tuple per group, key with session * - SessionWindowStateStoreSave (group: keys "without" session) * - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time) * - Complete (output the current result of the aggregation) @@ -395,23 +398,21 @@ object AggUtils { val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialAggregate) - val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, optRequiredChildDistribution = None, - optRequiredChildOrdering = None, restored) - - val partialMerged: SparkPlan = { + val mergedSessions = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( - requiredChildDistributionExpressions = - Some(groupingAttributes), + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = Some(restored.requiredChildDistribution), groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = updatedSession) + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored + ) } // Note: stateId and returnAllStates are filled in later with preparation rules @@ -423,7 +424,7 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, - partialMerged) + mergedSessions) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) @@ -444,37 +445,35 @@ object AggUtils { finalAndCompleteAggregate :: Nil } - private def mayAppendUpdatingSessionExec(groupingExpressions: Seq[NamedExpression], - maybeChildPlan: SparkPlan): SparkPlan = { - val interExec = groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { - case Some(sessionExpression) => - val groupWithoutSessionExpression = groupingExpressions.filterNot { p => - p.semanticEquals(sessionExpression) - } - - val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) - - val childDistribution = if (groupWithoutSessionExpression.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupWithoutSessionExpression) :: Nil - } - - val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) - .map(SortOrder(_, Ascending))) - - val updatedSession = UpdatingSessionExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, - optRequiredChildDistribution = Some(childDistribution), - optRequiredChildOrdering = Some(childOrdering), - maybeChildPlan) - - updatedSession + private def mayAppendMergingSessionExec( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + partialAggregate: SparkPlan): SparkPlan = { - case None => maybeChildPlan + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge)) + val aggAttributes = aggregateExpressions.map(_.resultAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression)) + val groupingWithoutSessionsAttributes = groupingWithoutSessionExpressions + .map(_.toAttribute) + + MergingSessionsExec( + requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes), + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggExpressions, + aggregateAttributes = aggAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + + case None => partialAggregate } - - interExec } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala new file mode 100644 index 0000000000000..1819f5966e969 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.TaskContext + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, NamedExpression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.SQLMetrics + +// FIXME: javadoc should provide precondition that input must be sorted +// or both required child distribution as well as required child ordering should be presented +// to guarantee input will be sorted +case class MergingSessionsExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + requiredChildDistributionOption: Option[Seq[Distribution]], + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) extends UnaryExecNode { + + val keyWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression)) + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def requiredChildDistribution: Seq[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => requiredChildDistributionOption match { + case Some(distributions) => distributions + case None => UnspecifiedDistribution :: Nil + } + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() + } else { + val outputIter = new MergingSessionsIterator( + partIndex, + groupingExpressions, + sessionExpression, + child.output, + iter, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), + numOutputRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + numOutputRows += 1 + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala new file mode 100644 index 0000000000000..3ce4b5c519a20 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, GenericInternalRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.{LongType, TimestampType} + +// FIXME: javadoc! +// FIXME: groupingExpressions should contain sessionExpression +class MergingSessionsIterator( + partIndex: Int, + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + valueAttributes: Seq[Attribute], + inputIterator: Iterator[InternalRow], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + numOutputRows: SQLMetric) + extends AggregationIterator( + partIndex, + groupingExpressions, + valueAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) { + + val groupingWithoutSession: Seq[NamedExpression] = + groupingExpressions.diff(Seq(sessionExpression)) + val groupingWithoutSessionAttributes: Seq[Attribute] = groupingWithoutSession.map(_.toAttribute) + + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: InternalRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericInternalRow(bufferRowSize) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + + val buffer = if (useUnsafeBuffer) { + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + unsafeProjection.apply(genericMutableBuffer) + } else { + genericMutableBuffer + } + initializeBuffer(buffer) + buffer + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states for sort based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + private[this] var currentGroupingKey: UnsafeRow = _ + + private[this] var currentSessionStart: Long = Long.MaxValue + + private[this] var currentSessionEnd: Long = Long.MinValue + + // The partition key of next partition. + private[this] var nextGroupingKey: UnsafeRow = _ + + private[this] var nextGroupingSessionStart: Long = Long.MaxValue + + private[this] var nextGroupingSessionEnd: Long = Long.MinValue + + // The first row of next partition. + private[this] var firstRowInNextGroup: InternalRow = _ + + // Indicates if we has new group of rows from the sorted input iterator + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer + + private[this] val groupingWithoutSessionProjection: UnsafeProjection = + UnsafeProjection.create(groupingWithoutSession, valueAttributes) + + private[this] val sessionIndex = resultExpressions.indexOf(sessionExpression) + + private[this] val sessionProjection: UnsafeProjection = + UnsafeProjection.create(Seq(sessionExpression), valueAttributes) + + protected def initialize(): Unit = { + if (inputIterator.hasNext) { + initializeBuffer(sortBasedAggregationBuffer) + val inputRow = inputIterator.next() + nextGroupingKey = groupingWithoutSessionProjection(inputRow).copy() + val session = sessionProjection(inputRow) + val sessionRow = session.getStruct(0, 2) + nextGroupingSessionStart = sessionRow.getLong(0) + nextGroupingSessionEnd = sessionRow.getLong(1) + firstRowInNextGroup = inputRow.copy() + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + + /** Processes rows in the current group. It will stop when it find a new group. */ + protected def processCurrentSortedGroup(): Unit = { + currentGroupingKey = nextGroupingKey + currentSessionStart = nextGroupingSessionStart + currentSessionEnd = nextGroupingSessionEnd + + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + while (!findNextPartition && inputIterator.hasNext) { + // Get the grouping key. + val currentRow = inputIterator.next() + val groupingKey = groupingWithoutSessionProjection(currentRow) + + val session = sessionProjection(currentRow) + val sessionRow = session.getStruct(0, 2) + val sessionStart = sessionRow.getLong(0) + val sessionEnd = sessionRow.getLong(1) + + // Check if the current row belongs the current input row. + if (currentGroupingKey == groupingKey) { + if (sessionStart < currentSessionStart) { + throw new IllegalArgumentException("Input iterator is not sorted based on session!") + } else if (sessionStart <= currentSessionEnd) { + // expanding session length if needed + expandEndOfCurrentSession(sessionEnd) + processRow(sortBasedAggregationBuffer, currentRow) + } else { + // We find a new group. + findNextPartition = true + startNewSession(currentRow, groupingKey, sessionStart, sessionEnd) + } + } else { + // We find a new group. + findNextPartition = true + startNewSession(currentRow, groupingKey, sessionStart, sessionEnd) + } + } + + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + sortedInputHasNewGroup = false + } + } + + private def startNewSession(currentRow: InternalRow, groupingKey: UnsafeRow, sessionStart: Long, + sessionEnd: Long): Unit = { + nextGroupingKey = groupingKey.copy() + nextGroupingSessionStart = sessionStart + nextGroupingSessionEnd = sessionEnd + firstRowInNextGroup = currentRow.copy() + } + + private def expandEndOfCurrentSession(sessionEnd: Long): Unit = { + if (sessionEnd > currentSessionEnd) { + currentSessionEnd = sessionEnd + } + } + + /////////////////////////////////////////////////////////////////////////// + // Iterator's public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = sortedInputHasNewGroup + + override final def next(): UnsafeRow = { + if (hasNext) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + + val groupingKey = generateGroupingKey() + + val outputRow = generateOutput(groupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + initializeBuffer(sortBasedAggregationBuffer) + numOutputRows += 1 + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def generateGroupingKey(): UnsafeRow = { + val sessionStruct = CreateNamedStruct( + Literal("start") :: + PreciseTimestampConversion( + Literal(currentSessionStart, LongType), LongType, TimestampType) :: + Literal("end") :: + PreciseTimestampConversion( + Literal(currentSessionEnd, LongType), LongType, TimestampType) :: + Nil) + + val convertedGroupingExpressions = groupingExpressions.map { x => + if (x.semanticEquals(sessionExpression)) { + sessionStruct + } else { + BindReferences.bindReference[Expression](x, groupingWithoutSessionAttributes) + } + } + + val proj = GenerateUnsafeProjection.generate(convertedGroupingExpressions, + groupingWithoutSessionAttributes) + proj(currentGroupingKey) + } + + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + initializeBuffer(sortBasedAggregationBuffer) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala deleted file mode 100644 index 7375cc77e4e4b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.TaskContext - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} - -// FIXME: javadoc should provide precondition that input must be sorted -// or both required child distribution as well as required child ordering should be presented -// to guarantee input will be sorted -case class UpdatingSessionExec( - keyExpressions: Seq[Attribute], - sessionExpression: Attribute, - optRequiredChildDistribution: Option[Seq[Distribution]], - optRequiredChildOrdering: Option[Seq[Seq[SortOrder]]], - child: SparkPlan) extends UnaryExecNode { - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val newIter = new UpdatingSessionIterator(iter, keyExpressions, sessionExpression, - child.output) - - val debugIter = newIter.map { row => - val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - - // FIXME: remove - val debugPartitionId = TaskContext.get().partitionId() - - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - keys ${keysProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row ${row}") - - row - } - - debugIter - } - } - - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = optRequiredChildDistribution match { - case Some(distribution) => distribution - case None => super.requiredChildDistribution - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = optRequiredChildOrdering match { - case Some(ordering) => ordering - case None => super.requiredChildOrdering - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 0ee41ea8fc077..3a55cfab5dded 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -121,8 +121,7 @@ class IncrementalExecution( case SessionWindowStateStoreSaveExec(keys, session, None, None, None, UnaryExecNode(agg, - UnaryExecNode(agg2, - SessionWindowStateStoreRestoreExec(_, _, None, None, child)))) => + SessionWindowStateStoreRestoreExec(_, _, None, None, child))) => val aggStateInfo = nextStatefulOperationStateInfo SessionWindowStateStoreSaveExec( keys, @@ -131,13 +130,12 @@ class IncrementalExecution( Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( - agg2.withNewChildren( - SessionWindowStateStoreRestoreExec( - keys, - session, - Some(aggStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs), - child) :: Nil) :: Nil)) + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + Some(offsetSeqMetadata.batchWatermarkMs), + child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => StreamingDeduplicateExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d54ce9f010aaa..e208ee4c7dc15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -514,6 +514,10 @@ case class SessionWindowStateStoreRestoreExec( override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = { + (keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending)) + } + override def requiredChildDistribution: Seq[Distribution] = { if (keyWithoutSessionExpressions.isEmpty) { AllTuples :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 11930e67a6e41..4e9b41258f21c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -107,12 +107,16 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // a => (19:39:34 ~ 19:39:50) // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) + val df2 = df.groupBy(session($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "counts", "sum") + + df2.explain(extended = true) + checkAnswer( - df.groupBy(session($"time", "10 seconds"), 'id) - .agg(count("*").as("counts"), sum("value").as("sum")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "counts", "sum"), + df2, Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index cb68299b9b98b..9437a609a1c9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -296,6 +296,10 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // sessions: key 1 => (10,16), (17,23) CheckNewAnswer(), AddData(inputData, 25), + AssertOnQuery { se => + se.explain(true) + true + }, // Advance watermark to 15 seconds // sessions: key 1 => (10,16), (17,23) / key 2 => (25,30) CheckNewAnswer(), @@ -333,6 +337,10 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 7 seconds // sessions: (10,16), (17,23) CheckNewAnswer(), + AssertOnQuery { se => + se.explain(true) + true + }, AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: (10,16), (17,23), (25,30) @@ -374,6 +382,10 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // sessions: key 1 => (10,16), (17,22) <- updated // updated: key 1 => (17,22) CheckNewAnswer((1, 17, 22, 1, 17)), + AssertOnQuery { se => + se.explain(true) + true + }, AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) @@ -424,6 +436,10 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30) // updated: key 1 => (25,30) CheckNewAnswer((1, 25, 30, 1, 25)), + AssertOnQuery { se => + se.explain(true) + true + }, AddData(inputData, 35), // Advance watermark to 25 seconds @@ -477,6 +493,10 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // sessions: (10,16), (17,22), (25,30) // updated: (25,30) CheckNewAnswer((25, 30, 1, 25)), + AssertOnQuery { se => + se.explain(true) + true + }, AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: (10,16), (17,22), (25,30), (35,40) From f7c2deb5a5166bdb8d46c05642317867313af8aa Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 18 Sep 2018 14:28:54 +0900 Subject: [PATCH 26/60] WIP simplify the code a bit --- .../streaming/statefulOperators.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e208ee4c7dc15..265200fa166a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -429,7 +429,6 @@ case class StateStoreSaveExec( } // FIXME: javadoc! -// FIXME: keyExpressions shouldn't have 'session': otherwise we should exclude it... case class SessionWindowStateStoreRestoreExec( keyWithoutSessionExpressions: Seq[Attribute], sessionExpression: Attribute, @@ -576,10 +575,8 @@ case class SessionWindowStateStoreSaveExec( val keyProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, child.output) - val alreadyRemovedKeys = new mutable.HashSet[UnsafeRow]() - var previousSessions: Seq[UnsafeRow] = null - - val ordering = TypeUtils.getInterpretedOrdering(child.output.toStructType) + var currentKey: UnsafeRow = null + var previousSessions: scala.collection.immutable.Set[UnsafeRow] = null // FIXME: DEBUG val debugPartitionId = TaskContext.get().partitionId() @@ -594,18 +591,18 @@ case class SessionWindowStateStoreSaveExec( val row = iter.next().asInstanceOf[UnsafeRow] val keys = keyProjection(row) - if (!alreadyRemovedKeys.contains(keys)) { + if (currentKey == null || currentKey != keys) { + currentKey = keys // This is necessary because MultiValuesStateManager doesn't guarantee // stable ordering. // The number of values for the given key is expected to be likely small, - // so sorting it here doesn't hurt. - previousSessions = stateManager.get(keys).toList + // so listing it here doesn't hurt. + previousSessions = stateManager.get(keys).toSet stateManager.removeKey(keys) - alreadyRemovedKeys.add(keys) } - if (!previousSessions.exists(p => ordering.equiv(row, p))) { + if (!previousSessions.contains(row)) { // such session is not in previous session val sessionProjection = GenerateUnsafeProjection.generate( @@ -656,18 +653,18 @@ case class SessionWindowStateStoreSaveExec( while (ret == null && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] val keys = keyProjection(row) - if (!alreadyRemovedKeys.contains(keys)) { + if (currentKey == null || currentKey != keys) { + currentKey = keys // This is necessary because MultiValuesStateManager doesn't guarantee // stable ordering. // The number of values for the given key is expected to be likely small, - // so sorting it here doesn't hurt. - previousSessions = stateManager.get(keys).toList + // so listing it here doesn't hurt. + previousSessions = stateManager.get(keys).toSet stateManager.removeKey(keys) - alreadyRemovedKeys.add(keys) } - if (!previousSessions.exists(p => ordering.equiv(row, p))) { + if (!previousSessions.contains(row)) { // such session is not in previous session val sessionProjection = GenerateUnsafeProjection.generate( @@ -697,6 +694,7 @@ case class SessionWindowStateStoreSaveExec( // !iter.hasNext && ret != null => can return ret, and next getNext() call will // set finished = true // iter.hasNext && (ret != null || ret == null) => not possible + numOutputRows += 1 ret } } From a81616b91679f52821785f75baeede7d4757f2ee Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 18 Sep 2018 14:56:41 +0900 Subject: [PATCH 27/60] WIP address batch distinct query for sessionization --- .../sql/execution/aggregate/AggUtils.scala | 40 ++++++++-- .../aggregate/UpdatingSessionExec.scala | 78 +++++++++++++++++++ .../sql/DataFrameSessionWindowingSuite.scala | 68 ++++++++++++++++ 3 files changed, 178 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 299e1841bb8b8..85e1af959601a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ @@ -118,8 +119,6 @@ object AggUtils { finalAggregate :: Nil } - // FIXME: now I'm not sure sessionization works with distinct (don't have imagination) - // FIXME: maybe it is not easy to add sessionization with distinct? def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], @@ -127,6 +126,8 @@ object AggUtils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child) + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be @@ -153,7 +154,7 @@ object AggUtils { aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) + child = maySessionChild) } // 2. Create an Aggregate Operator for partial merge aggregations. @@ -209,17 +210,13 @@ object AggUtils { mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialDistinctAggregate = createAggregate( + createAggregate( groupingExpressions = groupingAttributes, aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = partialAggregateResult, child = partialMergeAggregate) - - mayAppendMergingSessionExec(groupingExpressions, - mergeAggregateExpressions ++ distinctAggregateExpressions, - partialDistinctAggregate) } // 4. Create an Aggregate Operator for the final aggregation. @@ -445,6 +442,33 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + private def mayAppendUpdatingSessionExec(groupingExpressions: Seq[NamedExpression], + maybeChildPlan: SparkPlan): SparkPlan = + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val groupWithoutSessionExpression = groupingExpressions.filterNot { + p => p.semanticEquals(sessionExpression) + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val childDistribution = if (groupWithoutSessionExpression.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupWithoutSessionExpression) :: Nil + } + val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + val updatedSession = UpdatingSessionExec( + groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + optRequiredChildDistribution = Some(childDistribution), + optRequiredChildOrdering = Some(childOrdering), + maybeChildPlan) + updatedSession + case None => maybeChildPlan + } + private def mayAppendMergingSessionExec( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala new file mode 100644 index 0000000000000..2e5d753debc39 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.TaskContext + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} + +// FIXME: javadoc should provide precondition that input must be sorted +// or both required child distribution as well as required child ordering should be presented +// to guarantee input will be sorted +case class UpdatingSessionExec( + keyExpressions: Seq[Attribute], + sessionExpression: Attribute, + optRequiredChildDistribution: Option[Seq[Distribution]], + optRequiredChildOrdering: Option[Seq[Seq[SortOrder]]], + child: SparkPlan) extends UnaryExecNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val newIter = new UpdatingSessionIterator(iter, keyExpressions, sessionExpression, + child.output) + + val debugIter = newIter.map { row => + val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate( + Seq(sessionExpression), child.output) + val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) + + // FIXME: remove + val debugPartitionId = TaskContext.get().partitionId() + + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - keys ${keysProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - session ${sessionProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row (proj) ${rowProjection(row)}") + logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row ${row}") + + row + } + + debugIter + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = optRequiredChildDistribution match { + case Some(distribution) => distribution + case None => super.requiredChildDistribution + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = optRequiredChildOrdering match { + case Some(ordering) => ordering + case None => super.requiredChildOrdering + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 4e9b41258f21c..d11418ef41f93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -92,6 +92,74 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ) } + test("session window groupBy with multiple keys statement - one distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + val df2 = df.groupBy(session($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sumDistinct("value").as("sum")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "counts", "sum") + + df2.explain(extended = true) + + checkAnswer( + df2, + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 2) + ) + ) + } + + test("session window groupBy with multiple keys statement - two distincts") { + val df = Seq( + ("2016-03-27 19:39:34", 1, 2, "a"), + ("2016-03-27 19:39:39", 1, 2, "a"), + ("2016-03-27 19:39:56", 2, 4, "a"), + ("2016-03-27 19:40:04", 2, 4, "a"), + ("2016-03-27 19:39:27", 4, 8, "b")).toDF("time", "value", "value2", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + val df2 = df.groupBy(session($"time", "10 seconds"), 'id) + .agg(sumDistinct("value").as("sum"), sumDistinct("value2").as("sum2")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "sum", "sum2") + + df2.explain(extended = true) + + checkAnswer( + df2, + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + test("session window groupBy with multiple keys statement - keys overlapped with sessions") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), From 013785ded4db1671cce258ad8c0b8879cacc2f8d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 18 Sep 2018 15:17:15 +0900 Subject: [PATCH 28/60] WIP remove debug statements for test code --- .../sql/DataFrameSessionWindowingSuite.scala | 44 +++++++------------ .../streaming/EventTimeWatermarkSuite.scala | 40 ++++++++--------- 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index d11418ef41f93..0c5d858330d1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -107,17 +107,12 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) // key "b" => (19:39:27 ~ 19:39:37) - val df2 = df.groupBy(session($"time", "10 seconds"), 'id) - .agg(count("*").as("counts"), sumDistinct("value").as("sum")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "counts", "sum") - - df2.explain(extended = true) - checkAnswer( - df2, - + df.groupBy(session($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sumDistinct("value").as("sum")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "counts", "sum"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1), @@ -141,17 +136,12 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) // key "b" => (19:39:27 ~ 19:39:37) - val df2 = df.groupBy(session($"time", "10 seconds"), 'id) - .agg(sumDistinct("value").as("sum"), sumDistinct("value2").as("sum2")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "sum", "sum2") - - df2.explain(extended = true) - checkAnswer( - df2, - + df.groupBy(session($"time", "10 seconds"), 'id) + .agg(sumDistinct("value").as("sum"), sumDistinct("value2").as("sum2")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "sum", "sum2"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8), Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2), @@ -175,16 +165,12 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // a => (19:39:34 ~ 19:39:50) // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) - val df2 = df.groupBy(session($"time", "10 seconds"), 'id) - .agg(count("*").as("counts"), sum("value").as("sum")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "counts", "sum") - - df2.explain(extended = true) - checkAnswer( - df2, + df.groupBy(session($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session.start".asc) + .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", + "counts", "sum"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 9437a609a1c9e..f4f94a7a4a35e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -291,25 +291,26 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11), // sessions: key 1 => (10,16) CheckNewAnswer(), + AddData(inputData, 17), // Advance watermark to 7 seconds // sessions: key 1 => (10,16), (17,23) CheckNewAnswer(), + AddData(inputData, 25), - AssertOnQuery { se => - se.explain(true) - true - }, // Advance watermark to 15 seconds // sessions: key 1 => (10,16), (17,23) / key 2 => (25,30) CheckNewAnswer(), + AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) // evicts: key 1 => (10,16), (17,22) CheckNewAnswer((1, 10, 16, 2, 21), (1, 17, 22, 1, 17)), + AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), + AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: key 2 => (25,30) / key 3 => (35,45) @@ -333,25 +334,26 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11), // sessions: (10,16) CheckNewAnswer(), + AddData(inputData, 17), // Advance watermark to 7 seconds // sessions: (10,16), (17,23) CheckNewAnswer(), - AssertOnQuery { se => - se.explain(true) - true - }, + AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: (10,16), (17,23), (25,30) CheckNewAnswer(), + AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: (10,16), (17,22), (25,30), (35,40) // evicts: (10,16), (17,22) CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), + AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), + AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: (25,30) / (35,45) @@ -377,28 +379,29 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 1 seconds // sessions: key 1 => (10,16) CheckNewAnswer((1, 10, 16, 2, 21)), + AddData(inputData, 17), // Advance watermark to 7 seconds // sessions: key 1 => (10,16), (17,22) <- updated // updated: key 1 => (17,22) CheckNewAnswer((1, 17, 22, 1, 17)), - AssertOnQuery { se => - se.explain(true) - true - }, + AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) // updated: key 2 => (25,30) CheckNewAnswer((2, 25, 30, 1, 25)), + AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) // updated: key 3 => (35,40) // evicts: key 1 => (10,16), (17,22) CheckNewAnswer((3, 35, 40, 1, 35)), + AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), + AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: key 2 => (25,30) / key 3 => (35,40) / key 4 => (40, 45) @@ -436,10 +439,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30) // updated: key 1 => (25,30) CheckNewAnswer((1, 25, 30, 1, 25)), - AssertOnQuery { se => - se.explain(true) - true - }, AddData(inputData, 35), // Advance watermark to 25 seconds @@ -483,28 +482,29 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche // Advance watermark to 1 seconds // sessions: (10,16) CheckNewAnswer((10, 16, 2, 21)), + AddData(inputData, 17), // Advance watermark to 7 seconds // sessions: (10,16), (17,22) // updated: (17,22) CheckNewAnswer((17, 22, 1, 17)), + AddData(inputData, 25), // Advance watermark to 15 seconds // sessions: (10,16), (17,22), (25,30) // updated: (25,30) CheckNewAnswer((25, 30, 1, 25)), - AssertOnQuery { se => - se.explain(true) - true - }, + AddData(inputData, 35), // Advance watermark to 25 seconds // sessions: (10,16), (17,22), (25,30), (35,40) // updated: (35, 40) // evicts: (10,16), (17,22) CheckNewAnswer((35, 40, 1, 35)), + AddData(inputData, 10), // Should not emit anything as data less than watermark CheckNewAnswer(), + AddData(inputData, 40), // Advance watermark to 30 seconds // sessions: (25,30), (35,45) From 37fffefd1362493039ed8c1188f5cc2cfcdfc743 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 18 Sep 2018 15:29:21 +0900 Subject: [PATCH 29/60] WIP remove debug informations --- .../aggregate/UpdatingSessionExec.scala | 21 +---- .../aggregate/UpdatingSessionIterator.scala | 25 ------ ...gingSortWithMultiValuesStateIterator.scala | 8 -- .../streaming/statefulOperators.scala | 90 ++----------------- 4 files changed, 9 insertions(+), 135 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala index 2e5d753debc39..4592b72be3d14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala @@ -38,27 +38,8 @@ case class UpdatingSessionExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - val newIter = new UpdatingSessionIterator(iter, keyExpressions, sessionExpression, + new UpdatingSessionIterator(iter, keyExpressions, sessionExpression, child.output) - - val debugIter = newIter.map { row => - val keysProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - - // FIXME: remove - val debugPartitionId = TaskContext.get().partitionId() - - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - keys ${keysProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - updated session row - row ${row}") - - row - } - - debugIter } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala index f45b5c0783d35..6714ac86b73ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala @@ -65,11 +65,6 @@ class UpdatingSessionIterator( assertIteratorNotCorrupted() if (returnRowsIter != null && returnRowsIter.hasNext) { - System.err.println(s"DEBUG: has remaining returnRowsIter - not going into loop - " + - s"current session - currentKeys: $currentKeys / " + - s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + - s"currentSessionEnd: $currentSessionEnd") - return returnRowsIter.next() } @@ -104,9 +99,6 @@ class UpdatingSessionIterator( // expanding session length if needed expandEndOfCurrentSession(sessionEnd) currentRows += row - - System.err.println(s"DEBUG: - adding row: $row / currentRows: $currentRows") - } else { closeCurrentSession(keyChanged = false) startNewSession(row, keys, sessionStart, sessionEnd) @@ -120,10 +112,6 @@ class UpdatingSessionIterator( closeCurrentSession(keyChanged = false) } - System.err.println(s"DEBUG: end of loop - current session - currentKeys: $currentKeys / " + - s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + - s"currentSessionEnd: $currentSessionEnd") - // here returnRowsIter should be able to provide at least one row require(returnRowsIter != null && returnRowsIter.hasNext) @@ -147,10 +135,6 @@ class UpdatingSessionIterator( currentSessionEnd = sessionEnd currentRows.clear() currentRows += row - - System.err.println(s"DEBUG: started new session - currentKeys: $currentKeys / " + - s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + - s"currentSessionEnd: $currentSessionEnd") } private def handleBrokenPreconditionForSort(): Unit = { @@ -159,10 +143,6 @@ class UpdatingSessionIterator( } private def closeCurrentSession(keyChanged: Boolean): Unit = { - System.err.println(s"DEBUG: closing current session - currentKeys: $currentKeys / " + - s"currentRows: $currentRows / currentSessionStart: $currentSessionStart / " + - s"currentSessionEnd: $currentSessionEnd") - val sessionStruct = CreateNamedStruct( Literal("start") :: PreciseTimestampConversion( @@ -195,11 +175,6 @@ class UpdatingSessionIterator( returnRowsIter = returnRows.iterator } - // FIXME: DEBUG - val (rIter, tmpReturnRowsIter) = returnRowsIter.duplicate - returnRowsIter = rIter - System.err.println(s"DEBUG: closing current session - return rows iter will return: ${tmpReturnRowsIter.toList}") - if (keyChanged) processedKeys.add(currentKeys) currentKeys = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index c8b53ba85f68c..90a34227a015e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -84,7 +84,6 @@ class MergingSortWithMultiValuesStateIterator( // state row cannot advance to row in input, so state row should be lower false } else { - System.err.println(s"DEBUG: WARN - comparing row ${currentRow} and state row ${currentStateRow}") currentRow.sessionStart < currentStateRow.sessionStart } } @@ -102,22 +101,18 @@ class MergingSortWithMultiValuesStateIterator( } } - System.err.println(s"DEBUG: WARN - returning row ${ret.row} for iterator") - ret.row } private def mayFillCurrentRow(): Unit = { if (iter.hasNext) { currentRow = SessionRowInformation.of(iter.next()) - System.err.println(s"DEBUG - filling current row... current row: $currentRow") } } private def mayFillCurrentStateRow(): Unit = { if (currentStateIter != null && currentStateIter.hasNext) { currentStateRow = SessionRowInformation.of(currentStateIter.next()) - System.err.println(s"DEBUG - filling state row... current state row: $currentStateRow") } else { currentStateIter = null @@ -137,9 +132,6 @@ class MergingSortWithMultiValuesStateIterator( currentStateFetchedKey = currentRow.keys if (currentStateIter.hasNext) { currentStateRow = SessionRowInformation.of(currentStateIter.next()) - System.err.println(s"DEBUG: read data ${currentStateRow.row} from state for key ${currentRow.keys}") - } else { - System.err.println(s"DEBUG: no state data for key ${currentRow.keys}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 265200fa166a9..f7b3b7d98e27b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -450,37 +450,13 @@ case class SessionWindowStateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (stateManager, iter) => - // FIXME: remove - val debugPartitionId = TaskContext.get().partitionId() - - val debugIter = iter.map { row => - val keysProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, - child.output) - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - - logWarning(s"DEBUG: partitionId $debugPartitionId - input row - keys ${keysProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - input row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - input row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - input row - row $row") - - row - } - // We need to filter out outdated inputs val filteredIterator = watermarkPredicateForData match { - case Some(predicate) => debugIter.filter((row: InternalRow) => { - val pr = !predicate.eval(row) - if (!pr) { - logWarning(s"DEBUG - evicting input due to watermark... $row") - } - pr - }) - case None => debugIter + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter } - val mergedIter = new MergingSortWithMultiValuesStateIterator( + new MergingSortWithMultiValuesStateIterator( filteredIterator, stateManager, keyWithoutSessionExpressions, @@ -489,23 +465,6 @@ case class SessionWindowStateStoreRestoreExec( numOutputRows += 1 row } - - val debugMergedIter = mergedIter.map { row => - val keysProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, - child.output) - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - - logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - keys ${keysProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - merged row - row ${row}") - - row - } - - debugMergedIter } } @@ -602,41 +561,24 @@ case class SessionWindowStateStoreSaveExec( stateManager.removeKey(keys) } + stateManager.append(keys, row) + if (!previousSessions.contains(row)) { // such session is not in previous session - - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - keys ${keyProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row ${row}") - - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row for key ${keys} row ${row} ...") - - stateManager.append(keys, row) numUpdatedStateRows += 1 - } else { - logWarning(s"DEBUG: partitionId $debugPartitionId - add row for key ${keys} row ${row} but don't update count since it is already existed...") - stateManager.append(keys, row) } } - - logWarning(s"DEBUG: partitionId $debugPartitionId - finished iterating...") } val removalStartTimeNs = System.nanoTime val retIter = evictSessionsByWatermark(stateManager).map(_.value).map { row => - logWarning(s"DEBUG: partitionId $debugPartitionId - evicting row ${row} ...") numOutputRows += 1 row } CompletionIterator[InternalRow, Iterator[InternalRow]](retIter, { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - logWarning(s"DEBUG: partitionId $debugPartitionId - committing...") commitTimeMs += timeTakenMs { stateManager.commit() } setStoreMetrics(stateManager) }) @@ -664,26 +606,12 @@ case class SessionWindowStateStoreSaveExec( stateManager.removeKey(keys) } + stateManager.append(keys, row) + if (!previousSessions.contains(row)) { // such session is not in previous session - - val sessionProjection = GenerateUnsafeProjection.generate( - Seq(sessionExpression), child.output) - val rowProjection = GenerateUnsafeProjection.generate(child.output, child.output) - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - keys ${keyProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - session ${sessionProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row (proj) ${rowProjection(row)}") - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row - row ${row}") - - logWarning(s"DEBUG: partitionId $debugPartitionId - adding row for key ${keys} row ${row} ...") - - stateManager.append(keys, row) numUpdatedStateRows += 1 - ret = row - } else { - logWarning(s"DEBUG: partitionId $debugPartitionId - add row for key ${keys} row ${row} but don't update count since it is already existed...") - stateManager.append(keys, row) } } @@ -705,9 +633,7 @@ case class SessionWindowStateStoreSaveExec( // Remove old aggregates if watermark specified allRemovalsTimeMs += timeTakenMs { // fully consume iterator to let removal take effect - evictSessionsByWatermark(stateManager).map { rowPair => - System.err.println(s"DEBUG: evicting row ${rowPair.value}") - }.toList + evictSessionsByWatermark(stateManager).toList } commitTimeMs += timeTakenMs { stateManager.commit() } setStoreMetrics(stateManager) From df95e72630254bbe2b4757442c0aa922a32205a0 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 19 Sep 2018 05:51:50 +0900 Subject: [PATCH 30/60] WIP port Sessionization example to UT of session window --- .../streaming/EventTimeWatermarkSuite.scala | 60 ++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index f4f94a7a4a35e..71d51f1ab03e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.functions.{count, session, sum, window} +import org.apache.spark.sql.functions.{count, max, session, sum, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.util.Utils @@ -464,6 +464,64 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("StructuredSessionization - example") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents"), max("eventTime").as("max_timestamp")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Update())( + AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), + // Advance watermark to 1 seconds + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("spark", 10, 20, 10, 1), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 21, 10, 1) + ), + + AddData(inputData, ("spark streaming", 15L)), + // Advance watermark to 5 seconds + CheckNewAnswer(("spark", 10, 25, 15, 2), ("streaming", 11, 25, 14, 2)), + + AddData(inputData, ("hello world", 25L)), + // Advance watermark to 15 seconds + // ("hello", 10L) and ("world", 11L) are not evicted yet + // but new input rows doesn't fall into existing sessions + CheckNewAnswer(("hello", 25, 35, 10, 1), ("world", 25, 35, 10, 1)), + + AddData(inputData, ("hello world", 35L)), + // Advance watermark to 25 seconds + // ("hello", 10L) and ("world", 11L) are evicted + CheckNewAnswer(("hello", 25, 45, 20, 2), ("world", 25, 45, 20, 2)), + + AddData(inputData, ("hello world", 10L)), + // Should not emit anything as data less than watermark + // FIXME: this works but why watermark doesn't effect when we don't add 35L? + // FIXME: investigate what's happening here... + CheckNewAnswer(), + + AddData(inputData, ("hello apache spark", 40L)), + // Advance watermark to 30 seconds + CheckNewAnswer(("hello", 25, 50, 25, 3), ("apache", 40, 50, 10, 1), ("spark", 40, 50, 10, 1)) + ) + } + test("update mode - session - no key") { val inputData = MemoryStream[Int] From 16d6421e9b8ffb34fbab33ab3fbd98af4e7ff04e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 19 Sep 2018 06:38:35 +0900 Subject: [PATCH 31/60] WIP remove unnecessary thing --- .../apache/spark/sql/streaming/EventTimeWatermarkSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 71d51f1ab03e1..0888287e08781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -479,7 +479,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val sessionUpdates = events .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) - .agg(count("*").as("numEvents"), max("eventTime").as("max_timestamp")) + .agg(count("*").as("numEvents")) .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", "numEvents") From dc4330061dc626d12386c0eb752da330bf525357 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 19 Sep 2018 09:25:01 +0900 Subject: [PATCH 32/60] WIP fix all the issues with sessionization example UTs --- .../streaming/statefulOperators.scala | 23 +- .../streaming/EventTimeWatermarkSuite.scala | 585 +++++++++--------- 2 files changed, 303 insertions(+), 305 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index f7b3b7d98e27b..de3ef8788728e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -535,10 +535,10 @@ case class SessionWindowStateStoreSaveExec( child.output) var currentKey: UnsafeRow = null - var previousSessions: scala.collection.immutable.Set[UnsafeRow] = null + var previousSessions: List[UnsafeRow] = null - // FIXME: DEBUG - val debugPartitionId = TaskContext.get().partitionId() + val keyOrdering = TypeUtils.getInterpretedOrdering(keyWithoutSessionExpressions.toStructType) + val valueOrdering = TypeUtils.getInterpretedOrdering(child.output.toStructType) // assuming late events were dropped from MergingSortWithMultiValuesStateIterator outputMode match { @@ -550,20 +550,21 @@ case class SessionWindowStateStoreSaveExec( val row = iter.next().asInstanceOf[UnsafeRow] val keys = keyProjection(row) - if (currentKey == null || currentKey != keys) { + if (currentKey == null || !keyOrdering.equiv(currentKey, keys)) { currentKey = keys + // This is necessary because MultiValuesStateManager doesn't guarantee // stable ordering. // The number of values for the given key is expected to be likely small, // so listing it here doesn't hurt. - previousSessions = stateManager.get(keys).toSet + previousSessions = stateManager.get(keys).toList stateManager.removeKey(keys) } stateManager.append(keys, row) - if (!previousSessions.contains(row)) { + if (!previousSessions.exists(p => valueOrdering.equiv(row, p))) { // such session is not in previous session numUpdatedStateRows += 1 } @@ -595,20 +596,22 @@ case class SessionWindowStateStoreSaveExec( while (ret == null && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] val keys = keyProjection(row) - if (currentKey == null || currentKey != keys) { - currentKey = keys + + if (currentKey == null || !keyOrdering.equiv(currentKey, keys)) { + currentKey = keys.copy() + // This is necessary because MultiValuesStateManager doesn't guarantee // stable ordering. // The number of values for the given key is expected to be likely small, // so listing it here doesn't hurt. - previousSessions = stateManager.get(keys).toSet + previousSessions = stateManager.get(keys).toList stateManager.removeKey(keys) } stateManager.append(keys, row) - if (!previousSessions.contains(row)) { + if (!previousSessions.exists(p => valueOrdering.equiv(row, p))) { // such session is not in previous session numUpdatedStateRows += 1 ret = row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 0888287e08781..f719576eb27ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -276,301 +276,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } - test("append mode - session") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .selectExpr("*", "CAST(value / 10 AS INT) AS valuegroup") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"valuegroup", $"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11), // sessions: key 1 => (10,16) - CheckNewAnswer(), - - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: key 1 => (10,16), (17,23) - CheckNewAnswer(), - - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: key 1 => (10,16), (17,23) / key 2 => (25,30) - CheckNewAnswer(), - - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) - // evicts: key 1 => (10,16), (17,22) - CheckNewAnswer((1, 10, 16, 2, 21), (1, 17, 22, 1, 17)), - - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: key 2 => (25,30) / key 3 => (35,45) - // evicts: key 2 => (25,30) - CheckNewAnswer((2, 25, 30, 1, 25)) - ) - } - - test("append mode - session - no key") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .selectExpr("*") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11), // sessions: (10,16) - CheckNewAnswer(), - - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: (10,16), (17,23) - CheckNewAnswer(), - - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: (10,16), (17,23), (25,30) - CheckNewAnswer(), - - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: (10,16), (17,22), (25,30), (35,40) - // evicts: (10,16), (17,22) - CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), - - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: (25,30) / (35,45) - // evicts: (25,30) - CheckNewAnswer((25, 30, 1, 25)) - ) - } - - test("update mode - session") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .selectExpr("*", "CAST(value / 10 AS INT) AS valuegroup") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"valuegroup", $"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) - - testStream(windowedAggregation, OutputMode.Update())( - AddData(inputData, 10, 11), - // Advance watermark to 1 seconds - // sessions: key 1 => (10,16) - CheckNewAnswer((1, 10, 16, 2, 21)), - - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: key 1 => (10,16), (17,22) <- updated - // updated: key 1 => (17,22) - CheckNewAnswer((1, 17, 22, 1, 17)), - - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) - // updated: key 2 => (25,30) - CheckNewAnswer((2, 25, 30, 1, 25)), - - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: key 1 => (10,16), (17,22) / key 2 => (25,30) / key 3 => (35,40) - // updated: key 3 => (35,40) - // evicts: key 1 => (10,16), (17,22) - CheckNewAnswer((3, 35, 40, 1, 35)), - - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: key 2 => (25,30) / key 3 => (35,40) / key 4 => (40, 45) - // updated: key 4 => (40,45) - CheckNewAnswer((4, 40, 45, 1, 40)) - ) - } - - test("update mode - session - keys overlapped with sessions") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"valuegroup", $"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) - - testStream(windowedAggregation, OutputMode.Update())( - AddData(inputData, 10, 11, 12, 13), - // Advance watermark to 3 seconds - // sessions: key 0 => (10,17) / key 1 => (11, 18) - CheckNewAnswer((0, 10, 17, 2, 22), (1, 11, 18, 2, 24)), - - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: key 0 => (10,17) / key 1 => (11, 22) - // updated: key 1 => (11,22) - CheckNewAnswer((1, 11, 22, 3, 41)), - - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30) - // updated: key 1 => (25,30) - CheckNewAnswer((1, 25, 30, 1, 25)), - - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30), (35,40) - // updated: key 1 => (35,40) - // evicts: key 1 => (10,17), (11,22) - CheckNewAnswer((1, 35, 40, 1, 35)), - - AddData(inputData, 27), - // don't advance watermark - // sessions: key 1 => (25,32), (35,40) - // updated: key 1 => (25,32) - CheckNewAnswer((1, 25, 32, 2, 52)), - - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: key 0 => (40,45) / key 1 => (25,32), (35,40) - // updated: key 0 => (40,45) - CheckNewAnswer((0, 40, 45, 1, 40)) - ) - } - - test("StructuredSessionization - example") { - // Implements StructuredSessionization.scala leveraging "session" function - // as a test, to verify the sessionization works with simple example - - val inputData = MemoryStream[(String, Long)] - - // Split the lines into words, treat words as sessionId of events - val events = inputData.toDF() - .select($"_1".as("value"), $"_2".as("timestamp")) - .withColumn("eventTime", $"timestamp".cast("timestamp")) - .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") - .withWatermark("eventTime", "10 seconds") - - val sessionUpdates = events - .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) - .agg(count("*").as("numEvents")) - .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", - "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", - "numEvents") - - testStream(sessionUpdates, OutputMode.Update())( - AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), - // Advance watermark to 1 seconds - CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("spark", 10, 20, 10, 1), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 21, 10, 1) - ), - - AddData(inputData, ("spark streaming", 15L)), - // Advance watermark to 5 seconds - CheckNewAnswer(("spark", 10, 25, 15, 2), ("streaming", 11, 25, 14, 2)), - - AddData(inputData, ("hello world", 25L)), - // Advance watermark to 15 seconds - // ("hello", 10L) and ("world", 11L) are not evicted yet - // but new input rows doesn't fall into existing sessions - CheckNewAnswer(("hello", 25, 35, 10, 1), ("world", 25, 35, 10, 1)), - - AddData(inputData, ("hello world", 35L)), - // Advance watermark to 25 seconds - // ("hello", 10L) and ("world", 11L) are evicted - CheckNewAnswer(("hello", 25, 45, 20, 2), ("world", 25, 45, 20, 2)), - - AddData(inputData, ("hello world", 10L)), - // Should not emit anything as data less than watermark - // FIXME: this works but why watermark doesn't effect when we don't add 35L? - // FIXME: investigate what's happening here... - CheckNewAnswer(), - - AddData(inputData, ("hello apache spark", 40L)), - // Advance watermark to 30 seconds - CheckNewAnswer(("hello", 25, 50, 25, 3), ("apache", 40, 50, 10, 1), ("spark", 40, 50, 10, 1)) - ) - } - - test("update mode - session - no key") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .selectExpr("*") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) - - testStream(windowedAggregation, OutputMode.Update())( - - AddData(inputData, 10, 11), - // Advance watermark to 1 seconds - // sessions: (10,16) - CheckNewAnswer((10, 16, 2, 21)), - - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: (10,16), (17,22) - // updated: (17,22) - CheckNewAnswer((17, 22, 1, 17)), - - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: (10,16), (17,22), (25,30) - // updated: (25,30) - CheckNewAnswer((25, 30, 1, 25)), - - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: (10,16), (17,22), (25,30), (35,40) - // updated: (35, 40) - // evicts: (10,16), (17,22) - CheckNewAnswer((35, 40, 1, 35)), - - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: (25,30), (35,45) - // updated: (35, 45) - CheckNewAnswer((35, 45, 2, 75)) - ) - } - test("update mode") { val inputData = MemoryStream[Int] spark.conf.set("spark.sql.shuffle.partitions", "10") @@ -767,6 +472,296 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("append mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Append())( + AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), + // Advance watermark to 1 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("spark", 10, 20, 10, 1) + // ("structured", 11, 21, 10, 2) + // ("streaming", 11, 21, 10, 2) + CheckNewAnswer(), + + AddData(inputData, ("spark streaming", 15L)), + // Advance watermark to 5 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + CheckNewAnswer(), + + AddData(inputData, ("hello world", 25L)), + // Advance watermark to 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(), + + AddData(inputData, ("hello world", 3L)), + // input can match to not-yet-evicted sessions, but input itself is less than watermark + // so it should not match exiting sessions + // Watermark kept 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(), + + AddData(inputData, ("hello", 31L)), + // Advance watermark to 21 seconds + // current sessions after batch: + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 41, 16, 2) + // ("world", 25, 35, 10, 1) + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("structured", 11, 21, 10, 1) + ), + + AddData(inputData, ("hello", 35L)), + // Advance watermark to 25 seconds + // current sessions after batch: + // ("hello", 25, 45, 20, 3) + // ("world", 25, 35, 10, 1) + CheckNewAnswer( + ("spark", 10, 25, 15, 2), + ("streaming", 11, 25, 14, 2) + ), + + AddData(inputData, ("hello apache spark", 60L)), + // Advance watermark to 50 seconds + // current sessions after batch: + // ("hello", 60, 70, 10, 1) + // ("apache", 60, 70, 10, 1) + // ("spark", 60, 70, 10, 1) + CheckNewAnswer(("hello", 25, 45, 20, 3), ("world", 25, 35, 10, 1)) + ) + } + + test("append mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11), // sessions: (10,16) + CheckNewAnswer(), + + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,23) + CheckNewAnswer(), + + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,23), (25,30) + CheckNewAnswer(), + + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // evicts: (10,16), (17,22) + CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30) / (35,45) + // evicts: (25,30) + CheckNewAnswer((25, 30, 1, 25)) + ) + } + + test("update mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Update())( + AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), + // Advance watermark to 1 seconds + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("spark", 10, 20, 10, 1) + // ("structured", 11, 21, 10, 1) + // ("streaming", 11, 21, 10, 1) + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("spark", 10, 20, 10, 1), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 21, 10, 1) + ), + + AddData(inputData, ("spark streaming", 15L)), + // Advance watermark to 5 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + CheckNewAnswer(("spark", 10, 25, 15, 2), ("streaming", 11, 25, 14, 2)), + + AddData(inputData, ("hello world", 25L)), + // Advance watermark to 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(("hello", 25, 35, 10, 1), ("world", 25, 35, 10, 1)), + + AddData(inputData, ("hello world", 3L)), + // input can match to not-yet-evicted sessions, but input itself is less than watermark + // so it should not match exiting sessions + // Watermark kept 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(), + + AddData(inputData, ("hello", 31L)), + // Advance watermark to 21 seconds + // current sessions after batch: + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 41, 16, 2) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(("hello", 25, 41, 16, 2)), + + AddData(inputData, ("hello", 35L)), + // Advance watermark to 25 seconds + // current sessions after batch: + // ("hello", 25, 45, 20, 3) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(("hello", 25, 45, 20, 3)), + + AddData(inputData, ("hello apache spark", 60L)), + // Advance watermark to 50 seconds + // current sessions after batch: + // ("hello", 60, 70, 10, 1) + // ("apache", 60, 70, 10, 1) + // ("spark", 60, 70, 10, 1) + CheckNewAnswer(("hello", 60, 70, 10, 1), ("apache", 60, 70, 10, 1), ("spark", 60, 70, 10, 1)) + ) + } + + test("update mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Update())( + + AddData(inputData, 10, 11), + // Advance watermark to 1 seconds + // sessions: (10,16) + CheckNewAnswer((10, 16, 2, 21)), + + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,22) + // updated: (17,22) + CheckNewAnswer((17, 22, 1, 17)), + + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,22), (25,30) + // updated: (25,30) + CheckNewAnswer((25, 30, 1, 25)), + + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // updated: (35, 40) + // evicts: (10,16), (17,22) + CheckNewAnswer((35, 40, 1, 35)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30), (35,45) + // updated: (35, 45) + CheckNewAnswer((35, 45, 2, 75)) + ) + } + test("group by on raw timestamp") { val inputData = MemoryStream[Int] From 3637f6054ca41badfb22a80c32f059b17e5d0c65 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 19 Sep 2018 09:48:17 +0900 Subject: [PATCH 33/60] WIP apply merging session in each partition before shuffling --- .../sql/execution/aggregate/AggUtils.scala | 26 +++++++++++++++++-- .../streaming/EventTimeWatermarkSuite.scala | 1 + 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 85e1af959601a..9eadb1e5b74fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -348,8 +348,11 @@ object AggUtils { /** * Plans a streaming session aggregation using the following progression: * - * - Partial Merge (group: all keys) + * - Partial Aggregation * - all tuples will have aggregated columns with initial value + * - Sort within partition (sort: all keys) + * - SessionWindowStateStoreRestore (group: keys "without" session) + * - This will play as "Partial Merge" in each partition * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) * - SessionWindowStateStoreRestore (group: keys "without" session) * - merge input tuples with stored tuples (sessions) respecting sort order @@ -391,9 +394,28 @@ object AggUtils { child = child) } + // sort happens here to merge sessions on each partition + // this is to reduce amount of rows to shuffle + val partialMerged1: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + } + // shuffle & sort happens here: most of details are also handled in this physical plan val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialAggregate) + sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialMerged1) val mergedSessions = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index f719576eb27ce..7e60c2ccc55a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -638,6 +638,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(sessionUpdates, OutputMode.Update())( AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), // Advance watermark to 1 seconds + // current sessions after batch: // ("hello", 10, 21, 11, 2) // ("world", 10, 21, 11, 2) // ("spark", 10, 20, 10, 1) From 0d53831ba591655ac9b64a1cd59261f7c4d7723a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 20 Sep 2018 14:14:37 +0900 Subject: [PATCH 34/60] Fix scala checkstyle --- .../spark/sql/execution/aggregate/MergingSessionsExec.scala | 2 -- .../sql/execution/aggregate/MergingSessionsIterator.scala | 6 +++--- .../spark/sql/execution/aggregate/UpdatingSessionExec.scala | 5 +---- .../spark/sql/execution/python/AggregateInPandasExec.scala | 1 - .../streaming/state/MultiValuesStateStoreRDD.scala | 6 +++--- .../spark/sql/execution/streaming/statefulOperators.scala | 3 --- .../MergingSortWithMultiValuesStateIteratorSuite.scala | 2 +- 7 files changed, 8 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala index 1819f5966e969..e2784c0d2bd0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.TaskContext - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, NamedExpression, SortOrder, UnsafeRow} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala index 3ce4b5c519a20..fffda1aab489a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -54,9 +54,9 @@ class MergingSessionsIterator( /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ private def newBuffer: InternalRow = { val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala index 4592b72be3d14..75d1eb1282f3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala @@ -17,12 +17,9 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.TaskContext - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} @@ -56,4 +53,4 @@ case class UpdatingSessionExec( case Some(ordering) => ordering case None => super.requiredChildOrdering } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 9286635a17c9f..a9c709722c909 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -22,7 +22,6 @@ import java.io.File import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkEnv, TaskContext} - import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala index 652b463d85b72..32d63296a1abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala @@ -48,9 +48,9 @@ class MultiValuesStateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions /** - * Set the preferred location of each partition using the executor that has the related - * [[StateStoreProvider]] already loaded. - */ + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ override def getPreferredLocations(partition: Partition): Seq[String] = { val stateStoreProviderId = StateStoreProviderId( StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, partition.index), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index de3ef8788728e..1774662e4dee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -21,9 +21,6 @@ import java.util.UUID import java.util.concurrent.TimeUnit._ import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala index 929385322f977..7cd45b9f0a5da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala @@ -272,4 +272,4 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { } StateStore.stop() } -} \ No newline at end of file +} From a781400281dc02646843365f02e96f6c2398691f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 20 Sep 2018 14:26:21 +0900 Subject: [PATCH 35/60] Fix python style check --- python/pyspark/sql/functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fd89b4834cbc6..794ba4f83b2af 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1396,10 +1396,13 @@ def check_string_field(field, fieldName): res = sc._jvm.functions.window(time_col, windowDuration) return Column(res) + @since(3.0) @ignore_unicode_prefix def session(timeColumn, gapDuration): + """ # FIXME: python doc!! + """ def check_string_field(field, fieldName): if not field or type(field) is not str: raise TypeError("%s should be provided as a string" % fieldName) From 918dad221bf4d4c884fa0587ba65f1c2eb4b8f64 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 8 Oct 2018 17:44:47 +0900 Subject: [PATCH 36/60] WIP add complete mode, fix tricky bugs, apply ExternalAppendOnlyUnsafeRowArray * also add UTs which verify edge-case scenarios --- .../aggregate/UpdatingSessionExec.scala | 5 +- .../aggregate/UpdatingSessionIterator.scala | 28 +- .../python/AggregateInPandasExec.scala | 5 +- ...gingSortWithMultiValuesStateIterator.scala | 2 +- .../state/MultiValuesStateManager.scala | 45 ++ .../streaming/statefulOperators.scala | 35 +- .../sql/DataFrameSessionWindowingSuite.scala | 2 +- .../UpdatingSessionIteratorSuite.scala | 39 +- .../streaming/EventTimeWatermarkSuite.scala | 291 --------- .../StreamingSessionWindowSuite.scala | 565 ++++++++++++++++++ 10 files changed, 703 insertions(+), 314 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala index 75d1eb1282f3d..0c625f6cf1ce5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionExec.scala @@ -34,9 +34,12 @@ case class UpdatingSessionExec( child: SparkPlan) extends UnaryExecNode { override protected def doExecute(): RDD[InternalRow] = { + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold + child.execute().mapPartitions { iter => new UpdatingSessionIterator(iter, keyExpressions, sessionExpression, - child.output) + child.output, inMemoryThreshold, spillThreshold) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala index 6714ac86b73ca..14746098052e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.types.{LongType, TimestampType} // FIXME: javadoc!! @@ -29,7 +30,9 @@ class UpdatingSessionIterator( iter: Iterator[InternalRow], groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, - inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + inputSchema: Seq[Attribute], + inMemoryThreshold: Int, + spillThreshold: Int) extends Iterator[InternalRow] { val sessionIndex = inputSchema.indexOf(sessionExpression) @@ -40,8 +43,10 @@ class UpdatingSessionIterator( var currentSessionStart: Long = Long.MaxValue var currentSessionEnd: Long = Long.MinValue - val currentRows: mutable.MutableList[InternalRow] = new mutable.MutableList[InternalRow]() + var currentRows: ExternalAppendOnlyUnsafeRowArray = new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, spillThreshold) + var returnRows: ExternalAppendOnlyUnsafeRowArray = _ var returnRowsIter: Iterator[InternalRow] = _ var errorOnIterator: Boolean = false @@ -56,6 +61,7 @@ class UpdatingSessionIterator( if (returnRowsIter != null) { returnRowsIter = null + returnRows.clear() } iter.hasNext @@ -98,7 +104,7 @@ class UpdatingSessionIterator( } else if (sessionStart <= currentSessionEnd) { // expanding session length if needed expandEndOfCurrentSession(sessionEnd) - currentRows += row + currentRows.add(row.asInstanceOf[UnsafeRow]) } else { closeCurrentSession(keyChanged = false) startNewSession(row, keys, sessionStart, sessionEnd) @@ -133,8 +139,9 @@ class UpdatingSessionIterator( currentKeys = keys currentSessionStart = sessionStart currentSessionEnd = sessionEnd + currentRows.clear() - currentRows += row + currentRows.add(row.asInstanceOf[UnsafeRow]) } private def handleBrokenPreconditionForSort(): Unit = { @@ -164,15 +171,19 @@ class UpdatingSessionIterator( } } - val returnRows = currentRows.map { internalRow => + returnRows = currentRows + currentRows = new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, spillThreshold) + + val currentRowsIter = returnRows.generateIterator().map { internalRow => val proj = UnsafeProjection.create(newSchemaExpressions, inputSchema) proj(internalRow) - }.toList + } if (returnRowsIter != null && returnRowsIter.hasNext) { - returnRowsIter = returnRowsIter ++ returnRows.iterator + returnRowsIter = returnRowsIter ++ currentRowsIter } else { - returnRowsIter = returnRows.iterator + returnRowsIter = currentRowsIter } if (keyChanged) processedKeys.add(currentKeys) @@ -180,7 +191,6 @@ class UpdatingSessionIterator( currentKeys = null currentSessionStart = Long.MaxValue currentSessionEnd = Long.MinValue - currentRows.clear() } private def assertIteratorNotCorrupted(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index a9c709722c909..80f5aee29e5c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -177,8 +177,11 @@ case class AggregateInPandasExec( : Iterator[InternalRow] = { val newIter = sessionWindowOption match { case Some(sessionExpression) => + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold + new UpdatingSessionIterator(iter, groupingWithoutSessionExpressions, sessionExpression, - child.output) + child.output, inMemoryThreshold, spillThreshold) case None => iter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index 90a34227a015e..089f7fbe10b54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -122,7 +122,7 @@ class MergingSortWithMultiValuesStateIterator( // The number of values for the given key is expected to be likely small, // so sorting it here doesn't hurt. val unsortedIter = stateManager.get(currentRow.keys) - currentStateIter = unsortedIter.toList.sortWith((row1, row2) => { + currentStateIter = unsortedIter.map(_.copy()).toList.sortWith((row1, row2) => { val rowInfo1 = SessionRowInformation.of(row1) val rowInfo2 = SessionRowInformation.of(row2) // here sorting is based on the fact that keys are same diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala index 1abf330b3b1cd..633c4326e7406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala @@ -257,6 +257,51 @@ class MultiValuesStateManager( } } + /** Provide all (key, value) row pairs (key can be exposed multiple times) */ + def getAllRowPairs: Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKey: UnsafeRow = null + private var numValues: Long = 0L + private var index: Long = 0L + + override def getNext(): UnsafeRowPair = { + if (currentKey != null && index < numValues) { + provideCurrentRow() + } else { + if (!allKeyToNumValues.hasNext) { + // finished + finished = true + null + } else { + advanceGroup() + assert(numValues != 0) + provideCurrentRow() + } + } + } + + private def advanceGroup(): Unit = { + val currentKeyToNumValue = allKeyToNumValues.next() + currentKey = currentKeyToNumValue.key + numValues = currentKeyToNumValue.numValue + index = 0 + } + + private def provideCurrentRow(): UnsafeRowPair = { + val currentRow = keyWithIndexToValue.get(currentKey, index) + index += 1 + reusedPair.withRows(currentKey, currentRow) + } + + override def close: Unit = {} + } + } + /** Commit all the changes to all the state stores */ def commit(): Unit = { keyToNumValues.commit() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 1774662e4dee9..bfbd4eda5e5e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -539,6 +539,39 @@ case class SessionWindowStateStoreSaveExec( // assuming late events were dropped from MergingSortWithMultiValuesStateIterator outputMode match { + case Some(Complete) => + allUpdatesTimeMs += timeTakenMs { + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val keys = keyProjection(row) + + if (currentKey == null || !keyOrdering.equiv(currentKey, keys)) { + currentKey = keys.copy() + + // This is necessary because MultiValuesStateManager doesn't guarantee + // stable ordering. + // The number of values for the given key is expected to be likely small, + // so listing it here doesn't hurt. + previousSessions = stateManager.get(keys).toList + + stateManager.removeKey(keys) + } + + stateManager.append(keys, row) + + if (!previousSessions.exists(p => valueOrdering.equiv(row, p))) { + // such session is not in previous session + numUpdatedStateRows += 1 + } + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]]( + stateManager.getAllRowPairs.map(_.value), { + commitTimeMs += timeTakenMs { stateManager.commit() } + setStoreMetrics(stateManager) + }) + // Update and output only sessions being evicted from the MultiValuesStateManager // Assumption: watermark predicates must be non-empty if append mode is allowed case Some(Append) => @@ -548,7 +581,7 @@ case class SessionWindowStateStoreSaveExec( val keys = keyProjection(row) if (currentKey == null || !keyOrdering.equiv(currentKey, keys)) { - currentKey = keys + currentKey = keys.copy() // This is necessary because MultiValuesStateManager doesn't guarantee // stable ordering. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 0c5d858330d1b..c41390140c7a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -121,7 +121,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ) } - test("session window groupBy with multiple keys statement - two distincts") { + test("session window groupBy with multiple keys statement - two distinct") { val df = Seq( ("2016-03-27 19:39:34", 1, 2, "a"), ("2016-03-27 19:39:39", 1, 2, "a"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala index ea0fa651d2f40..e941002a31189 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.streaming +import java.util.Properties + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -45,9 +49,26 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { attr => List("aggVal1", "aggVal2").contains(attr.name) } + override def beforeAll(): Unit = { + super.beforeAll() + val taskManager = new TaskMemoryManager(new TestMemoryManager(sqlContext.sparkContext.conf), 0) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) + } + + override def afterAll(): Unit = try { + TaskContext.unset() + } finally { + super.afterAll() + } + + // just copying default values to avoid bothering with SQLContext + val inMemoryThreshold = 4096 + val spillThreshold = Int.MaxValue + test("no row") { val iterator = new UpdatingSessionIterator(None.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) assert(!iterator.hasNext) } @@ -56,7 +77,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(createRow("a", 1, 100, 110, 10, 1.1)) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) assert(iterator.hasNext) @@ -74,7 +95,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows = rows.indices.map { _ => assert(iterator.hasNext) @@ -105,7 +126,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -141,7 +162,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -186,7 +207,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows1 = rows1.indices.map { _ => assert(iterator.hasNext) @@ -239,7 +260,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) // UpdatingSessionIterator can't detect error on hasNext assert(iterator.hasNext) @@ -266,7 +287,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3) val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, - sessionAttribute, rowAttributes) + sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) // UpdatingSessionIterator can't detect error on hasNext assert(iterator.hasNext) @@ -332,7 +353,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rows = List(row1, row2, row3, row4) val iterator = new UpdatingSessionIterator(rows.iterator, Seq.empty[Attribute], - noKeySessionAttribute, noKeyRowAttributes) + noKeySessionAttribute, noKeyRowAttributes, inMemoryThreshold, spillThreshold) val retRows = rows.indices.map { _ => assert(iterator.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 7e60c2ccc55a1..f51fdab0377c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -472,297 +472,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } - test("append mode - session window") { - // Implements StructuredSessionization.scala leveraging "session" function - // as a test, to verify the sessionization works with simple example - - val inputData = MemoryStream[(String, Long)] - - // Split the lines into words, treat words as sessionId of events - val events = inputData.toDF() - .select($"_1".as("value"), $"_2".as("timestamp")) - .withColumn("eventTime", $"timestamp".cast("timestamp")) - .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") - .withWatermark("eventTime", "10 seconds") - - val sessionUpdates = events - .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) - .agg(count("*").as("numEvents")) - .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", - "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", - "numEvents") - - testStream(sessionUpdates, OutputMode.Append())( - AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), - // Advance watermark to 1 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("spark", 10, 20, 10, 1) - // ("structured", 11, 21, 10, 2) - // ("streaming", 11, 21, 10, 2) - CheckNewAnswer(), - - AddData(inputData, ("spark streaming", 15L)), - // Advance watermark to 5 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - CheckNewAnswer(), - - AddData(inputData, ("hello world", 25L)), - // Advance watermark to 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(), - - AddData(inputData, ("hello world", 3L)), - // input can match to not-yet-evicted sessions, but input itself is less than watermark - // so it should not match exiting sessions - // Watermark kept 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(), - - AddData(inputData, ("hello", 31L)), - // Advance watermark to 21 seconds - // current sessions after batch: - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 41, 16, 2) - // ("world", 25, 35, 10, 1) - CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("structured", 11, 21, 10, 1) - ), - - AddData(inputData, ("hello", 35L)), - // Advance watermark to 25 seconds - // current sessions after batch: - // ("hello", 25, 45, 20, 3) - // ("world", 25, 35, 10, 1) - CheckNewAnswer( - ("spark", 10, 25, 15, 2), - ("streaming", 11, 25, 14, 2) - ), - - AddData(inputData, ("hello apache spark", 60L)), - // Advance watermark to 50 seconds - // current sessions after batch: - // ("hello", 60, 70, 10, 1) - // ("apache", 60, 70, 10, 1) - // ("spark", 60, 70, 10, 1) - CheckNewAnswer(("hello", 25, 45, 20, 3), ("world", 25, 35, 10, 1)) - ) - } - - test("append mode - session window - no key") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .selectExpr("*") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11), // sessions: (10,16) - CheckNewAnswer(), - - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: (10,16), (17,23) - CheckNewAnswer(), - - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: (10,16), (17,23), (25,30) - CheckNewAnswer(), - - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: (10,16), (17,22), (25,30), (35,40) - // evicts: (10,16), (17,22) - CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), - - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: (25,30) / (35,45) - // evicts: (25,30) - CheckNewAnswer((25, 30, 1, 25)) - ) - } - - test("update mode - session window") { - // Implements StructuredSessionization.scala leveraging "session" function - // as a test, to verify the sessionization works with simple example - - val inputData = MemoryStream[(String, Long)] - - // Split the lines into words, treat words as sessionId of events - val events = inputData.toDF() - .select($"_1".as("value"), $"_2".as("timestamp")) - .withColumn("eventTime", $"timestamp".cast("timestamp")) - .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") - .withWatermark("eventTime", "10 seconds") - - val sessionUpdates = events - .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) - .agg(count("*").as("numEvents")) - .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", - "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", - "numEvents") - - testStream(sessionUpdates, OutputMode.Update())( - AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), - // Advance watermark to 1 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("spark", 10, 20, 10, 1) - // ("structured", 11, 21, 10, 1) - // ("streaming", 11, 21, 10, 1) - CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("spark", 10, 20, 10, 1), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 21, 10, 1) - ), - - AddData(inputData, ("spark streaming", 15L)), - // Advance watermark to 5 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - CheckNewAnswer(("spark", 10, 25, 15, 2), ("streaming", 11, 25, 14, 2)), - - AddData(inputData, ("hello world", 25L)), - // Advance watermark to 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(("hello", 25, 35, 10, 1), ("world", 25, 35, 10, 1)), - - AddData(inputData, ("hello world", 3L)), - // input can match to not-yet-evicted sessions, but input itself is less than watermark - // so it should not match exiting sessions - // Watermark kept 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(), - - AddData(inputData, ("hello", 31L)), - // Advance watermark to 21 seconds - // current sessions after batch: - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 41, 16, 2) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(("hello", 25, 41, 16, 2)), - - AddData(inputData, ("hello", 35L)), - // Advance watermark to 25 seconds - // current sessions after batch: - // ("hello", 25, 45, 20, 3) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(("hello", 25, 45, 20, 3)), - - AddData(inputData, ("hello apache spark", 60L)), - // Advance watermark to 50 seconds - // current sessions after batch: - // ("hello", 60, 70, 10, 1) - // ("apache", 60, 70, 10, 1) - // ("spark", 60, 70, 10, 1) - CheckNewAnswer(("hello", 60, 70, 10, 1), ("apache", 60, 70, 10, 1), ("spark", 60, 70, 10, 1)) - ) - } - - test("update mode - session window - no key") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .selectExpr("*") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) - - testStream(windowedAggregation, OutputMode.Update())( - - AddData(inputData, 10, 11), - // Advance watermark to 1 seconds - // sessions: (10,16) - CheckNewAnswer((10, 16, 2, 21)), - - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: (10,16), (17,22) - // updated: (17,22) - CheckNewAnswer((17, 22, 1, 17)), - - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: (10,16), (17,22), (25,30) - // updated: (25,30) - CheckNewAnswer((25, 30, 1, 25)), - - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: (10,16), (17,22), (25,30), (35,40) - // updated: (35, 40) - // evicts: (10,16), (17,22) - CheckNewAnswer((35, 40, 1, 35)), - - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: (25,30), (35,45) - // updated: (35, 45) - CheckNewAnswer((35, 45, 2, 75)) - ) - } - test("group by on raw timestamp") { val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala new file mode 100644 index 0000000000000..e4b3aaeaeb141 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.{BeforeAndAfter, Matchers} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.{count, session, sum} + +class StreamingSessionWindowSuite extends StreamTest + with BeforeAndAfter with Matchers with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("complete mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + // note that complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + + val sessionUpdates = events + .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Complete())( + AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("spark", 10, 20, 10, 1), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 21, 10, 1) + ), + + AddData(inputData, ("spark streaming", 15L)), + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("spark", 10, 25, 15, 2), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 25, 14, 2) + ), + + AddData(inputData, ("hello world", 25L)), + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("spark", 10, 25, 15, 2), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 25, 14, 2), + ("hello", 25, 35, 10, 1), + ("world", 25, 35, 10, 1) + ), + + AddData(inputData, ("hello world", 3L)), + CheckNewAnswer( + ("hello", 3, 21, 18, 3), + ("world", 3, 21, 18, 3), + ("spark", 10, 25, 15, 2), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 25, 14, 2), + ("hello", 25, 35, 10, 1), + ("world", 25, 35, 10, 1) + ), + + AddData(inputData, ("hello", 31L)), + CheckNewAnswer( + ("hello", 3, 21, 18, 3), + ("world", 3, 21, 18, 3), + ("spark", 10, 25, 15, 2), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 25, 14, 2), + ("hello", 25, 41, 16, 2), + ("world", 25, 35, 10, 1) + ), + + AddData(inputData, ("hello", 35L)), + CheckNewAnswer( + ("hello", 3, 21, 18, 3), + ("world", 3, 21, 18, 3), + ("spark", 10, 25, 15, 2), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 25, 14, 2), + ("hello", 25, 45, 20, 3), + ("world", 25, 35, 10, 1) + ), + + AddData(inputData, ("hello apache spark", 60L)), + CheckNewAnswer( + ("hello", 3, 21, 18, 3), + ("world", 3, 21, 18, 3), + ("spark", 10, 25, 15, 2), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 25, 14, 2), + ("hello", 25, 45, 20, 3), + ("world", 25, 35, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ) + ) + } + + test("complete mode - session window - no key") { + // complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .groupBy(session($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Complete())( + AddData(inputData, 10, 11), + CheckNewAnswer((10, 16, 2, 21)), + + AddData(inputData, 17), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 22, 1, 17) + ), + + AddData(inputData, 35), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 22, 1, 17), + (35, 40, 1, 35) + ), + + // should reflect late row + AddData(inputData, 22), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 27, 2, 39), + (35, 40, 1, 35) + ), + + AddData(inputData, 40), + CheckNewAnswer( + (10, 16, 2, 21), + (17, 27, 2, 39), + (35, 45, 2, 75) + ) + ) + } + + test("append mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Append())( + AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), + // Advance watermark to 1 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("spark", 10, 20, 10, 1) + // ("structured", 11, 21, 10, 2) + // ("streaming", 11, 21, 10, 2) + CheckNewAnswer(), + + AddData(inputData, ("spark streaming", 15L)), + // Advance watermark to 5 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + CheckNewAnswer(), + + AddData(inputData, ("hello world", 25L)), + // Advance watermark to 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(), + + AddData(inputData, ("hello world", 3L)), + // input can match to not-yet-evicted sessions, but input itself is less than watermark + // so it should not match exiting sessions + // Watermark kept 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(), + + AddData(inputData, ("hello", 31L)), + // Advance watermark to 21 seconds + // current sessions after batch: + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 41, 16, 2) + // ("world", 25, 35, 10, 1) + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("structured", 11, 21, 10, 1) + ), + + AddData(inputData, ("hello", 35L)), + // Advance watermark to 25 seconds + // current sessions after batch: + // ("hello", 25, 45, 20, 3) + // ("world", 25, 35, 10, 1) + CheckNewAnswer( + ("spark", 10, 25, 15, 2), + ("streaming", 11, 25, 14, 2) + ), + + AddData(inputData, ("hello apache spark", 60L)), + // Advance watermark to 50 seconds + // current sessions after batch: + // ("hello", 60, 70, 10, 1) + // ("apache", 60, 70, 10, 1) + // ("spark", 60, 70, 10, 1) + CheckNewAnswer(("hello", 25, 45, 20, 3), ("world", 25, 35, 10, 1)) + ) + } + + test("append mode - session window - storing multiple sessions in given key") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"valuegroup", $"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Append())( + AddData(inputData, 10, 11, 12, 13), + // Advance watermark to 3 seconds + // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 18, 2, 24) + CheckNewAnswer(), + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41) + CheckNewAnswer(), + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41), (25, 30, 1, 25) + CheckNewAnswer(), + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41), (25, 30, 1, 25), + // (35, 40, 1, 35) + // evicts: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41) + CheckNewAnswer((0, 10, 17, 2, 22), (1, 11, 22, 3, 41)), + AddData(inputData, 27), + // don't advance watermark + // sessions: key 1 => (25, 32, 2, 52), (35, 40, 1, 35) + CheckNewAnswer(), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: key 0 => (40, 45, 1, 40) / key 1 => (25, 32, 2, 52), (35, 40, 1, 35) + CheckNewAnswer() + ) + } + + test("append mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 10, 11), // sessions: (10,16) + CheckNewAnswer(), + + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,23) + CheckNewAnswer(), + + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,23), (25,30) + CheckNewAnswer(), + + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // evicts: (10,16), (17,22) + CheckNewAnswer((10, 16, 2, 21), (17, 22, 1, 17)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30) / (35,45) + // evicts: (25,30) + CheckNewAnswer((25, 30, 1, 25)) + ) + } + + test("update mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Update())( + AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), + // Advance watermark to 1 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("spark", 10, 20, 10, 1) + // ("structured", 11, 21, 10, 1) + // ("streaming", 11, 21, 10, 1) + CheckNewAnswer( + ("hello", 10, 21, 11, 2), + ("world", 10, 21, 11, 2), + ("spark", 10, 20, 10, 1), + ("structured", 11, 21, 10, 1), + ("streaming", 11, 21, 10, 1) + ), + + AddData(inputData, ("spark streaming", 15L)), + // Advance watermark to 5 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + CheckNewAnswer(("spark", 10, 25, 15, 2), ("streaming", 11, 25, 14, 2)), + + AddData(inputData, ("hello world", 25L)), + // Advance watermark to 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(("hello", 25, 35, 10, 1), ("world", 25, 35, 10, 1)), + + AddData(inputData, ("hello world", 3L)), + // input can match to not-yet-evicted sessions, but input itself is less than watermark + // so it should not match exiting sessions + // Watermark kept 15 seconds + // current sessions after batch: + // ("hello", 10, 21, 11, 2) + // ("world", 10, 21, 11, 2) + // ("structured", 11, 21, 10, 2) + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 35, 10, 1) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(), + + AddData(inputData, ("hello", 31L)), + // Advance watermark to 21 seconds + // current sessions after batch: + // ("spark", 10, 25, 15, 2) + // ("streaming", 11, 25, 14, 2) + // ("hello", 25, 41, 16, 2) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(("hello", 25, 41, 16, 2)), + + AddData(inputData, ("hello", 35L)), + // Advance watermark to 25 seconds + // current sessions after batch: + // ("hello", 25, 45, 20, 3) + // ("world", 25, 35, 10, 1) + CheckNewAnswer(("hello", 25, 45, 20, 3)), + + AddData(inputData, ("hello apache spark", 60L)), + // Advance watermark to 50 seconds + // current sessions after batch: + // ("hello", 60, 70, 10, 1) + // ("apache", 60, 70, 10, 1) + // ("spark", 60, 70, 10, 1) + CheckNewAnswer(("hello", 60, 70, 10, 1), ("apache", 60, 70, 10, 1), ("spark", 60, 70, 10, 1)) + ) + } + + test("update mode - session window - storing multiple sessions in given key") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"valuegroup", $"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Update())( + AddData(inputData, 10, 11, 12, 13), + // Advance watermark to 3 seconds + // sessions: key 0 => (10,17) / key 1 => (11, 18) + CheckNewAnswer((0, 10, 17, 2, 22), (1, 11, 18, 2, 24)), + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: key 0 => (10,17) / key 1 => (11, 22) + // updated: key 1 => (11,22) + CheckNewAnswer((1, 11, 22, 3, 41)), + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30) + // updated: key 1 => (25,30) + CheckNewAnswer((1, 25, 30, 1, 25)), + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30), (35,40) + // updated: key 1 => (35,40) + // evicts: key 0 => (10,17) / key 1 => (11,22) + CheckNewAnswer((1, 35, 40, 1, 35)), + AddData(inputData, 27), + // don't advance watermark + // sessions: key 1 => (25,32), (35,40) + // updated: key 1 => (25,32) + CheckNewAnswer((1, 25, 32, 2, 52)), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: key 0 => (40,45) / key 1 => (25,32), (35,40) + // updated: key 0 => (40,45) + CheckNewAnswer((0, 40, 45, 1, 40)) + ) + } + + test("update mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + testStream(windowedAggregation, OutputMode.Update())( + + AddData(inputData, 10, 11), + // Advance watermark to 1 seconds + // sessions: (10,16) + CheckNewAnswer((10, 16, 2, 21)), + + AddData(inputData, 17), + // Advance watermark to 7 seconds + // sessions: (10,16), (17,22) + // updated: (17,22) + CheckNewAnswer((17, 22, 1, 17)), + + AddData(inputData, 25), + // Advance watermark to 15 seconds + // sessions: (10,16), (17,22), (25,30) + // updated: (25,30) + CheckNewAnswer((25, 30, 1, 25)), + + AddData(inputData, 35), + // Advance watermark to 25 seconds + // sessions: (10,16), (17,22), (25,30), (35,40) + // updated: (35, 40) + // evicts: (10,16), (17,22) + CheckNewAnswer((35, 40, 1, 35)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer(), + + AddData(inputData, 40), + // Advance watermark to 30 seconds + // sessions: (25,30), (35,45) + // updated: (35, 45) + CheckNewAnswer((35, 45, 2, 75)) + ) + } + +} From fd6377b69eaa8e1891448219744810562ebf4586 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 9 Oct 2018 07:20:29 +0900 Subject: [PATCH 37/60] WIP add "session" function to exclude list for description --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 631ab1b7ece7f..4c99a86deb19a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -107,7 +107,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { + val excludes = Seq("cube", "grouping", "grouping_id", "rollup", "window", "session") + if (!excludes.contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } } From e029e12753091480450e22ab39ba6d6cd1ad9a82 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 10 Oct 2018 15:03:48 +0900 Subject: [PATCH 38/60] WIP rename function & column name "session" to "session_window" * This is to be compatible with Baidu's patch --- python/pyspark/sql/functions.py | 4 +- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../org/apache/spark/sql/functions.scala | 4 +- .../sql/DataFrameSessionWindowingSuite.scala | 68 ++++++++++--------- .../streaming/EventTimeWatermarkSuite.scala | 2 +- .../StreamingSessionWindowSuite.scala | 18 ++--- 7 files changed, 52 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 794ba4f83b2af..19c85b8d3c94e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1399,7 +1399,7 @@ def check_string_field(field, fieldName): @since(3.0) @ignore_unicode_prefix -def session(timeColumn, gapDuration): +def session_window(timeColumn, gapDuration): """ # FIXME: python doc!! """ @@ -1410,7 +1410,7 @@ def check_string_field(field, fieldName): sc = SparkContext._active_spark_context time_col = _to_java_column(timeColumn) check_string_field(gapDuration, "gapDuration") - res = sc._jvm.functions.session(time_col, gapDuration) + res = sc._jvm.functions.session_window(time_col, gapDuration) return Column(res) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e3c3765fcd334..94e4a2fa00ffe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2730,7 +2730,7 @@ object TimeWindowing extends Rule[LogicalPlan] { object SessionWindowing extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.dsl.expressions._ - private final val SESSION_COL_NAME = "session" + private final val SESSION_COL_NAME = "session_window" private final val SESSION_START = "start" private final val SESSION_END = "end" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 36d1fff4d834b..620b2b396ce40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -398,7 +398,7 @@ object FunctionRegistry { expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), - expression[SessionWindow]("session"), + expression[SessionWindow]("session_window"), // collection functions expression[CreateArray]("array"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 55bd85f459ddf..192688ded73e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3263,10 +3263,10 @@ object functions { } // FIXME: javadoc! - def session(timeColumn: Column, gapDuration: String): Column = { + def session_window(timeColumn: Column, gapDuration: String): Column = { withExpr { SessionWindow(timeColumn.expr, gapDuration) - }.as("session") + }.as("session_window") } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index c41390140c7a0..9f56d9f5962f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -34,10 +34,11 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") checkAnswer( - df.groupBy(session($"time", "10 seconds")) + df.groupBy(session_window($"time", "10 seconds")) .agg(count("*").as("counts")) - .orderBy($"session.start".asc) - .select($"session.start".cast("string"), $"session.end".cast("string"), $"counts"), + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"counts"), Seq( Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) ) @@ -54,9 +55,9 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // whereas time window doesn't checkAnswer( - df.groupBy(session($"time", "10 seconds")) + df.groupBy(session_window($"time", "10 seconds")) .agg(count("*").as("counts")) - .orderBy($"session.start".asc) + .orderBy($"session_window.start".asc) .select("counts"), Seq(Row(2), Row(1)) ) @@ -78,11 +79,11 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // key "b" => (19:39:27 ~ 19:39:37) checkAnswer( - df.groupBy(session($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), 'id) .agg(count("*").as("counts"), sum("value").as("sum")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "counts", "sum"), + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), @@ -108,11 +109,11 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // key "b" => (19:39:27 ~ 19:39:37) checkAnswer( - df.groupBy(session($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), 'id) .agg(count("*").as("counts"), sumDistinct("value").as("sum")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "counts", "sum"), + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1), @@ -137,11 +138,11 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // key "b" => (19:39:27 ~ 19:39:37) checkAnswer( - df.groupBy(session($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), 'id) .agg(sumDistinct("value").as("sum"), sumDistinct("value2").as("sum2")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "sum", "sum2"), + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "sum", "sum2"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8), Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2), @@ -166,11 +167,11 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) checkAnswer( - df.groupBy(session($"time", "10 seconds"), 'id) + df.groupBy(session_window($"time", "10 seconds"), 'id) .agg(count("*").as("counts"), sum("value").as("sum")) - .orderBy($"session.start".asc) - .selectExpr("CAST(session.start AS STRING)", "CAST(session.end AS STRING)", "id", - "counts", "sum"), + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), @@ -185,9 +186,10 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") - .select(session($"time", "10 seconds"), $"value") - .orderBy($"session.start".asc) - .select($"session.start".cast("string"), $"session.end".cast("string"), $"value") + .select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"value") val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) assert(expands.isEmpty, "Session windows shouldn't require expand") @@ -208,8 +210,8 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") checkAnswer( - df.select(session($"time", "10 seconds"), $"value", explode($"ids")) - .orderBy($"session.start".asc).select("value"), + df.select(session_window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"session_window.start".asc).select("value"), // first window exploded to two rows for "a", and "b", second window exploded to 3 rows Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) ) @@ -223,8 +225,8 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext (null, 4)).toDF("time", "value") checkDataset( - df.select(session($"time", "10 seconds"), $"value") - .orderBy($"session.start".asc) + df.select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) .select("value") .as[Int], 1, 2) // null columns are dropped @@ -238,7 +240,8 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ("2016-03-27 09:00:02", 3), ("2016-03-27 09:00:35", 6)).toDF("time", "value") val e = intercept[AnalysisException] { - df.select(session($"time", "10 second"), session($"time", "15 second")).collect() + df.select(session_window($"time", "10 second"), session_window($"time", "15 second")) + .collect() } assert(e.getMessage.contains( "Multiple time/session window expressions would result in a cartesian product")) @@ -250,7 +253,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") checkAnswer( - df.select(session($"time", "10 seconds").as("session_window"), $"value") + df.select(session_window($"time", "10 seconds").as("session_window"), $"value") .orderBy($"session_window.start".asc) .select("value"), Seq(Row(1), Row(2)) @@ -273,8 +276,9 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSQLContext test("time window in SQL with single string expression") { withTempTable { table => checkAnswer( - spark.sql(s"""select session(time, "10 seconds"), value from $table""") - .select($"session.start".cast(StringType), $"session.end".cast(StringType), $"value"), + spark.sql(s"""select session_window(time, "10 seconds"), value from $table""") + .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType), + $"value"), Seq( Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index f51fdab0377c5..1558669f3fd95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.functions.{count, max, session, sum, window} +import org.apache.spark.sql.functions.{count, max, session_window, sum, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index e4b3aaeaeb141..d4eca1de34b86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions.{count, session, sum} +import org.apache.spark.sql.functions.{count, session_window, sum} class StreamingSessionWindowSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -48,7 +48,7 @@ class StreamingSessionWindowSuite extends StreamTest .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") val sessionUpdates = events - .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) .agg(count("*").as("numEvents")) .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", @@ -142,7 +142,7 @@ class StreamingSessionWindowSuite extends StreamTest val windowedAggregation = inputData.toDF() .selectExpr("*") .withColumn("eventTime", $"value".cast("timestamp")) - .groupBy(session($"eventTime", "5 seconds") as 'session) + .groupBy(session_window($"eventTime", "5 seconds") as 'session) .agg(count("*") as 'count, sum("value") as 'sum) .select($"session".getField("start").cast("long").as[Long], $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) @@ -195,7 +195,7 @@ class StreamingSessionWindowSuite extends StreamTest .withWatermark("eventTime", "10 seconds") val sessionUpdates = events - .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) .agg(count("*").as("numEvents")) .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", @@ -287,7 +287,7 @@ class StreamingSessionWindowSuite extends StreamTest .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) + .groupBy(session_window($"eventTime", "5 seconds") as 'session, 'valuegroup) .agg(count("*") as 'count, sum("value") as 'sum) .select($"valuegroup", $"session".getField("start").cast("long").as[Long], $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) @@ -331,7 +331,7 @@ class StreamingSessionWindowSuite extends StreamTest .selectExpr("*") .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session) + .groupBy(session_window($"eventTime", "5 seconds") as 'session) .agg(count("*") as 'count, sum("value") as 'sum) .select($"session".getField("start").cast("long").as[Long], $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) @@ -381,7 +381,7 @@ class StreamingSessionWindowSuite extends StreamTest .withWatermark("eventTime", "10 seconds") val sessionUpdates = events - .groupBy(session($"eventTime", "10 seconds") as 'session, 'sessionId) + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) .agg(count("*").as("numEvents")) .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", @@ -472,7 +472,7 @@ class StreamingSessionWindowSuite extends StreamTest .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session, 'valuegroup) + .groupBy(session_window($"eventTime", "5 seconds") as 'session, 'valuegroup) .agg(count("*") as 'count, sum("value") as 'sum) .select($"valuegroup", $"session".getField("start").cast("long").as[Long], $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) @@ -520,7 +520,7 @@ class StreamingSessionWindowSuite extends StreamTest .selectExpr("*") .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") - .groupBy(session($"eventTime", "5 seconds") as 'session) + .groupBy(session_window($"eventTime", "5 seconds") as 'session) .agg(count("*") as 'count, sum("value") as 'sum) .select($"session".getField("start").cast("long").as[Long], $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) From a2fc65230f16cc2a55d62253622029b27daf5af6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 16 Oct 2018 04:54:05 +0900 Subject: [PATCH 39/60] WIP reducing unnecessary codegen which seriously harmed performance --- .../aggregate/MergingSessionsIterator.scala | 1 + .../aggregate/UpdatingSessionIterator.scala | 14 ++++++++------ .../MergingSortWithMultiValuesStateIterator.scala | 12 ++++++------ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala index fffda1aab489a..b282decc1c5ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -220,6 +220,7 @@ class MergingSessionsIterator( } private def generateGroupingKey(): UnsafeRow = { + // FIXME: Convert to JoinRow if possible to reduce codegen for unsafe projection val sessionStruct = CreateNamedStruct( Literal("start") :: PreciseTimestampConversion( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala index 14746098052e0..3cae4d55789a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala @@ -39,6 +39,10 @@ class UpdatingSessionIterator( val valuesExpressions: Seq[Attribute] = inputSchema.diff(groupWithoutSessionExpressions) .diff(Seq(sessionExpression)) + val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, + inputSchema) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) + var currentKeys: InternalRow = _ var currentSessionStart: Long = Long.MaxValue var currentSessionEnd: Long = Long.MinValue @@ -81,12 +85,8 @@ class UpdatingSessionIterator( // without this, multiple rows in same key will be returned with same content val row = iter.next().copy() - val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, - inputSchema) - val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) - - val keys = keysProjection(row) - val session = sessionProjection(row) + val keys = keysProjection(row).copy() + val session = sessionProjection(row).copy() val sessionRow = session.getStruct(0, 2) val sessionStart = sessionRow.getLong(0) val sessionEnd = sessionRow.getLong(1) @@ -150,6 +150,7 @@ class UpdatingSessionIterator( } private def closeCurrentSession(keyChanged: Boolean): Unit = { + // FIXME: Convert to JoinRow if possible to reduce codegen for unsafe projection val sessionStruct = CreateNamedStruct( Literal("start") :: PreciseTimestampConversion( @@ -176,6 +177,7 @@ class UpdatingSessionIterator( inMemoryThreshold, spillThreshold) val currentRowsIter = returnRows.generateIterator().map { internalRow => + // FIXME: is there any way to change this? val proj = UnsafeProjection.create(newSchemaExpressions, inputSchema) proj(internalRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index 089f7fbe10b54..2b502f917300c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -30,17 +30,17 @@ class MergingSortWithMultiValuesStateIterator( sessionExpression: Expression, inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, + inputSchema) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) + private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, row: InternalRow) private object SessionRowInformation { def of(row: InternalRow): SessionRowInformation = { - val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, - inputSchema) - val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) - - val keys = keysProjection(row) - val session = sessionProjection(row) + val keys = keysProjection(row).copy() + val session = sessionProjection(row).copy() val sessionRow = session.getStruct(0, 2) val sessionStart = sessionRow.getLong(0) val sessionEnd = sessionRow.getLong(1) From 2dc413beff4630b98e4b718d2146b34855816752 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 16 Oct 2018 05:36:25 +0900 Subject: [PATCH 40/60] WIP reduce codegen once again for MergingSessionsIterator --- .../aggregate/MergingSessionsIterator.scala | 21 ++++++++----------- .../aggregate/UpdatingSessionIterator.scala | 1 + 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala index b282decc1c5ab..64c6ffc8c5fb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, GenericInternalRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, GenericInternalRow, JoinedRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.metric.SQLMetric @@ -219,8 +219,12 @@ class MergingSessionsIterator( } } + private val join = new JoinedRow + + private val groupingKeyProj = GenerateUnsafeProjection.generate(groupingExpressions, + groupingWithoutSessionAttributes :+ sessionExpression.toAttribute) + private def generateGroupingKey(): UnsafeRow = { - // FIXME: Convert to JoinRow if possible to reduce codegen for unsafe projection val sessionStruct = CreateNamedStruct( Literal("start") :: PreciseTimestampConversion( @@ -230,17 +234,10 @@ class MergingSessionsIterator( Literal(currentSessionEnd, LongType), LongType, TimestampType) :: Nil) - val convertedGroupingExpressions = groupingExpressions.map { x => - if (x.semanticEquals(sessionExpression)) { - sessionStruct - } else { - BindReferences.bindReference[Expression](x, groupingWithoutSessionAttributes) - } - } + val joined = join(currentGroupingKey, + UnsafeProjection.create(sessionStruct).apply(InternalRow.empty)) - val proj = GenerateUnsafeProjection.generate(convertedGroupingExpressions, - groupingWithoutSessionAttributes) - proj(currentGroupingKey) + groupingKeyProj(joined) } def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala index 3cae4d55789a3..00b380b320749 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala @@ -151,6 +151,7 @@ class UpdatingSessionIterator( private def closeCurrentSession(keyChanged: Boolean): Unit = { // FIXME: Convert to JoinRow if possible to reduce codegen for unsafe projection + // FIXME: Same approach on MergingSessionsIterator.generateGroupingKey doesn't work here, why? val sessionStruct = CreateNamedStruct( Literal("start") :: PreciseTimestampConversion( From 4dd0e89eb3e619ee8bcb5e744e9853ff6821b350 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 16 Oct 2018 07:00:40 +0900 Subject: [PATCH 41/60] WIP optimize a bit more on codegen... --- ...gingSortWithMultiValuesStateIterator.scala | 30 ++++++++++++++----- .../streaming/statefulOperators.scala | 13 ++++++-- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index 2b502f917300c..e58e7241396b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.streaming.state.MultiValuesStateManager // FIXME: javadoc!! @@ -28,11 +28,21 @@ class MergingSortWithMultiValuesStateIterator( stateManager: MultiValuesStateManager, groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, + keysProjection: UnsafeProjection, + sessionProjection: UnsafeProjection, inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { - val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, - inputSchema) - val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) + def this( + iter: Iterator[InternalRow], + stateManager: MultiValuesStateManager, + groupWithoutSessionExpressions: Seq[Expression], + sessionExpression: Expression, + inputSchema: Seq[Attribute]) { + this(iter, stateManager, groupWithoutSessionExpressions, sessionExpression, + GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, inputSchema), + GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema), + inputSchema) + } private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, row: InternalRow) @@ -123,10 +133,14 @@ class MergingSortWithMultiValuesStateIterator( // so sorting it here doesn't hurt. val unsortedIter = stateManager.get(currentRow.keys) currentStateIter = unsortedIter.map(_.copy()).toList.sortWith((row1, row2) => { - val rowInfo1 = SessionRowInformation.of(row1) - val rowInfo2 = SessionRowInformation.of(row2) + def getSessionStart(r: InternalRow): Long = { + val session = sessionProjection(r) + val sessionRow = session.getStruct(0, 2) + sessionRow.getLong(0) + } + // here sorting is based on the fact that keys are same - rowInfo1.sessionStart.compareTo(rowInfo2.sessionStart) < 0 + getSessionStart(row1).compareTo(getSessionStart(row2)) < 0 }).iterator currentStateFetchedKey = currentRow.keys diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index bfbd4eda5e5e5..9601a4788385d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -447,6 +447,11 @@ case class SessionWindowStateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (stateManager, iter) => + val keyWithoutSessionProjection = GenerateUnsafeProjection.generate( + keyWithoutSessionExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), + child.output) + // We need to filter out outdated inputs val filteredIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) @@ -458,6 +463,8 @@ case class SessionWindowStateStoreRestoreExec( stateManager, keyWithoutSessionExpressions, sessionExpression, + keyWithoutSessionProjection, + sessionProjection, child.output).map { row => numOutputRows += 1 row @@ -528,12 +535,12 @@ case class SessionWindowStateStoreSaveExec( val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") val commitTimeMs = longMetric("commitTimeMs") - val keyProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, - child.output) - var currentKey: UnsafeRow = null var previousSessions: List[UnsafeRow] = null + val keyProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, + child.output) + val keyOrdering = TypeUtils.getInterpretedOrdering(keyWithoutSessionExpressions.toStructType) val valueOrdering = TypeUtils.getInterpretedOrdering(child.output.toStructType) From cf520444e8ed01ed829133c6881c7db033c5d535 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 16 Oct 2018 07:37:16 +0900 Subject: [PATCH 42/60] WIP make the feature "merge session in local partition" optional * the option brings performance hit in cases there're less rows in each session in a batch * the option brings performance gain in cases there're plenty of rows in same session in same partition in source * former is more usual case, so setting the default value as 'false' --- .../apache/spark/sql/internal/SQLConf.scala | 12 +++++++ .../spark/sql/execution/SparkStrategies.scala | 1 + .../sql/execution/aggregate/AggUtils.scala | 17 ++++++---- .../StreamingSessionWindowSuite.scala | 33 ++++++++++++++----- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b699707d85235..675e66bd173be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1222,6 +1222,15 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") + .internal() + .doc("When true, streaming session window sorts and merge sessions in local partition " + + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + + "there're lots of rows in a batch being assigned to same sessions.") + .booleanConf + .createWithDefault(false) + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() @@ -1750,6 +1759,9 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def streamingSessionWindowMergeSessionInLocalPartition: Boolean = + getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7268d2bcc07f2..0774b9657951b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -340,6 +340,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sessionWindow, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, + conf.streamingSessionWindowMergeSessionInLocalPartition, planLater(child)) case None => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 9eadb1e5b74fc..5ade7852dbffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -350,9 +350,10 @@ object AggUtils { * * - Partial Aggregation * - all tuples will have aggregated columns with initial value - * - Sort within partition (sort: all keys) - * - SessionWindowStateStoreRestore (group: keys "without" session) - * - This will play as "Partial Merge" in each partition + * - (If "spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition" is enabled) + * - Sort within partition (sort: all keys) + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) * - SessionWindowStateStoreRestore (group: keys "without" session) * - merge input tuples with stored tuples (sessions) respecting sort order @@ -369,6 +370,7 @@ object AggUtils { sessionExpression: NamedExpression, functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], + mergeSessionsInLocalPartition: Boolean, child: SparkPlan): Seq[SparkPlan] = { val groupWithoutSessionExpression = groupingExpressions.filterNot { p => @@ -394,11 +396,12 @@ object AggUtils { child = child) } - // sort happens here to merge sessions on each partition - // this is to reduce amount of rows to shuffle - val partialMerged1: SparkPlan = { + val partialMerged1: SparkPlan = if (mergeSessionsInLocalPartition) { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + + // sort happens here to merge sessions on each partition + // this is to reduce amount of rows to shuffle MergingSessionsExec( requiredChildDistributionExpressions = None, requiredChildDistributionOption = None, @@ -411,6 +414,8 @@ object AggUtils { aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = partialAggregate ) + } else { + partialAggregate } // shuffle & sort happens here: most of details are also handled in this physical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index d4eca1de34b86..eb7e705a8ced6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, session_window, sum} +import org.apache.spark.sql.internal.SQLConf class StreamingSessionWindowSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -32,7 +33,21 @@ class StreamingSessionWindowSuite extends StreamTest sqlContext.streams.active.foreach(_.stop()) } - test("complete mode - session window") { + def testWithAllOptionsMergingSessionInLocalPartition(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + val key = SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key + val availableOptions = Seq(true, false) + + for (enabled <- availableOptions) { + test(s"$name - merging sessions in local partition: $enabled") { + withSQLConf(confPairs ++ Seq(key -> enabled.toString): _*) { + func + } + } + } + } + + testWithAllOptionsMergingSessionInLocalPartition("complete mode - session window") { // Implements StructuredSessionization.scala leveraging "session" function // as a test, to verify the sessionization works with simple example @@ -133,7 +148,7 @@ class StreamingSessionWindowSuite extends StreamTest ) } - test("complete mode - session window - no key") { + testWithAllOptionsMergingSessionInLocalPartition("complete mode - session window - no key") { // complete mode doesn't honor watermark: even it is specified, watermark will be // always Unix timestamp 0 @@ -181,7 +196,7 @@ class StreamingSessionWindowSuite extends StreamTest ) } - test("append mode - session window") { + testWithAllOptionsMergingSessionInLocalPartition("append mode - session window") { // Implements StructuredSessionization.scala leveraging "session" function // as a test, to verify the sessionization works with simple example @@ -281,7 +296,8 @@ class StreamingSessionWindowSuite extends StreamTest ) } - test("append mode - session window - storing multiple sessions in given key") { + testWithAllOptionsMergingSessionInLocalPartition("append mode - session window - " + + "storing multiple sessions in given key") { val inputData = MemoryStream[Int] val windowedAggregation = inputData.toDF() .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") @@ -324,7 +340,7 @@ class StreamingSessionWindowSuite extends StreamTest ) } - test("append mode - session window - no key") { + testWithAllOptionsMergingSessionInLocalPartition("append mode - session window - no key") { val inputData = MemoryStream[Int] val windowedAggregation = inputData.toDF() @@ -367,7 +383,7 @@ class StreamingSessionWindowSuite extends StreamTest ) } - test("update mode - session window") { + testWithAllOptionsMergingSessionInLocalPartition("update mode - session window") { // Implements StructuredSessionization.scala leveraging "session" function // as a test, to verify the sessionization works with simple example @@ -466,7 +482,8 @@ class StreamingSessionWindowSuite extends StreamTest ) } - test("update mode - session window - storing multiple sessions in given key") { + testWithAllOptionsMergingSessionInLocalPartition("update mode - session window - " + + "storing multiple sessions in given key") { val inputData = MemoryStream[Int] val windowedAggregation = inputData.toDF() .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") @@ -513,7 +530,7 @@ class StreamingSessionWindowSuite extends StreamTest ) } - test("update mode - session window - no key") { + testWithAllOptionsMergingSessionInLocalPartition("update mode - session window - no key") { val inputData = MemoryStream[Int] val windowedAggregation = inputData.toDF() From 5c746090a8d5560f043754383656d54653a315dc Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 17 Oct 2018 09:41:15 +0900 Subject: [PATCH 43/60] WIP add "session_window" to exclude list --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4c99a86deb19a..7b60ca0510aad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -107,7 +107,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - val excludes = Seq("cube", "grouping", "grouping_id", "rollup", "window", "session") + val excludes = Seq("cube", "grouping", "grouping_id", "rollup", "window", "session_window") if (!excludes.contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } From fb6c59fea5a6d97aa7f412627f9773324f12097a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 18 Oct 2018 18:26:55 +0900 Subject: [PATCH 44/60] WIP Enable versioning of session window state format * also moved most of the logic related to session window state to manager --- .../apache/spark/sql/internal/SQLConf.scala | 29 +++- .../spark/sql/execution/SparkStrategies.scala | 3 + .../sql/execution/aggregate/AggUtils.scala | 5 +- .../streaming/IncrementalExecution.scala | 6 +- ...gingSortWithMultiValuesStateIterator.scala | 6 +- .../sql/execution/streaming/OffsetSeq.scala | 5 +- .../state/StreamingSessionStateManager.scala | 159 ++++++++++++++++++ .../streaming/statefulOperators.scala | 118 ++++--------- ...ortWithMultiValuesStateIteratorSuite.scala | 65 ++++--- 9 files changed, 274 insertions(+), 122 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 675e66bd173be..db898a99e4630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -915,6 +915,26 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") + .internal() + .doc("When true, streaming session window sorts and merge sessions in local partition " + + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + + "there're lots of rows in a batch being assigned to same sessions.") + .booleanConf + .createWithDefault(false) + + val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion") + .internal() + .doc("State format version used by streaming session window aggregation operations " + + "in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .intConf + .checkValue(v => Set(1).contains(v), "Valid versions are 1") + .createWithDefault(1) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -1222,15 +1242,6 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) - val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = - buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") - .internal() - .doc("When true, streaming session window sorts and merge sessions in local partition " + - "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + - "there're lots of rows in a batch being assigned to same sessions.") - .booleanConf - .createWithDefault(false) - val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0774b9657951b..3a69cc5e54a09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -335,12 +335,15 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sessionWindowOption match { case Some(sessionWindow) => + val stateVersion = conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION) + aggregate.AggUtils.planStreamingAggregationForSession( namedGroupingExpressions, sessionWindow, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, conf.streamingSessionWindowMergeSessionInLocalPartition, + stateVersion, planLater(child)) case None => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 5ade7852dbffe..6042fc64a466a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -371,6 +371,7 @@ object AggUtils { functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], mergeSessionsInLocalPartition: Boolean, + stateFormatVersion: Int, child: SparkPlan): Seq[SparkPlan] = { val groupWithoutSessionExpression = groupingExpressions.filterNot { p => @@ -420,7 +421,8 @@ object AggUtils { // shuffle & sort happens here: most of details are also handled in this physical plan val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialMerged1) + sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, + stateFormatVersion, partialMerged1) val mergedSessions = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -448,6 +450,7 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, + stateFormatVersion, mergedSessions) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3a55cfab5dded..10ec0daeb24ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -119,9 +119,9 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) - case SessionWindowStateStoreSaveExec(keys, session, None, None, None, + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion, UnaryExecNode(agg, - SessionWindowStateStoreRestoreExec(_, _, None, None, child))) => + SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo SessionWindowStateStoreSaveExec( keys, @@ -129,12 +129,14 @@ class IncrementalExecution( Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, agg.withNewChildren( SessionWindowStateStoreRestoreExec( keys, session, Some(aggStateInfo), Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala index e58e7241396b8..3054b2268f41c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.streaming.state.MultiValuesStateManager +import org.apache.spark.sql.execution.streaming.state.StreamingSessionStateManager // FIXME: javadoc!! class MergingSortWithMultiValuesStateIterator( iter: Iterator[InternalRow], - stateManager: MultiValuesStateManager, + stateManager: StreamingSessionStateManager, groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, keysProjection: UnsafeProjection, @@ -34,7 +34,7 @@ class MergingSortWithMultiValuesStateIterator( def this( iter: Iterator[InternalRow], - stateManager: MultiValuesStateManager, + stateManager: StreamingSessionStateManager, groupWithoutSessionExpressions: Seq[Expression], sessionExpression: Expression, inputSchema: Seq[Attribute]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73cf355dbe758..7e2f7d95353ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager, StreamingSessionStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** @@ -89,7 +89,8 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, + STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala new file mode 100644 index 0000000000000..539f3339ab4d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.StructType + +// FIXME: javadoc! +sealed trait StreamingSessionStateManager extends Serializable { + def getKey(row: UnsafeRow): UnsafeRow + + def getStateValueSchema: StructType + + def get(key: UnsafeRow): Iterator[UnsafeRow] + + def append(session: UnsafeRow): Boolean + + def doFinalize(): Unit + + def getAll(): Iterator[UnsafeRow] + + def evictSessionsByWatermark(): Iterator[UnsafeRow] + + def doEvictSessionsByWatermark(): Unit +} + +// FIXME: javadoc! +trait MultiValuesStateManagerInjectable { + def setMultiValuesStateManager(manager: MultiValuesStateManager): Unit +} + +object StreamingSessionStateManager extends Logging { + val supportedVersions = Seq(1) + + def createStateManager( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + watermarkPredicateForData: Option[Predicate], + stateFormatVersion: Int): StreamingSessionStateManager = { + stateFormatVersion match { + case 1 => new StreamingSessionStateManagerImplV1(keyExpressions, inputRowAttributes, + watermarkPredicateForData) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } +} + +abstract class StreamingSessionStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val inputRowAttributes: Seq[Attribute]) extends StreamingSessionStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + + override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) +} + +// FIXME: javadoc! +class StreamingSessionStateManagerImplV1( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + watermarkPredicateForData: Option[Predicate]) + extends StreamingSessionStateManagerBaseImpl(keyExpressions, inputRowAttributes) + with MultiValuesStateManagerInjectable { + + var stateManager: MultiValuesStateManager = _ + var currentKey: UnsafeRow = _ + var previousSessions: List[UnsafeRow] = _ + + @transient protected lazy val keyOrdering = TypeUtils.getInterpretedOrdering( + keyExpressions.toStructType) + @transient protected lazy val valueOrdering = TypeUtils.getInterpretedOrdering( + inputRowAttributes.toStructType) + + override def setMultiValuesStateManager(manager: MultiValuesStateManager): Unit = { + stateManager = manager + } + + override def getStateValueSchema: StructType = inputRowAttributes.toStructType + + override def get(key: UnsafeRow): Iterator[UnsafeRow] = { + assertAvailability() + + stateManager.get(key) + } + + override def append(session: UnsafeRow): Boolean = { + assertAvailability() + + val key = keyProjector(session) + + if (currentKey == null || !keyOrdering.equiv(currentKey, key)) { + currentKey = key.copy() + + // This is necessary because MultiValuesStateManager doesn't guarantee + // stable ordering. + // The number of values for the given key is expected to be likely small, + // so listing it here doesn't hurt. + previousSessions = stateManager.get(key).toList + + stateManager.removeKey(key) + } + + stateManager.append(key, session) + + !previousSessions.exists(p => valueOrdering.equiv(session, p)) + } + + override def doFinalize(): Unit = { + assertAvailability() + + // do nothing + } + + override def getAll(): Iterator[UnsafeRow] = { + assertAvailability() + + stateManager.getAllRowPairs.map(_.value) + } + + override def evictSessionsByWatermark(): Iterator[UnsafeRow] = { + assertAvailability() + + stateManager.removeByValueCondition { row => watermarkPredicateForData match { + case Some(predicate) => predicate.eval(row) + case None => false + } + }.map(_.value) + } + + override def doEvictSessionsByWatermark(): Unit = { + assertAvailability() + + // consume all elements to let removal take effect + evictSessionsByWatermark().toList + } + + private def assertAvailability(): Unit = { + require(stateManager != null, "MultiValuesStateManager should be set before calling methods!") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 9601a4788385d..df2c1df7c274b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -431,6 +431,7 @@ case class SessionWindowStateStoreRestoreExec( sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader with WatermarkSupport { @@ -445,7 +446,17 @@ case class SessionWindowStateStoreRestoreExec( child.output.toStructType, indexOrdinal = None, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (stateManager, iter) => + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + + val stateManager = StreamingSessionStateManager.createStateManager( + keyWithoutSessionExpressions, child.output, watermarkPredicateForData, stateFormatVersion) + + stateManager match { + case mvState: MultiValuesStateManagerInjectable => + mvState.setMultiValuesStateManager(store) + case _ => throw new IllegalStateException("Session state manager is expected to work " + + "with MultiValuesStateManager") + } val keyWithoutSessionProjection = GenerateUnsafeProjection.generate( keyWithoutSessionExpressions, child.output) @@ -503,6 +514,7 @@ case class SessionWindowStateStoreSaveExec( stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -519,14 +531,16 @@ case class SessionWindowStateStoreSaveExec( child.output.toStructType, indexOrdinal = None, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (stateManager, iter) => + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - def evictSessionsByWatermark(manager: MultiValuesStateManager): Iterator[UnsafeRowPair] = { - manager.removeByValueCondition { row => watermarkPredicateForData match { - case Some(predicate) => predicate.eval(row) - case None => false - } - } + val stateManager = StreamingSessionStateManager.createStateManager( + keyWithoutSessionExpressions, child.output, watermarkPredicateForData, stateFormatVersion) + + stateManager match { + case mvState: MultiValuesStateManagerInjectable => + mvState.setMultiValuesStateManager(store) + case _ => throw new IllegalStateException("Session state manager is expected to work " + + "with MultiValuesStateManager") } val numOutputRows = longMetric("numOutputRows") @@ -535,48 +549,24 @@ case class SessionWindowStateStoreSaveExec( val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") val commitTimeMs = longMetric("commitTimeMs") - var currentKey: UnsafeRow = null - var previousSessions: List[UnsafeRow] = null - - val keyProjection = GenerateUnsafeProjection.generate(keyWithoutSessionExpressions, - child.output) - - val keyOrdering = TypeUtils.getInterpretedOrdering(keyWithoutSessionExpressions.toStructType) - val valueOrdering = TypeUtils.getInterpretedOrdering(child.output.toStructType) - // assuming late events were dropped from MergingSortWithMultiValuesStateIterator outputMode match { case Some(Complete) => allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val keys = keyProjection(row) - - if (currentKey == null || !keyOrdering.equiv(currentKey, keys)) { - currentKey = keys.copy() - - // This is necessary because MultiValuesStateManager doesn't guarantee - // stable ordering. - // The number of values for the given key is expected to be likely small, - // so listing it here doesn't hurt. - previousSessions = stateManager.get(keys).toList - - stateManager.removeKey(keys) - } - - stateManager.append(keys, row) - - if (!previousSessions.exists(p => valueOrdering.equiv(row, p))) { - // such session is not in previous session + if (stateManager.append(row)) { numUpdatedStateRows += 1 } } + + stateManager.doFinalize() } CompletionIterator[InternalRow, Iterator[InternalRow]]( - stateManager.getAllRowPairs.map(_.value), { - commitTimeMs += timeTakenMs { stateManager.commit() } - setStoreMetrics(stateManager) + stateManager.getAll(), { + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) }) // Update and output only sessions being evicted from the MultiValuesStateManager @@ -585,24 +575,7 @@ case class SessionWindowStateStoreSaveExec( allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val keys = keyProjection(row) - - if (currentKey == null || !keyOrdering.equiv(currentKey, keys)) { - currentKey = keys.copy() - - // This is necessary because MultiValuesStateManager doesn't guarantee - // stable ordering. - // The number of values for the given key is expected to be likely small, - // so listing it here doesn't hurt. - previousSessions = stateManager.get(keys).toList - - stateManager.removeKey(keys) - } - - stateManager.append(keys, row) - - if (!previousSessions.exists(p => valueOrdering.equiv(row, p))) { - // such session is not in previous session + if (stateManager.append(row)) { numUpdatedStateRows += 1 } } @@ -610,15 +583,15 @@ case class SessionWindowStateStoreSaveExec( val removalStartTimeNs = System.nanoTime - val retIter = evictSessionsByWatermark(stateManager).map(_.value).map { row => + val retIter = stateManager.evictSessionsByWatermark().map { row => numOutputRows += 1 row } CompletionIterator[InternalRow, Iterator[InternalRow]](retIter, { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - commitTimeMs += timeTakenMs { stateManager.commit() } - setStoreMetrics(stateManager) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) }) // Update and output modified rows from the MultiValuesStateManager. @@ -632,24 +605,7 @@ case class SessionWindowStateStoreSaveExec( while (ret == null && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val keys = keyProjection(row) - - if (currentKey == null || !keyOrdering.equiv(currentKey, keys)) { - currentKey = keys.copy() - - // This is necessary because MultiValuesStateManager doesn't guarantee - // stable ordering. - // The number of values for the given key is expected to be likely small, - // so listing it here doesn't hurt. - previousSessions = stateManager.get(keys).toList - - stateManager.removeKey(keys) - } - - stateManager.append(keys, row) - - if (!previousSessions.exists(p => valueOrdering.equiv(row, p))) { - // such session is not in previous session + if (stateManager.append(row)) { numUpdatedStateRows += 1 ret = row } @@ -668,15 +624,15 @@ case class SessionWindowStateStoreSaveExec( } override protected def close(): Unit = { + stateManager.doFinalize() allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) // Remove old aggregates if watermark specified allRemovalsTimeMs += timeTakenMs { - // fully consume iterator to let removal take effect - evictSessionsByWatermark(stateManager).toList + stateManager.doEvictSessionsByWatermark() } - commitTimeMs += timeTakenMs { stateManager.commit() } - setStoreMetrics(stateManager) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala index 7cd45b9f0a5da..5bb349ee38336 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.streaming.state.{MultiValuesStateManager, StateStore, StateStoreConf} +import org.apache.spark.sql.execution.streaming.state.{MultiValuesStateManager, MultiValuesStateManagerInjectable, StateStore, StateStoreConf, StreamingSessionStateManager} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -49,8 +49,12 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { attr => List("aggVal1", "aggVal2").contains(attr.name) } + // TODO: would we want to randomize or test all? + val stateStoreVersion = StreamingSessionStateManager.supportedVersions.last + test("no row in input data") { - withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, + stateStoreVersion) { manager => val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) @@ -59,7 +63,8 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { } test("no row in input data but having state") { - withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, + stateStoreVersion) { manager => val srow11 = createRow("a", 1, 55, 85, 50, 2.5) val srow12 = createRow("a", 1, 105, 140, 30, 2.0) appendRowToStateManager(manager, srow11, srow12) @@ -72,7 +77,8 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { } test("no previous state") { - withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, + stateStoreVersion) { manager => val row1 = createRow("a", 1, 100, 110, 10, 1.1) val row2 = createRow("a", 1, 100, 110, 20, 1.2) val row3 = createRow("a", 2, 110, 120, 10, 1.1) @@ -92,7 +98,8 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { } test("multiple keys in input data and state") { - withStateManager(rowAttributes, keysWithoutSessionAttributes) { manager => + withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, + stateStoreVersion) { manager => // key 1 - placing sessions in state to start and end val row11 = createRow("a", 1, 100, 110, 10, 1.1) val row12 = createRow("a", 1, 100, 110, 20, 1.2) @@ -179,11 +186,13 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { assert(doubleEquals(retRow.getDouble(2), expectedRow.getDouble(2))) } - def appendNoKeyRowToStateManager(manager: MultiValuesStateManager, rows: UnsafeRow*): Unit = { - rows.foreach(row => manager.append(new UnsafeRow(0), row)) + def appendNoKeyRowToStateManager(manager: StreamingSessionStateManager, rows: UnsafeRow*) + : Unit = { + rows.foreach(manager.append) } - withStateManager(noKeyRowAttributes, Seq.empty[Attribute]) { manager => + withStreamingSessionStateManager(noKeyRowAttributes, Seq.empty[Attribute], + stateStoreVersion) { manager => // only input data val row1 = createNoKeyRow(100, 110, 10, 1.1) val row2 = createNoKeyRow(100, 110, 20, 1.2) @@ -208,12 +217,6 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { } } - private def getKeyRow(row: UnsafeRow): UnsafeRow = { - val keyProjection = GenerateUnsafeProjection.generate(keysWithoutSessionAttributes, - rowAttributes) - keyProjection(row) - } - private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, aggVal1: Long, aggVal2: Double): UnsafeRow = { val genericRow = new GenericInternalRow(6) @@ -238,8 +241,9 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { rowProjection(genericRow) } - private def appendRowToStateManager(manager: MultiValuesStateManager, rows: UnsafeRow*): Unit = { - rows.foreach(row => manager.append(getKeyRow(row), row)) + private def appendRowToStateManager(manager: StreamingSessionStateManager, rows: UnsafeRow*) + : Unit = { + rows.foreach(row => manager.append(row)) } private def doubleEquals(value1: Double, value2: Double): Boolean = { @@ -255,19 +259,32 @@ class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { assert(doubleEquals(retRow.getDouble(3), expectedRow.getDouble(3))) } - private def withStateManager( + private def withStreamingSessionStateManager( inputValueAttribs: Seq[Attribute], - keyExprs: Seq[Expression])(f: MultiValuesStateManager => Unit): Unit = { + keyAttribs: Seq[Attribute], + stateVersion: Int)(f: StreamingSessionStateManager => Unit): Unit = { withTempDir { file => val storeConf = new StateStoreConf() val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) - val manager = new MultiValuesStateManager("session-", inputValueAttribs, keyExprs, - Some(stateInfo), storeConf, new Configuration) - try { - f(manager) - } finally { - manager.abortIfNeeded() + + val manager = StreamingSessionStateManager.createStateManager( + keyAttribs, inputValueAttribs, None, stateVersion) + + manager match { + case mvState: MultiValuesStateManagerInjectable => + val store = new MultiValuesStateManager("session-", inputValueAttribs, keyAttribs, + Some(stateInfo), storeConf, new Configuration) + mvState.setMultiValuesStateManager(store) + + try { + f(manager) + } finally { + + store.abortIfNeeded() + } + + case _ => throw new IllegalStateException("Should inject matching underlying state store!") } } StateStore.stop() From dd29af2799b341afcfdd75292947e6d30e0c9f30 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 19 Oct 2018 18:23:13 +0900 Subject: [PATCH 45/60] WIP some correction in comment --- .../streaming/StreamingSessionWindowSuite.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index eb7e705a8ced6..a372daf391344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -223,8 +223,8 @@ class StreamingSessionWindowSuite extends StreamTest // ("hello", 10, 21, 11, 2) // ("world", 10, 21, 11, 2) // ("spark", 10, 20, 10, 1) - // ("structured", 11, 21, 10, 2) - // ("streaming", 11, 21, 10, 2) + // ("structured", 11, 21, 10, 1) + // ("streaming", 11, 21, 10, 1) CheckNewAnswer(), AddData(inputData, ("spark streaming", 15L)), @@ -232,7 +232,7 @@ class StreamingSessionWindowSuite extends StreamTest // current sessions after batch: // ("hello", 10, 21, 11, 2) // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) + // ("structured", 11, 21, 10, 1) // ("spark", 10, 25, 15, 2) // ("streaming", 11, 25, 14, 2) CheckNewAnswer(), @@ -251,7 +251,7 @@ class StreamingSessionWindowSuite extends StreamTest AddData(inputData, ("hello world", 3L)), // input can match to not-yet-evicted sessions, but input itself is less than watermark - // so it should not match exiting sessions + // so it should not match existing sessions // Watermark kept 15 seconds // current sessions after batch: // ("hello", 10, 21, 11, 2) @@ -425,7 +425,7 @@ class StreamingSessionWindowSuite extends StreamTest // current sessions after batch: // ("hello", 10, 21, 11, 2) // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) + // ("structured", 11, 21, 10, 1) // ("spark", 10, 25, 15, 2) // ("streaming", 11, 25, 14, 2) CheckNewAnswer(("spark", 10, 25, 15, 2), ("streaming", 11, 25, 14, 2)), @@ -435,7 +435,7 @@ class StreamingSessionWindowSuite extends StreamTest // current sessions after batch: // ("hello", 10, 21, 11, 2) // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) + // ("structured", 11, 21, 10, 1) // ("spark", 10, 25, 15, 2) // ("streaming", 11, 25, 14, 2) // ("hello", 25, 35, 10, 1) @@ -449,7 +449,7 @@ class StreamingSessionWindowSuite extends StreamTest // current sessions after batch: // ("hello", 10, 21, 11, 2) // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) + // ("structured", 11, 21, 10, 1) // ("spark", 10, 25, 15, 2) // ("streaming", 11, 25, 14, 2) // ("hello", 25, 35, 10, 1) From 1f6e496c9b3474d37e735fcede4ac3587136de35 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 20 Oct 2018 22:15:38 +0900 Subject: [PATCH 46/60] WIP cover all the cases for session window in UTs --- .../StreamingSessionWindowSuite.scala | 488 ++++++++---------- 1 file changed, 207 insertions(+), 281 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index a372daf391344..1297b5408cf32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -70,80 +70,65 @@ class StreamingSessionWindowSuite extends StreamTest "numEvents") testStream(sessionUpdates, OutputMode.Complete())( - AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), - CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("spark", 10, 20, 10, 1), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 21, 10, 1) - ), - - AddData(inputData, ("spark streaming", 15L)), - CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("spark", 10, 25, 15, 2), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 25, 14, 2) + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) ), - - AddData(inputData, ("hello world", 25L)), CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("spark", 10, 25, 15, 2), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 25, 14, 2), - ("hello", 25, 35, 10, 1), - ("world", 25, 35, 10, 1) + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) ), - AddData(inputData, ("hello world", 3L)), + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), CheckNewAnswer( - ("hello", 3, 21, 18, 3), - ("world", 3, 21, 18, 3), - ("spark", 10, 25, 15, 2), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 25, 14, 2), - ("hello", 25, 35, 10, 1), - ("world", 25, 35, 10, 1) + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) ), - AddData(inputData, ("hello", 31L)), + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), CheckNewAnswer( - ("hello", 3, 21, 18, 3), - ("world", 3, 21, 18, 3), - ("spark", 10, 25, 15, 2), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 25, 14, 2), - ("hello", 25, 41, 16, 2), - ("world", 25, 35, 10, 1) + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) ), - AddData(inputData, ("hello", 35L)), + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), CheckNewAnswer( - ("hello", 3, 21, 18, 3), - ("world", 3, 21, 18, 3), - ("spark", 10, 25, 15, 2), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 25, 14, 2), - ("hello", 25, 45, 20, 3), - ("world", 25, 35, 10, 1) + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) ), - AddData(inputData, ("hello apache spark", 60L)), + AddData(inputData, ("structured streaming", 90L)), CheckNewAnswer( - ("hello", 3, 21, 18, 3), - ("world", 3, 21, 18, 3), - ("spark", 10, 25, 15, 2), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 25, 14, 2), - ("hello", 25, 45, 20, 3), - ("world", 25, 35, 10, 1), + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), ("hello", 60, 70, 10, 1), ("apache", 60, 70, 10, 1), - ("spark", 60, 70, 10, 1) + ("spark", 60, 70, 10, 1), + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) ) ) } @@ -207,7 +192,7 @@ class StreamingSessionWindowSuite extends StreamTest .select($"_1".as("value"), $"_2".as("timestamp")) .withColumn("eventTime", $"timestamp".cast("timestamp")) .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") - .withWatermark("eventTime", "10 seconds") + .withWatermark("eventTime", "30 seconds") val sessionUpdates = events .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) @@ -217,126 +202,91 @@ class StreamingSessionWindowSuite extends StreamTest "numEvents") testStream(sessionUpdates, OutputMode.Append())( - AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), - // Advance watermark to 1 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("spark", 10, 20, 10, 1) - // ("structured", 11, 21, 10, 1) - // ("streaming", 11, 21, 10, 1) - CheckNewAnswer(), - - AddData(inputData, ("spark streaming", 15L)), - // Advance watermark to 5 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 1) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - CheckNewAnswer(), + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), - AddData(inputData, ("hello world", 25L)), - // Advance watermark to 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(), + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), - AddData(inputData, ("hello world", 3L)), - // input can match to not-yet-evicted sessions, but input itself is less than watermark - // so it should not match existing sessions - // Watermark kept 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 2) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(), + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), - AddData(inputData, ("hello", 31L)), - // Advance watermark to 21 seconds - // current sessions after batch: - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 41, 16, 2) - // ("world", 25, 35, 10, 1) + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("structured", 11, 21, 10, 1) ), - AddData(inputData, ("hello", 35L)), - // Advance watermark to 25 seconds - // current sessions after batch: - // ("hello", 25, 45, 20, 3) - // ("world", 25, 35, 10, 1) + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) CheckNewAnswer( - ("spark", 10, 25, 15, 2), - ("streaming", 11, 25, 14, 2) ), + // placing new sessions after previous sessions AddData(inputData, ("hello apache spark", 60L)), - // Advance watermark to 50 seconds - // current sessions after batch: - // ("hello", 60, 70, 10, 1) - // ("apache", 60, 70, 10, 1) + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), // ("spark", 60, 70, 10, 1) - CheckNewAnswer(("hello", 25, 45, 20, 3), ("world", 25, 35, 10, 1)) - ) - } - - testWithAllOptionsMergingSessionInLocalPartition("append mode - session window - " + - "storing multiple sessions in given key") { - val inputData = MemoryStream[Int] - val windowedAggregation = inputData.toDF() - .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session_window($"eventTime", "5 seconds") as 'session, 'valuegroup) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"valuegroup", $"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + CheckNewAnswer( + ), - testStream(windowedAggregation, OutputMode.Append())( - AddData(inputData, 10, 11, 12, 13), - // Advance watermark to 3 seconds - // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 18, 2, 24) - CheckNewAnswer(), - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41) - CheckNewAnswer(), - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41), (25, 30, 1, 25) - CheckNewAnswer(), - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41), (25, 30, 1, 25), - // (35, 40, 1, 35) - // evicts: key 0 => (10, 17, 2, 22) / key 1 => (11, 22, 3, 41) - CheckNewAnswer((0, 10, 17, 2, 22), (1, 11, 22, 3, 41)), - AddData(inputData, 27), - // don't advance watermark - // sessions: key 1 => (25, 32, 2, 52), (35, 40, 1, 35) - CheckNewAnswer(), - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: key 0 => (40, 45, 1, 40) / key 1 => (25, 32, 2, 52), (35, 40, 1, 35) - CheckNewAnswer() + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ) ) } @@ -404,129 +354,105 @@ class StreamingSessionWindowSuite extends StreamTest "numEvents") testStream(sessionUpdates, OutputMode.Update())( - AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)), - // Advance watermark to 1 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("spark", 10, 20, 10, 1) - // ("structured", 11, 21, 10, 1) - // ("streaming", 11, 21, 10, 1) + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) CheckNewAnswer( - ("hello", 10, 21, 11, 2), - ("world", 10, 21, 11, 2), - ("spark", 10, 20, 10, 1), - ("structured", 11, 21, 10, 1), - ("streaming", 11, 21, 10, 1) + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) ), - AddData(inputData, ("spark streaming", 15L)), - // Advance watermark to 5 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 1) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - CheckNewAnswer(("spark", 10, 25, 15, 2), ("streaming", 11, 25, 14, 2)), - - AddData(inputData, ("hello world", 25L)), - // Advance watermark to 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 1) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(("hello", 25, 35, 10, 1), ("world", 25, 35, 10, 1)), - - AddData(inputData, ("hello world", 3L)), - // input can match to not-yet-evicted sessions, but input itself is less than watermark - // so it should not match exiting sessions - // Watermark kept 15 seconds - // current sessions after batch: - // ("hello", 10, 21, 11, 2) - // ("world", 10, 21, 11, 2) - // ("structured", 11, 21, 10, 1) - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 35, 10, 1) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(), + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1) + ), - AddData(inputData, ("hello", 31L)), - // Advance watermark to 21 seconds - // current sessions after batch: - // ("spark", 10, 25, 15, 2) - // ("streaming", 11, 25, 14, 2) - // ("hello", 25, 41, 16, 2) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(("hello", 25, 41, 16, 2)), + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), - AddData(inputData, ("hello", 35L)), - // Advance watermark to 25 seconds - // current sessions after batch: - // ("hello", 25, 45, 20, 3) - // ("world", 25, 35, 10, 1) - CheckNewAnswer(("hello", 25, 45, 20, 3)), + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4) + ), + // placing new sessions after previous sessions AddData(inputData, ("hello apache spark", 60L)), - // Advance watermark to 50 seconds - // current sessions after batch: - // ("hello", 60, 70, 10, 1) - // ("apache", 60, 70, 10, 1) + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), // ("spark", 60, 70, 10, 1) - CheckNewAnswer(("hello", 60, 70, 10, 1), ("apache", 60, 70, 10, 1), ("spark", 60, 70, 10, 1)) - ) - } - - testWithAllOptionsMergingSessionInLocalPartition("update mode - session window - " + - "storing multiple sessions in given key") { - val inputData = MemoryStream[Int] - val windowedAggregation = inputData.toDF() - .selectExpr("*", "CAST(MOD(value, 2) AS INT) AS valuegroup") - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(session_window($"eventTime", "5 seconds") as 'session, 'valuegroup) - .agg(count("*") as 'count, sum("value") as 'sum) - .select($"valuegroup", $"session".getField("start").cast("long").as[Long], - $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + CheckNewAnswer( + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), - testStream(windowedAggregation, OutputMode.Update())( - AddData(inputData, 10, 11, 12, 13), - // Advance watermark to 3 seconds - // sessions: key 0 => (10,17) / key 1 => (11, 18) - CheckNewAnswer((0, 10, 17, 2, 22), (1, 11, 18, 2, 24)), - AddData(inputData, 17), - // Advance watermark to 7 seconds - // sessions: key 0 => (10,17) / key 1 => (11, 22) - // updated: key 1 => (11,22) - CheckNewAnswer((1, 11, 22, 3, 41)), - AddData(inputData, 25), - // Advance watermark to 15 seconds - // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30) - // updated: key 1 => (25,30) - CheckNewAnswer((1, 25, 30, 1, 25)), - AddData(inputData, 35), - // Advance watermark to 25 seconds - // sessions: key 0 => (10,17) / key 1 => (11,22), (25,30), (35,40) - // updated: key 1 => (35,40) - // evicts: key 0 => (10,17) / key 1 => (11,22) - CheckNewAnswer((1, 35, 40, 1, 35)), - AddData(inputData, 27), - // don't advance watermark - // sessions: key 1 => (25,32), (35,40) - // updated: key 1 => (25,32) - CheckNewAnswer((1, 25, 32, 2, 52)), - AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckNewAnswer(), - AddData(inputData, 40), - // Advance watermark to 30 seconds - // sessions: key 0 => (40,45) / key 1 => (25,32), (35,40) - // updated: key 0 => (40,45) - CheckNewAnswer((0, 40, 45, 1, 40)) + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + // evicted + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) ) } From 0673b6e9b7f5be4138e71a8f5ed5f60e5c7ffa93 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 23 Oct 2018 16:57:05 +0900 Subject: [PATCH 47/60] WIP Add Linked List data structure for storing session windows --- .../state/SessionWindowLinkedListState.scala | 647 ++++++++++++++++++ .../SessionWindowLinkedListStateSuite.scala | 312 +++++++++ 2 files changed, 959 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala new file mode 100644 index 0000000000000..1eedd0ed34d24 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -0,0 +1,647 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.Locale + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.util.NextIterator + +// FIXME: javadoc!! +class SessionWindowLinkedListState( + storeNamePrefix: String, + inputValueAttributes: Seq[Attribute], + keys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) extends Logging { + + import SessionWindowLinkedListState._ + + /* + ===================================================== + Public methods + ===================================================== + */ + + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + keyToHeadSessionStartStore.get(key) match { + case Some(headSessionStart) => + new NextIterator[UnsafeRow] { + var curSessionStart: Option[Long] = Some(headSessionStart) + + override protected def getNext(): UnsafeRow = { + curSessionStart match { + case Some(sessionStart) => + val ret = keyAndSessionStartToValueStore.get(key, sessionStart) + curSessionStart = keyAndSessionStartToPointerStore.get(key, sessionStart)._2 + ret + + case None => + finished = true + null + } + } + + override protected def close(): Unit = {} + } + + case None => + Seq.empty[UnsafeRow].iterator + } + } + + def get(key: UnsafeRow, sessionStart: Long): UnsafeRow = { + keyAndSessionStartToValueStore.get(key, sessionStart) + } + + def setHead(key: UnsafeRow, sessionStart: Long, value: UnsafeRow): Unit = { + require(keyToHeadSessionStartStore.get(key).isEmpty, "Head should not be exist.") + + keyToHeadSessionStartStore.put(key, sessionStart) + keyAndSessionStartToPointerStore.put(key, sessionStart, None, None) + keyAndSessionStartToValueStore.put(key, sessionStart, value) + } + + def addBefore(key: UnsafeRow, sessionStart: Long, value: UnsafeRow, + targetSessionStart: Long): Unit = { + require(sessionStart < targetSessionStart) + + val targetPointer = keyAndSessionStartToPointerStore.get(key, targetSessionStart) + assertValidPointer(targetPointer) + + targetPointer._1 match { + case Some(prev) => + keyAndSessionStartToPointerStore.updateNext(key, prev, Some(sessionStart)) + keyAndSessionStartToPointerStore.updatePrev(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, + Some(prev), Some(targetSessionStart)) + + case None => + // we're changing head + keyAndSessionStartToPointerStore.updatePrev(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, None, Some(targetSessionStart)) + keyToHeadSessionStartStore.put(key, sessionStart) + } + + keyAndSessionStartToValueStore.put(key, sessionStart, value) + } + + def addAfter(key: UnsafeRow, sessionStart: Long, value: UnsafeRow, + targetSessionStart: Long): Unit = { + require(sessionStart > targetSessionStart) + + val targetPointer = keyAndSessionStartToPointerStore.get(key, targetSessionStart) + assertValidPointer(targetPointer) + + targetPointer._2 match { + case Some(next) => + keyAndSessionStartToPointerStore.updatePrev(key, next, Some(sessionStart)) + keyAndSessionStartToPointerStore.updateNext(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, Some(targetSessionStart), + Some(next)) + + case None => + keyAndSessionStartToPointerStore.updateNext(key, targetSessionStart, Some(sessionStart)) + keyAndSessionStartToPointerStore.put(key, sessionStart, Some(targetSessionStart), None) + } + + keyAndSessionStartToValueStore.put(key, sessionStart, value) + } + + def update(key: UnsafeRow, sessionStart: Long, newValue: UnsafeRow): Unit = { + val targetPointer = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(targetPointer) + + keyAndSessionStartToValueStore.put(key, sessionStart, newValue) + } + + def remove(key: UnsafeRow, sessionStart: Long): Unit = { + val targetPointer = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(targetPointer) + + val prevOption = targetPointer._1 + val nextOption = targetPointer._2 + + targetPointer match { + case (Some(prev), Some(next)) => + keyAndSessionStartToPointerStore.updateNext(key, prev, nextOption) + keyAndSessionStartToPointerStore.updatePrev(key, next, prevOption) + + case (Some(prev), None) => + keyAndSessionStartToPointerStore.updateNext(key, prev, None) + + case (None, Some(next)) => + keyAndSessionStartToPointerStore.updatePrev(key, next, None) + keyToHeadSessionStartStore.put(key, next) + + case (None, None) => + if (keyToHeadSessionStartStore.get(key).get != sessionStart) { + throw new IllegalStateException("The element has pointer information for head, " + + "but the list has different head.") + } + keyAndSessionStartToPointerStore.remove(key, sessionStart) + keyToHeadSessionStartStore.remove(key) + } + + keyAndSessionStartToValueStore.remove(key, sessionStart) + } + + def removeByValueCondition(removalCondition: UnsafeRow => Boolean, + stopOnConditionMismatch: Boolean = false): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() + + private val allKeysToHeadSessionStarts = keyToHeadSessionStartStore.iterator + + private var currentKey: UnsafeRow = null + private var currentSessionStart: Option[Long] = None + + override protected def getNext(): UnsafeRowPair = { + + // first setup + if (currentKey == null) { + if (!setupNextKey()) { + finished = true + return null + } + } + + val retVal = findNextValueToRemove() + if (retVal == null) { + finished = true + return null + } + + reusedPair.withRows(currentKey.copy(), retVal) + } + + override protected def close(): Unit = {} + + private def setupNextKey(): Boolean = { + if (!allKeysToHeadSessionStarts.hasNext) { + false + } else { + val keyAndHeadSessionStart = allKeysToHeadSessionStarts.next() + currentKey = keyAndHeadSessionStart.key.copy() + currentSessionStart = Some(keyAndHeadSessionStart.sessionStart) + true + } + } + + private def findNextValueToRemove(): UnsafeRow = { + var nextValue: UnsafeRow = null + while (nextValue == null) { + currentSessionStart match { + case Some(sessionStart) => + val pointers = keyAndSessionStartToPointerStore.get(currentKey, sessionStart) + val session = keyAndSessionStartToValueStore.get(currentKey, sessionStart) + + if (pointers == null || session == null) { + throw new IllegalStateException("Should not happen!") + } + + if (removalCondition(session)) { + nextValue = session + remove(currentKey, sessionStart) + currentSessionStart = pointers._2 + } else { + if (stopOnConditionMismatch) { + currentSessionStart = None + } else { + currentSessionStart = pointers._2 + } + } + + case None => + if (!setupNextKey()) { + return null + } + } + } + + nextValue + } + } + } + + def getAllRowPairs: Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() + + private val allKeysToHeadSessionStarts = keyToHeadSessionStartStore.iterator + + private var currentKey: UnsafeRow = _ + private var currentSessionStart: Option[Long] = None + + override def getNext(): UnsafeRowPair = { + // first setup + if (currentKey == null) { + if (!setupNextKey()) { + finished = true + return null + } + } + + val nextValue = findNextValue() + if (nextValue == null) { + finished = true + return null + } + + reusedPair.withRows(currentKey, nextValue) + } + + override def close(): Unit = {} + + private def setupNextKey(): Boolean = { + if (!allKeysToHeadSessionStarts.hasNext) { + false + } else { + val keyAndHeadSessionStart = allKeysToHeadSessionStarts.next() + currentKey = keyAndHeadSessionStart.key.copy() + currentSessionStart = Some(keyAndHeadSessionStart.sessionStart) + true + } + } + + private def findNextValue(): UnsafeRow = { + var nextValue: UnsafeRow = null + while (nextValue == null) { + currentSessionStart match { + case Some(sessionStart) => + val pointers = keyAndSessionStartToPointerStore.get(currentKey, sessionStart) + val session = keyAndSessionStartToValueStore.get(currentKey, sessionStart) + + currentSessionStart = pointers._2 + nextValue = session + + case None => + if (!setupNextKey()) { + finished = true + return null + } + } + } + + nextValue + } + + } + } + + /** Commit all the changes to all the state stores */ + def commit(): Unit = { + keyToHeadSessionStartStore.commit() + keyAndSessionStartToPointerStore.commit() + keyAndSessionStartToValueStore.commit() + } + + /** Abort any changes to the state stores if needed */ + def abortIfNeeded(): Unit = { + keyToHeadSessionStartStore.abortIfNeeded() + keyAndSessionStartToPointerStore.abortIfNeeded() + keyAndSessionStartToValueStore.abortIfNeeded() + } + + + /** Get the combined metrics of all the state stores */ + def metrics: StateStoreMetrics = { + val keyToHeadSessionStartMetrics = keyToHeadSessionStartStore.metrics + val keyAndSessionStartToPointerMetrics = keyAndSessionStartToPointerStore.metrics + val keyAndSessionStartToValueMetrics = keyAndSessionStartToValueStore.metrics + def newDesc(desc: String): String = s"${storeNamePrefix.toUpperCase(Locale.ROOT)}: $desc" + + val totalSize = keyToHeadSessionStartMetrics.memoryUsedBytes + + keyAndSessionStartToPointerMetrics.memoryUsedBytes + + keyAndSessionStartToValueMetrics.memoryUsedBytes + StateStoreMetrics( + keyAndSessionStartToValueMetrics.numKeys, // represent each buffered row only once + totalSize, + keyAndSessionStartToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSumMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomSizeMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomTimingMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s, _) => + throw new IllegalArgumentException( + s"Unknown state store custom metric is found at metrics: $s") + } + ) + } + + /* + ===================================================== + Private methods and inner classes + ===================================================== + */ + + private def assertValidPointer(targetPointer: (Option[Long], Option[Long])): Unit = { + if (targetPointer == null) { + throw new IllegalArgumentException("Update must be against existing session start.") + } + } + + private val keySchema = StructType( + keys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) + private val keyAttributes = keySchema.toAttributes + + private val keyToHeadSessionStartStore = new KeyToHeadSessionStartStore() + private val keyAndSessionStartToPointerStore = new KeyAndSessionStartToPointerStore() + private val keyAndSessionStartToValueStore = new KeyAndSessionStartToValueStore() + + // Clean up any state store resources if necessary at the end of the task + Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } + + /** Helper trait for invoking common functionalities of a state store. */ + private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { + + /** StateStore that the subclasses of this class is going to operate on */ + protected def stateStore: StateStore + + def commit(): Unit = { + stateStore.commit() + logDebug("Committed, metrics = " + stateStore.metrics) + } + + def abortIfNeeded(): Unit = { + if (!stateStore.hasCommitted) { + logInfo(s"Aborted store ${stateStore.id}") + stateStore.abort() + } + } + + def metrics: StateStoreMetrics = stateStore.metrics + + /** Get the StateStore with the given schema */ + protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { + val storeProviderId = StateStoreProviderId(stateInfo.get, TaskContext.getPartitionId(), + getStateStoreName(storeNamePrefix, stateStoreType)) + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, None, + stateInfo.get.storeVersion, storeConf, hadoopConf) + logInfo(s"Loaded store ${store.id}") + store + } + } + + /** + * Helper class for representing data returned by [[KeyToHeadSessionStartStore]]. + * Designed for object reuse. + */ + private case class KeyAndHeadSessionStart(var key: UnsafeRow = null, var sessionStart: Long = 0) { + def withNew(newKey: UnsafeRow, newSessionStart: Long): this.type = { + this.key = newKey + this.sessionStart = newSessionStart + this + } + } + + /** + * Helper class for representing data returned by [[KeyAndSessionStartToPointerStore]]. + * Designed for object reuse. + */ + private case class KeyWithSessionStartAndPointers( + var key: UnsafeRow = null, + var sessionStart: Long = 0, + var prevSessionStart: Option[Long] = None, + var nextSessionStart: Option[Long] = None) { + def withNew(newKey: UnsafeRow, sessionStart: Long, prevSessionStart: Option[Long], + nextSessionStart: Option[Long]): this.type = { + this.key = newKey + this.sessionStart = sessionStart + this.prevSessionStart = prevSessionStart + this.nextSessionStart = nextSessionStart + this + } + } + + /** + * Helper class for representing data returned by [[KeyAndSessionStartToValueStore]]. + * Designed for object reuse. + */ + private case class KeyWithSessionStartAndValue( + var key: UnsafeRow = null, + var sessionStart: Long = 0, + var value: UnsafeRow = null) { + def withNew(newKey: UnsafeRow, sessionStart: Long, newValue: UnsafeRow): this.type = { + this.key = newKey + this.sessionStart = sessionStart + this.value = newValue + this + } + } + + private class KeyToHeadSessionStartStore extends StateStoreHandler(KeyToHeadSessionStartType) { + private val longValueSchema = new StructType().add("value", "long") + private val longToUnsafeRow = UnsafeProjection.create(longValueSchema) + private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema)) + protected val stateStore: StateStore = getStateStore(keySchema, longValueSchema) + + /** Get the head of list via session start the key has */ + def get(key: UnsafeRow): Option[Long] = { + val longValueRow = stateStore.get(key) + if (longValueRow != null) { + Some(longValueRow.getLong(0)) + } else { + None + } + } + + /** Set the head of list via session start the key has */ + def put(key: UnsafeRow, sessionStart: Long): Unit = { + valueRow.setLong(0, sessionStart) + stateStore.put(key, valueRow) + } + + def remove(key: UnsafeRow): Unit = { + stateStore.remove(key) + } + + def iterator: Iterator[KeyAndHeadSessionStart] = { + val keyAndHeadSessionStart = KeyAndHeadSessionStart() + stateStore.getRange(None, None).map { pair => + keyAndHeadSessionStart.withNew(pair.key, pair.value.getLong(0)) + } + } + } + + private abstract class KeyAndSessionStartAsKeyStore(t: StateStoreType) + extends StateStoreHandler(t) { + protected val keyWithSessionStartExprs = keyAttributes :+ Literal(1L) + protected val keyWithSessionStartSchema = keySchema.add("sessionStart", LongType) + protected val indexOrdinalInKeyWithSessionStartRow = keyAttributes.size + + // Projection to generate (key + session start) row from key row + protected val keyWithSessionStartRowGenerator = UnsafeProjection.create( + keyWithSessionStartExprs, keyAttributes) + + // Projection to generate key row from (key + index) row + protected val keyRowGenerator = UnsafeProjection.create( + keyAttributes, keyAttributes :+ AttributeReference("sessionStart", LongType)()) + + /** Generated a row using the key and session start */ + protected def keyWithSessionStartRow(key: UnsafeRow, sessionStart: Long): UnsafeRow = { + val row = keyWithSessionStartRowGenerator(key) + row.setLong(indexOrdinalInKeyWithSessionStartRow, sessionStart) + row + } + } + + private class KeyAndSessionStartToPointerStore extends KeyAndSessionStartAsKeyStore( + KeyAndSessionStartToPointerType) { + private val doublyPointersValueSchema = new StructType() + .add("prev", "long", nullable = true).add("next", "long", nullable = true) + private val doublyPointersToUnsafeRow = UnsafeProjection.create(doublyPointersValueSchema) + private val valueRow = doublyPointersToUnsafeRow( + new SpecificInternalRow(doublyPointersValueSchema)) + protected val stateStore: StateStore = getStateStore(keySchema, doublyPointersValueSchema) + + /** Get the prev/next pointer of current session */ + def get(key: UnsafeRow, sessionStart: Long): (Option[Long], Option[Long]) = { + val actualRow = stateStore.get(keyWithSessionStartRow(key, sessionStart)) + if (actualRow != null) { + (getPrevSessionStart(actualRow), getNextSessionStart(actualRow)) + } else { + null + } + } + + def updatePrev(key: UnsafeRow, sessionStart: Long, prevSessionStart: Option[Long]): Unit = { + val actualKeyRow = keyWithSessionStartRow(key, sessionStart) + val row = stateStore.get(actualKeyRow) + setPrevSessionStart(row, prevSessionStart) + stateStore.put(actualKeyRow, row) + } + + def updateNext(key: UnsafeRow, sessionStart: Long, nextSessionStart: Option[Long]): Unit = { + val actualKeyRow = keyWithSessionStartRow(key, sessionStart) + val row = stateStore.get(actualKeyRow) + setNextSessionStart(row, nextSessionStart) + stateStore.put(actualKeyRow, row) + } + + /** Set the head of list via session start the key has */ + def put(key: UnsafeRow, sessionStart: Long, prevSessionStart: Option[Long], + nextSessionStart: Option[Long]): Unit = { + setPrevSessionStart(valueRow, prevSessionStart) + setNextSessionStart(valueRow, nextSessionStart) + stateStore.put(keyWithSessionStartRow(key, sessionStart), valueRow) + } + + def remove(key: UnsafeRow, sessionStart: Long): Unit = { + stateStore.remove(keyWithSessionStartRow(key, sessionStart)) + } + + def iterator: Iterator[KeyWithSessionStartAndPointers] = { + val keyWithSessionStartAndPointers = KeyWithSessionStartAndPointers() + stateStore.getRange(None, None).map { pair => + val keyPart = keyRowGenerator(pair.key) + val sessionStart = pair.key.getLong(indexOrdinalInKeyWithSessionStartRow) + val prevSessionStart = getPrevSessionStart(pair.value) + val nextSessionStart = getNextSessionStart(pair.value) + keyWithSessionStartAndPointers.withNew(keyPart, sessionStart, prevSessionStart, + nextSessionStart) + } + } + + private def getPrevSessionStart(value: UnsafeRow): Option[Long] = { + if (value.isNullAt(0)) { + None + } else { + Some(value.getLong(0)) + } + } + + private def setPrevSessionStart(value: UnsafeRow, sessionStart: Option[Long]): Unit = { + sessionStart match { + case Some(l) => value.setLong(0, l) + case None => value.setNullAt(0) + } + } + + private def getNextSessionStart(value: UnsafeRow): Option[Long] = { + if (value.isNullAt(1)) { + None + } else { + Some(value.getLong(1)) + } + } + + private def setNextSessionStart(value: UnsafeRow, sessionStart: Option[Long]): Unit = { + sessionStart match { + case Some(l) => value.setLong(1, l) + case None => value.setNullAt(1) + } + } + } + + private class KeyAndSessionStartToValueStore extends KeyAndSessionStartAsKeyStore( + KeyAndSessionStartToValueType) { + protected val stateStore = getStateStore(keyWithSessionStartSchema, + inputValueAttributes.toStructType) + + def get(key: UnsafeRow, sessionStart: Long): UnsafeRow = { + stateStore.get(keyWithSessionStartRow(key, sessionStart)) + } + + /** Put new value for key at the given index */ + def put(key: UnsafeRow, sessionStart: Long, value: UnsafeRow): Unit = { + val keyWithSessionStart = keyWithSessionStartRow(key, sessionStart) + stateStore.put(keyWithSessionStart, value) + } + + /** + * Remove key and value at given session start. + */ + def remove(key: UnsafeRow, sessionStart: Long): Unit = { + stateStore.remove(keyWithSessionStartRow(key, sessionStart)) + } + } +} + +object SessionWindowLinkedListState { + sealed trait StateStoreType + + case object KeyToHeadSessionStartType extends StateStoreType { + override def toString(): String = "keyToHeadSessionStart" + } + + case object KeyAndSessionStartToPointerType extends StateStoreType { + override def toString(): String = "keyAndSessionStartToPointer" + } + + case object KeyAndSessionStartToValueType extends StateStoreType { + override def toString(): String = "keyAndSessionStartToValue" + } + + def getStateStoreName(storeNamePrefix: String, storeType: StateStoreType): String = { + s"$storeNamePrefix-$storeType" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala new file mode 100644 index 0000000000000..4c2736b251c17 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + +class SessionWindowLinkedListStateSuite extends StreamTest { + + test("add sessions - normal case") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(get(20) === Seq.empty) + setHead(20, 3, time = 3) + assert(get(20) === Seq(3)) + assert(numRows === 1) + + // add element before head: 1 is the new head + addBefore(20, 1, time = 1, targetTime = 3) + assert(get(20) === Seq(1, 3)) + assert(numRows === 2) + + // add element before other element but after head + addBefore(20, 2, time = 2, targetTime = 3) + assert(get(20) === Seq(1, 2, 3)) + assert(numRows === 3) + + // add element at the end + addAfter(20, 5, time = 5, targetTime = 3) + assert(get(20) === Seq(1, 2, 3, 5)) + assert(numRows === 4) + + // add element after other element but before tail element + addAfter(20, 4, time = 4, targetTime = 3) + assert(get(20) === Seq(1, 2, 3, 4, 5)) + assert(numRows === 5) + + update(20, 100, time = 3) + assert(get(20) === Seq(1, 2, 100, 4, 5)) + assert(numRows === 5) + + assert(get(30) === Seq.empty) + setHead(30, 1, time = 1) + assert(get(30) === Seq(1)) + assert(get(20) === Seq(1, 2, 100, 4, 5)) + assert(numRows === 6) + } + } + + test("add sessions - improper usage") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(get(20) === Seq.empty) + + setHead(20, 2, time = 2) + // setting head twice + intercept[IllegalArgumentException] { + setHead(20, 2, time = 2) + } + + // add element with dangling pointer + intercept[IllegalArgumentException] { + addBefore(20, 1, time = 1, targetTime = 3) + } + + // add element with dangling pointer + intercept[IllegalArgumentException] { + addAfter(20, 2, time = 5, targetTime = 3) + } + } + } + + test("remove sessions - normal usage") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(numRows === 0) + + setHead(20, 1, time = 1) + addAfter(20, 2, time = 2, targetTime = 1) + addAfter(20, 3, time = 3, targetTime = 2) + addAfter(20, 4, time = 4, targetTime = 3) + assert(numRows === 4) + + // remove head which list has another elements as well + remove(20, time = 1) + assert(get(20) === Seq(2, 3, 4)) + assert(numRows === 3) + + // remove intermediate element + remove(20, time = 3) + assert(get(20) === Seq(2, 4)) + assert(numRows === 2) + + // remove tail element + remove(20, time = 4) + assert(get(20) === Seq(2)) + assert(numRows === 1) + + // remove head which list has only one element + remove(20, time = 2) + assert(get(20) === Seq.empty) + assert(numRows === 0) + } + } + + test("remove sessions - improper usage") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(get(20) === Seq.empty) + setHead(20, 2, time = 2) + + // try to remove non-exist time + intercept[IllegalArgumentException] { + remove(20, 3) + } + + assert(get(20) === Seq(2)) + assert(numRows === 1) + } + } + + test("get all pairs") { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + assert(numRows === 0) + + setHead(20, 1, time = 1) + addAfter(20, 2, time = 2, targetTime = 1) + addAfter(20, 3, time = 3, targetTime = 2) + addAfter(20, 4, time = 4, targetTime = 3) + + setHead(30, 5, time = 5) + addAfter(30, 6, time = 6, targetTime = 5) + addAfter(30, 7, time = 7, targetTime = 6) + addAfter(30, 8, time = 8, targetTime = 7) + + setHead(40, 10, time = 10) + addAfter(40, 11, time = 11, targetTime = 10) + addAfter(40, 12, time = 12, targetTime = 11) + addAfter(40, 13, time = 13, targetTime = 12) + + assert(numRows === 12) + + // must keep input order per key + val groupedTuples = getAllRowPairs.groupBy(_._1) + assert(groupedTuples(20).map(_._2) === Seq(1, 2, 3, 4)) + assert(groupedTuples(30).map(_._2) === Seq(5, 6, 7, 8)) + assert(groupedTuples(40).map(_._2) === Seq(10, 11, 12, 13)) + } + } + + test("remove by watermark - stop on condition mismatch == true") { + removeByWatermarkTest(stopOnConditionMismatch = true) + } + + test("remove by watermark - stop on condition mismatch == false") { + removeByWatermarkTest(stopOnConditionMismatch = false) + } + + private def removeByWatermarkTest(stopOnConditionMismatch: Boolean): Unit = { + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + assert(numRows === 0) + + setHead(20, 1, time = 1) + addAfter(20, 2, time = 2, targetTime = 1) + addAfter(20, 3, time = 3, targetTime = 2) + addAfter(20, 4, time = 4, targetTime = 3) + + setHead(30, 5, time = 5) + addAfter(30, 6, time = 6, targetTime = 5) + addAfter(30, 7, time = 7, targetTime = 6) + addAfter(30, 8, time = 8, targetTime = 7) + + setHead(40, 10, time = 10) + addAfter(40, 11, time = 11, targetTime = 10) + addAfter(40, 12, time = 12, targetTime = 11) + addAfter(40, 13, time = 13, targetTime = 12) + + assert(numRows === 12) + + // must keep input order per key + val groupedTuples = removeByValue(6, stopOnConditionMismatch).groupBy(_._1) + assert(groupedTuples(20).map(_._2) === Seq(1, 2, 3, 4)) + assert(groupedTuples(30).map(_._2) === Seq(5, 6)) + assert(groupedTuples.get(40).isEmpty) + + assert(get(20) === Seq.empty) + assert(get(30) === Seq(7, 8)) + assert(get(40) === Seq(10, 11, 12, 13)) + assert(numRows === 6) + } + } + + val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build() + val inputValueSchema = new StructType() + .add(StructField("time", IntegerType, metadata = watermarkMetadata)) + .add(StructField("value", BooleanType)) + val inputValueAttribs = inputValueSchema.toAttributes + val inputValueAttribWithWatermark = inputValueAttribs(0) + val keyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) + + val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray) + val keyGen = UnsafeProjection.create(keyExprs.map(_.dataType).toArray) + + def toInputValue(i: Int): UnsafeRow = { + inputValueGen.apply(new GenericInternalRow(Array[Any](i, false))) + } + + def toKeyRow(i: Int): UnsafeRow = { + keyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) + } + + def toKeyInt(inputKeyRow: UnsafeRow): Int = inputKeyRow.getInt(1) + + def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) + + def setHead(key: Int, value: Int, time: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.setHead(toKeyRow(key), time, toInputValue(value)) + } + + def addBefore(key: Int, value: Int, time: Int, targetTime: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.addBefore(toKeyRow(key), time, toInputValue(value), targetTime) + } + + def addAfter(key: Int, value: Int, time: Int, targetTime: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.addAfter(toKeyRow(key), time, toInputValue(value), targetTime) + } + + def update(key: Int, value: Int, time: Int) + (implicit state: SessionWindowLinkedListState): Unit = { + state.update(toKeyRow(key), time, toInputValue(value)) + } + + def remove(key: Int, time: Int)(implicit state: SessionWindowLinkedListState): Unit = { + state.remove(toKeyRow(key), time) + } + + def get(key: Int)(implicit state: SessionWindowLinkedListState): Seq[Int] = { + state.get(toKeyRow(key)).map(toValueInt).toSeq + } + + def getAllRowPairs(implicit state: SessionWindowLinkedListState): Seq[(Int, Int)] = { + state.getAllRowPairs + .map(pair => (toKeyInt(pair.key), toValueInt(pair.value))) + .toSeq + } + + /** Remove values where `time <= threshold` */ + def removeByValue(watermark: Long, stopOnConditionMismatch: Boolean) + (implicit state: SessionWindowLinkedListState) + : Seq[(Int, Int)] = { + val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) + state.removeByValueCondition( + GeneratePredicate.generate(expr, inputValueAttribs).eval _, + stopOnConditionMismatch) + .map(pair => (toKeyInt(pair.key), toValueInt(pair.value))) + .toSeq + } + + def numRows(implicit state: SessionWindowLinkedListState): Long = { + state.metrics.numKeys + } + + def withSessionWindowLinkedListState( + inputValueAttribs: Seq[Attribute], + keyExprs: Seq[Expression])(f: SessionWindowLinkedListState => Unit): Unit = { + + withTempDir { file => + val storeConf = new StateStoreConf() + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + val state = new SessionWindowLinkedListState("testing", inputValueAttribs, keyExprs, + Some(stateInfo), storeConf, new Configuration) + try { + f(state) + } finally { + state.abortIfNeeded() + } + } + StateStore.stop() + } +} From 4698f6d097ce7d5a4435d7c16ed3f5d18d3ce72b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 24 Oct 2018 08:19:43 +0900 Subject: [PATCH 48/60] WIP add SessionWindowLinkedListStateStoreRDD --- ...SessionWindowLinkedListStateStoreRDD.scala | 79 +++++++++++ .../execution/streaming/state/package.scala | 125 ++++++++++++------ 2 files changed, 163 insertions(+), 41 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala new file mode 100644 index 0000000000000..4594df679fb66 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker +import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +// FIXME: javadoc!! +class SessionWindowLinkedListStateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (SessionWindowLinkedListState, Iterator[T]) => Iterator[U], + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + sessionState: SessionState, + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + extends RDD[U](dataRDD) { + + private val storeConf = new StateStoreConf(sessionState.conf) + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val hadoopConfBroadcast = dataRDD.context.broadcast( + new SerializableConfiguration(sessionState.newHadoopConf())) + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ + override def getPreferredLocations(partition: Partition): Seq[String] = { + val stateStoreProviderId = StateStoreProviderId( + StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, partition.index), + stateInfo.queryRunId) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => stateInfo.storeVersion + case Some(value) => value + } + + val modifiedStateInfo = stateInfo.copy(storeVersion = currentVersion) + + val state = new SessionWindowLinkedListState(s"session-${stateInfo.operatorId}-", + valueSchema.toAttributes, keySchema.toAttributes, Some(modifiedStateInfo), storeConf, + hadoopConfBroadcast.value.value) + + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(state, inputIter) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 00b51c0bdcedc..d522a48404d29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -79,49 +79,92 @@ package object state { indexOrdinal, sessionState, storeCoordinator) - } + } - /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ - def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( - sqlContext: SQLContext, - stateInfo: StatefulOperatorStateInfo, - keySchema: StructType, - valueSchema: StructType, - indexOrdinal: Option[Int])( - storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) - : MultiValuesStateStoreRDD[T, U] = { - - mapPartitionsWithMultiValuesStateManager( - stateInfo, - keySchema, - valueSchema, - indexOrdinal, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator))( - storeUpdateFunction) - } + /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ + def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( + sqlContext: SQLContext, + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int])( + storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) + : MultiValuesStateStoreRDD[T, U] = { + + mapPartitionsWithMultiValuesStateManager( + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) + } + + /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ + private[streaming] def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + sessionState: SessionState, + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) + : MultiValuesStateStoreRDD[T, U] = { + + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new MultiValuesStateStoreRDD( + dataRDD, + cleanedF, + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sessionState, + storeCoordinator) + } + + /** Map each partition of an RDD along with data in a [[SessionWindowLinkedListState]]. */ + def mapPartitionsWithSessionWindowLinkedListState[U: ClassTag]( + sqlContext: SQLContext, + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int])( + storeUpdateFunction: (SessionWindowLinkedListState, Iterator[T]) => Iterator[U]) + : SessionWindowLinkedListStateStoreRDD[T, U] = { + + mapPartitionsWithSessionWindowLinkedListState( + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) + } + + /** Map each partition of an RDD along with data in a [[SessionWindowLinkedListState]]. */ + private[streaming] def mapPartitionsWithSessionWindowLinkedListState[U: ClassTag]( + stateInfo: StatefulOperatorStateInfo, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + sessionState: SessionState, + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (SessionWindowLinkedListState, Iterator[T]) => Iterator[U]) + : SessionWindowLinkedListStateStoreRDD[T, U] = { - /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ - private[streaming] def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( - stateInfo: StatefulOperatorStateInfo, - keySchema: StructType, - valueSchema: StructType, - indexOrdinal: Option[Int], - sessionState: SessionState, - storeCoordinator: Option[StateStoreCoordinatorRef])( - storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) - : MultiValuesStateStoreRDD[T, U] = { - - val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) - new MultiValuesStateStoreRDD( - dataRDD, - cleanedF, - stateInfo, - keySchema, - valueSchema, - indexOrdinal, - sessionState, - storeCoordinator) + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new SessionWindowLinkedListStateStoreRDD( + dataRDD, + cleanedF, + stateInfo, + keySchema, + valueSchema, + indexOrdinal, + sessionState, + storeCoordinator) } } } From 5c67f72e1d1c5539ad037045a1cee36ac3284879 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 24 Oct 2018 14:01:02 +0900 Subject: [PATCH 49/60] WIP add more functionalities to SessionWindowLinkedListState --- .../state/SessionWindowLinkedListState.scala | 77 +++++++++++++++++++ .../SessionWindowLinkedListStateSuite.scala | 63 ++++++++++++++- 2 files changed, 139 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala index 1eedd0ed34d24..a66e9bc7f6359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -76,6 +76,34 @@ class SessionWindowLinkedListState( keyAndSessionStartToValueStore.get(key, sessionStart) } + def iteratePointers(key: UnsafeRow): Iterator[(Long, Option[Long], Option[Long])] = { + keyToHeadSessionStartStore.get(key) match { + case Some(headSessionStart) => + new NextIterator[(Long, Option[Long], Option[Long])] { + var curSessionStart: Option[Long] = Some(headSessionStart) + + override protected def getNext(): (Long, Option[Long], Option[Long]) = { + curSessionStart match { + case Some(sessionStart) => + val ret = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(ret) + curSessionStart = ret._2 + (sessionStart, ret._1, ret._2) + + case None => + finished = true + null + } + } + + override protected def close(): Unit = {} + } + + case None => + Seq.empty[(Long, Option[Long], Option[Long])].iterator + } + } + def setHead(key: UnsafeRow, sessionStart: Long, value: UnsafeRow): Unit = { require(keyToHeadSessionStartStore.get(key).isEmpty, "Head should not be exist.") @@ -137,6 +165,55 @@ class SessionWindowLinkedListState( keyAndSessionStartToValueStore.put(key, sessionStart, newValue) } + def isEmpty(key: UnsafeRow): Boolean = { + keyToHeadSessionStartStore.get(key).isEmpty + } + + def findFirstSessionStartEnsurePredicate(key: UnsafeRow, predicate: Long => Boolean, + startIndex: Long): Option[Long] = { + + val pointers = keyAndSessionStartToPointerStore.get(key, startIndex) + assertValidPointer(pointers) + + var currentSessionStart: Option[Long] = Some(startIndex) + var ret: Option[Long] = None + var found = false + + while (!found && currentSessionStart.isDefined) { + val cur = currentSessionStart.get + if (predicate.apply(cur)) { + ret = Some(cur) + found = true + } else { + currentSessionStart = getNextSessionStart(key, cur) + } + } + + ret + } + + def findFirstSessionStartEnsurePredicate(key: UnsafeRow, predicate: Long => Boolean) + : Option[Long] = { + val head = keyToHeadSessionStartStore.get(key) + if (head.isEmpty) { + return None + } + + findFirstSessionStartEnsurePredicate(key, predicate, head.get) + } + + def getPrevSessionStart(key: UnsafeRow, sessionStart: Long): Option[Long] = { + val pointers = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(pointers) + pointers._1 + } + + def getNextSessionStart(key: UnsafeRow, sessionStart: Long): Option[Long] = { + val pointers = keyAndSessionStartToPointerStore.get(key, sessionStart) + assertValidPointer(pointers) + pointers._2 + } + def remove(key: UnsafeRow, sessionStart: Long): Unit = { val targetPointer = keyAndSessionStartToPointerStore.get(key, sessionStart) assertValidPointer(targetPointer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala index 4c2736b251c17..556394b66c17e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala @@ -146,7 +146,7 @@ class SessionWindowLinkedListStateSuite extends StreamTest { } } - test("get all pairs") { + test("get all pairs, iterate pointers, find first") { withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => implicit val st = state assert(numRows === 0) @@ -173,6 +173,36 @@ class SessionWindowLinkedListStateSuite extends StreamTest { assert(groupedTuples(20).map(_._2) === Seq(1, 2, 3, 4)) assert(groupedTuples(30).map(_._2) === Seq(5, 6, 7, 8)) assert(groupedTuples(40).map(_._2) === Seq(10, 11, 12, 13)) + + // iterate pointers + + val expected = Seq((1, None, Some(2)), (2, Some(1), Some(3)), (3, Some(2), Some(4)), + (4, Some(3), None)) + expected.foreach { case (current, expectedPrev, expectedNext) => + assert(getPrevTime(20, current) == expectedPrev) + assert(getNextTime(20, current) == expectedNext) + } + + assert(iterateTimes(20).toSeq === expected.map(s => (s._1, s._2, s._3))) + + // against non-exist key + assert(iterateTimes(100).toSeq === Seq.empty) + + // find first + + assert(findFirstTime(20, time => time > 0) === Some(1)) + assert(findFirstTime(20, time => time > 3) === Some(4)) + assert(findFirstTime(20, time => time > 5) === None) + + // using start time to skip elements + assert(findFirstTime(20, time => time > 0, startTime = 3) === Some(3)) + assert(findFirstTime(20, time => time > 3, startTime = 1) === Some(4)) + intercept[IllegalArgumentException] { + findFirstTime(20, time => time > 3, startTime = 7) + } + + // against non-exist key + assert(findFirstTime(100, time => time > 1) === None) } } @@ -270,6 +300,37 @@ class SessionWindowLinkedListStateSuite extends StreamTest { state.get(toKeyRow(key)).map(toValueInt).toSeq } + def iterateTimes(key: Int)(implicit state: SessionWindowLinkedListState) + : Iterator[(Int, Option[Int], Option[Int])] = { + state.iteratePointers(toKeyRow(key)).map { s => + (s._1.toInt, s._2.map(_.toInt), s._3.map(_.toInt)) + } + } + + def getPrevTime(key: Int, time: Int)(implicit state: SessionWindowLinkedListState) + : Option[Int] = { + state.getPrevSessionStart(toKeyRow(key), time).map(_.toInt) + } + + def getNextTime(key: Int, time: Int)(implicit state: SessionWindowLinkedListState) + : Option[Int] = { + state.getNextSessionStart(toKeyRow(key), time).map(_.toInt) + } + + def findFirstTime(key: Int, predicate: Int => Boolean) + (implicit state: SessionWindowLinkedListState): Option[Int] = { + val ret = state.findFirstSessionStartEnsurePredicate( + toKeyRow(key), (s: Long) => predicate.apply(s.intValue())) + ret.map(_.intValue()) + } + + def findFirstTime(key: Int, predicate: Int => Boolean, startTime: Int) + (implicit state: SessionWindowLinkedListState): Option[Int] = { + val ret = state.findFirstSessionStartEnsurePredicate( + toKeyRow(key), (s: Long) => predicate.apply(s.intValue()), startTime) + ret.map(_.intValue()) + } + def getAllRowPairs(implicit state: SessionWindowLinkedListState): Seq[(Int, Int)] = { state.getAllRowPairs .map(pair => (toKeyInt(pair.key), toValueInt(pair.value))) From 7bb0060f62331301ad6eaf3a6280fa196828f16c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 25 Oct 2018 18:33:05 +0900 Subject: [PATCH 50/60] WIP it works but a bit suboptimal --- ...SessionWindowLinkedListStateIterator.scala | 312 +++++++++++++ .../state/SessionWindowLinkedListState.scala | 28 +- .../streaming/statefulOperators.scala | 197 ++++++-- ...onWindowLinkedListStateIteratorSuite.scala | 433 ++++++++++++++++++ 4 files changed, 923 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala new file mode 100644 index 0000000000000..482079938d957 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution.streaming.state.SessionWindowLinkedListState + +// FIXME: javadoc!! +class MergingSortWithSessionWindowLinkedListStateIterator( + iter: Iterator[InternalRow], + state: SessionWindowLinkedListState, + groupWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + keysProjection: UnsafeProjection, + sessionProjection: UnsafeProjection, + inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { + + def this( + iter: Iterator[InternalRow], + state: SessionWindowLinkedListState, + groupWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + inputSchema: Seq[Attribute]) { + this(iter, state, groupWithoutSessionExpressions, sessionExpression, + GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, inputSchema), + GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema), + inputSchema) + } + + private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, + row: InternalRow) + + private object SessionRowInformation { + def of(row: InternalRow): SessionRowInformation = { + val keys = keysProjection(row).copy() + val session = sessionProjection(row).copy() + val sessionRow = session.getStruct(0, 2) + val sessionStart = sessionRow.getLong(0) + val sessionEnd = sessionRow.getLong(1) + + SessionRowInformation(keys, sessionStart, sessionEnd, row) + } + } + + private def findSessionPointerEnclosingEvent(row: SessionRowInformation, + startPointer: Option[Long]) + : Option[(Option[Long], Option[Long])] = { + val startOption = startPointer match { + case None => state.getFirstSessionStart(currentRow.keys) + case _ => startPointer + } + + startOption match { + // empty list + case None => None + case Some(start) => + var currOption: Option[Long] = Some(start) + + var enclosingSessions: Option[(Option[Long], Option[Long])] = None + while (enclosingSessions.isEmpty && currOption.isDefined) { + val curr = currOption.get + val newPrev = state.getPrevSessionStart(currentRow.keys, curr) + val newNext = state.getNextSessionStart(currentRow.keys, curr) + + val isEventEnclosed = newPrev match { + case Some(prev) => + prev <= currentRow.sessionStart && currentRow.sessionStart <= curr + case None => currentRow.sessionStart <= curr + } + + val willNotBeEnclosed = newPrev match { + case Some(prev) => prev > currentRow.sessionStart + case None => false + } + + if (isEventEnclosed) { + enclosingSessions = Some((newPrev, currOption)) + } else if (willNotBeEnclosed) { + enclosingSessions = Some((None, None)) + } else if (newNext.isEmpty) { + // curr is the last session in state + if (currentRow.sessionStart >= curr) { + enclosingSessions = Some((currOption, None)) + } else { + enclosingSessions = Some((None, None)) + } + } + + currOption = newNext + } + + // enclosingSessions should not be None unless list is empty + enclosingSessions + } + } + + private def isSessionsOverlap(s1: SessionRowInformation, s2: SessionRowInformation): Boolean = { + (s1.sessionStart >= s2.sessionStart && s1.sessionStart <= s2.sessionEnd) || + (s2.sessionStart >= s1.sessionStart && s2.sessionStart <= s1.sessionEnd) + } + + private var lastKey: UnsafeRow = _ + private var currentRow: SessionRowInformation = _ + + private val stateRowsToEmit: scala.collection.mutable.ListBuffer[SessionRowInformation] = + new scala.collection.mutable.ListBuffer[SessionRowInformation]() + private val stateRowsChecked: scala.collection.mutable.HashSet[SessionRowInformation] = + new scala.collection.mutable.HashSet[SessionRowInformation]() + + private var lastEmittedStateSessionKey: UnsafeRow = _ + private var lastEmittedStateSessionStartOption: Option[Long] = None + private var stateRowWaitForEmit: SessionRowInformation = _ + + private val keyOrdering: Ordering[UnsafeRow] = TypeUtils.getInterpretedOrdering( + groupWithoutSessionExpressions.toStructType).asInstanceOf[Ordering[UnsafeRow]] + + override def hasNext: Boolean = { + currentRow != null || iter.hasNext || stateRowsToEmit.nonEmpty + } + + override def next(): InternalRow = { + if (currentRow == null) { + mayFillCurrentRow() + } + + if (currentRow == null && stateRowsToEmit.isEmpty) { + throw new IllegalStateException("No Row to provide in next() which should not happen!") + } + + // early return on input rows vs state row waiting for emitting + val returnCurrentRow = if (currentRow == null) { + false + } else if (stateRowsToEmit.isEmpty) { + true + } else { + // compare between current row and state row waiting for emitting + val stateRow = stateRowsToEmit.head + if (!keyOrdering.equiv(currentRow.keys, stateRow.keys)) { + // state row cannot advance to row in input, so state row should be lower + false + } else { + currentRow.sessionStart < stateRow.sessionStart + } + } + + // if state row should be emitted, do emit + if (!returnCurrentRow) { + val stateRow = stateRowsToEmit.head + stateRowsToEmit.remove(0) + return stateRow.row + } + + if (lastKey == null || !keyOrdering.equiv(lastKey, currentRow.keys)) { + // new key + stateRowsToEmit.clear() + stateRowsChecked.clear() + lastKey = currentRow.keys + } + + // FIXME: how to provide start pointer to avoid reiterating? + val stateSessionsEnclosingCurrentRow = findSessionPointerEnclosingEvent(currentRow, + startPointer = None) + + stateSessionsEnclosingCurrentRow match { + case None => + case Some(x) => + x._1 match { + case Some(prev) => + val prevSession = SessionRowInformation.of(state.get(currentRow.keys, prev)) + + if (!stateRowsChecked.contains(prevSession)) { + // based on definition of session window and the fact that events are sorted, + // if the state session is not matched to this event, it will not be matched with + // later events as well + stateRowsChecked += prevSession + + if (isSessionsOverlap(currentRow, prevSession)) { + stateRowsToEmit += prevSession + } + } + + case None => + } + + x._2 match { + case Some(next) => + val nextSession = SessionRowInformation.of(state.get(currentRow.keys, next)) + + if (!stateRowsChecked.contains(nextSession)) { + // next session could be matched to latter events even it doesn't match to + // current event, so unless it is added to rows to emit, don't add to checked set + if (isSessionsOverlap(currentRow, nextSession)) { + stateRowsToEmit += nextSession + stateRowsChecked += nextSession + } + } + + case None => + } + } + + if (stateRowsToEmit.isEmpty) { + emitCurrentRow() + } else if (currentRow.sessionStart < stateRowsToEmit.head.sessionStart) { + emitCurrentRow() + } else { + val stateRow = stateRowsToEmit.head + stateRowsToEmit.remove(0) + stateRow.row + } + } + + private def emitCurrentRow(): InternalRow = { + val ret = currentRow + currentRow = null + ret.row + } + + private def emitStateRowForWaiting(): InternalRow = { + val ret = stateRowWaitForEmit + stateRowWaitForEmit = null + recordStateRowToEmit(ret) + ret.row + } + + private def recordStateRowToEmit(stateRow: SessionRowInformation) = { + lastEmittedStateSessionStartOption = Some(stateRow.sessionStart) + lastEmittedStateSessionKey = stateRow.keys + } + + private def mayFillCurrentRow(): Unit = { + if (iter.hasNext) { + currentRow = SessionRowInformation.of(iter.next()) + } + } + + private def currentRowIsSmallerThanWaitingStateRow(): Boolean = { + // compare between current row and state row waiting for emitting + if (!keyOrdering.equiv(currentRow.keys, stateRowWaitForEmit.keys)) { + // state row cannot advance to row in input, so state row should be lower + false + } else { + currentRow.sessionStart < stateRowWaitForEmit.sessionStart + } + } + + private def getEnclosingStatesForEvent(row: SessionRowInformation) + : (Option[SessionRowInformation], Option[SessionRowInformation]) = { + // find two state sessions wrapping current row + + if (lastEmittedStateSessionKey != null) { + mayInvalidateLastEmittedStateSession() + } + + val nextStateSessionStart: Option[Long] = lastEmittedStateSessionStartOption match { + case Some(lastEmittedStateSessionStart) => + state.findFirstSessionStartEnsurePredicate(currentRow.keys, + _ >= currentRow.sessionEnd, lastEmittedStateSessionStart) + case None => + state.findFirstSessionStartEnsurePredicate(currentRow.keys, + _ >= currentRow.sessionEnd) + } + + val prevStateSessionStart: Option[Long] = nextStateSessionStart match { + case Some(next) => state.getPrevSessionStart(currentRow.keys, next) + case None => state.getLastSessionStart(currentRow.keys) + } + + // only return sessions which overlap with current row + val pSession = if (prevStateSessionStart.isDefined) { + Some(SessionRowInformation.of(state.get(currentRow.keys, prevStateSessionStart.get))) + } else { + None + } + + val nSession = if (nextStateSessionStart.isDefined) { + Some(SessionRowInformation.of(state.get(currentRow.keys, nextStateSessionStart.get))) + } else { + None + } + + (pSession, nSession) + } + + private def mayInvalidateLastEmittedStateSession(): Unit = { + // invalidate last emitted state session key as well as session start + // if keys are changed + if (!keyOrdering.equiv(lastEmittedStateSessionKey, currentRow.keys)) { + lastEmittedStateSessionKey = null + lastEmittedStateSessionStartOption = None + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala index a66e9bc7f6359..890b7220f1237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -214,6 +214,32 @@ class SessionWindowLinkedListState( pointers._2 } + // FIXME: cover with test cases + def getFirstSessionStart(key: UnsafeRow): Option[Long] = { + keyToHeadSessionStartStore.get(key) + } + + // FIXME: cover with test cases + def getLastSessionStart(key: UnsafeRow): Option[Long] = { + getFirstSessionStart(key) match { + case Some(start) => getLastSessionStart(key, start) + case None => None + } + } + + // FIXME: cover with test cases + def getLastSessionStart(key: UnsafeRow, startIndex: Long): Option[Long] = { + val pointers = keyAndSessionStartToPointerStore.get(key, startIndex) + assertValidPointer(pointers) + + var lastSessionStart = startIndex + while (getNextSessionStart(key, lastSessionStart).isDefined) { + lastSessionStart = getNextSessionStart(key, lastSessionStart).get + } + + Some(lastSessionStart) + } + def remove(key: UnsafeRow, sessionStart: Long): Unit = { val targetPointer = keyAndSessionStartToPointerStore.get(key, sessionStart) assertValidPointer(targetPointer) @@ -441,7 +467,7 @@ class SessionWindowLinkedListState( private def assertValidPointer(targetPointer: (Option[Long], Option[Long])): Unit = { if (targetPointer == null) { - throw new IllegalArgumentException("Update must be against existing session start.") + throw new IllegalArgumentException("Invalid pointer is provided.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index df2c1df7c274b..3dcbf3c8c835e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -440,23 +440,13 @@ case class SessionWindowStateStoreRestoreExec( override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsWithMultiValuesStateManager( + child.execute().mapPartitionsWithSessionWindowLinkedListState( getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - - val stateManager = StreamingSessionStateManager.createStateManager( - keyWithoutSessionExpressions, child.output, watermarkPredicateForData, stateFormatVersion) - - stateManager match { - case mvState: MultiValuesStateManagerInjectable => - mvState.setMultiValuesStateManager(store) - case _ => throw new IllegalStateException("Session state manager is expected to work " + - "with MultiValuesStateManager") - } + Some(sqlContext.streams.stateStoreCoordinator)) { case (state, iter) => val keyWithoutSessionProjection = GenerateUnsafeProjection.generate( keyWithoutSessionExpressions, child.output) @@ -469,9 +459,9 @@ case class SessionWindowStateStoreRestoreExec( case None => iter } - new MergingSortWithMultiValuesStateIterator( + new MergingSortWithSessionWindowLinkedListStateIterator( filteredIterator, - stateManager, + state, keyWithoutSessionExpressions, sessionExpression, keyWithoutSessionProjection, @@ -525,23 +515,13 @@ case class SessionWindowStateStoreSaveExec( assert(outputMode.nonEmpty, "Incorrect planning in IncrementalExecution, outputMode has not been set") - child.execute().mapPartitionsWithMultiValuesStateManager( + child.execute().mapPartitionsWithSessionWindowLinkedListState( getStateInfo, keyWithoutSessionExpressions.toStructType, child.output.toStructType, indexOrdinal = None, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - - val stateManager = StreamingSessionStateManager.createStateManager( - keyWithoutSessionExpressions, child.output, watermarkPredicateForData, stateFormatVersion) - - stateManager match { - case mvState: MultiValuesStateManagerInjectable => - mvState.setMultiValuesStateManager(store) - case _ => throw new IllegalStateException("Session state manager is expected to work " + - "with MultiValuesStateManager") - } + Some(sqlContext.streams.stateStoreCoordinator)) { (state, iter) => val numOutputRows = longMetric("numOutputRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") @@ -549,25 +529,144 @@ case class SessionWindowStateStoreSaveExec( val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") val commitTimeMs = longMetric("commitTimeMs") - // assuming late events were dropped from MergingSortWithMultiValuesStateIterator + val keyProjection = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), + child.output) + + val keyOrdering = TypeUtils.getInterpretedOrdering(keyExpressions.toStructType) + .asInstanceOf[Ordering[UnsafeRow]] + val valueOrdering = TypeUtils.getInterpretedOrdering(child.output.toStructType) + .asInstanceOf[Ordering[UnsafeRow]] + + var lastSearchedSessionStartOption: Option[Long] = None + var stateFetchedKey: UnsafeRow = null + + def reflectNewSession(row: UnsafeRow): Boolean = { + val key = keyProjection(row) + val session = sessionProjection(row).getStruct(0, 2) + val sessionStart = session.getLong(0) + val sessionEnd = session.getLong(1) + + if (state.isEmpty(key)) { + state.setHead(key, sessionStart, row) + return true + } + + val existing = state.get(key, sessionStart) + if (existing != null && valueOrdering.equiv(existing, row)) { + // session already exist and do not need to update + return false + } + + if (stateFetchedKey == null || keyOrdering.equiv(stateFetchedKey, key)) { + stateFetchedKey = key + lastSearchedSessionStartOption = None + } + + // need to find sessions which could be replaced with new session + // new session should enclose previous session(s) if it overlaps, + // since session always expands + + // find the first state session which is enclosed by new session + val firstStateSessionEnclosedByNewSession = lastSearchedSessionStartOption match { + case Some(lastSearchedSessionStart) => + state.findFirstSessionStartEnsurePredicate(key, start => start >= sessionStart, + lastSearchedSessionStart) + + case None => + state.findFirstSessionStartEnsurePredicate(key, start => start >= sessionStart) + } + + firstStateSessionEnclosedByNewSession match { + case Some(firstStateSessionStart) => + // get previous earlier to enable addAfter on new session after removal + val prevForFirstStateSession = state.getPrevSessionStart(key, firstStateSessionStart) + + // search and remove sessions which is enclosed by new session + var currentStateSessionStart: Option[Long] = Some(firstStateSessionStart) + var stop = false + while (!stop && currentStateSessionStart.isDefined) { + val stateSession = state.get(key, currentStateSessionStart.get) + + val stateSessionStart = sessionProjection(stateSession).getStruct(0, 2).getLong(0) + val stateSessionEnd = sessionProjection(stateSession).getStruct(0, 2).getLong(1) + + require(stateSessionStart == currentStateSessionStart.get, + "Session pointer doesn't match with actual session start!") + + // get next to continue searching after removal + val nextStateSessionStart = state.getNextSessionStart(key, stateSessionStart) + + // remove session if it is enclosed + if (stateSessionStart >= sessionStart && stateSessionEnd <= sessionEnd) { + state.remove(key, stateSessionStart) + currentStateSessionStart = nextStateSessionStart + } else { + stop = true + } + } + + // currentStateSessionStart is now the earliest session in state which + // new session should be added before + (prevForFirstStateSession, currentStateSessionStart) match { + case (_, Some(next)) => + state.addBefore(key, sessionStart, row, next) + lastSearchedSessionStartOption = Some(sessionStart) + + case (Some(prev), None) => + state.addAfter(key, sessionStart, row, prev) + lastSearchedSessionStartOption = Some(sessionStart) + + case (None, None) => + // we removed all elements + require(state.isEmpty(key), "It must be empty list since all elements are removed!") + state.setHead(key, sessionStart, row) + } + + case None => + // add to last: we got rid of the case list is empty + val lastSessionStartOption = lastSearchedSessionStartOption match { + case Some(lastSearchedSessionStart) => + state.getLastSessionStart(key, lastSearchedSessionStart) + case None => state.getLastSessionStart(key) + } + + lastSessionStartOption match { + case Some(lastSessionStart) => + state.addAfter(key, sessionStart, row, lastSessionStart) + + case None => + throw new IllegalStateException("List should not be empty!") + } + } + + // we don't need to search before the start of new session, since new sessions are sorted + // by session start + lastSearchedSessionStartOption = Some(sessionStart) + + true + } + + // assuming late events were dropped before + outputMode match { case Some(Complete) => allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - if (stateManager.append(row)) { + + if (reflectNewSession(row)) { numUpdatedStateRows += 1 } } - - stateManager.doFinalize() } CompletionIterator[InternalRow, Iterator[InternalRow]]( - stateManager.getAll(), { - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - }) + state.getAllRowPairs.map(_.value), { + commitTimeMs += timeTakenMs { state.commit() } + setStoreMetrics(state) + } + ) // Update and output only sessions being evicted from the MultiValuesStateManager // Assumption: watermark predicates must be non-empty if append mode is allowed @@ -575,7 +674,7 @@ case class SessionWindowStateStoreSaveExec( allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - if (stateManager.append(row)) { + if (reflectNewSession(row)) { numUpdatedStateRows += 1 } } @@ -583,15 +682,18 @@ case class SessionWindowStateStoreSaveExec( val removalStartTimeNs = System.nanoTime - val retIter = stateManager.evictSessionsByWatermark().map { row => + val retIter = state.removeByValueCondition(row => watermarkPredicateForData match { + case Some(predicate) => predicate.eval(row) + case None => false + }, stopOnConditionMismatch = true).map { row => numOutputRows += 1 - row + row.value } CompletionIterator[InternalRow, Iterator[InternalRow]](retIter, { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) + commitTimeMs += timeTakenMs { state.commit() } + setStoreMetrics(state) }) // Update and output modified rows from the MultiValuesStateManager. @@ -605,7 +707,7 @@ case class SessionWindowStateStoreSaveExec( while (ret == null && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - if (stateManager.append(row)) { + if (reflectNewSession(row)) { numUpdatedStateRows += 1 ret = row } @@ -624,15 +726,18 @@ case class SessionWindowStateStoreSaveExec( } override protected def close(): Unit = { - stateManager.doFinalize() allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) // Remove old aggregates if watermark specified allRemovalsTimeMs += timeTakenMs { - stateManager.doEvictSessionsByWatermark() + // fully consume iterator to ensure all necessary elements are evicted + state.removeByValueCondition(row => watermarkPredicateForData match { + case Some(predicate) => predicate.eval(row) + case None => false + }, stopOnConditionMismatch = true).toList } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) + commitTimeMs += timeTakenMs { state.commit() } + setStoreMetrics(state) } } @@ -663,8 +768,8 @@ case class SessionWindowStateStoreSaveExec( newMetadata.batchWatermarkMs > eventTimeWatermark.get } - protected def setStoreMetrics(manager: MultiValuesStateManager): Unit = { - val storeMetrics = manager.metrics + protected def setStoreMetrics(state: SessionWindowLinkedListState): Unit = { + val storeMetrics = state.metrics longMetric("numTotalStateRows") += storeMetrics.numKeys longMetric("stateMemory") += storeMetrics.memoryUsedBytes storeMetrics.customMetrics.foreach { case (metric, value) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala new file mode 100644 index 0000000000000..546e598cb62af --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{SessionWindowLinkedListState, StateStore, StateStoreConf} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class MergingSortWithSessionWindowLinkedListStateIteratorSuite extends SharedSQLContext { + + val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val rowAttributes = rowSchema.toAttributes + + val keysWithoutSessionSchema = rowSchema.filter(st => List("key1", "key2").contains(st.name)) + val keysWithoutSessionAttributes = rowAttributes.filter { + attr => List("key1", "key2").contains(attr.name) + } + + val sessionSchema = rowSchema.filter(st => st.name == "session").head + val sessionAttribute = rowAttributes.filter(attr => attr.name == "session").head + + val valuesSchema = rowSchema.filter(st => List("aggVal1", "aggVal2").contains(st.name)) + val valuesAttributes = rowAttributes.filter { + attr => List("aggVal1", "aggVal2").contains(attr.name) + } + + val keyProjection = GenerateUnsafeProjection.generate(keysWithoutSessionAttributes, rowAttributes) + val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionAttribute), rowAttributes) + + test("no row in input data") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(None.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + assert(!iterator.hasNext) + } + } + + test("no row in input data but having state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + val srow11 = createRow("a", 1, 55, 85, 50, 2.5) + val srow12 = createRow("a", 1, 105, 140, 30, 2.0) + + setRowsInState(state, keyProjection(srow11), srow11, srow12) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(None.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + assert(!iterator.hasNext) + } + } + + test("no previous state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + val row1 = createRow("a", 1, 100, 110, 10, 1.1) + val row2 = createRow("a", 1, 100, 110, 20, 1.2) + val row3 = createRow("a", 2, 110, 120, 10, 1.1) + val row4 = createRow("a", 2, 115, 125, 20, 1.2) + val rows = List(row1, row2, row3, row4) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + rows.foreach { row => + assert(iterator.hasNext) + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("single previous state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + // key1 - events are earlier than state + val row11 = createRow("a", 1, 100, 110, 10, 1.1) + val row12 = createRow("a", 1, 110, 120, 20, 1.2) + + // below will not be picked up since the session is not matched to new events + val srow11 = createRow("a", 1, 200, 220, 10, 1.3) + setRowsInState(state, keyProjection(srow11), srow11) + + // key2 - events are later than state + // below will not be picked up since the session is not matched to new events + val srow21 = createRow("a", 2, 50, 70, 10, 1.1) + + val row21 = createRow("a", 2, 100, 110, 10, 1.1) + val row22 = createRow("a", 2, 110, 120, 20, 1.2) + setRowsInState(state, keyProjection(srow21), srow21) + + // key3 - events are enclosing the state + val row31 = createRow("a", 3, 90, 100, 10, 1.1) + val srow31 = createRow("a", 3, 100, 110, 10, 1.1) + val row32 = createRow("a", 3, 105, 115, 20, 1.2) + setRowsInState(state, keyProjection(srow31), srow31) + + val rows = List(row11, row12) ++ List(row21, row22) ++ List(row31, row32) + + val expectedRows = List(row11, row12) ++ List(row21, row22) ++ + List(row31, srow31, row32) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + expectedRows.foreach { row => + assert(iterator.hasNext, "Iterator.hasNext is false while we expected row " + + s"${getTupleFromRow(row)}") + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("only emitting sessions in state which enclose events") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + // below example is group by line separated + + val row1 = createRow("a", 1, 10, 20, 1, 1.1) + val row2 = createRow("a", 1, 20, 30, 1, 1.1) + val row3 = createRow("a", 1, 30, 40, 1, 1.1) + val srow1 = createRow("a", 1, 40, 60, 2, 2.2) + val row4 = createRow("a", 1, 40, 50, 1, 1.1) + + // below will not be picked up since the session is not matched to new events + val srow2 = createRow("a", 1, 80, 90, 2, 2.2) + val srow3 = createRow("a", 1, 100, 110, 2, 2.2) + + val srow4 = createRow("a", 1, 120, 130, 2, 2.2) + val row5 = createRow("a", 1, 125, 135, 1, 1.1) + val row6 = createRow("a", 1, 140, 150, 1, 1.1) + + // below will not be picked up since the session is not matched to new events + val srow5 = createRow("a", 1, 180, 200, 2, 2.2) + val srow6 = createRow("a", 1, 220, 260, 2, 2.2) + + setRowsInState(state, keyProjection(srow1), srow1, srow2, srow3, srow4, srow5, srow6) + + val rows = List(row1, row2, row3, row4, row5, row6) + + val expectedRowSequence = List(row1, row2, row3, srow1, row4, srow4, + row5, row6) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext) + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("multiple keys in input data and state") { + withSessionWindowLinkedListState(rowAttributes, keysWithoutSessionAttributes) { state => + // key 1 - placing sessions in state to start and end + val srow11 = createRow("a", 1, 85, 105, 50, 2.5) + val row11 = createRow("a", 1, 100, 110, 10, 1.1) + val row12 = createRow("a", 1, 100, 110, 20, 1.2) + val srow12 = createRow("a", 1, 105, 140, 30, 2.0) + + val key1 = keyProjection(srow11) + setRowsInState(state, key1, srow11, srow12) + + val rowsForKey1 = List(row11, row12) + val expectedForKey1 = List(srow11, row11, row12, srow12) + + // key 2 - no state + val row21 = createRow("a", 2, 110, 120, 10, 1.1) + val row22 = createRow("a", 2, 115, 125, 20, 1.2) + + val rowsForKey2 = List(row21, row22) + val expectedForKey2 = List(row21, row22) + + // key 3 - placing sessions in state to only start + + // below will not be picked up since the session is not matched to new events + val srow31 = createRow("a", 3, 105, 115, 30, 2.0) + val srow32 = createRow("a", 3, 120, 125, 30, 2.0) + + val row31 = createRow("a", 3, 130, 140, 10, 1.1) + val row32 = createRow("a", 3, 135, 145, 20, 1.2) + + val key3 = keyProjection(srow31) + setRowsInState(state, key3, srow31, srow32) + + val rowsForKey3 = List(row31, row32) + val expectedForKey3 = List(row31, row32) + + // key 4 - placing sessions in state to only end + val row41 = createRow("a", 4, 100, 110, 10, 1.1) + val row42 = createRow("a", 4, 100, 115, 20, 1.2) + + // below will not be picked up since the session is not matched to new events + val srow41 = createRow("a", 4, 120, 140, 30, 2.0) + val srow42 = createRow("a", 4, 150, 180, 30, 2.0) + + val key4 = keyProjection(srow41) + setRowsInState(state, key4, srow41, srow42) + + val rowsForKey4 = List(row41, row42) + val expectedForKey4 = List(row41, row42) + + // key 5 - placing sessions in state like one row and state session and another + val srow51 = createRow("a", 5, 90, 120, 30, 2.0) + val row51 = createRow("a", 5, 100, 110, 10, 1.1) + val srow52 = createRow("a", 5, 130, 155, 30, 2.0) + val row52 = createRow("a", 5, 140, 160, 20, 1.2) + val srow53 = createRow("a", 5, 160, 190, 30, 2.0) + + val key5 = keyProjection(srow51) + setRowsInState(state, key5, srow51, srow52, srow53) + + val rowsForKey5 = List(row51, row52) + val expectedForKey5 = List(srow51, row51, srow52, row52, srow53) + + val rows = rowsForKey1 ++ rowsForKey2 ++ rowsForKey3 ++ rowsForKey4 ++ rowsForKey5 + + val expectedRowSequence = expectedForKey1 ++ expectedForKey2 ++ expectedForKey3 ++ + expectedForKey4 ++ expectedForKey5 + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext, s"Iterator closed while we expect ${getTupleFromRow(row)}") + assertRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + test("no keys in input data and state") { + val noKeyRowSchema = new StructType() + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("aggVal1", LongType).add("aggVal2", DoubleType) + val noKeyRowAttributes = noKeyRowSchema.toAttributes + + val noKeySessionAttribute = noKeyRowAttributes.filter(attr => attr.name == "session").head + + def createNoKeyRow(sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(4) + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(0, sessionRow) + + genericRow.setLong(1, aggVal1) + genericRow.setDouble(2, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(noKeyRowAttributes, noKeyRowAttributes) + rowProjection(genericRow) + } + + def assertNoKeyRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + assert(retRow.getStruct(0, 2).getLong(0) == expectedRow.getStruct(0, 2).getLong(0)) + assert(retRow.getStruct(0, 2).getLong(1) == expectedRow.getStruct(0, 2).getLong(1)) + assert(retRow.getLong(1) === expectedRow.getLong(1)) + assert(doubleEquals(retRow.getDouble(2), expectedRow.getDouble(2))) + } + + def setNoKeyRowsInState(state: SessionWindowLinkedListState, rows: UnsafeRow*) + : Unit = { + def getSessionStart(row: UnsafeRow): Long = { + row.getStruct(0, 2).getLong(0) + } + + val key = new UnsafeRow(0) + val iter = rows.sortBy(getSessionStart).iterator + + var prevSessionStart: Option[Long] = None + while (iter.hasNext) { + val row = iter.next() + val sessionStart = getSessionStart(row) + if (prevSessionStart.isDefined) { + state.addAfter(key, sessionStart, row, prevSessionStart.get) + } else { + state.setHead(key, sessionStart, row) + } + + prevSessionStart = Some(sessionStart) + } + } + + withSessionWindowLinkedListState(noKeyRowAttributes, Seq.empty[Attribute]) { state => + // this will not be picked up because the session in state is not enclosing events + val srow1 = createNoKeyRow(10, 16, 10, 21) + val srow2 = createNoKeyRow(17, 27, 2, 39) + + val srow3 = createNoKeyRow(35, 40, 1, 35) + val row1 = createNoKeyRow(40, 45, 10, 45) + setNoKeyRowsInState(state, srow1, srow2, srow3) + + val rows = List(row1) + + val expectedRowSequence = List(srow3, row1) + + val iterator = new MergingSortWithSessionWindowLinkedListStateIterator(rows.iterator, + state, Seq.empty[Attribute], noKeySessionAttribute, noKeyRowAttributes) + + expectedRowSequence.foreach { row => + assert(iterator.hasNext) + assertNoKeyRowsEquals(row, iterator.next()) + } + + assert(!iterator.hasNext) + } + } + + private def setRowsInState(state: SessionWindowLinkedListState, key: UnsafeRow, + rows: UnsafeRow*): Unit = { + def getSessionStart(row: UnsafeRow): Long = { + row.getStruct(2, 2).getLong(0) + } + + val iter = rows.sortBy(getSessionStart).iterator + + var prevSessionStart: Option[Long] = None + while (iter.hasNext) { + val row = iter.next() + val sessionStart = getSessionStart(row) + if (prevSessionStart.isDefined) { + state.addAfter(key, sessionStart, row, prevSessionStart.get) + } else { + state.setHead(key, sessionStart, row) + } + + prevSessionStart = Some(sessionStart) + } + } + + private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, + aggVal1: Long, aggVal2: Double): UnsafeRow = { + val genericRow = new GenericInternalRow(6) + if (key1 != null) { + genericRow.update(0, UTF8String.fromString(key1)) + } else { + genericRow.setNullAt(0) + } + genericRow.setInt(1, key2) + + val session: Array[Any] = new Array[Any](2) + session(0) = sessionStart + session(1) = sessionEnd + + val sessionRow = new GenericInternalRow(session) + genericRow.update(2, sessionRow) + + genericRow.setLong(3, aggVal1) + genericRow.setDouble(4, aggVal2) + + val rowProjection = GenerateUnsafeProjection.generate(rowAttributes, rowAttributes) + rowProjection(genericRow) + } + + private def doubleEquals(value1: Double, value2: Double): Boolean = { + value1 > value2 - 0.000001 && value1 < value2 + 0.000001 + } + + private def getTupleFromRow(row: InternalRow): (String, Int, Long, Long, Long, Double) = { + (row.getString(0), row.getInt(1), row.getStruct(2, 2).getLong(0), + row.getStruct(2, 2).getLong(1), row.getLong(3), row.getDouble(4)) + } + + private def assertRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { + + val tupleFromExpectedRow = getTupleFromRow(expectedRow) + val tupleFromInternalRow = getTupleFromRow(retRow) + try { + assert(tupleFromExpectedRow._1 === tupleFromInternalRow._1) + assert(tupleFromExpectedRow._2 === tupleFromInternalRow._2) + assert(tupleFromExpectedRow._3 === tupleFromInternalRow._3) + assert(tupleFromExpectedRow._4 === tupleFromInternalRow._4) + assert(tupleFromExpectedRow._5 === tupleFromInternalRow._5) + assert(doubleEquals(tupleFromExpectedRow._6, tupleFromInternalRow._6)) + } catch { + case e: TestFailedException => + throw new TestFailedException(s"$tupleFromExpectedRow did not equal $tupleFromInternalRow", + e, e.failedCodeStackDepth) + } + } + + private def withSessionWindowLinkedListState( + inputValueAttribs: Seq[Attribute], + keyAttribs: Seq[Attribute])(f: SessionWindowLinkedListState => Unit): Unit = { + + withTempDir { file => + val storeConf = new StateStoreConf() + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + + val state = new SessionWindowLinkedListState(s"session-${stateInfo.operatorId}-", + inputValueAttribs, keyAttribs, Some(stateInfo), storeConf, new Configuration) + try { + f(state) + } finally { + state.abortIfNeeded() + } + } + StateStore.stop() + } +} From 35c9712c9ea0ad56c215b101fb8943f8abf8f374 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 25 Oct 2018 19:02:19 +0900 Subject: [PATCH 51/60] WIP optimized! --- ...SessionWindowLinkedListStateIterator.scala | 228 ++++++------------ 1 file changed, 80 insertions(+), 148 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala index 482079938d957..2c513b3f2b278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala @@ -60,80 +60,16 @@ class MergingSortWithSessionWindowLinkedListStateIterator( } } - private def findSessionPointerEnclosingEvent(row: SessionRowInformation, - startPointer: Option[Long]) - : Option[(Option[Long], Option[Long])] = { - val startOption = startPointer match { - case None => state.getFirstSessionStart(currentRow.keys) - case _ => startPointer - } - - startOption match { - // empty list - case None => None - case Some(start) => - var currOption: Option[Long] = Some(start) - - var enclosingSessions: Option[(Option[Long], Option[Long])] = None - while (enclosingSessions.isEmpty && currOption.isDefined) { - val curr = currOption.get - val newPrev = state.getPrevSessionStart(currentRow.keys, curr) - val newNext = state.getNextSessionStart(currentRow.keys, curr) - - val isEventEnclosed = newPrev match { - case Some(prev) => - prev <= currentRow.sessionStart && currentRow.sessionStart <= curr - case None => currentRow.sessionStart <= curr - } - - val willNotBeEnclosed = newPrev match { - case Some(prev) => prev > currentRow.sessionStart - case None => false - } - - if (isEventEnclosed) { - enclosingSessions = Some((newPrev, currOption)) - } else if (willNotBeEnclosed) { - enclosingSessions = Some((None, None)) - } else if (newNext.isEmpty) { - // curr is the last session in state - if (currentRow.sessionStart >= curr) { - enclosingSessions = Some((currOption, None)) - } else { - enclosingSessions = Some((None, None)) - } - } - - currOption = newNext - } - - // enclosingSessions should not be None unless list is empty - enclosingSessions - } - } - - private def isSessionsOverlap(s1: SessionRowInformation, s2: SessionRowInformation): Boolean = { - (s1.sessionStart >= s2.sessionStart && s1.sessionStart <= s2.sessionEnd) || - (s2.sessionStart >= s1.sessionStart && s2.sessionStart <= s1.sessionEnd) - } - private var lastKey: UnsafeRow = _ private var currentRow: SessionRowInformation = _ - - private val stateRowsToEmit: scala.collection.mutable.ListBuffer[SessionRowInformation] = - new scala.collection.mutable.ListBuffer[SessionRowInformation]() - private val stateRowsChecked: scala.collection.mutable.HashSet[SessionRowInformation] = - new scala.collection.mutable.HashSet[SessionRowInformation]() - - private var lastEmittedStateSessionKey: UnsafeRow = _ - private var lastEmittedStateSessionStartOption: Option[Long] = None + private var lastCheckpointOnStateRows: Option[Long] = _ private var stateRowWaitForEmit: SessionRowInformation = _ private val keyOrdering: Ordering[UnsafeRow] = TypeUtils.getInterpretedOrdering( groupWithoutSessionExpressions.toStructType).asInstanceOf[Ordering[UnsafeRow]] override def hasNext: Boolean = { - currentRow != null || iter.hasNext || stateRowsToEmit.nonEmpty + currentRow != null || iter.hasNext || stateRowWaitForEmit != null } override def next(): InternalRow = { @@ -141,44 +77,45 @@ class MergingSortWithSessionWindowLinkedListStateIterator( mayFillCurrentRow() } - if (currentRow == null && stateRowsToEmit.isEmpty) { + if (currentRow == null && stateRowWaitForEmit == null) { throw new IllegalStateException("No Row to provide in next() which should not happen!") } // early return on input rows vs state row waiting for emitting val returnCurrentRow = if (currentRow == null) { false - } else if (stateRowsToEmit.isEmpty) { + } else if (stateRowWaitForEmit == null) { true } else { // compare between current row and state row waiting for emitting - val stateRow = stateRowsToEmit.head - if (!keyOrdering.equiv(currentRow.keys, stateRow.keys)) { + if (!keyOrdering.equiv(currentRow.keys, stateRowWaitForEmit.keys)) { // state row cannot advance to row in input, so state row should be lower false } else { - currentRow.sessionStart < stateRow.sessionStart + currentRow.sessionStart < stateRowWaitForEmit.sessionStart } } // if state row should be emitted, do emit if (!returnCurrentRow) { - val stateRow = stateRowsToEmit.head - stateRowsToEmit.remove(0) + val stateRow = stateRowWaitForEmit + stateRowWaitForEmit = null return stateRow.row } if (lastKey == null || !keyOrdering.equiv(lastKey, currentRow.keys)) { // new key - stateRowsToEmit.clear() - stateRowsChecked.clear() + stateRowWaitForEmit = null + lastCheckpointOnStateRows = None lastKey = currentRow.keys } - // FIXME: how to provide start pointer to avoid reiterating? + // we don't need to check against sessions which are already candidate to emit + // so we apply checkpoint to skip some sessions val stateSessionsEnclosingCurrentRow = findSessionPointerEnclosingEvent(currentRow, - startPointer = None) + startPointer = lastCheckpointOnStateRows) + var prevSessionToEmit: Option[SessionRowInformation] = None stateSessionsEnclosingCurrentRow match { case None => case Some(x) => @@ -186,14 +123,19 @@ class MergingSortWithSessionWindowLinkedListStateIterator( case Some(prev) => val prevSession = SessionRowInformation.of(state.get(currentRow.keys, prev)) - if (!stateRowsChecked.contains(prevSession)) { + val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { + case Some(lastCheckpoint) => lastCheckpoint < prevSession.sessionStart + case None => true + } + + if (sessionLaterThanCheckpoint) { // based on definition of session window and the fact that events are sorted, // if the state session is not matched to this event, it will not be matched with // later events as well - stateRowsChecked += prevSession + lastCheckpointOnStateRows = Some(prevSession.sessionStart) if (isSessionsOverlap(currentRow, prevSession)) { - stateRowsToEmit += prevSession + prevSessionToEmit = Some(prevSession) } } @@ -204,12 +146,17 @@ class MergingSortWithSessionWindowLinkedListStateIterator( case Some(next) => val nextSession = SessionRowInformation.of(state.get(currentRow.keys, next)) - if (!stateRowsChecked.contains(nextSession)) { + val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { + case Some(lastCheckpoint) => lastCheckpoint < nextSession.sessionStart + case None => true + } + + if (sessionLaterThanCheckpoint) { // next session could be matched to latter events even it doesn't match to // current event, so unless it is added to rows to emit, don't add to checked set if (isSessionsOverlap(currentRow, nextSession)) { - stateRowsToEmit += nextSession - stateRowsChecked += nextSession + stateRowWaitForEmit = nextSession + lastCheckpointOnStateRows = Some(nextSession.sessionStart) } } @@ -217,14 +164,11 @@ class MergingSortWithSessionWindowLinkedListStateIterator( } } - if (stateRowsToEmit.isEmpty) { - emitCurrentRow() - } else if (currentRow.sessionStart < stateRowsToEmit.head.sessionStart) { - emitCurrentRow() - } else { - val stateRow = stateRowsToEmit.head - stateRowsToEmit.remove(0) - stateRow.row + // emitting sessions always follows the pattern: + // previous sessions if any -> current event -> (later events) -> next sessions + prevSessionToEmit match { + case Some(prevSession) => prevSession.row + case None => emitCurrentRow() } } @@ -234,79 +178,67 @@ class MergingSortWithSessionWindowLinkedListStateIterator( ret.row } - private def emitStateRowForWaiting(): InternalRow = { - val ret = stateRowWaitForEmit - stateRowWaitForEmit = null - recordStateRowToEmit(ret) - ret.row - } - - private def recordStateRowToEmit(stateRow: SessionRowInformation) = { - lastEmittedStateSessionStartOption = Some(stateRow.sessionStart) - lastEmittedStateSessionKey = stateRow.keys - } - private def mayFillCurrentRow(): Unit = { if (iter.hasNext) { currentRow = SessionRowInformation.of(iter.next()) } } - private def currentRowIsSmallerThanWaitingStateRow(): Boolean = { - // compare between current row and state row waiting for emitting - if (!keyOrdering.equiv(currentRow.keys, stateRowWaitForEmit.keys)) { - // state row cannot advance to row in input, so state row should be lower - false - } else { - currentRow.sessionStart < stateRowWaitForEmit.sessionStart + private def findSessionPointerEnclosingEvent(row: SessionRowInformation, + startPointer: Option[Long]) + : Option[(Option[Long], Option[Long])] = { + val startOption = startPointer match { + case None => state.getFirstSessionStart(currentRow.keys) + case _ => startPointer } - } - private def getEnclosingStatesForEvent(row: SessionRowInformation) - : (Option[SessionRowInformation], Option[SessionRowInformation]) = { - // find two state sessions wrapping current row + startOption match { + // empty list + case None => None + case Some(start) => + var currOption: Option[Long] = Some(start) - if (lastEmittedStateSessionKey != null) { - mayInvalidateLastEmittedStateSession() - } + var enclosingSessions: Option[(Option[Long], Option[Long])] = None + while (enclosingSessions.isEmpty && currOption.isDefined) { + val curr = currOption.get + val newPrev = state.getPrevSessionStart(currentRow.keys, curr) + val newNext = state.getNextSessionStart(currentRow.keys, curr) - val nextStateSessionStart: Option[Long] = lastEmittedStateSessionStartOption match { - case Some(lastEmittedStateSessionStart) => - state.findFirstSessionStartEnsurePredicate(currentRow.keys, - _ >= currentRow.sessionEnd, lastEmittedStateSessionStart) - case None => - state.findFirstSessionStartEnsurePredicate(currentRow.keys, - _ >= currentRow.sessionEnd) - } + val isEventEnclosed = newPrev match { + case Some(prev) => + prev <= currentRow.sessionStart && currentRow.sessionStart <= curr + case None => currentRow.sessionStart <= curr + } - val prevStateSessionStart: Option[Long] = nextStateSessionStart match { - case Some(next) => state.getPrevSessionStart(currentRow.keys, next) - case None => state.getLastSessionStart(currentRow.keys) - } + val willNotBeEnclosed = newPrev match { + case Some(prev) => prev > currentRow.sessionStart + case None => false + } - // only return sessions which overlap with current row - val pSession = if (prevStateSessionStart.isDefined) { - Some(SessionRowInformation.of(state.get(currentRow.keys, prevStateSessionStart.get))) - } else { - None - } + if (isEventEnclosed) { + enclosingSessions = Some((newPrev, currOption)) + } else if (willNotBeEnclosed) { + enclosingSessions = Some((None, None)) + } else if (newNext.isEmpty) { + // curr is the last session in state + if (currentRow.sessionStart >= curr) { + enclosingSessions = Some((currOption, None)) + } else { + enclosingSessions = Some((None, None)) + } + } - val nSession = if (nextStateSessionStart.isDefined) { - Some(SessionRowInformation.of(state.get(currentRow.keys, nextStateSessionStart.get))) - } else { - None - } + currOption = newNext + } - (pSession, nSession) + // enclosingSessions should not be None unless list is empty + enclosingSessions + } } - private def mayInvalidateLastEmittedStateSession(): Unit = { - // invalidate last emitted state session key as well as session start - // if keys are changed - if (!keyOrdering.equiv(lastEmittedStateSessionKey, currentRow.keys)) { - lastEmittedStateSessionKey = null - lastEmittedStateSessionStartOption = None - } + private def isSessionsOverlap(s1: SessionRowInformation, s2: SessionRowInformation): Boolean = { + (s1.sessionStart >= s2.sessionStart && s1.sessionStart <= s2.sessionEnd) || + (s2.sessionStart >= s1.sessionStart && s2.sessionStart <= s1.sessionEnd) } } From ede078ad103f8d9fe06da6a7f058f11cee5b26e8 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Oct 2018 22:42:44 +0900 Subject: [PATCH 52/60] WIP remove requirement on sort, add UT to test linked list state with numbers of randomized operations --- .../state/SessionWindowLinkedListState.scala | 46 ++-- .../SessionWindowLinkedListStateSuite.scala | 197 ++++++++++++++++++ 2 files changed, 229 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala index 890b7220f1237..0119a5b985b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -114,8 +114,6 @@ class SessionWindowLinkedListState( def addBefore(key: UnsafeRow, sessionStart: Long, value: UnsafeRow, targetSessionStart: Long): Unit = { - require(sessionStart < targetSessionStart) - val targetPointer = keyAndSessionStartToPointerStore.get(key, targetSessionStart) assertValidPointer(targetPointer) @@ -138,8 +136,6 @@ class SessionWindowLinkedListState( def addAfter(key: UnsafeRow, sessionStart: Long, value: UnsafeRow, targetSessionStart: Long): Unit = { - require(sessionStart > targetSessionStart) - val targetPointer = keyAndSessionStartToPointerStore.get(key, targetSessionStart) assertValidPointer(targetPointer) @@ -431,7 +427,6 @@ class SessionWindowLinkedListState( keyAndSessionStartToValueStore.abortIfNeeded() } - /** Get the combined metrics of all the state stores */ def metrics: StateStoreMetrics = { val keyToHeadSessionStartMetrics = keyToHeadSessionStartStore.metrics @@ -459,6 +454,18 @@ class SessionWindowLinkedListState( ) } + private[state] def getIteratorOfHeadPointers: Iterator[KeyAndHeadSessionStart] = { + keyToHeadSessionStartStore.iterator + } + + private[state] def getIteratorOfRawPointers: Iterator[KeyWithSessionStartAndPointers] = { + keyAndSessionStartToPointerStore.iterator + } + + private[state] def getIteratorOfRawValues: Iterator[KeyWithSessionStartAndValue] = { + keyAndSessionStartToValueStore.iterator + } + /* ===================================================== Private methods and inner classes @@ -515,10 +522,11 @@ class SessionWindowLinkedListState( } /** - * Helper class for representing data returned by [[KeyToHeadSessionStartStore]]. - * Designed for object reuse. - */ - private case class KeyAndHeadSessionStart(var key: UnsafeRow = null, var sessionStart: Long = 0) { + * Helper class for representing data returned by [[KeyToHeadSessionStartStore]]. + * Designed for object reuse. + */ + private[state] case class KeyAndHeadSessionStart(var key: UnsafeRow = null, + var sessionStart: Long = 0) { def withNew(newKey: UnsafeRow, newSessionStart: Long): this.type = { this.key = newKey this.sessionStart = newSessionStart @@ -530,7 +538,7 @@ class SessionWindowLinkedListState( * Helper class for representing data returned by [[KeyAndSessionStartToPointerStore]]. * Designed for object reuse. */ - private case class KeyWithSessionStartAndPointers( + private[state] case class KeyWithSessionStartAndPointers( var key: UnsafeRow = null, var sessionStart: Long = 0, var prevSessionStart: Option[Long] = None, @@ -549,7 +557,7 @@ class SessionWindowLinkedListState( * Helper class for representing data returned by [[KeyAndSessionStartToValueStore]]. * Designed for object reuse. */ - private case class KeyWithSessionStartAndValue( + private[state] case class KeyWithSessionStartAndValue( var key: UnsafeRow = null, var sessionStart: Long = 0, var value: UnsafeRow = null) { @@ -645,7 +653,7 @@ class SessionWindowLinkedListState( def updateNext(key: UnsafeRow, sessionStart: Long, nextSessionStart: Option[Long]): Unit = { val actualKeyRow = keyWithSessionStartRow(key, sessionStart) - val row = stateStore.get(actualKeyRow) + val row = stateStore.get(actualKeyRow).copy() setNextSessionStart(row, nextSessionStart) stateStore.put(actualKeyRow, row) } @@ -721,11 +729,21 @@ class SessionWindowLinkedListState( } /** - * Remove key and value at given session start. - */ + * Remove key and value at given session start. + */ def remove(key: UnsafeRow, sessionStart: Long): Unit = { stateStore.remove(keyWithSessionStartRow(key, sessionStart)) } + + def iterator: Iterator[KeyWithSessionStartAndValue] = { + val keyWithSessionStartAndValue = KeyWithSessionStartAndValue() + stateStore.getRange(None, None).map { pair => + val keyPart = keyRowGenerator(pair.key) + val sessionStart = pair.key.getLong(indexOrdinalInKeyWithSessionStartRow) + val value = pair.value + keyWithSessionStartAndValue.withNew(keyPart, sessionStart, value) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala index 556394b66c17e..15a79e7a83fcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.execution.streaming.state import java.util.UUID +import scala.util.Random + import org.apache.hadoop.conf.Configuration +import org.scalatest.exceptions.TestFailedException import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate @@ -214,6 +217,200 @@ class SessionWindowLinkedListStateSuite extends StreamTest { removeByWatermarkTest(stopOnConditionMismatch = false) } + test("run chaos monkey") { + + // FIXME: too many args + def printFailureInformation( + ex: TestFailedException, + state: SessionWindowLinkedListState, + operation: Int, + addBefore: Boolean, + opIdx: Int, + targetIdx: Int, + key: UnsafeRow, + headPointersBeforeOp: List[(Int, Long)], + rawPointersBeforeOp: List[(Int, Long, Option[Long], Option[Long])], + pointersBeforeOp: List[Long], + valuesBeforeOp: List[(Int, Int)], + refListBeforeOp: java.util.LinkedList[String], + refList: java.util.LinkedList[String]): Unit = { + logError("Assertion failure!", ex) + + logError("===== Operation information =====") + val opString = operation match { + case 0 => "Append" + case 1 => "Remove" + case 2 => "RemoveValuesByCondition" + case _ => throw new IllegalStateException(s"Unknown operation $operation") + } + val addPositionString = if (addBefore) "AddBefore" else "AddAfter" + logError(s"Operation Index: $opIdx") + logError(s"Operation: $opString") + logError(s"Position to add: $addPositionString") + logError(s"Target index: $targetIdx") + + logError("===== Before applying operation =====") + logError(s"Head pointers in state: $headPointersBeforeOp") + logError(s"Raw pointers in state: $rawPointersBeforeOp") + logError(s"Pointers in state via iteratePointers: $pointersBeforeOp") + logError(s"Values in state: $valuesBeforeOp") + logError(s"Values in reference list: $refListBeforeOp") + + logError("===== After applying operation =====") + + val headPointers = state.getIteratorOfHeadPointers.map { pair => + (toKeyInt(pair.key), pair.sessionStart) + }.toList + val pointers = state.iteratePointers(key).map(_._1).toList + val rawPointers = state.getIteratorOfRawPointers.map { pointer => + (toKeyInt(pointer.key), pointer.sessionStart, pointer.prevSessionStart, + pointer.nextSessionStart) + }.toList + val values = state.getIteratorOfRawValues.map { value => + (toKeyInt(value.key), toValueInt(value.value)) + }.toList + + logError(s"Head pointers in state: $headPointers") + logError(s"Raw pointers in state: $rawPointers") + logError(s"Pointers in state via iteratePointers: $pointers") + logError(s"Values in state: $values") + logError(s"Values in reference list: $refList") + } + + withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => + implicit val st = state + + assert(numRows === 0) + + val rand = new Random() + + val keys = (0 to 2).map(id => toKeyRow(id).copy()) + // using String type to avoid confusion in remove(int) vs remove(Object) + // which LinkedList[Integer] will be remove(int) vs remove(Integer) + val refLists = keys.map(_ => new java.util.LinkedList[String]()) + + val maxOperations = 100000 + (0 until maxOperations).foreach { opIdx => + + val selectedKeyIdx = rand.nextInt(keys.length) + val selectedKey = keys(selectedKeyIdx) + val selectedRefList = refLists(selectedKeyIdx) + + // 0: append, 1: remove, 2: removeValueByCondition + val operation = rand.nextInt(3) + val addBefore = rand.nextBoolean() + val targetIdx = if (selectedRefList.isEmpty) -1 else rand.nextInt(selectedRefList.size()) + + val headPointersBeforeOp = state.getIteratorOfHeadPointers.map { pair => + (toKeyInt(pair.key), pair.sessionStart) + }.toList + val pointersBeforeOp = state.iteratePointers(selectedKey).map(_._1).toList + val rawPointersBeforeOp = state.getIteratorOfRawPointers.map { pointer => + (toKeyInt(pointer.key), pointer.sessionStart, pointer.prevSessionStart, + pointer.nextSessionStart) + }.toList + val valuesBeforeOp = state.getIteratorOfRawValues.map { value => + (toKeyInt(value.key), toValueInt(value.value)) + }.toList + + val refListBeforeOp = new java.util.LinkedList[String](refLists(selectedKeyIdx)) + + operation match { + case 0 => + if (selectedRefList.isEmpty) { + assert(state.isEmpty(selectedKey)) + state.setHead(selectedKey, opIdx, toInputValue(opIdx)) + selectedRefList.addFirst(String.valueOf(opIdx)) + } else { + val addBefore = rand.nextBoolean() + if (addBefore) { + val idxToAddBefore = selectedRefList.get(targetIdx) + selectedRefList.add(targetIdx, String.valueOf(opIdx)) + state.addBefore(selectedKey, opIdx, toInputValue(opIdx), idxToAddBefore.toInt) + } else { + val idxToAddAfter = selectedRefList.get(targetIdx) + selectedRefList.add(targetIdx + 1, String.valueOf(opIdx)) + state.addAfter(selectedKey, opIdx, toInputValue(opIdx), idxToAddAfter.toInt) + } + } + + case 1 => + if (selectedRefList.isEmpty) { + assert(state.isEmpty(selectedKey)) + // skip removing + } else { + val pointerToRemove = selectedRefList.get(targetIdx) + selectedRefList.remove(targetIdx) + state.remove(selectedKey, pointerToRemove.toInt) + } + + case 2 => + if (selectedRefList.isEmpty) { + assert(state.isEmpty(selectedKey)) + // skip removing + } else { + val pointerToRemove = selectedRefList.get(targetIdx) + val removedIter = state.removeByValueCondition { r => + toValueInt(r) <= pointerToRemove.toInt + } + + val valuesFromRef = new scala.collection.mutable.MutableList[Int]() + refLists.foreach { refList => + val refIter = refList.iterator() + while (refIter.hasNext) { + val ref = refIter.next() + if (ref.toInt <= pointerToRemove.toInt) { + valuesFromRef += ref.toInt + refIter.remove() + } + } + } + + try { + assert(removedIter.map(pair => toValueInt(pair.value)).toSet == + valuesFromRef.toSet) + } catch { + case ex: TestFailedException => + printFailureInformation(ex, state, operation, addBefore, opIdx, targetIdx, + selectedKey, headPointersBeforeOp, rawPointersBeforeOp, pointersBeforeOp, + valuesBeforeOp, refListBeforeOp, selectedRefList) + + throw ex + } + } + } + + keys.indices.foreach { index => + val key = keys(index) + val refList = refLists(index) + + try { + if (refList.isEmpty) { + assert(state.isEmpty(key), s"Reference list is empty but " + + s"state list for $key is not empty") + } else { + import scala.collection.JavaConverters._ + val statePointers = state.iteratePointers(key).map(_._1).toList + assert(refList.asScala.map(_.toInt) === statePointers, + s"State pointers for $key is expected to be $refList but $statePointers") + + val stateValues = state.get(key).map(toValueInt).toList + assert(refList.asScala.map(_.toInt) === + stateValues, s"State list for $key is expected to be $refList but $stateValues") + } + } catch { + case ex: TestFailedException => + printFailureInformation(ex, state, operation, addBefore, opIdx, targetIdx, + selectedKey, headPointersBeforeOp, rawPointersBeforeOp, pointersBeforeOp, + valuesBeforeOp, refListBeforeOp, selectedRefList) + + throw ex + } + } + } + } + } + private def removeByWatermarkTest(stopOnConditionMismatch: Boolean): Unit = { withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => implicit val st = state From f8e8ff6067736260f52001a5d1689a0a6cfae5cb Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 27 Oct 2018 22:54:08 +0900 Subject: [PATCH 53/60] WIP add code to print out information when task crashes with dangling pointer --- ...SessionWindowLinkedListStateIterator.scala | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala index 2c513b3f2b278..84dc60c2bddc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala @@ -115,13 +115,30 @@ class MergingSortWithSessionWindowLinkedListStateIterator( val stateSessionsEnclosingCurrentRow = findSessionPointerEnclosingEvent(currentRow, startPointer = lastCheckpointOnStateRows) + // FIXME: debugging... + def loadSession( + keys: UnsafeRow, + sessionStart: Long, + stateSessionsEnclosingCurrentRow: Option[(Option[Long], Option[Long])]) + : UnsafeRow = { + val sessionState = state.get(currentRow.keys, sessionStart) + require(sessionState != null, + s"Session must be presented in state: it may represent dangling pointer - " + + s"$sessionStart / key: $keys / currentRow: $currentRow" + + s"sessionsEnclosingCurrentRow: $stateSessionsEnclosingCurrentRow" + + s"Pointers: ${state.iteratePointers(currentRow.keys).toList}" + + s"Values: ${state.get(currentRow.keys).toList}") + sessionState + } + var prevSessionToEmit: Option[SessionRowInformation] = None stateSessionsEnclosingCurrentRow match { case None => case Some(x) => x._1 match { case Some(prev) => - val prevSession = SessionRowInformation.of(state.get(currentRow.keys, prev)) + val prevSession = SessionRowInformation.of( + loadSession(currentRow.keys, prev, stateSessionsEnclosingCurrentRow)) val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { case Some(lastCheckpoint) => lastCheckpoint < prevSession.sessionStart @@ -144,7 +161,8 @@ class MergingSortWithSessionWindowLinkedListStateIterator( x._2 match { case Some(next) => - val nextSession = SessionRowInformation.of(state.get(currentRow.keys, next)) + val nextSession = SessionRowInformation.of( + loadSession(currentRow.keys, next, stateSessionsEnclosingCurrentRow)) val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { case Some(lastCheckpoint) => lastCheckpoint < nextSession.sessionStart From b05abc71a103662448db649fc215dd17a2df3c11 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 29 Oct 2018 15:55:00 +0900 Subject: [PATCH 54/60] WIP fixed the issue with benchmark run --- ...SessionWindowLinkedListStateIterator.scala | 24 +-- .../state/SessionWindowLinkedListState.scala | 15 +- ...onWindowLinkedListStateIteratorSuite.scala | 1 - .../SessionWindowLinkedListStateSuite.scala | 194 ------------------ 4 files changed, 11 insertions(+), 223 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala index 84dc60c2bddc3..9955aa5045ae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.execution.streaming.state.SessionWindowLinkedListState @@ -115,30 +115,13 @@ class MergingSortWithSessionWindowLinkedListStateIterator( val stateSessionsEnclosingCurrentRow = findSessionPointerEnclosingEvent(currentRow, startPointer = lastCheckpointOnStateRows) - // FIXME: debugging... - def loadSession( - keys: UnsafeRow, - sessionStart: Long, - stateSessionsEnclosingCurrentRow: Option[(Option[Long], Option[Long])]) - : UnsafeRow = { - val sessionState = state.get(currentRow.keys, sessionStart) - require(sessionState != null, - s"Session must be presented in state: it may represent dangling pointer - " + - s"$sessionStart / key: $keys / currentRow: $currentRow" + - s"sessionsEnclosingCurrentRow: $stateSessionsEnclosingCurrentRow" + - s"Pointers: ${state.iteratePointers(currentRow.keys).toList}" + - s"Values: ${state.get(currentRow.keys).toList}") - sessionState - } - var prevSessionToEmit: Option[SessionRowInformation] = None stateSessionsEnclosingCurrentRow match { case None => case Some(x) => x._1 match { case Some(prev) => - val prevSession = SessionRowInformation.of( - loadSession(currentRow.keys, prev, stateSessionsEnclosingCurrentRow)) + val prevSession = SessionRowInformation.of(state.get(currentRow.keys, prev)) val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { case Some(lastCheckpoint) => lastCheckpoint < prevSession.sessionStart @@ -161,8 +144,7 @@ class MergingSortWithSessionWindowLinkedListStateIterator( x._2 match { case Some(next) => - val nextSession = SessionRowInformation.of( - loadSession(currentRow.keys, next, stateSessionsEnclosingCurrentRow)) + val nextSession = SessionRowInformation.of(state.get(currentRow.keys, next)) val sessionLaterThanCheckpoint = lastCheckpointOnStateRows match { case Some(lastCheckpoint) => lastCheckpoint < nextSession.sessionStart diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala index 0119a5b985b93..ee899a6882adf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -157,7 +157,6 @@ class SessionWindowLinkedListState( def update(key: UnsafeRow, sessionStart: Long, newValue: UnsafeRow): Unit = { val targetPointer = keyAndSessionStartToPointerStore.get(key, sessionStart) assertValidPointer(targetPointer) - keyAndSessionStartToValueStore.put(key, sessionStart, newValue) } @@ -243,6 +242,9 @@ class SessionWindowLinkedListState( val prevOption = targetPointer._1 val nextOption = targetPointer._2 + keyAndSessionStartToPointerStore.remove(key, sessionStart) + keyAndSessionStartToValueStore.remove(key, sessionStart) + targetPointer match { case (Some(prev), Some(next)) => keyAndSessionStartToPointerStore.updateNext(key, prev, nextOption) @@ -260,11 +262,10 @@ class SessionWindowLinkedListState( throw new IllegalStateException("The element has pointer information for head, " + "but the list has different head.") } - keyAndSessionStartToPointerStore.remove(key, sessionStart) + keyToHeadSessionStartStore.remove(key) } - keyAndSessionStartToValueStore.remove(key, sessionStart) } def removeByValueCondition(removalCondition: UnsafeRow => Boolean, @@ -454,15 +455,15 @@ class SessionWindowLinkedListState( ) } - private[state] def getIteratorOfHeadPointers: Iterator[KeyAndHeadSessionStart] = { + private[sql] def getIteratorOfHeadPointers: Iterator[KeyAndHeadSessionStart] = { keyToHeadSessionStartStore.iterator } - private[state] def getIteratorOfRawPointers: Iterator[KeyWithSessionStartAndPointers] = { + private[sql] def getIteratorOfRawPointers: Iterator[KeyWithSessionStartAndPointers] = { keyAndSessionStartToPointerStore.iterator } - private[state] def getIteratorOfRawValues: Iterator[KeyWithSessionStartAndValue] = { + private[sql] def getIteratorOfRawValues: Iterator[KeyWithSessionStartAndValue] = { keyAndSessionStartToValueStore.iterator } @@ -646,7 +647,7 @@ class SessionWindowLinkedListState( def updatePrev(key: UnsafeRow, sessionStart: Long, prevSessionStart: Option[Long]): Unit = { val actualKeyRow = keyWithSessionStartRow(key, sessionStart) - val row = stateStore.get(actualKeyRow) + val row = stateStore.get(actualKeyRow).copy() setPrevSessionStart(row, prevSessionStart) stateStore.put(actualKeyRow, row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala index 546e598cb62af..6e7b13c63fd04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowLinkedListStateIteratorSuite.scala @@ -395,7 +395,6 @@ class MergingSortWithSessionWindowLinkedListStateIteratorSuite extends SharedSQL } private def assertRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { - val tupleFromExpectedRow = getTupleFromRow(expectedRow) val tupleFromInternalRow = getTupleFromRow(retRow) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala index 15a79e7a83fcb..af7f2f986a30e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateSuite.scala @@ -217,200 +217,6 @@ class SessionWindowLinkedListStateSuite extends StreamTest { removeByWatermarkTest(stopOnConditionMismatch = false) } - test("run chaos monkey") { - - // FIXME: too many args - def printFailureInformation( - ex: TestFailedException, - state: SessionWindowLinkedListState, - operation: Int, - addBefore: Boolean, - opIdx: Int, - targetIdx: Int, - key: UnsafeRow, - headPointersBeforeOp: List[(Int, Long)], - rawPointersBeforeOp: List[(Int, Long, Option[Long], Option[Long])], - pointersBeforeOp: List[Long], - valuesBeforeOp: List[(Int, Int)], - refListBeforeOp: java.util.LinkedList[String], - refList: java.util.LinkedList[String]): Unit = { - logError("Assertion failure!", ex) - - logError("===== Operation information =====") - val opString = operation match { - case 0 => "Append" - case 1 => "Remove" - case 2 => "RemoveValuesByCondition" - case _ => throw new IllegalStateException(s"Unknown operation $operation") - } - val addPositionString = if (addBefore) "AddBefore" else "AddAfter" - logError(s"Operation Index: $opIdx") - logError(s"Operation: $opString") - logError(s"Position to add: $addPositionString") - logError(s"Target index: $targetIdx") - - logError("===== Before applying operation =====") - logError(s"Head pointers in state: $headPointersBeforeOp") - logError(s"Raw pointers in state: $rawPointersBeforeOp") - logError(s"Pointers in state via iteratePointers: $pointersBeforeOp") - logError(s"Values in state: $valuesBeforeOp") - logError(s"Values in reference list: $refListBeforeOp") - - logError("===== After applying operation =====") - - val headPointers = state.getIteratorOfHeadPointers.map { pair => - (toKeyInt(pair.key), pair.sessionStart) - }.toList - val pointers = state.iteratePointers(key).map(_._1).toList - val rawPointers = state.getIteratorOfRawPointers.map { pointer => - (toKeyInt(pointer.key), pointer.sessionStart, pointer.prevSessionStart, - pointer.nextSessionStart) - }.toList - val values = state.getIteratorOfRawValues.map { value => - (toKeyInt(value.key), toValueInt(value.value)) - }.toList - - logError(s"Head pointers in state: $headPointers") - logError(s"Raw pointers in state: $rawPointers") - logError(s"Pointers in state via iteratePointers: $pointers") - logError(s"Values in state: $values") - logError(s"Values in reference list: $refList") - } - - withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => - implicit val st = state - - assert(numRows === 0) - - val rand = new Random() - - val keys = (0 to 2).map(id => toKeyRow(id).copy()) - // using String type to avoid confusion in remove(int) vs remove(Object) - // which LinkedList[Integer] will be remove(int) vs remove(Integer) - val refLists = keys.map(_ => new java.util.LinkedList[String]()) - - val maxOperations = 100000 - (0 until maxOperations).foreach { opIdx => - - val selectedKeyIdx = rand.nextInt(keys.length) - val selectedKey = keys(selectedKeyIdx) - val selectedRefList = refLists(selectedKeyIdx) - - // 0: append, 1: remove, 2: removeValueByCondition - val operation = rand.nextInt(3) - val addBefore = rand.nextBoolean() - val targetIdx = if (selectedRefList.isEmpty) -1 else rand.nextInt(selectedRefList.size()) - - val headPointersBeforeOp = state.getIteratorOfHeadPointers.map { pair => - (toKeyInt(pair.key), pair.sessionStart) - }.toList - val pointersBeforeOp = state.iteratePointers(selectedKey).map(_._1).toList - val rawPointersBeforeOp = state.getIteratorOfRawPointers.map { pointer => - (toKeyInt(pointer.key), pointer.sessionStart, pointer.prevSessionStart, - pointer.nextSessionStart) - }.toList - val valuesBeforeOp = state.getIteratorOfRawValues.map { value => - (toKeyInt(value.key), toValueInt(value.value)) - }.toList - - val refListBeforeOp = new java.util.LinkedList[String](refLists(selectedKeyIdx)) - - operation match { - case 0 => - if (selectedRefList.isEmpty) { - assert(state.isEmpty(selectedKey)) - state.setHead(selectedKey, opIdx, toInputValue(opIdx)) - selectedRefList.addFirst(String.valueOf(opIdx)) - } else { - val addBefore = rand.nextBoolean() - if (addBefore) { - val idxToAddBefore = selectedRefList.get(targetIdx) - selectedRefList.add(targetIdx, String.valueOf(opIdx)) - state.addBefore(selectedKey, opIdx, toInputValue(opIdx), idxToAddBefore.toInt) - } else { - val idxToAddAfter = selectedRefList.get(targetIdx) - selectedRefList.add(targetIdx + 1, String.valueOf(opIdx)) - state.addAfter(selectedKey, opIdx, toInputValue(opIdx), idxToAddAfter.toInt) - } - } - - case 1 => - if (selectedRefList.isEmpty) { - assert(state.isEmpty(selectedKey)) - // skip removing - } else { - val pointerToRemove = selectedRefList.get(targetIdx) - selectedRefList.remove(targetIdx) - state.remove(selectedKey, pointerToRemove.toInt) - } - - case 2 => - if (selectedRefList.isEmpty) { - assert(state.isEmpty(selectedKey)) - // skip removing - } else { - val pointerToRemove = selectedRefList.get(targetIdx) - val removedIter = state.removeByValueCondition { r => - toValueInt(r) <= pointerToRemove.toInt - } - - val valuesFromRef = new scala.collection.mutable.MutableList[Int]() - refLists.foreach { refList => - val refIter = refList.iterator() - while (refIter.hasNext) { - val ref = refIter.next() - if (ref.toInt <= pointerToRemove.toInt) { - valuesFromRef += ref.toInt - refIter.remove() - } - } - } - - try { - assert(removedIter.map(pair => toValueInt(pair.value)).toSet == - valuesFromRef.toSet) - } catch { - case ex: TestFailedException => - printFailureInformation(ex, state, operation, addBefore, opIdx, targetIdx, - selectedKey, headPointersBeforeOp, rawPointersBeforeOp, pointersBeforeOp, - valuesBeforeOp, refListBeforeOp, selectedRefList) - - throw ex - } - } - } - - keys.indices.foreach { index => - val key = keys(index) - val refList = refLists(index) - - try { - if (refList.isEmpty) { - assert(state.isEmpty(key), s"Reference list is empty but " + - s"state list for $key is not empty") - } else { - import scala.collection.JavaConverters._ - val statePointers = state.iteratePointers(key).map(_._1).toList - assert(refList.asScala.map(_.toInt) === statePointers, - s"State pointers for $key is expected to be $refList but $statePointers") - - val stateValues = state.get(key).map(toValueInt).toList - assert(refList.asScala.map(_.toInt) === - stateValues, s"State list for $key is expected to be $refList but $stateValues") - } - } catch { - case ex: TestFailedException => - printFailureInformation(ex, state, operation, addBefore, opIdx, targetIdx, - selectedKey, headPointersBeforeOp, rawPointersBeforeOp, pointersBeforeOp, - valuesBeforeOp, refListBeforeOp, selectedRefList) - - throw ex - } - } - } - } - } - private def removeByWatermarkTest(stopOnConditionMismatch: Boolean): Unit = { withSessionWindowLinkedListState(inputValueAttribs, keyExprs) { state => implicit val st = state From 17570f24644805a7d478c2b4d68a531eee2aa015 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 30 Oct 2018 18:18:12 +0900 Subject: [PATCH 55/60] WIP optimize a bit on storing new sessions --- .../state/SessionWindowLinkedListState.scala | 4 +++ .../streaming/statefulOperators.scala | 34 ++++++++++++++----- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala index ee899a6882adf..6088543f3cea1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -197,6 +197,10 @@ class SessionWindowLinkedListState( findFirstSessionStartEnsurePredicate(key, predicate, head.get) } + def getSessionStartOnNearest(key: UnsafeRow, sessionStart: Long): (Option[Long], Option[Long]) = { + keyAndSessionStartToPointerStore.get(key, sessionStart) + } + def getPrevSessionStart(key: UnsafeRow, sessionStart: Long): Option[Long] = { val pointers = keyAndSessionStartToPointerStore.get(key, sessionStart) assertValidPointer(pointers) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3dcbf3c8c835e..d07405f32dfbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -552,10 +552,32 @@ case class SessionWindowStateStoreSaveExec( return true } - val existing = state.get(key, sessionStart) - if (existing != null && valueOrdering.equiv(existing, row)) { - // session already exist and do not need to update - return false + // need to find sessions which could be replaced with new session + // new session should enclose previous session(s) if it overlaps, + // since session always expands + + val nearestSessions = state.getSessionStartOnNearest(key, sessionStart) + if (nearestSessions != null) { + // there's rare chance that existing session and row is equivalent + // because in MergingSortWithSessionWindowLinkedListStateIterator, + // we emit existing sessions only when it overlaps with input row + // so unless aggregation make no difference, it will not happen + // always replace instead of comparing with actual value + + // if the old session can be replaced with new session, + // (condition: 1:1 match, no change on "session start") + // just replace it to avoid overhead on manipulating linked list + + nearestSessions._2 match { + case Some(next) if next > sessionEnd => + state.update(key, sessionStart, row) + return true + case None => + state.update(key, sessionStart, row) + return true + + case _ => + } } if (stateFetchedKey == null || keyOrdering.equiv(stateFetchedKey, key)) { @@ -563,10 +585,6 @@ case class SessionWindowStateStoreSaveExec( lastSearchedSessionStartOption = None } - // need to find sessions which could be replaced with new session - // new session should enclose previous session(s) if it overlaps, - // since session always expands - // find the first state session which is enclosed by new session val firstStateSessionEnclosedByNewSession = lastSearchedSessionStartOption match { case Some(lastSearchedSessionStart) => From 958de3192fb782372c58e494ba7956cbb6d87304 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 31 Oct 2018 10:37:02 +0900 Subject: [PATCH 56/60] WIP Fixed critical bug which tasks don't respect preference on state store --- .../StreamingSymmetricHashJoinExec.scala | 5 +- .../state/MultiValuesStateManager.scala | 5 ++ .../state/MultiValuesStateStoreRDD.scala | 79 ------------------- .../state/SessionWindowLinkedListState.scala | 7 +- ...SessionWindowLinkedListStateStoreRDD.scala | 12 +-- .../execution/streaming/state/package.scala | 43 ---------- 6 files changed, 19 insertions(+), 132 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 494e5579646fb..cddac5cdf1ef7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -507,9 +507,6 @@ case class StreamingSymmetricHashJoinExec( } private def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) - for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { - getStateStoreName(joinSide.toString, stateStoreType) - } + joinSides.flatMap(j => MultiValuesStateManager.getAllStateStoreName(j.toString)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala index 633c4326e7406..c112fc28bfce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala @@ -542,4 +542,9 @@ object MultiValuesStateManager { def getStateStoreName(storeNamePrefix: String, storeType: StateStoreType): String = { s"$storeNamePrefix-$storeType" } + + def getAllStateStoreName(storeNamePrefix: String): Seq[String] = { + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) + allStateStoreTypes.map(getStateStoreName(storeNamePrefix, _)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala deleted file mode 100644 index 32d63296a1abb..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateStoreRDD.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.state - -import scala.reflect.ClassTag - -import org.apache.spark.{Partition, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo -import org.apache.spark.sql.execution.streaming.continuous.EpochTracker -import org.apache.spark.sql.internal.SessionState -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration - -// FIXME: javadoc!! -class MultiValuesStateStoreRDD[T: ClassTag, U: ClassTag]( - dataRDD: RDD[T], - storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U], - stateInfo: StatefulOperatorStateInfo, - keySchema: StructType, - valueSchema: StructType, - indexOrdinal: Option[Int], - sessionState: SessionState, - @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) - extends RDD[U](dataRDD) { - - private val storeConf = new StateStoreConf(sessionState.conf) - - // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val hadoopConfBroadcast = dataRDD.context.broadcast( - new SerializableConfiguration(sessionState.newHadoopConf())) - - override protected def getPartitions: Array[Partition] = dataRDD.partitions - - /** - * Set the preferred location of each partition using the executor that has the related - * [[StateStoreProvider]] already loaded. - */ - override def getPreferredLocations(partition: Partition): Seq[String] = { - val stateStoreProviderId = StateStoreProviderId( - StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, partition.index), - stateInfo.queryRunId) - storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq - } - - override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { - // If we're in continuous processing mode, we should get the store version for the current - // epoch rather than the one at planning time. - val currentVersion = EpochTracker.getCurrentEpoch match { - case None => stateInfo.storeVersion - case Some(value) => value - } - - val modifiedStateInfo = stateInfo.copy(storeVersion = currentVersion) - - val stateManager: MultiValuesStateManager = new MultiValuesStateManager("session-", - valueSchema.toAttributes, keySchema.toAttributes, Some(modifiedStateInfo), storeConf, - hadoopConfBroadcast.value.value) - - val inputIter = dataRDD.iterator(partition, ctxt) - storeUpdateFunction(stateManager, inputIter) - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala index 6088543f3cea1..0e021e1c7c410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import java.util.Locale import org.apache.hadoop.conf.Configuration - import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} @@ -770,4 +769,10 @@ object SessionWindowLinkedListState { def getStateStoreName(storeNamePrefix: String, storeType: StateStoreType): String = { s"$storeNamePrefix-$storeType" } + + def getAllStateStoreName(storeNamePrefix: String): Seq[String] = { + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToHeadSessionStartType, + KeyAndSessionStartToPointerType, KeyAndSessionStartToValueType) + allStateStoreTypes.map(getStateStoreName(storeNamePrefix, _)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala index 4594df679fb66..51bb709deb209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListStateStoreRDD.scala @@ -45,6 +45,8 @@ class SessionWindowLinkedListStateStoreRDD[T: ClassTag, U: ClassTag]( private val hadoopConfBroadcast = dataRDD.context.broadcast( new SerializableConfiguration(sessionState.newHadoopConf())) + private val stateStorePrefix: String = s"sessionwindow-${stateInfo.operatorId}" + override protected def getPartitions: Array[Partition] = dataRDD.partitions /** @@ -52,10 +54,10 @@ class SessionWindowLinkedListStateStoreRDD[T: ClassTag, U: ClassTag]( * [[StateStoreProvider]] already loaded. */ override def getPreferredLocations(partition: Partition): Seq[String] = { - val stateStoreProviderId = StateStoreProviderId( - StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, partition.index), - stateInfo.queryRunId) - storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq + SessionWindowLinkedListState.getAllStateStoreName(stateStorePrefix).flatMap { storeName => + val stateStoreProviderId = StateStoreProviderId(stateInfo, partition.index, storeName) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)) + }.distinct } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { @@ -68,7 +70,7 @@ class SessionWindowLinkedListStateStoreRDD[T: ClassTag, U: ClassTag]( val modifiedStateInfo = stateInfo.copy(storeVersion = currentVersion) - val state = new SessionWindowLinkedListState(s"session-${stateInfo.operatorId}-", + val state = new SessionWindowLinkedListState(stateStorePrefix, valueSchema.toAttributes, keySchema.toAttributes, Some(modifiedStateInfo), storeConf, hadoopConfBroadcast.value.value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index d522a48404d29..0495da280ae6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -81,49 +81,6 @@ package object state { storeCoordinator) } - /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ - def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( - sqlContext: SQLContext, - stateInfo: StatefulOperatorStateInfo, - keySchema: StructType, - valueSchema: StructType, - indexOrdinal: Option[Int])( - storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) - : MultiValuesStateStoreRDD[T, U] = { - - mapPartitionsWithMultiValuesStateManager( - stateInfo, - keySchema, - valueSchema, - indexOrdinal, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator))( - storeUpdateFunction) - } - - /** Map each partition of an RDD along with data in a [[MultiValuesStateManager]]. */ - private[streaming] def mapPartitionsWithMultiValuesStateManager[U: ClassTag]( - stateInfo: StatefulOperatorStateInfo, - keySchema: StructType, - valueSchema: StructType, - indexOrdinal: Option[Int], - sessionState: SessionState, - storeCoordinator: Option[StateStoreCoordinatorRef])( - storeUpdateFunction: (MultiValuesStateManager, Iterator[T]) => Iterator[U]) - : MultiValuesStateStoreRDD[T, U] = { - - val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) - new MultiValuesStateStoreRDD( - dataRDD, - cleanedF, - stateInfo, - keySchema, - valueSchema, - indexOrdinal, - sessionState, - storeCoordinator) - } - /** Map each partition of an RDD along with data in a [[SessionWindowLinkedListState]]. */ def mapPartitionsWithSessionWindowLinkedListState[U: ClassTag]( sqlContext: SQLContext, From ee67bcaf6fa2d1ab17e755cb7d5edd5dd10115bc Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 1 Nov 2018 13:36:59 +0900 Subject: [PATCH 57/60] WIP Fix critical perf. issue: remove codegen on generating session row for group key Also modify CodeGenerator to print out debug information when code generation takes too long --- .../expressions/codegen/CodeGenerator.scala | 25 +++++++ .../aggregate/MergingSessionsIterator.scala | 68 +++++++++---------- .../state/SessionWindowLinkedListState.scala | 1 + 3 files changed, 58 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d5857e060a2c4..d47c161ae7d0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1229,6 +1229,10 @@ object CodeGenerator extends Logging { // bytecode instruction final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + // This is the threshold to print out debug information when code generation takes more + // than this value. + final val SLOW_CODEGEN_MILLIS_THRESHOLD = 100 + /** * Compile the Java source code into a Java class, using Janino. * @@ -1375,6 +1379,27 @@ object CodeGenerator extends Logging { CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length) CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong) logInfo(s"Code generated in $timeMs ms") + + if (timeMs > SLOW_CODEGEN_MILLIS_THRESHOLD) { + logWarning(s"Code generation took more than $SLOW_CODEGEN_MILLIS_THRESHOLD ms." + + "Please set logger level to DEBUG to see further debug information.") + + logDebug(s"Printing out debug information - body: ${code.body}... / " + + s"comment: ${code.comment}") + + def getRelevantStackTraceForDebug(): Array[StackTraceElement] = { + Thread.currentThread().getStackTrace.drop(1) + .filterNot { p => + p.getClassName.startsWith("com.google.common") || + p.getClassName.startsWith("org.apache.spark.sql.catalyst") || + p.getClassName.startsWith("org.apache.spark.rdd") + } + } + + logDebug(s"Stack trace - " + + s"${getRelevantStackTraceForDebug().take(30).map(_.toString).mkString("\n")}") + } + result } }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala index 64c6ffc8c5fb3..feace961e3bed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, GenericInternalRow, JoinedRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, GenericInternalRow, JoinedRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.metric.SQLMetric @@ -82,16 +82,12 @@ class MergingSessionsIterator( // The partition key of the current partition. private[this] var currentGroupingKey: UnsafeRow = _ - private[this] var currentSessionStart: Long = Long.MaxValue - - private[this] var currentSessionEnd: Long = Long.MinValue + private[this] var currentSession: UnsafeRow = _ // The partition key of next partition. private[this] var nextGroupingKey: UnsafeRow = _ - private[this] var nextGroupingSessionStart: Long = Long.MaxValue - - private[this] var nextGroupingSessionEnd: Long = Long.MinValue + private[this] var nextGroupingSession: UnsafeRow = _ // The first row of next partition. private[this] var firstRowInNextGroup: InternalRow = _ @@ -116,9 +112,7 @@ class MergingSessionsIterator( val inputRow = inputIterator.next() nextGroupingKey = groupingWithoutSessionProjection(inputRow).copy() val session = sessionProjection(inputRow) - val sessionRow = session.getStruct(0, 2) - nextGroupingSessionStart = sessionRow.getLong(0) - nextGroupingSessionEnd = sessionRow.getLong(1) + nextGroupingSession = session.getStruct(0, 2).copy() firstRowInNextGroup = inputRow.copy() sortedInputHasNewGroup = true } else { @@ -132,8 +126,7 @@ class MergingSessionsIterator( /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey - currentSessionStart = nextGroupingSessionStart - currentSessionEnd = nextGroupingSessionEnd + currentSession = nextGroupingSession // Now, we will start to find all rows belonging to this group. // We create a variable to track if we see the next group. @@ -149,27 +142,27 @@ class MergingSessionsIterator( val groupingKey = groupingWithoutSessionProjection(currentRow) val session = sessionProjection(currentRow) - val sessionRow = session.getStruct(0, 2) - val sessionStart = sessionRow.getLong(0) - val sessionEnd = sessionRow.getLong(1) + val sessionStruct = session.getStruct(0, 2) + val sessionStart = getSessionStart(sessionStruct) + val sessionEnd = getSessionEnd(sessionStruct) // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - if (sessionStart < currentSessionStart) { + if (sessionStart < getSessionStart(currentSession)) { throw new IllegalArgumentException("Input iterator is not sorted based on session!") - } else if (sessionStart <= currentSessionEnd) { + } else if (sessionStart <= getSessionEnd(currentSession)) { // expanding session length if needed expandEndOfCurrentSession(sessionEnd) processRow(sortBasedAggregationBuffer, currentRow) } else { // We find a new group. findNextPartition = true - startNewSession(currentRow, groupingKey, sessionStart, sessionEnd) + startNewSession(currentRow, groupingKey, sessionStruct) } } else { // We find a new group. findNextPartition = true - startNewSession(currentRow, groupingKey, sessionStart, sessionEnd) + startNewSession(currentRow, groupingKey, sessionStruct) } } @@ -180,17 +173,28 @@ class MergingSessionsIterator( } } - private def startNewSession(currentRow: InternalRow, groupingKey: UnsafeRow, sessionStart: Long, - sessionEnd: Long): Unit = { + private def startNewSession(currentRow: InternalRow, groupingKey: UnsafeRow, + sessionStruct: UnsafeRow): Unit = { nextGroupingKey = groupingKey.copy() - nextGroupingSessionStart = sessionStart - nextGroupingSessionEnd = sessionEnd + nextGroupingSession = sessionStruct.copy() firstRowInNextGroup = currentRow.copy() } + private def getSessionStart(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(0) + } + + private def getSessionEnd(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(1) + } + + def updateSessionEnd(sessionStruct: UnsafeRow, sessionEnd: Long): Unit = { + sessionStruct.setLong(1, sessionEnd) + } + private def expandEndOfCurrentSession(sessionEnd: Long): Unit = { - if (sessionEnd > currentSessionEnd) { - currentSessionEnd = sessionEnd + if (sessionEnd > getSessionEnd(currentSession)) { + updateSessionEnd(currentSession, sessionEnd) } } @@ -225,17 +229,9 @@ class MergingSessionsIterator( groupingWithoutSessionAttributes :+ sessionExpression.toAttribute) private def generateGroupingKey(): UnsafeRow = { - val sessionStruct = CreateNamedStruct( - Literal("start") :: - PreciseTimestampConversion( - Literal(currentSessionStart, LongType), LongType, TimestampType) :: - Literal("end") :: - PreciseTimestampConversion( - Literal(currentSessionEnd, LongType), LongType, TimestampType) :: - Nil) - - val joined = join(currentGroupingKey, - UnsafeProjection.create(sessionStruct).apply(InternalRow.empty)) + val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) + newRow.update(0, currentSession) + val joined = join(currentGroupingKey, newRow) groupingKeyProj(joined) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala index 0e021e1c7c410..eaac81d23724b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SessionWindowLinkedListState.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import java.util.Locale import org.apache.hadoop.conf.Configuration + import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} From 8a0331e253ebd54654c2464c1fa7b0935b7f711a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 2 Nov 2018 13:10:02 +0900 Subject: [PATCH 58/60] WIP Rolling back unnecessary changes --- .../aggregate/MergingSessionsIterator.scala | 3 +- ...gingSortWithMultiValuesStateIterator.scala | 153 --------- .../StreamingSymmetricHashJoinExec.scala | 11 +- .../state/StreamingSessionStateManager.scala | 159 ---------- ...la => SymmetricHashJoinStateManager.scala} | 108 ++----- ...ortWithMultiValuesStateIteratorSuite.scala | 292 ------------------ ... SymmetricHashJoinStateManagerSuite.scala} | 35 ++- 7 files changed, 51 insertions(+), 710 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/{MultiValuesStateManager.scala => SymmetricHashJoinStateManager.scala} (85%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/{MultiValuesStateManagerSuite.scala => SymmetricHashJoinStateManagerSuite.scala} (81%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala index feace961e3bed..1b9f78274c566 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/MergingSessionsIterator.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, CreateNamedStruct, Expression, GenericInternalRow, JoinedRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, MutableProjection, NamedExpression, PreciseTimestampConversion, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.{LongType, TimestampType} // FIXME: javadoc! // FIXME: groupingExpressions should contain sessionExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala deleted file mode 100644 index 3054b2268f41c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIterator.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.streaming.state.StreamingSessionStateManager - -// FIXME: javadoc!! -class MergingSortWithMultiValuesStateIterator( - iter: Iterator[InternalRow], - stateManager: StreamingSessionStateManager, - groupWithoutSessionExpressions: Seq[Expression], - sessionExpression: Expression, - keysProjection: UnsafeProjection, - sessionProjection: UnsafeProjection, - inputSchema: Seq[Attribute]) extends Iterator[InternalRow] { - - def this( - iter: Iterator[InternalRow], - stateManager: StreamingSessionStateManager, - groupWithoutSessionExpressions: Seq[Expression], - sessionExpression: Expression, - inputSchema: Seq[Attribute]) { - this(iter, stateManager, groupWithoutSessionExpressions, sessionExpression, - GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, inputSchema), - GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema), - inputSchema) - } - - private case class SessionRowInformation(keys: UnsafeRow, sessionStart: Long, sessionEnd: Long, - row: InternalRow) - - private object SessionRowInformation { - def of(row: InternalRow): SessionRowInformation = { - val keys = keysProjection(row).copy() - val session = sessionProjection(row).copy() - val sessionRow = session.getStruct(0, 2) - val sessionStart = sessionRow.getLong(0) - val sessionEnd = sessionRow.getLong(1) - - SessionRowInformation(keys, sessionStart, sessionEnd, row) - } - } - - private var currentRow: SessionRowInformation = _ - private var currentStateRow: SessionRowInformation = _ - private var currentStateIter: Iterator[InternalRow] = _ - private var currentStateFetchedKey: UnsafeRow = _ - - override def hasNext: Boolean = { - currentRow != null || currentStateRow != null || - (currentStateIter != null && currentStateIter.hasNext) || iter.hasNext - } - - override def next(): InternalRow = { - if (currentRow == null) { - mayFillCurrentRow() - } - - if (currentStateRow == null) { - mayFillCurrentStateRow() - } - - if (currentRow == null && currentStateRow == null) { - throw new IllegalStateException("No Row to provide in next() which should not happen!") - } - - // return current row vs current state row, should return smaller key, earlier session start - val returnCurrentRow: Boolean = { - if (currentRow == null) { - false - } else if (currentStateRow == null) { - true - } else { - // compare - if (currentRow.keys != currentStateRow.keys) { - // state row cannot advance to row in input, so state row should be lower - false - } else { - currentRow.sessionStart < currentStateRow.sessionStart - } - } - } - - val ret: SessionRowInformation = { - if (returnCurrentRow) { - val toRet = currentRow - currentRow = null - toRet - } else { - val toRet = currentStateRow - currentStateRow = null - toRet - } - } - - ret.row - } - - private def mayFillCurrentRow(): Unit = { - if (iter.hasNext) { - currentRow = SessionRowInformation.of(iter.next()) - } - } - - private def mayFillCurrentStateRow(): Unit = { - if (currentStateIter != null && currentStateIter.hasNext) { - currentStateRow = SessionRowInformation.of(currentStateIter.next()) - } else { - currentStateIter = null - - if (currentRow != null && currentRow.keys != currentStateFetchedKey) { - - // This is necessary because MultiValuesStateManager doesn't guarantee stable ordering - // The number of values for the given key is expected to be likely small, - // so sorting it here doesn't hurt. - val unsortedIter = stateManager.get(currentRow.keys) - currentStateIter = unsortedIter.map(_.copy()).toList.sortWith((row1, row2) => { - def getSessionStart(r: InternalRow): Long = { - val session = sessionProjection(r) - val sessionRow = session.getStruct(0, 2) - sessionRow.getLong(0) - } - - // here sorting is based on the fact that keys are same - getSessionStart(row1).compareTo(getSessionStart(row2)) < 0 - }).iterator - - currentStateFetchedKey = currentRow.keys - if (currentStateIter.hasNext) { - currentStateRow = SessionRowInformation.of(currentStateIter.next()) - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index cddac5cdf1ef7..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.execution.streaming.state.MultiValuesStateManager.{getStateStoreName, KeyToNumValuesType, KeyWithIndexToValueType, StateStoreType} import org.apache.spark.sql.internal.SessionState import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -201,7 +200,7 @@ case class StreamingSymmetricHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator - val stateStoreNames = allStateStoreNames(LeftSide, RightSide) + val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) left.execute().stateStoreAwareZipPartitions( right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions) } @@ -395,8 +394,8 @@ case class StreamingSymmetricHashJoinExec( val preJoinFilter = newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ - private val joinStateManager = new MultiValuesStateManager(joinSide.toString, inputAttributes, - joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) + private val joinStateManager = new SymmetricHashJoinStateManager( + joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { @@ -505,8 +504,4 @@ case class StreamingSymmetricHashJoinExec( def numUpdatedStateRows: Long = updatedStateRowsCount } - - private def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - joinSides.flatMap(j => MultiValuesStateManager.getAllStateStoreName(j.toString)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala deleted file mode 100644 index 539f3339ab4d0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionStateManager.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.state - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.StructType - -// FIXME: javadoc! -sealed trait StreamingSessionStateManager extends Serializable { - def getKey(row: UnsafeRow): UnsafeRow - - def getStateValueSchema: StructType - - def get(key: UnsafeRow): Iterator[UnsafeRow] - - def append(session: UnsafeRow): Boolean - - def doFinalize(): Unit - - def getAll(): Iterator[UnsafeRow] - - def evictSessionsByWatermark(): Iterator[UnsafeRow] - - def doEvictSessionsByWatermark(): Unit -} - -// FIXME: javadoc! -trait MultiValuesStateManagerInjectable { - def setMultiValuesStateManager(manager: MultiValuesStateManager): Unit -} - -object StreamingSessionStateManager extends Logging { - val supportedVersions = Seq(1) - - def createStateManager( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute], - watermarkPredicateForData: Option[Predicate], - stateFormatVersion: Int): StreamingSessionStateManager = { - stateFormatVersion match { - case 1 => new StreamingSessionStateManagerImplV1(keyExpressions, inputRowAttributes, - watermarkPredicateForData) - case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") - } - } -} - -abstract class StreamingSessionStateManagerBaseImpl( - protected val keyExpressions: Seq[Attribute], - protected val inputRowAttributes: Seq[Attribute]) extends StreamingSessionStateManager { - - @transient protected lazy val keyProjector = - GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) - - override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) -} - -// FIXME: javadoc! -class StreamingSessionStateManagerImplV1( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute], - watermarkPredicateForData: Option[Predicate]) - extends StreamingSessionStateManagerBaseImpl(keyExpressions, inputRowAttributes) - with MultiValuesStateManagerInjectable { - - var stateManager: MultiValuesStateManager = _ - var currentKey: UnsafeRow = _ - var previousSessions: List[UnsafeRow] = _ - - @transient protected lazy val keyOrdering = TypeUtils.getInterpretedOrdering( - keyExpressions.toStructType) - @transient protected lazy val valueOrdering = TypeUtils.getInterpretedOrdering( - inputRowAttributes.toStructType) - - override def setMultiValuesStateManager(manager: MultiValuesStateManager): Unit = { - stateManager = manager - } - - override def getStateValueSchema: StructType = inputRowAttributes.toStructType - - override def get(key: UnsafeRow): Iterator[UnsafeRow] = { - assertAvailability() - - stateManager.get(key) - } - - override def append(session: UnsafeRow): Boolean = { - assertAvailability() - - val key = keyProjector(session) - - if (currentKey == null || !keyOrdering.equiv(currentKey, key)) { - currentKey = key.copy() - - // This is necessary because MultiValuesStateManager doesn't guarantee - // stable ordering. - // The number of values for the given key is expected to be likely small, - // so listing it here doesn't hurt. - previousSessions = stateManager.get(key).toList - - stateManager.removeKey(key) - } - - stateManager.append(key, session) - - !previousSessions.exists(p => valueOrdering.equiv(session, p)) - } - - override def doFinalize(): Unit = { - assertAvailability() - - // do nothing - } - - override def getAll(): Iterator[UnsafeRow] = { - assertAvailability() - - stateManager.getAllRowPairs.map(_.value) - } - - override def evictSessionsByWatermark(): Iterator[UnsafeRow] = { - assertAvailability() - - stateManager.removeByValueCondition { row => watermarkPredicateForData match { - case Some(predicate) => predicate.eval(row) - case None => false - } - }.map(_.value) - } - - override def doEvictSessionsByWatermark(): Unit = { - assertAvailability() - - // consume all elements to let removal take effect - evictSessionsByWatermark().toList - } - - private def assertAvailability(): Unit = { - require(stateManager != null, "MultiValuesStateManager should be set before calling methods!") - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala similarity index 85% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index c112fc28bfce4..43f22803e7685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -24,20 +24,21 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec} +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.util.NextIterator /** - * Helper class to manage state which is useful for operations which require to store multiple - * values for a key. + * Helper class to manage state required by a single side of [[StreamingSymmetricHashJoinExec]]. * The interface of this class is basically that of a multi-map: * - Get: Returns an iterator of multiple values for given key * - Append: Append a new value to the given key * - Remove Data by predicate: Drop any state using a predicate condition on keys or values * + * @param joinSide Defines the join side * @param inputValueAttributes Attributes of the input row which will be stored as value - * @param keys Expressions to generate rows that will be used to key the value rows + * @param joinKeys Expressions to generate rows that will be used to key the value rows * @param stateInfo Information about how to retrieve the correct version of state * @param storeConf Configuration for the state store. * @param hadoopConf Hadoop configuration for reading state data from storage @@ -57,17 +58,16 @@ import org.apache.spark.util.NextIterator * the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore * by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex), * decrement corresponding num values in KeyToNumValuesStore - * (the operation doesn't guarantee stable ordering once value is removed) */ -class MultiValuesStateManager( - storeNamePrefix: String, +class SymmetricHashJoinStateManager( + val joinSide: JoinSide, inputValueAttributes: Seq[Attribute], - keys: Seq[Expression], + joinKeys: Seq[Expression], stateInfo: Option[StatefulOperatorStateInfo], storeConf: StateStoreConf, hadoopConf: Configuration) extends Logging { - import MultiValuesStateManager._ + import SymmetricHashJoinStateManager._ /* ===================================================== @@ -88,12 +88,6 @@ class MultiValuesStateManager( keyToNumValues.put(key, numExistingValues + 1) } - def removeKey(key: UnsafeRow): Unit = { - val numExistingValues = keyToNumValues.get(key) - keyToNumValues.remove(key) - (0L until numExistingValues).foreach(keyWithIndexToValue.remove(key, _)) - } - /** * Remove using a predicate on keys. * @@ -160,8 +154,6 @@ class MultiValuesStateManager( * * This implies the iterator must be consumed fully without any other operations on this manager * or the underlying store being interleaved. - * - * NOTE: It doesn't keep order of values being stable when removing one for performance gain. */ def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { new NextIterator[UnsafeRowPair] { @@ -257,51 +249,6 @@ class MultiValuesStateManager( } } - /** Provide all (key, value) row pairs (key can be exposed multiple times) */ - def getAllRowPairs: Iterator[UnsafeRowPair] = { - new NextIterator[UnsafeRowPair] { - // Reuse this object to avoid creation+GC overhead. - private val reusedPair = new UnsafeRowPair() - - private val allKeyToNumValues = keyToNumValues.iterator - - private var currentKey: UnsafeRow = null - private var numValues: Long = 0L - private var index: Long = 0L - - override def getNext(): UnsafeRowPair = { - if (currentKey != null && index < numValues) { - provideCurrentRow() - } else { - if (!allKeyToNumValues.hasNext) { - // finished - finished = true - null - } else { - advanceGroup() - assert(numValues != 0) - provideCurrentRow() - } - } - } - - private def advanceGroup(): Unit = { - val currentKeyToNumValue = allKeyToNumValues.next() - currentKey = currentKeyToNumValue.key - numValues = currentKeyToNumValue.numValue - index = 0 - } - - private def provideCurrentRow(): UnsafeRowPair = { - val currentRow = keyWithIndexToValue.get(currentKey, index) - index += 1 - reusedPair.withRows(currentKey, currentRow) - } - - override def close: Unit = {} - } - } - /** Commit all the changes to all the state stores */ def commit(): Unit = { keyToNumValues.commit() @@ -318,7 +265,7 @@ class MultiValuesStateManager( def metrics: StateStoreMetrics = { val keyToNumValuesMetrics = keyToNumValues.metrics val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics - def newDesc(desc: String): String = s"${storeNamePrefix.toUpperCase(Locale.ROOT)}: $desc" + def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc" StateStoreMetrics( keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once @@ -344,7 +291,7 @@ class MultiValuesStateManager( */ private val keySchema = StructType( - keys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) + joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) private val keyAttributes = keySchema.toAttributes private val keyToNumValues = new KeyToNumValuesStore() private val keyWithIndexToValue = new KeyWithIndexToValueStore() @@ -374,8 +321,8 @@ class MultiValuesStateManager( /** Get the StateStore with the given schema */ protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { - val storeProviderId = StateStoreProviderId(stateInfo.get, TaskContext.getPartitionId(), - getStateStoreName(storeNamePrefix, stateStoreType)) + val storeProviderId = StateStoreProviderId( + stateInfo.get, TaskContext.getPartitionId(), getStateStoreName(joinSide, stateStoreType)) val store = StateStore.get( storeProviderId, keySchema, valueSchema, None, stateInfo.get.storeVersion, storeConf, hadoopConf) @@ -433,8 +380,8 @@ class MultiValuesStateManager( * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. * Designed for object reuse. */ - private case class KeyWithIndexAndValue(var key: UnsafeRow = null, var valueIndex: Long = -1, - var value: UnsafeRow = null) { + private case class KeyWithIndexAndValue( + var key: UnsafeRow = null, var valueIndex: Long = -1, var value: UnsafeRow = null) { def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow): this.type = { this.key = newKey this.valueIndex = newIndex @@ -528,23 +475,26 @@ class MultiValuesStateManager( } } -object MultiValuesStateManager { - sealed trait StateStoreType +object SymmetricHashJoinStateManager { - case object KeyToNumValuesType extends StateStoreType { - override def toString(): String = "keyToNumValues" + def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) + for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { + getStateStoreName(joinSide, stateStoreType) + } } - case object KeyWithIndexToValueType extends StateStoreType { - override def toString(): String = "keyWithIndexToValue" + private sealed trait StateStoreType + + private case object KeyToNumValuesType extends StateStoreType { + override def toString(): String = "keyToNumValues" } - def getStateStoreName(storeNamePrefix: String, storeType: StateStoreType): String = { - s"$storeNamePrefix-$storeType" + private case object KeyWithIndexToValueType extends StateStoreType { + override def toString(): String = "keyWithIndexToValue" } - def getAllStateStoreName(storeNamePrefix: String): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) - allStateStoreTypes.map(getStateStoreName(storeNamePrefix, _)) + private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { + s"$joinSide-$storeType" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala deleted file mode 100644 index 5bb349ee38336..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithMultiValuesStateIteratorSuite.scala +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.UUID - -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.streaming.state.{MultiValuesStateManager, MultiValuesStateManagerInjectable, StateStore, StateStoreConf, StreamingSessionStateManager} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class MergingSortWithMultiValuesStateIteratorSuite extends SharedSQLContext { - - val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) - .add("session", new StructType().add("start", LongType).add("end", LongType)) - .add("aggVal1", LongType).add("aggVal2", DoubleType) - val rowAttributes = rowSchema.toAttributes - - val keysWithoutSessionSchema = rowSchema.filter(st => List("key1", "key2").contains(st.name)) - val keysWithoutSessionAttributes = rowAttributes.filter { - attr => List("key1", "key2").contains(attr.name) - } - - val sessionSchema = rowSchema.filter(st => st.name == "session").head - val sessionAttribute = rowAttributes.filter(attr => attr.name == "session").head - - val valuesSchema = rowSchema.filter(st => List("aggVal1", "aggVal2").contains(st.name)) - val valuesAttributes = rowAttributes.filter { - attr => List("aggVal1", "aggVal2").contains(attr.name) - } - - // TODO: would we want to randomize or test all? - val stateStoreVersion = StreamingSessionStateManager.supportedVersions.last - - test("no row in input data") { - withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, - stateStoreVersion) { manager => - val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) - - assert(!iterator.hasNext) - } - } - - test("no row in input data but having state") { - withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, - stateStoreVersion) { manager => - val srow11 = createRow("a", 1, 55, 85, 50, 2.5) - val srow12 = createRow("a", 1, 105, 140, 30, 2.0) - appendRowToStateManager(manager, srow11, srow12) - - val iterator = new MergingSortWithMultiValuesStateIterator(None.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) - - assert(!iterator.hasNext) - } - } - - test("no previous state") { - withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, - stateStoreVersion) { manager => - val row1 = createRow("a", 1, 100, 110, 10, 1.1) - val row2 = createRow("a", 1, 100, 110, 20, 1.2) - val row3 = createRow("a", 2, 110, 120, 10, 1.1) - val row4 = createRow("a", 2, 115, 125, 20, 1.2) - val rows = List(row1, row2, row3, row4) - - val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) - - rows.foreach { row => - assert(iterator.hasNext) - assertRowsEquals(row, iterator.next()) - } - - assert(!iterator.hasNext) - } - } - - test("multiple keys in input data and state") { - withStreamingSessionStateManager(rowAttributes, keysWithoutSessionAttributes, - stateStoreVersion) { manager => - // key 1 - placing sessions in state to start and end - val row11 = createRow("a", 1, 100, 110, 10, 1.1) - val row12 = createRow("a", 1, 100, 110, 20, 1.2) - - val srow11 = createRow("a", 1, 55, 85, 50, 2.5) - val srow12 = createRow("a", 1, 105, 140, 30, 2.0) - appendRowToStateManager(manager, srow11, srow12) - - // key 2 - no state - val row21 = createRow("a", 2, 110, 120, 10, 1.1) - val row22 = createRow("a", 2, 115, 125, 20, 1.2) - - // key 3 - placing sessions in state to only start - val row31 = createRow("a", 3, 130, 140, 10, 1.1) - val row32 = createRow("a", 3, 135, 145, 20, 1.2) - - val srow31 = createRow("a", 3, 105, 140, 30, 2.0) - val srow32 = createRow("a", 3, 120, 150, 30, 2.0) - appendRowToStateManager(manager, srow31, srow32) - - // key 4 - placing sessions in state to only end - val row41 = createRow("a", 4, 100, 110, 10, 1.1) - val row42 = createRow("a", 4, 100, 115, 20, 1.2) - - val srow41 = createRow("a", 4, 105, 140, 30, 2.0) - val srow42 = createRow("a", 4, 120, 150, 30, 2.0) - appendRowToStateManager(manager, srow41, srow42) - - // key 5 - placing sessions in state like one row and state session and another - val row51 = createRow("a", 5, 100, 110, 10, 1.1) - val row52 = createRow("a", 5, 120, 130, 20, 1.2) - - val srow51 = createRow("a", 5, 90, 120, 30, 2.0) - val srow52 = createRow("a", 5, 110, 125, 30, 2.0) - val srow53 = createRow("a", 5, 130, 150, 30, 2.0) - appendRowToStateManager(manager, srow51, srow52, srow53) - - val rows = List(row11, row12, row21, row22, row31, row32, row41, row42, row51, row52) - - val expectedRowSequence = List(srow11, row11, row12, srow12, row21, row22, srow31, srow32, - row31, row32, row41, row42, srow41, srow42, srow51, row51, srow52, row52, srow53) - - val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, keysWithoutSessionAttributes, sessionAttribute, rowAttributes) - - expectedRowSequence.foreach { row => - assert(iterator.hasNext) - assertRowsEquals(row, iterator.next()) - } - - assert(!iterator.hasNext) - } - } - - test("no keys in input data and state") { - val noKeyRowSchema = new StructType() - .add("session", new StructType().add("start", LongType).add("end", LongType)) - .add("aggVal1", LongType).add("aggVal2", DoubleType) - val noKeyRowAttributes = noKeyRowSchema.toAttributes - - val noKeySessionAttribute = noKeyRowAttributes.filter(attr => attr.name == "session").head - - def createNoKeyRow(sessionStart: Long, sessionEnd: Long, - aggVal1: Long, aggVal2: Double): UnsafeRow = { - val genericRow = new GenericInternalRow(4) - val session: Array[Any] = new Array[Any](2) - session(0) = sessionStart - session(1) = sessionEnd - - val sessionRow = new GenericInternalRow(session) - genericRow.update(0, sessionRow) - - genericRow.setLong(1, aggVal1) - genericRow.setDouble(2, aggVal2) - - val rowProjection = GenerateUnsafeProjection.generate(noKeyRowAttributes, noKeyRowAttributes) - rowProjection(genericRow) - } - - def assertNoKeyRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { - assert(retRow.getStruct(0, 2).getLong(0) == expectedRow.getStruct(0, 2).getLong(0)) - assert(retRow.getStruct(0, 2).getLong(1) == expectedRow.getStruct(0, 2).getLong(1)) - assert(retRow.getLong(1) === expectedRow.getLong(1)) - assert(doubleEquals(retRow.getDouble(2), expectedRow.getDouble(2))) - } - - def appendNoKeyRowToStateManager(manager: StreamingSessionStateManager, rows: UnsafeRow*) - : Unit = { - rows.foreach(manager.append) - } - - withStreamingSessionStateManager(noKeyRowAttributes, Seq.empty[Attribute], - stateStoreVersion) { manager => - // only input data - val row1 = createNoKeyRow(100, 110, 10, 1.1) - val row2 = createNoKeyRow(100, 110, 20, 1.2) - - val srow1 = createNoKeyRow(55, 85, 50, 2.5) - val srow2 = createNoKeyRow(105, 140, 30, 2.0) - appendNoKeyRowToStateManager(manager, srow1, srow2) - - val rows = List(row1, row2) - - val expectedRowSequence = List(srow1, row1, row2, srow2) - - val iterator = new MergingSortWithMultiValuesStateIterator(rows.iterator, - manager, Seq.empty[Attribute], noKeySessionAttribute, noKeyRowAttributes) - - expectedRowSequence.foreach { row => - assert(iterator.hasNext) - assertNoKeyRowsEquals(row, iterator.next()) - } - - assert(!iterator.hasNext) - } - } - - private def createRow(key1: String, key2: Int, sessionStart: Long, sessionEnd: Long, - aggVal1: Long, aggVal2: Double): UnsafeRow = { - val genericRow = new GenericInternalRow(6) - if (key1 != null) { - genericRow.update(0, UTF8String.fromString(key1)) - } else { - genericRow.setNullAt(0) - } - genericRow.setInt(1, key2) - - val session: Array[Any] = new Array[Any](2) - session(0) = sessionStart - session(1) = sessionEnd - - val sessionRow = new GenericInternalRow(session) - genericRow.update(2, sessionRow) - - genericRow.setLong(3, aggVal1) - genericRow.setDouble(4, aggVal2) - - val rowProjection = GenerateUnsafeProjection.generate(rowAttributes, rowAttributes) - rowProjection(genericRow) - } - - private def appendRowToStateManager(manager: StreamingSessionStateManager, rows: UnsafeRow*) - : Unit = { - rows.foreach(row => manager.append(row)) - } - - private def doubleEquals(value1: Double, value2: Double): Boolean = { - value1 > value2 - 0.000001 && value1 < value2 + 0.000001 - } - - private def assertRowsEquals(expectedRow: InternalRow, retRow: InternalRow): Unit = { - assert(retRow.getString(0) === expectedRow.getString(0)) - assert(retRow.getInt(1) === expectedRow.getInt(1)) - assert(retRow.getStruct(2, 2).getLong(0) == expectedRow.getStruct(2, 2).getLong(0)) - assert(retRow.getStruct(2, 2).getLong(1) == expectedRow.getStruct(2, 2).getLong(1)) - assert(retRow.getLong(3) === expectedRow.getLong(3)) - assert(doubleEquals(retRow.getDouble(3), expectedRow.getDouble(3))) - } - - private def withStreamingSessionStateManager( - inputValueAttribs: Seq[Attribute], - keyAttribs: Seq[Attribute], - stateVersion: Int)(f: StreamingSessionStateManager => Unit): Unit = { - - withTempDir { file => - val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) - - val manager = StreamingSessionStateManager.createStateManager( - keyAttribs, inputValueAttribs, None, stateVersion) - - manager match { - case mvState: MultiValuesStateManagerInjectable => - val store = new MultiValuesStateManager("session-", inputValueAttribs, keyAttribs, - Some(stateInfo), storeConf, new Configuration) - mvState.setMultiValuesStateManager(store) - - try { - f(manager) - } finally { - - store.abortIfNeeded() - } - - case _ => throw new IllegalStateException("Should inject matching underlying state store!") - } - } - StateStore.stop() - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala similarity index 81% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManagerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 7843dddee84e0..f3c334040c2ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MultiValuesStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types._ -class MultiValuesStateManagerSuite extends StreamTest with BeforeAndAfter { +class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -39,8 +39,8 @@ class MultiValuesStateManagerSuite extends StreamTest with BeforeAndAfter { } - test("MultiValuesStateManager - all operations") { - withStateManager(inputValueAttribs, keyExprs) { manager => + test("SymmetricHashJoinStateManager - all operations") { + withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager => implicit val mgr = manager assert(get(20) === Seq.empty) // initially empty @@ -106,10 +106,10 @@ class MultiValuesStateManagerSuite extends StreamTest with BeforeAndAfter { .add(StructField("value", BooleanType)) val inputValueAttribs = inputValueSchema.toAttributes val inputValueAttribWithWatermark = inputValueAttribs(0) - val keyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) + val joinKeyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) val inputValueGen = UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray) - val keyGen = UnsafeProjection.create(keyExprs.map(_.dataType).toArray) + val joinKeyGen = UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray) def toInputValue(i: Int): UnsafeRow = { @@ -117,21 +117,21 @@ class MultiValuesStateManagerSuite extends StreamTest with BeforeAndAfter { } def toJoinKeyRow(i: Int): UnsafeRow = { - keyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) + joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0))) } def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0) - def append(key: Int, value: Int)(implicit manager: MultiValuesStateManager): Unit = { + def append(key: Int, value: Int)(implicit manager: SymmetricHashJoinStateManager): Unit = { manager.append(toJoinKeyRow(key), toInputValue(value)) } - def get(key: Int)(implicit manager: MultiValuesStateManager): Seq[Int] = { + def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted } /** Remove keys (and corresponding values) where `time <= threshold` */ - def removeByKey(threshold: Long)(implicit manager: MultiValuesStateManager): Unit = { + def removeByKey(threshold: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { val expr = LessThanOrEqual( BoundReference( @@ -142,26 +142,27 @@ class MultiValuesStateManagerSuite extends StreamTest with BeforeAndAfter { } /** Remove values where `time <= threshold` */ - def removeByValue(watermark: Long)(implicit manager: MultiValuesStateManager): Unit = { + def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) while (iter.hasNext) iter.next() } - def numRows(implicit manager: MultiValuesStateManager): Long = { + def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { manager.metrics.numKeys } - def withStateManager( - inputValueAttribs: Seq[Attribute], - keyExprs: Seq[Expression])(f: MultiValuesStateManager => Unit): Unit = { + + def withJoinStateManager( + inputValueAttribs: Seq[Attribute], + joinKeyExprs: Seq[Expression])(f: SymmetricHashJoinStateManager => Unit): Unit = { withTempDir { file => val storeConf = new StateStoreConf() val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) - val manager = new MultiValuesStateManager(LeftSide.toString(), inputValueAttribs, keyExprs, - Some(stateInfo), storeConf, new Configuration) + val manager = new SymmetricHashJoinStateManager( + LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) try { f(manager) } finally { @@ -170,4 +171,4 @@ class MultiValuesStateManagerSuite extends StreamTest with BeforeAndAfter { } StateStore.stop() } -} +} \ No newline at end of file From b6ccecdfa3a3d31305667541b8d8fd761e5d3aee Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 2 Nov 2018 14:12:56 +0900 Subject: [PATCH 59/60] WIP Apply removing codegen to UpdatingSessionIterator as well --- .../sql/execution/aggregate/AggUtils.scala | 2 +- .../aggregate/UpdatingSessionIterator.scala | 133 ++++++++++-------- .../UpdatingSessionIteratorSuite.scala | 26 ++-- .../SymmetricHashJoinStateManagerSuite.scala | 2 +- 4 files changed, 90 insertions(+), 73 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 6042fc64a466a..ae4da8d6b1ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -490,7 +490,7 @@ object AggUtils { val childOrdering = Seq((groupingWithoutSessionAttributes ++ Seq(sessionExpression)) .map(SortOrder(_, Ascending))) val updatedSession = UpdatingSessionExec( - groupingWithoutSessionAttributes, + groupingExpressions.map(_.toAttribute), sessionExpression.toAttribute, optRequiredChildDistribution = Some(childDistribution), optRequiredChildOrdering = Some(childOrdering), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala index 00b380b320749..ab54fa2220cbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionIterator.scala @@ -28,24 +28,28 @@ import org.apache.spark.sql.types.{LongType, TimestampType} // FIXME: javadoc!! class UpdatingSessionIterator( iter: Iterator[InternalRow], - groupWithoutSessionExpressions: Seq[Expression], - sessionExpression: Expression, + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, inputSchema: Seq[Attribute], inMemoryThreshold: Int, spillThreshold: Int) extends Iterator[InternalRow] { val sessionIndex = inputSchema.indexOf(sessionExpression) - val valuesExpressions: Seq[Attribute] = inputSchema.diff(groupWithoutSessionExpressions) - .diff(Seq(sessionExpression)) + private val groupingWithoutSession: Seq[NamedExpression] = + groupingExpressions.diff(Seq(sessionExpression)) + private val groupingWithoutSessionAttributes: Seq[Attribute] = + groupingWithoutSession.map(_.toAttribute) + private[this] val groupingWithoutSessionProjection: UnsafeProjection = + UnsafeProjection.create(groupingWithoutSession, inputSchema) - val keysProjection = GenerateUnsafeProjection.generate(groupWithoutSessionExpressions, - inputSchema) - val sessionProjection = GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) + val valuesExpressions: Seq[Attribute] = inputSchema.diff(groupingWithoutSession) + + private[this] val sessionProjection: UnsafeProjection = + UnsafeProjection.create(Seq(sessionExpression), inputSchema) var currentKeys: InternalRow = _ - var currentSessionStart: Long = Long.MaxValue - var currentSessionEnd: Long = Long.MinValue + var currentSession: UnsafeRow = _ var currentRows: ExternalAppendOnlyUnsafeRowArray = new ExternalAppendOnlyUnsafeRowArray( inMemoryThreshold, spillThreshold) @@ -85,29 +89,29 @@ class UpdatingSessionIterator( // without this, multiple rows in same key will be returned with same content val row = iter.next().copy() - val keys = keysProjection(row).copy() - val session = sessionProjection(row).copy() - val sessionRow = session.getStruct(0, 2) - val sessionStart = sessionRow.getLong(0) - val sessionEnd = sessionRow.getLong(1) + val keys = groupingWithoutSessionProjection(row) + val session = sessionProjection(row) + val sessionStruct = session.getStruct(0, 2) + val sessionStart = getSessionStart(sessionStruct) + val sessionEnd = getSessionEnd(sessionStruct) if (currentKeys == null) { - startNewSession(row, keys, sessionStart, sessionEnd) + startNewSession(row, keys, sessionStruct) } else if (keys != currentKeys) { closeCurrentSession(keyChanged = true) processedKeys.add(currentKeys) - startNewSession(row, keys, sessionStart, sessionEnd) + startNewSession(row, keys, sessionStruct) exitCondition = true } else { - if (sessionStart < currentSessionStart) { + if (sessionStart < getSessionStart(currentSession)) { handleBrokenPreconditionForSort() - } else if (sessionStart <= currentSessionEnd) { + } else if (sessionStart <= getSessionEnd(currentSession)) { // expanding session length if needed expandEndOfCurrentSession(sessionEnd) currentRows.add(row.asInstanceOf[UnsafeRow]) } else { closeCurrentSession(keyChanged = false) - startNewSession(row, keys, sessionStart, sessionEnd) + startNewSession(row, keys, sessionStruct) exitCondition = true } } @@ -124,24 +128,35 @@ class UpdatingSessionIterator( returnRowsIter.next() } - private def expandEndOfCurrentSession(sessionEnd: Long): Unit = { - if (sessionEnd > currentSessionEnd) { - currentSessionEnd = sessionEnd - } - } - - private def startNewSession(row: InternalRow, keys: UnsafeRow, sessionStart: Long, - sessionEnd: Long): Unit = { - if (processedKeys.contains(keys)) { + private def startNewSession(currentRow: InternalRow, groupingKey: UnsafeRow, + sessionStruct: UnsafeRow): Unit = { + if (processedKeys.contains(groupingKey)) { handleBrokenPreconditionForSort() } - currentKeys = keys - currentSessionStart = sessionStart - currentSessionEnd = sessionEnd + currentKeys = groupingKey.copy() + currentSession = sessionStruct.copy() currentRows.clear() - currentRows.add(row.asInstanceOf[UnsafeRow]) + currentRows.add(currentRow.asInstanceOf[UnsafeRow]) + } + + private def getSessionStart(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(0) + } + + private def getSessionEnd(sessionStruct: UnsafeRow): Long = { + sessionStruct.getLong(1) + } + + def updateSessionEnd(sessionStruct: UnsafeRow, sessionEnd: Long): Unit = { + sessionStruct.setLong(1, sessionEnd) + } + + private def expandEndOfCurrentSession(sessionEnd: Long): Unit = { + if (sessionEnd > getSessionEnd(currentSession)) { + updateSessionEnd(currentSession, sessionEnd) + } } private def handleBrokenPreconditionForSort(): Unit = { @@ -149,38 +164,39 @@ class UpdatingSessionIterator( throw new IllegalStateException("The iterator must be sorted by key and session start!") } - private def closeCurrentSession(keyChanged: Boolean): Unit = { - // FIXME: Convert to JoinRow if possible to reduce codegen for unsafe projection - // FIXME: Same approach on MergingSessionsIterator.generateGroupingKey doesn't work here, why? - val sessionStruct = CreateNamedStruct( - Literal("start") :: - PreciseTimestampConversion( - Literal(currentSessionStart, LongType), LongType, TimestampType) :: - Literal("end") :: - PreciseTimestampConversion( - Literal(currentSessionEnd, LongType), LongType, TimestampType) :: - Nil) - - val convertedAllExpressions = inputSchema.map { x => - BindReferences.bindReference[Expression](x, inputSchema) - } + private def createSessionRow(): InternalRow = { + val sessionRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) + sessionRow.update(0, currentSession) + sessionRow + } - val newSchemaExpressions = convertedAllExpressions.indices.map { idx => - if (idx == sessionIndex) { - sessionStruct - } else { - convertedAllExpressions(idx) - } - } + private val join = new JoinedRow + private val join2 = new JoinedRow + private val groupingKeyProj = GenerateUnsafeProjection.generate(groupingExpressions, + groupingWithoutSessionAttributes :+ sessionExpression.toAttribute) + private val valueProj = GenerateUnsafeProjection.generate(valuesExpressions, inputSchema) + private val restoreProj = GenerateUnsafeProjection.generate(inputSchema, + groupingExpressions.map(_.toAttribute) ++ valuesExpressions.map(_.toAttribute)) + + private def generateGroupingKey(): UnsafeRow = { + val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) + newRow.update(0, currentSession) + val joined = join(currentKeys, newRow) + + groupingKeyProj(joined) + } + + private def closeCurrentSession(keyChanged: Boolean): Unit = { returnRows = currentRows currentRows = new ExternalAppendOnlyUnsafeRowArray( inMemoryThreshold, spillThreshold) + val groupingKey = generateGroupingKey() + val currentRowsIter = returnRows.generateIterator().map { internalRow => - // FIXME: is there any way to change this? - val proj = UnsafeProjection.create(newSchemaExpressions, inputSchema) - proj(internalRow) + val valueRow = valueProj(internalRow) + restoreProj(join2(groupingKey, valueRow)).copy() } if (returnRowsIter != null && returnRowsIter.hasNext) { @@ -192,8 +208,7 @@ class UpdatingSessionIterator( if (keyChanged) processedKeys.add(currentKeys) currentKeys = null - currentSessionStart = Long.MaxValue - currentSessionEnd = Long.MinValue + currentSession = null } private def assertIteratorNotCorrupted(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala index e941002a31189..ff6d774200f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionIteratorSuite.scala @@ -36,9 +36,11 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { .add("aggVal1", LongType).add("aggVal2", DoubleType) val rowAttributes = rowSchema.toAttributes - val keysWithoutSessionSchema = rowSchema.filter(st => List("key1", "key2").contains(st.name)) - val keysWithoutSessionAttributes = rowAttributes.filter { - attr => List("key1", "key2").contains(attr.name) + val keysWithSessionSchema = rowSchema.filter { attr => + List("key1", "key2", "session").contains(attr.name) + } + val keysWithSessionAttributes = rowAttributes.filter { attr => + List("key1", "key2", "session").contains(attr.name) } val sessionSchema = rowSchema.filter(st => st.name == "session").head @@ -67,7 +69,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val spillThreshold = Int.MaxValue test("no row") { - val iterator = new UpdatingSessionIterator(None.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(None.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) assert(!iterator.hasNext) @@ -76,7 +78,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { test("only one row") { val rows = List(createRow("a", 1, 100, 110, 10, 1.1)) - val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) assert(iterator.hasNext) @@ -94,7 +96,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val row4 = createRow("a", 1, 113, 123, 40, 1.4) val rows = List(row1, row2, row3, row4) - val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows = rows.indices.map { _ => @@ -125,7 +127,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 - val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows1 = rows1.indices.map { _ => @@ -161,7 +163,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 - val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows1 = rows1.indices.map { _ => @@ -206,7 +208,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 - val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(rowsAll.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) val retRows1 = rows1.indices.map { _ => @@ -259,7 +261,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val row4 = createRow("a", 1, 113, 123, 40, 1.4) val rows = List(row1, row2, row3, row4) - val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) // UpdatingSessionIterator can't detect error on hasNext @@ -286,7 +288,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val row3 = createRow("a", 1, 113, 123, 40, 1.4) val rows = List(row1, row2, row3) - val iterator = new UpdatingSessionIterator(rows.iterator, keysWithoutSessionAttributes, + val iterator = new UpdatingSessionIterator(rows.iterator, keysWithSessionAttributes, sessionAttribute, rowAttributes, inMemoryThreshold, spillThreshold) // UpdatingSessionIterator can't detect error on hasNext @@ -352,7 +354,7 @@ class UpdatingSessionIteratorSuite extends SharedSQLContext { val row4 = createNoKeyRow(113, 123, 40, 1.4) val rows = List(row1, row2, row3, row4) - val iterator = new UpdatingSessionIterator(rows.iterator, Seq.empty[Attribute], + val iterator = new UpdatingSessionIterator(rows.iterator, Seq(noKeySessionAttribute), noKeySessionAttribute, noKeyRowAttributes, inMemoryThreshold, spillThreshold) val retRows = rows.indices.map { _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index f3c334040c2ab..c0216a2ef3e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -171,4 +171,4 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } StateStore.stop() } -} \ No newline at end of file +} From 75c7611996ddadfefed12655b6e039646ff6f3d4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 2 Nov 2018 15:35:02 +0900 Subject: [PATCH 60/60] WIP remove state version for now: it will be reintroduced when actual review is in progress --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 ----------- .../apache/spark/sql/execution/SparkStrategies.scala | 2 -- .../spark/sql/execution/aggregate/AggUtils.scala | 4 +--- .../execution/streaming/IncrementalExecution.scala | 6 ++---- .../spark/sql/execution/streaming/OffsetSeq.scala | 5 ++--- .../sql/execution/streaming/statefulOperators.scala | 6 +----- 6 files changed, 6 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index db898a99e4630..88c513a3e46bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -924,17 +924,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION = - buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion") - .internal() - .doc("State format version used by streaming session window aggregation operations " + - "in a streaming query. " + - "State between versions are tend to be incompatible, so state format version shouldn't " + - "be modified after running.") - .intConf - .checkValue(v => Set(1).contains(v), "Valid versions are 1") - .createWithDefault(1) - val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3a69cc5e54a09..2cc8ddaf99799 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -335,7 +335,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sessionWindowOption match { case Some(sessionWindow) => - val stateVersion = conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION) aggregate.AggUtils.planStreamingAggregationForSession( namedGroupingExpressions, @@ -343,7 +342,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, conf.streamingSessionWindowMergeSessionInLocalPartition, - stateVersion, planLater(child)) case None => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ae4da8d6b1ee8..faca078041efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -371,7 +371,6 @@ object AggUtils { functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], mergeSessionsInLocalPartition: Boolean, - stateFormatVersion: Int, child: SparkPlan): Seq[SparkPlan] = { val groupWithoutSessionExpression = groupingExpressions.filterNot { p => @@ -422,7 +421,7 @@ object AggUtils { // shuffle & sort happens here: most of details are also handled in this physical plan val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, - stateFormatVersion, partialMerged1) + partialMerged1) val mergedSessions = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -450,7 +449,6 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, - stateFormatVersion, mergedSessions) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 10ec0daeb24ea..3a55cfab5dded 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -119,9 +119,9 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) - case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion, + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, UnaryExecNode(agg, - SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) => + SessionWindowStateStoreRestoreExec(_, _, None, None, child))) => val aggStateInfo = nextStatefulOperationStateInfo SessionWindowStateStoreSaveExec( keys, @@ -129,14 +129,12 @@ class IncrementalExecution( Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), - stateFormatVersion, agg.withNewChildren( SessionWindowStateStoreRestoreExec( keys, session, Some(aggStateInfo), Some(offsetSeqMetadata.batchWatermarkMs), - stateFormatVersion, child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 7e2f7d95353ff..73cf355dbe758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager, StreamingSessionStateManager} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** @@ -89,8 +89,7 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, - STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d07405f32dfbb..e0a98bb064ac6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -431,7 +431,6 @@ case class SessionWindowStateStoreRestoreExec( sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], - stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader with WatermarkSupport { @@ -496,7 +495,7 @@ case class SessionWindowStateStoreRestoreExec( /** * For each input tuple, the key is calculated and sessions are being `put` into - * the [[MultiValuesStateManager]]. + * the [[SessionWindowLinkedListState]]. */ case class SessionWindowStateStoreSaveExec( keyWithoutSessionExpressions: Seq[Attribute], @@ -504,7 +503,6 @@ case class SessionWindowStateStoreSaveExec( stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, - stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -535,8 +533,6 @@ case class SessionWindowStateStoreSaveExec( val keyOrdering = TypeUtils.getInterpretedOrdering(keyExpressions.toStructType) .asInstanceOf[Ordering[UnsafeRow]] - val valueOrdering = TypeUtils.getInterpretedOrdering(child.output.toStructType) - .asInstanceOf[Ordering[UnsafeRow]] var lastSearchedSessionStartOption: Option[Long] = None var stateFetchedKey: UnsafeRow = null