Skip to content

Commit 22cbb96

Browse files
chenhao-dbcloud-fan
authored andcommitted
[SPARK-50746][SQL] Replace Either with VariantPathSegment
### What changes were proposed in this pull request? It replaces `type PathSegment = Either[String, Int]` with a dedicated class `VariantPathSegment`. There is no semantic change, but the code has clear naming. ### Why are the changes needed? To make the code easier to understand. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49385 from chenhao-db/VariantPathSegment. Authored-by: Chenhao Li <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent b3182e5 commit 22cbb96

File tree

3 files changed

+26
-23
lines changed

3 files changed

+26
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -184,33 +184,37 @@ case class ToVariantObject(child: Expression)
184184
}
185185
}
186186

187-
object VariantPathParser extends RegexParsers {
188-
// A path segment in the `VariantGet` expression represents either an object key access or an
189-
// array index access.
190-
type PathSegment = Either[String, Int]
187+
// A path segment in the `VariantGet` expression represents either an object key access or an array
188+
// index access.
189+
sealed abstract class VariantPathSegment extends Serializable
190+
191+
case class ObjectExtraction(key: String) extends VariantPathSegment
191192

193+
case class ArrayExtraction(index: Int) extends VariantPathSegment
194+
195+
object VariantPathParser extends RegexParsers {
192196
private def root: Parser[Char] = '$'
193197

194198
// Parse index segment like `[123]`.
195-
private def index: Parser[PathSegment] =
199+
private def index: Parser[VariantPathSegment] =
196200
for {
197201
index <- '[' ~> "\\d+".r <~ ']'
198202
} yield {
199-
scala.util.Right(index.toInt)
203+
ArrayExtraction(index.toInt)
200204
}
201205

202206
// Parse key segment like `.name`, `['name']`, or `["name"]`.
203-
private def key: Parser[PathSegment] =
207+
private def key: Parser[VariantPathSegment] =
204208
for {
205209
key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
206210
"[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
207211
} yield {
208-
scala.util.Left(key)
212+
ObjectExtraction(key)
209213
}
210214

211-
private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
215+
private val parser: Parser[List[VariantPathSegment]] = phrase(root ~> rep(key | index))
212216

213-
def parse(str: String): Option[Array[PathSegment]] = {
217+
def parse(str: String): Option[Array[VariantPathSegment]] = {
214218
this.parseAll(parser, str) match {
215219
case Success(result, _) => Some(result.toArray)
216220
case _ => None
@@ -349,14 +353,14 @@ case object VariantGet {
349353
/** The actual implementation of the `VariantGet` expression. */
350354
def variantGet(
351355
input: VariantVal,
352-
parsedPath: Array[VariantPathParser.PathSegment],
356+
parsedPath: Array[VariantPathSegment],
353357
dataType: DataType,
354358
castArgs: VariantCastArgs): Any = {
355359
var v = new Variant(input.getValue, input.getMetadata)
356360
for (path <- parsedPath) {
357361
v = path match {
358-
case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
359-
case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
362+
case ObjectExtraction(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
363+
case ArrayExtraction(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
360364
case _ => null
361365
}
362366
if (v == null) return null

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable.HashMap
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
24-
import org.apache.spark.sql.catalyst.expressions.variant.{VariantGet, VariantPathParser}
24+
import org.apache.spark.sql.catalyst.expressions.variant._
2525
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2626
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery}
2727
import org.apache.spark.sql.catalyst.rules.Rule
@@ -54,7 +54,7 @@ case class VariantMetadata(
5454
.build()
5555
).build()
5656

57-
def parsedPath(): Array[VariantPathParser.PathSegment] = {
57+
def parsedPath(): Array[VariantPathSegment] = {
5858
VariantPathParser.parse(path).getOrElse {
5959
val name = if (failOnError) "variant_get" else "try_variant_get"
6060
throw QueryExecutionErrors.invalidVariantGetPath(path, name)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.codegen._
2626
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2727
import org.apache.spark.sql.catalyst.expressions.variant._
28-
import org.apache.spark.sql.catalyst.expressions.variant.VariantPathParser.PathSegment
2928
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
3029
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3130
import org.apache.spark.sql.execution.RowToColumnConverter
@@ -56,9 +55,9 @@ case class SparkShreddedRow(row: SpecializedGetters) extends ShreddingUtils.Shre
5655
override def numElements(): Int = row.asInstanceOf[ArrayData].numElements()
5756
}
5857

59-
// The search result of a `PathSegment` in a `VariantSchema`.
58+
// The search result of a `VariantPathSegment` in a `VariantSchema`.
6059
case class SchemaPathSegment(
61-
rawPath: PathSegment,
60+
rawPath: VariantPathSegment,
6261
// Whether this path segment is an object or array extraction.
6362
isObject: Boolean,
6463
// `schema.typedIdx`, if the path exists in the schema (for object extraction, the schema
@@ -714,11 +713,11 @@ case object SparkShreddingUtils {
714713
// found at a certain level of the file type, then `typedIdx` will be -1 starting from
715714
// this position, and the final `schema` will be null.
716715
for (i <- rawPath.indices) {
717-
val isObject = rawPath(i).isLeft
716+
val isObject = rawPath(i).isInstanceOf[ObjectExtraction]
718717
var typedIdx = -1
719718
var extractionIdx = -1
720719
rawPath(i) match {
721-
case scala.util.Left(key) if schema != null && schema.objectSchema != null =>
720+
case ObjectExtraction(key) if schema != null && schema.objectSchema != null =>
722721
val fieldIdx = schema.objectSchemaMap.get(key)
723722
if (fieldIdx != null) {
724723
typedIdx = schema.typedIdx
@@ -727,7 +726,7 @@ case object SparkShreddingUtils {
727726
} else {
728727
schema = null
729728
}
730-
case scala.util.Right(index) if schema != null && schema.arraySchema != null =>
729+
case ArrayExtraction(index) if schema != null && schema.arraySchema != null =>
731730
typedIdx = schema.typedIdx
732731
extractionIdx = index
733732
schema = schema.arraySchema
@@ -770,8 +769,8 @@ case object SparkShreddingUtils {
770769
var v = new Variant(row.getBinary(variantIdx), topLevelMetadata)
771770
while (pathIdx < pathLen) {
772771
v = pathList(pathIdx).rawPath match {
773-
case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
774-
case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
772+
case ObjectExtraction(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
773+
case ArrayExtraction(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
775774
case _ => null
776775
}
777776
if (v == null) return null

0 commit comments

Comments
 (0)