Skip to content

Commit

Permalink
Simplify analyzeExistingDefault
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Feb 8, 2025
1 parent 164bd6b commit 2fa0b0b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ import scala.util.Try
import org.apache.commons.codec.binary.{Hex => ApacheHex}
import org.json4s.JsonAST._

import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, FunctionIdentifier, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.catalyst.types._
Expand Down Expand Up @@ -265,6 +267,15 @@ object Literal {
s"Literal must have a corresponding value to ${dataType.catalogString}, " +
s"but class ${Utils.getSimpleName(value.getClass)} found.")
}

def fromSQL(sql: String): Expression = {
CatalystSqlParser.parseExpression(sql).transformUp {
case u: UnresolvedFunction =>
assert(u.nameParts.length == 1)
assert(!u.isDistinct)
FunctionRegistry.builtin.lookupFunction(FunctionIdentifier(u.nameParts.head), u.arguments)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.{Literal => ExprLiteral}
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, Optimizer}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -334,62 +333,24 @@ object ResolveDefaultColumns extends QueryErrorsBase
/**
* Analyze EXISTS_DEFAULT value. This skips some steps of analyze as most of the
* analysis has been done before.
*
* VisibleForTesting
*/
def analyzeExistingDefault(field: StructField,
analyzer: Analyzer = DefaultColumnAnalyzer): Expression = {
val colName = field.name
val dataType = field.dataType
def analyzeExistingDefault(field: StructField, defaultSQL: String): Expression = {
val defaultSQL = field.metadata.getString(EXISTS_DEFAULT_COLUMN_METADATA_KEY)

// Parse the expression.
lazy val parser = new CatalystSqlParser()
val parsed: Expression = try {
parser.parseExpression(defaultSQL)
} catch {
case ex: ParseException =>
throw QueryCompilationErrors.defaultValuesUnresolvedExprError(
"", colName, defaultSQL, ex)
}
// Check invariants before moving on to analysis.
if (parsed.containsPattern(PLAN_EXPRESSION)) {
throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions(
"", colName, defaultSQL)
}

// Analyze the parse result.
val plan = try {
val analyzer: Analyzer = defaultColumnAnalyzer
val analyzed = analyzer.execute(Project(Seq(Alias(parsed, colName)()), OneRowRelation()))
analyzer.checkAnalysis(analyzed)
// Eagerly execute constant-folding rules before checking whether the
// expression is foldable and resolved.
ConstantFolding(analyzed)
} catch {
case ex: AnalysisException =>
throw QueryCompilationErrors.defaultValuesUnresolvedExprError(
"", colName, defaultSQL, ex)
}
val analyzed: Expression = plan.collectFirst {
case Project(Seq(a: Alias), OneRowRelation()) => a.child
}.get
val expr = Literal.fromSQL(defaultSQL)

// Extra check, expressions should already be resolved and foldable
if (!analyzed.foldable) {
throw QueryCompilationErrors.defaultValueNotConstantError(defaultSQL, colName, defaultSQL)
// Check invariants
if (expr.containsPattern(PLAN_EXPRESSION)) {
throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions(
"", field.name, defaultSQL)
}

if (!analyzed.resolved) {
if (!expr.resolved) {
throw QueryCompilationErrors.defaultValuesUnresolvedExprError(
"",
colName,
defaultSQL,
cause = null)
"", field.name, defaultSQL, null)
}

// Perform implicit coercion from the provided expression type to the required column type.
coerceDefaultValue(analyzed, dataType, defaultSQL, colName, defaultSQL)
expr
}

/**
Expand Down Expand Up @@ -479,10 +440,7 @@ object ResolveDefaultColumns extends QueryErrorsBase
val defaultValue: Option[String] = field.getExistenceDefaultValue()
defaultValue.map { text: String =>
val expr = try {
val expr = analyzeExistingDefault(field)
expr match {
case _: ExprLiteral | _: Cast => expr
}
analyzeExistingDefault(field, text)
} catch {
// AnalysisException thrown from analyze is already formatted, throw it directly.
case ae: AnalysisException => throw ae
Expand Down

0 comments on commit 2fa0b0b

Please sign in to comment.