diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b54ae9082f840..567375cbd1283 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -528,6 +528,11 @@ object FunctionRegistry { expression[HllSketchAgg]("hll_sketch_agg"), expression[HllUnionAgg]("hll_union_agg"), expression[ApproxTopK]("approx_top_k"), + expression[ThetaSketchAgg]("theta_sketch_agg"), + expression[ThetaUnionAgg]("theta_union_agg"), + expression[ThetaIntersectionAgg]("theta_intersection_agg"), + + // string functions expression[Ascii]("ascii"), @@ -786,6 +791,10 @@ object FunctionRegistry { expression[EqualNull]("equal_null"), expression[HllSketchEstimate]("hll_sketch_estimate"), expression[HllUnion]("hll_union"), + expression[ThetaSketchEstimate]("theta_sketch_estimate"), + expression[ThetaUnion]("theta_union"), + expression[ThetaDifference]("theta_difference"), + expression[ThetaIntersection]("theta_intersection"), // grouping sets expression[Grouping]("grouping"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala new file mode 100644 index 0000000000000..a6df1bd58e205 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala @@ -0,0 +1,668 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.datasketches.common.SketchesArgumentException +import org.apache.datasketches.memory.Memory +import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Union, UpdateSketch, UpdateSketchBuilder} +import org.apache.datasketches.thetacommon.ThetaUtil + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BinaryType, DataType, DoubleType, FloatType, IntegerType, LongType, StringType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String + +sealed trait ThetaSketchState +case class UpdatableSketchBuffer(sketch: UpdateSketch) extends ThetaSketchState +case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState +case class UnionAggregationBuffer(sketch: Union) extends ThetaSketchState +case class IntersectionAggregationBuffer(sketch: Intersection) extends ThetaSketchState + +/** + * The ThetaSketchAgg function utilizes a Datasketches ThetaSketch instance to count a + * probabilistic approximation of the number of unique values in a given column, and outputs the + * binary representation of the ThetaSketch. + * + * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. + * + * @param left + * child expression against which unique counting will occur + * @param right + * the log-base-2 of nomEntries decides the number of buckets for the sketch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, lgNomEntries) - Returns the ThetaSketch's compact binary representation. + `lgNomEntries` (optional) the log-base-2 of Nominal Entries, with Nominal Entries deciding + the number buckets or slots for the ThetaSketch. """, + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col); + 3 + """, + group = "agg_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +case class ThetaSketchAgg( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[ThetaSketchState] + with BinaryLike[Expression] + with ExpectsInputTypes { + + // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation. + + lazy val lgNomEntries: Int = { + val lgNomEntriesInput = right.eval().asInstanceOf[Int] + val nomEntries = 1 << lgNomEntriesInput + ThetaUtil.checkNomLongs(nomEntries) + } + + // Constructors + + def this(child: Expression) = { + this(child, Literal(ThetaUtil.checkNomLongs(ThetaUtil.DEFAULT_NOMINAL_ENTRIES)), 0, 0) + } + + def this(child: Expression, lgNomEntries: Expression) = { + this(child, lgNomEntries, 0, 0) + } + + def this(child: Expression, lgNomEntries: Int) = { + this(child, Literal(lgNomEntries), 0, 0) + } + + // Copy constructors required by ImperativeAggregate + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaSketchAgg = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAgg = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): ThetaSketchAgg = + copy(left = newLeft, right = newRight) + + // Overrides for TypedImperativeAggregate + + override def prettyName: String = "theta_sketch_agg" + + override def inputTypes: Seq[AbstractDataType] = + Seq( + TypeCollection( + IntegerType, + LongType, + FloatType, + DoubleType, + StringTypeWithCollation(supportsTrimCollation = true), + BinaryType, + ArrayType(IntegerType), + ArrayType(LongType)), + IntegerType) + + override def dataType: DataType = BinaryType + + override def nullable: Boolean = false + + /** + * Instantiate an UpdateSketch instance using the lgNomEntries param. + * + * @return + * an UpdateSketch instance wrapped with UpdatableSketchBuffer + */ + override def createAggregationBuffer(): ThetaSketchState = { + val builder = new UpdateSketchBuilder + builder.setLogNominalEntries(lgNomEntries) + UpdatableSketchBuffer(builder.build) + } + + /** + * Evaluate the input row and update the UpdateSketch instance with the row's value. The update + * function only supports a subset of Spark SQL types, and an exception will be thrown for + * unsupported types. + * + * @param sketchState + * The UpdateSketch instance wrapped with ThetaSketchState. + * @param input + * an input row + */ + override def update(sketchState: ThetaSketchState, input: InternalRow): ThetaSketchState = + sketchState match { + case UpdatableSketchBuffer(sketch) => + val v = left.eval(input) + if (v != null) { + left.dataType match { + case IntegerType => + sketch.update(v.asInstanceOf[Int].toLong) // Promote to long + case LongType => + sketch.update(v.asInstanceOf[Long]) + case DoubleType => + sketch.update(v.asInstanceOf[Double]) + case FloatType => + sketch.update(v.asInstanceOf[Float].toDouble) // Promote to double + case st: StringType => + val cKey = + CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId) + sketch.update(cKey.toString) + case BinaryType => + val bytes = v.asInstanceOf[Array[Byte]] + if (bytes.nonEmpty) sketch.update(bytes) + case ArrayType(IntegerType, _) => + val arr = v.asInstanceOf[ArrayData].toIntArray() + if (arr.nonEmpty) sketch.update(arr) + case ArrayType(LongType, _) => + val arr = v.asInstanceOf[ArrayData].toLongArray() + if (arr.nonEmpty) sketch.update(arr) + case _ => + throw new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_3121", + messageParameters = Map("dataType" -> left.dataType.toString)) + } + } + UpdatableSketchBuffer(sketch) // Return updated sketch wrapped again + case _ => + sketchState + } + + /** + * Merges an input CompactSketch into the sketch which is acting as the aggregation buffer. + * + * @param sketchState + * the UpdateSketch instance used to store the aggregation result wrapped with + * ThetaSketchState. + * @param input + * an input CompactSketch instance wrapped with ThetaSketchState + */ + override def merge( + sketchState: ThetaSketchState, + input: ThetaSketchState): ThetaSketchState = { + val union = SetOperation.builder + .setLogNominalEntries(lgNomEntries) + .buildUnion + + // match all the possible Theta Sketch states possible in this class + (sketchState, input) match { + case (UpdatableSketchBuffer(sketch1), UpdatableSketchBuffer(sketch2)) => + union.union(sketch1.compact) + union.union(sketch2.compact) + case (UpdatableSketchBuffer(sketch1), FinalizedSketch(sketch2)) => + union.union(sketch1.compact) + union.union(sketch2) + case (FinalizedSketch(sketch1), UpdatableSketchBuffer(sketch2)) => + union.union(sketch1) + union.union(sketch2.compact) + case (FinalizedSketch(sketch1), FinalizedSketch(sketch2)) => + union.union(sketch1) + union.union(sketch2) + // Should never make it to here + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + // Compact the result before returning for a more efficient serialized representation + FinalizedSketch(union.getResult) + } + + /** + * Returns a CompactSketch derived from the input column or expression + * + * @param sketchState + * UpdateSketch instance used as an aggregation buffer or a CompactSketch result wrapped with + * ThetaSketchState + * @return + * A binary compact sketch which can be evaluated or merged + */ + override def eval(sketchState: ThetaSketchState): Any = { + sketchState match { + case UpdatableSketchBuffer(s) => s.rebuild.compact.toByteArray + case FinalizedSketch(s) => s.toByteArray + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** Convert the underlying UpdateSketch into an compact byte array */ + override def serialize(sketchState: ThetaSketchState): Array[Byte] = { + sketchState match { + case UpdatableSketchBuffer(s) => s.rebuild.compact.toByteArray + case FinalizedSketch(s) => s.toByteArray + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** De-serializes byte array into a Compact Sketch instance wrapped with FinalizedSketch */ + override def deserialize(buffer: Array[Byte]): ThetaSketchState = { + FinalizedSketch(CompactSketch.wrap(Memory.wrap(buffer))) + } +} + +/** + * The ThetaUnionAgg function ingests and merges Datasketches ThetaSketch instances previously + * produced by the ThetaSketchAgg function, and outputs the merged ThetaSketch. + * + * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. + * + * @param left + * Child expression against which unique counting will occur + * @param right + * the log-base-2 of nomEntries decides the number of buckets for the sketch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, lgNomEntries) - Returns the ThetaSketch's compact binary representation. + `lgNomEntries` (optional) the log-base-2 of Nominal Entries, with Nominal Entries deciding + the number buckets or slots for the ThetaSketch.""", + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(sketch)) FROM (SELECT theta_sketch_agg(col) as sketch FROM VALUES (1) tab(col) UNION ALL SELECT theta_sketch_agg(col, 20) as sketch FROM VALUES (1) tab(col)); + 1 + """, + group = "agg_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +case class ThetaUnionAgg( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[ThetaSketchState] + with BinaryLike[Expression] + with ExpectsInputTypes { + + // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation. + + lazy val lgNomEntries: Int = { + val lgNomEntries = right.eval().asInstanceOf[Int] + val nomEntries = 1 << lgNomEntries + ThetaUtil.checkNomLongs(nomEntries) + } + + // Constructors + + def this(child: Expression) = { + this(child, Literal(ThetaUtil.checkNomLongs(ThetaUtil.DEFAULT_NOMINAL_ENTRIES)), 0, 0) + } + + def this(child: Expression, lgNomEntries: Expression) = { + this(child, lgNomEntries, 0, 0) + } + + def this(child: Expression, lgNomEntries: Int) = { + this(child, Literal(lgNomEntries), 0, 0) + } + + // Copy constructors required by ImperativeAggregate + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaUnionAgg = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaUnionAgg = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): ThetaUnionAgg = + copy(left = newLeft, right = newRight) + + // Overrides for TypedImperativeAggregate + + override def prettyName: String = "theta_union_agg" + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, IntegerType) + + override def dataType: DataType = BinaryType + + override def nullable: Boolean = false + + /** + * Instantiate an Union instance using the lgNomEntries param. + * + * @return + * an Union instance wrapped with UnionAggregationBuffer + */ + override def createAggregationBuffer(): ThetaSketchState = { + UnionAggregationBuffer( + SetOperation.builder + .setLogNominalEntries(lgNomEntries) + .buildUnion) + } + + /** + * Update the Union instance with the Compact Sketch byte array obtained from the row. + * + * @param unionBuffer + * A previously initialized Union instance + * @param input + * An input row + */ + override def update(unionBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = { + val v = left.eval(input) + if (v != null) { + left.dataType match { + case BinaryType => + try { + val inputSketch = CompactSketch.wrap(Memory.wrap(v.asInstanceOf[Array[Byte]])) + + val union = unionBuffer match { + case UnionAggregationBuffer(existingUnionBuffer) => existingUnionBuffer + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + + union.union(inputSketch) + UnionAggregationBuffer(union) + + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + + case _ => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } else { + unionBuffer // Input is null, return buffer unchanged + } + } + + /** + * Merges an input Union into the union which is acting as the aggregation buffer. + * + * @param unionBuffer + * The Union instance used to store the aggregation result + * @param input + * An input Union or Compact Sketch instance + */ + override def merge(unionBuffer: ThetaSketchState, input: ThetaSketchState): ThetaSketchState = { + (unionBuffer, input) match { + // Both are unions, merge them directly + case (UnionAggregationBuffer(unionLeft), UnionAggregationBuffer(unionRight)) => + unionLeft.union(unionRight.getResult) + FinalizedSketch(unionLeft.getResult) + + // One or both are immutable sketches (need to wrap in union) + case (FinalizedSketch(sketch1), FinalizedSketch(sketch2)) => + val union = SetOperation.builder.setLogNominalEntries(lgNomEntries).buildUnion + union.union(sketch1) + union.union(sketch2) + FinalizedSketch(union.getResult) + case (FinalizedSketch(sketch), UnionAggregationBuffer(union)) => + union.union(sketch) + FinalizedSketch(union.getResult) + case (UnionAggregationBuffer(union), FinalizedSketch(sketch)) => + union.union(sketch) + FinalizedSketch(union.getResult) + + // Should never make it here + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** + * Returns an Compact Sketch derived from the merged sketches + * + * @param sketchState + * Union instance used as an aggregation buffer or a compact sketch + * @return + * A binary sketch which can be evaluated or merged + */ + override def eval(sketchState: ThetaSketchState): Any = { + sketchState match { + case UnionAggregationBuffer(union) => union.getResult.toByteArray + case FinalizedSketch(s) => s.toByteArray + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** Convert the underlying Union into an compact byte array */ + override def serialize(sketchState: ThetaSketchState): Array[Byte] = { + sketchState match { + case UnionAggregationBuffer(union) => union.getResult.toByteArray + case FinalizedSketch(s) => s.toByteArray + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** Wrap the byte array into a compact sketch instance */ + override def deserialize(buffer: Array[Byte]): ThetaSketchState = { + if (buffer.nonEmpty) { + FinalizedSketch(CompactSketch.wrap(Memory.wrap(buffer))) + } else { + UnionAggregationBuffer( + SetOperation.builder + .setLogNominalEntries(lgNomEntries) + .buildUnion) + } + } +} + +/** + * The ThetaIntersectionAgg function ingests and merges Datasketches ThetaSketch instances + * previously produced by the ThetaSketchAgg function, and outputs the merged ThetaSketch. + * + * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. + * + * @param left + * Child expression against which unique counting will occur + * @param right + * the log-base-2 of nomEntries decides the number of buckets for the sketch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, lgNomEntries) - Returns the ThetaSketch's compact binary representation. + `lgNomEntries` (optional) the log-base-2 of Nominal Entries, with Nominal Entries deciding + the number buckets or slots for the ThetaSketch.""", + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(sketch)) FROM (SELECT theta_sketch_agg(col) as sketch FROM VALUES (1) tab(col) UNION ALL SELECT theta_sketch_agg(col, 20) as sketch FROM VALUES (1) tab(col)); + 1 + """, + group = "agg_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +case class ThetaIntersectionAgg( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[ThetaSketchState] + with BinaryLike[Expression] + with ExpectsInputTypes { + + // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation. + + lazy val lgNomEntries: Int = { + val lgNomEntries = right.eval().asInstanceOf[Int] + val nomEntries = 1 << lgNomEntries + ThetaUtil.checkNomLongs(nomEntries) + } + + // Constructors + + def this(child: Expression) = { + this(child, Literal(ThetaUtil.checkNomLongs(ThetaUtil.DEFAULT_NOMINAL_ENTRIES)), 0, 0) + } + + def this(child: Expression, lgNomEntries: Expression) = { + this(child, lgNomEntries, 0, 0) + } + + def this(child: Expression, lgNomEntries: Int) = { + this(child, Literal(lgNomEntries), 0, 0) + } + + // Copy constructors required by ImperativeAggregate + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int): ThetaIntersectionAgg = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaIntersectionAgg = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): ThetaIntersectionAgg = + copy(left = newLeft, right = newRight) + + // Overrides for TypedImperativeAggregate + + override def prettyName: String = "theta_intersection_agg" + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, IntegerType) + + override def dataType: DataType = BinaryType + + override def nullable: Boolean = false + + /** + * Instantiate an Intersection instance using the lgNomEntries param. + * + * @return + * an Intersection instance wrapped with IntersectionAggregationBuffer + */ + override def createAggregationBuffer(): ThetaSketchState = { + IntersectionAggregationBuffer( + SetOperation.builder + .setLogNominalEntries(lgNomEntries) + .buildIntersection) + } + + /** + * Update the Intersection instance with the Compact Sketch byte array obtained from the row. + * + * @param intersectionBuffer + * A previously initialized Intersection instance + * @param input + * An input row + */ + override def update( + intersectionBuffer: ThetaSketchState, + input: InternalRow): ThetaSketchState = { + val v = left.eval(input) + if (v != null) { + left.dataType match { + case BinaryType => + try { + val inputSketch = CompactSketch.wrap(Memory.wrap(v.asInstanceOf[Array[Byte]])) + + val intersection = intersectionBuffer match { + case IntersectionAggregationBuffer(existingIntersection) => existingIntersection + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + + intersection.intersect(inputSketch) + IntersectionAggregationBuffer(intersection) + + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + + case _ => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } else { + intersectionBuffer // Input is null, return buffer unchanged + } + } + + /** + * Merges an input Intersection into the union which is acting as the aggregation buffer. + * + * @param intersectionBuffer + * The Intersection instance used to store the aggregation result + * @param input + * An input Union or Compact Sketch instance + */ + override def merge( + intersectionBuffer: ThetaSketchState, + input: ThetaSketchState): ThetaSketchState = { + (intersectionBuffer, input) match { + // Both are unions, merge them directly + case ( + IntersectionAggregationBuffer(intersectLeft), + IntersectionAggregationBuffer(intersectRight)) => + intersectLeft.intersect(intersectRight.getResult) + FinalizedSketch(intersectLeft.getResult) + + // One or both are immutable sketches (need to wrap in union) + case (FinalizedSketch(sketch1), FinalizedSketch(sketch2)) => + val intersection = + SetOperation.builder.setLogNominalEntries(lgNomEntries).buildIntersection + intersection.intersect(sketch1) + intersection.intersect(sketch2) + FinalizedSketch(intersection.getResult) + case (FinalizedSketch(sketch), IntersectionAggregationBuffer(intersection)) => + intersection.intersect(sketch) + FinalizedSketch(intersection.getResult) + case (IntersectionAggregationBuffer(intersection), FinalizedSketch(sketch)) => + intersection.intersect(sketch) + FinalizedSketch(intersection.getResult) + + // Should never make it here + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** + * Returns an Compact Sketch derived from the merged sketches + * + * @param sketchState + * Intersection instance used as an aggregation buffer or a compact sketch + * @return + * A binary sketch which can be evaluated or merged + */ + override def eval(sketchState: ThetaSketchState): Any = { + sketchState match { + case IntersectionAggregationBuffer(intersection) => intersection.getResult.toByteArray + case FinalizedSketch(s) => s.toByteArray + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** Convert the underlying Intersection into an compact byte array */ + override def serialize(sketchState: ThetaSketchState): Array[Byte] = { + sketchState match { + case IntersectionAggregationBuffer(intersection) => intersection.getResult.toByteArray + case FinalizedSketch(s) => s.toByteArray + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } + + /** Wrap the byte array into a compact sketch instance */ + override def deserialize(buffer: Array[Byte]): ThetaSketchState = { + if (buffer.nonEmpty) { + FinalizedSketch(CompactSketch.wrap(Memory.wrap(buffer))) + } else { + IntersectionAggregationBuffer( + SetOperation.builder + .setLogNominalEntries(lgNomEntries) + .buildIntersection) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/thetasketchesExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/thetasketchesExpressions.scala new file mode 100644 index 0000000000000..8347ceb24015a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/thetasketchesExpressions.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.datasketches.common.SketchesArgumentException +import org.apache.datasketches.memory.Memory +import org.apache.datasketches.theta.{CompactSketch, SetOperation} +import org.apache.datasketches.thetacommon.ThetaUtil + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType} + +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the estimated number of unique values given the binary representation + of a Datasketches ThetaSketch. """, + examples = """ + Examples: + > SELECT _FUNC_(theta_sketch_agg(col)) FROM VALUES (1), (1), (2), (2), (3) tab(col); + 3 + """, + group = "misc_funcs", + since = "4.0.0") +case class ThetaSketchEstimate(child: Expression) + extends UnaryExpression + with CodegenFallback + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true + + override protected def withNewChildInternal(newChild: Expression): ThetaSketchEstimate = + copy(child = newChild) + + override def prettyName: String = "theta_sketch_estimate" + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override def dataType: DataType = LongType + + override def nullSafeEval(input: Any): Any = { + val buffer = input.asInstanceOf[Array[Byte]] + try { + Math.round(CompactSketch.wrap(Memory.wrap(buffer)).getEstimate) + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(first, second, lgNomEntries) - Merges two binary representations of + Datasketches ThetaSketch objects, using a Datasketches Union object. Set + lgNomEntries to a value between 4 and 26 to find the unions of sketches with different + union buffer sizes values (defaults to 12). """, + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(theta_sketch_agg(col1), theta_sketch_agg(col2))) FROM VALUES (1, 4), (1, 4), (2, 5), (2, 5), (3, 6) tab(col1, col2); + 6 + """, + group = "misc_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +case class ThetaUnion(first: Expression, second: Expression, third: Expression) + extends TernaryExpression + with CodegenFallback + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true + + def this(first: Expression, second: Expression) = { + this(first, second, Literal(ThetaUtil.checkNomLongs(ThetaUtil.DEFAULT_NOMINAL_ENTRIES))) + } + + def this(first: Expression, second: Expression, third: Int) = { + this(first, second, Literal(third)) + } + + override protected def withNewChildrenInternal( + newFirst: Expression, + newSecond: Expression, + newThird: Expression): ThetaUnion = + copy(first = newFirst, second = newSecond, third = newThird) + + override def prettyName: String = "theta_union" + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType, IntegerType) + + override def dataType: DataType = BinaryType + + override def nullSafeEval(value1: Any, value2: Any, value3: Any): Any = { + val sketch1 = + try { + CompactSketch.wrap(Memory.wrap(value1.asInstanceOf[Array[Byte]])) + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + val sketch2 = + try { + CompactSketch.wrap(Memory.wrap(value2.asInstanceOf[Array[Byte]])) + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + val logNominalEntries = value3.asInstanceOf[Int] + val union = SetOperation.builder + .setLogNominalEntries(logNominalEntries) + .buildUnion + .union(sketch1, sketch2) + + union.toByteArray + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(first, second, lgNomEntries) - Subtracts two binary representations of + Datasketches ThetaSketch objects, using a Datasketches AnotB object. Set + lgNomEntries to a value between 4 and 26 to find the difference of sketches with different + AnotB buffer sizes values (defaults to 12). """, + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(theta_sketch_agg(col1), theta_sketch_agg(col2))) FROM VALUES (5, 4), (1, 4), (2, 5), (2, 5), (3, 1) tab(col1, col2); + 2 + """, + group = "misc_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +case class ThetaDifference(first: Expression, second: Expression, third: Expression) + extends TernaryExpression + with CodegenFallback + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true + + def this(first: Expression, second: Expression) = { + this(first, second, Literal(ThetaUtil.checkNomLongs(ThetaUtil.DEFAULT_NOMINAL_ENTRIES))) + } + + def this(first: Expression, second: Expression, third: Int) = { + this(first, second, Literal(third)) + } + + override protected def withNewChildrenInternal( + newFirst: Expression, + newSecond: Expression, + newThird: Expression): ThetaDifference = + copy(first = newFirst, second = newSecond, third = newThird) + + override def prettyName: String = "theta_difference" + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType, IntegerType) + + override def dataType: DataType = BinaryType + + override def nullSafeEval(value1: Any, value2: Any, value3: Any): Any = { + val sketch1 = + try { + CompactSketch.wrap(Memory.wrap(value1.asInstanceOf[Array[Byte]])) + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + val sketch2 = + try { + CompactSketch.wrap(Memory.wrap(value2.asInstanceOf[Array[Byte]])) + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + val logNominalEntries = value3.asInstanceOf[Int] + val difference = SetOperation.builder + .setLogNominalEntries(logNominalEntries) + .buildANotB + .aNotB(sketch1, sketch2) + + difference.toByteArray + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(first, second, lgNomEntries) - Intersects two binary representations of + Datasketches ThetaSketch objects, using a Datasketches Intersect object. Set + lgNomEntries to a value between 4 and 26 to find the intersection of sketches with different + intersection buffer sizes values (defaults to 12). """, + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(theta_sketch_agg(col1), theta_sketch_agg(col2))) FROM VALUES (5, 4), (1, 4), (2, 5), (2, 5), (3, 1) tab(col1, col2); + 2 + """, + group = "misc_funcs", + since = "4.0.0") +// scalastyle:on line.size.limit +case class ThetaIntersection(first: Expression, second: Expression, third: Expression) + extends TernaryExpression + with CodegenFallback + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true + + def this(first: Expression, second: Expression) = { + this(first, second, Literal(ThetaUtil.checkNomLongs(ThetaUtil.DEFAULT_NOMINAL_ENTRIES))) + } + + def this(first: Expression, second: Expression, third: Int) = { + this(first, second, Literal(third)) + } + + override protected def withNewChildrenInternal( + newFirst: Expression, + newSecond: Expression, + newThird: Expression): ThetaIntersection = + copy(first = newFirst, second = newSecond, third = newThird) + + override def prettyName: String = "theta_intersection" + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType, IntegerType) + + override def dataType: DataType = BinaryType + + override def nullSafeEval(value1: Any, value2: Any, value3: Any): Any = { + val sketch1 = + try { + CompactSketch.wrap(Memory.wrap(value1.asInstanceOf[Array[Byte]])) + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + val sketch2 = + try { + CompactSketch.wrap(Memory.wrap(value2.asInstanceOf[Array[Byte]])) + } catch { + case _: SketchesArgumentException | _: java.lang.Error => + throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + val logNominalEntries = value3.asInstanceOf[Int] + val intersection = SetOperation.builder + .setLogNominalEntries(logNominalEntries) + .buildIntersection + .intersect(sketch1, sketch2) + + intersection.toByteArray + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 7f623039778ab..b2e6dc445c835 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -3071,4 +3071,10 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ) ) } + + def thetaInvalidInputSketchBuffer(function: String): Throwable = { + new SparkRuntimeException( + errorClass = "THETA_INVALID_INPUT_SKETCH_BUFFER", + messageParameters = Map("function" -> toSQLId(function))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ThetasketchesAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ThetasketchesAggSuite.scala new file mode 100644 index 0000000000000..0386b2642da36 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ThetasketchesAggSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import scala.collection.immutable.NumericRange +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, ThetaSketchEstimate} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DoubleType, FloatType, IntegerType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ThetasketchesAggSuite extends SparkFunSuite { + + def simulateUpdateMerge( + dataType: DataType, + input: Seq[Any], + numSketches: Integer = 5): (Long, NumericRange[Long]) = { + + // create a map of agg function instances + val aggFunctionMap = Seq + .tabulate(numSketches)(index => { + val sketch = new ThetaSketchAgg(BoundReference(0, dataType, nullable = true)) + index -> (sketch, sketch.createAggregationBuffer()) + }) + .toMap + + // randomly update agg function instances + input.map(value => { + val (aggFunction, aggBuffer) = aggFunctionMap(Random.nextInt(numSketches)) + aggFunction.update(aggBuffer, InternalRow(value)) + }) + + def serializeDeserialize( + tuple: (ThetaSketchAgg, ThetaSketchState)): (ThetaSketchAgg, ThetaSketchState) = { + val (agg, buf) = tuple + val serialized = agg.serialize(buf) + (agg, agg.deserialize(serialized)) + } + + // simulate serialization -> deserialization -> merge + val mapValues = aggFunctionMap.values + val (mergedAgg, FinalizedSketch(mergedBuf)) = + mapValues.tail.foldLeft(mapValues.head)((prev, cur) => { + val (prevAgg, prevBuf) = serializeDeserialize(prev) + val (_, curBuf) = serializeDeserialize(cur) + + (prevAgg, prevAgg.merge(prevBuf, curBuf)) + }) + + val estimator = ThetaSketchEstimate(BoundReference(0, BinaryType, nullable = true)) + val estimate = estimator.eval(InternalRow(mergedBuf.toByteArray)).asInstanceOf[Long] + (estimate, mergedBuf.getLowerBound(3).toLong to mergedBuf.getUpperBound(3).toLong) + } + + test("SPARK-52407: Test min/max values of supported datatypes") { + val intRange = Integer.MIN_VALUE to Integer.MAX_VALUE by 10000000 + val (intEstimate, intEstimateRange) = simulateUpdateMerge(IntegerType, intRange) + assert(intEstimate == intRange.size || intEstimateRange.contains(intRange.size.toLong)) + + val longRange = Long.MinValue to Long.MaxValue by 1000000000000000L + val (longEstimate, longEstimateRange) = simulateUpdateMerge(LongType, longRange) + assert(longEstimate == longRange.size || longEstimateRange.contains(longRange.size.toLong)) + + val stringRange = Seq.tabulate(1000)(i => UTF8String.fromString(Random.nextString(i + 1))) + val (stringEstimate, stringEstimateRange) = simulateUpdateMerge(StringType, stringRange) + assert( + stringEstimate == stringRange.size || + stringEstimateRange.contains(stringRange.size.toLong)) + + val binaryRange = + Seq.tabulate(1000)(i => UTF8String.fromString(Random.nextString(i + 1)).getBytes) + val (binaryEstimate, binaryEstimateRange) = simulateUpdateMerge(BinaryType, binaryRange) + assert( + binaryEstimate == binaryRange.size || + binaryEstimateRange.contains(binaryRange.size.toLong)) + + val floatRange = (1 to 1000).map(_.toFloat) + val (floatEstimate, floatRangeEst) = simulateUpdateMerge(FloatType, floatRange) + assert(floatEstimate == floatRange.size || floatRangeEst.contains(floatRange.size.toLong)) + + val doubleRange = (1 to 1000).map(_.toDouble) + val (doubleEstimate, doubleRangeEst) = simulateUpdateMerge(DoubleType, doubleRange) + assert(doubleEstimate == doubleRange.size || doubleRangeEst.contains(doubleRange.size.toLong)) + + val arrayIntRange = (1 to 500).map(i => ArrayData.toArrayData(Array(i, i + 1))) + val (arrayIntEstimate, arrayIntRangeEst) = + simulateUpdateMerge(ArrayType(IntegerType), arrayIntRange) + assert( + arrayIntEstimate == arrayIntRange.size || + arrayIntRangeEst.contains(arrayIntRange.size.toLong)) + + val arrayLongRange = (1 to 500).map(i => ArrayData.toArrayData(Array(i.toLong, (i + 1).toLong))) + val (arrayLongEstimate, arrayLongRangeEst) = + simulateUpdateMerge(ArrayType(LongType), arrayLongRange) + assert( + arrayLongEstimate == arrayLongRange.size || + arrayLongRangeEst.contains(arrayLongRange.size.toLong)) + } + + test("SPARK-52407: Test lgNomEntries results in downsampling sketches during Union") { + // Create sketch with larger configuration (more precise) + val aggFunc1 = new ThetaSketchAgg(BoundReference(0, IntegerType, nullable = true), 12) + val sketch1 = aggFunc1.createAggregationBuffer() + (0 to 100).map(i => aggFunc1.update(sketch1, InternalRow(i))) + val binary1 = aggFunc1.eval(sketch1) + + // Create sketch with smaller configuration (less precise) + val aggFunc2 = new ThetaSketchAgg(BoundReference(0, IntegerType, nullable = true), 10) + val sketch2 = aggFunc2.createAggregationBuffer() + (0 to 100).map(i => aggFunc2.update(sketch2, InternalRow(i))) + val binary2 = aggFunc2.eval(sketch2) + + // Union the sketches + val unionAgg = new ThetaUnionAgg(BoundReference(0, BinaryType, nullable = true), 12) + val union = unionAgg.createAggregationBuffer() + unionAgg.update(union, InternalRow(binary1)) + unionAgg.update(union, InternalRow(binary2)) + val unionResult = unionAgg.eval(union) + + // Verify the estimate is still accurate despite different configurations + val estimate = ThetaSketchEstimate(BoundReference(0, BinaryType, nullable = true)) + .eval(InternalRow(unionResult)) + assert(estimate.asInstanceOf[Long] >= 95 && estimate.asInstanceOf[Long] <= 105) + } + + test("SPARK-52407: Test lgNomEntries results in downsampling sketches during intersection") { + // Create sketch with larger configuration (more precise) + val aggFunc1 = new ThetaSketchAgg(BoundReference(0, IntegerType, nullable = true), 12) + val sketch1 = aggFunc1.createAggregationBuffer() + (0 to 150).map(i => aggFunc1.update(sketch1, InternalRow(i))) + val binary1 = aggFunc1.eval(sketch1) + + // Create sketch with smaller configuration (less precise) + val aggFunc2 = new ThetaSketchAgg(BoundReference(0, IntegerType, nullable = true), 10) + val sketch2 = aggFunc2.createAggregationBuffer() + (50 to 200).map(i => aggFunc2.update(sketch2, InternalRow(i))) + val binary2 = aggFunc2.eval(sketch2) + + // Intersect the sketches + val intersectionAgg = + new ThetaIntersectionAgg(BoundReference(0, BinaryType, nullable = true), 12) + val intersection = intersectionAgg.createAggregationBuffer() + intersectionAgg.update(intersection, InternalRow(binary1)) + intersectionAgg.update(intersection, InternalRow(binary2)) + val intersectionResult = intersectionAgg.eval(intersection) + + // Verify the estimate is still accurate despite different configurations + // Should be around 101 (overlap from 50 to 150) + val estimate = ThetaSketchEstimate(BoundReference(0, BinaryType, nullable = true)) + .eval(InternalRow(intersectionResult)) + assert(estimate.asInstanceOf[Long] >= 95 && estimate.asInstanceOf[Long] <= 105) + } +}