diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index c1225f9e5b502..3f9fba185ca4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -40,9 +40,11 @@ import scala.util.Try import org.apache.commons.codec.binary.{Hex => ApacheHex} import org.json4s.JsonAST._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, FunctionIdentifier, InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL} import org.apache.spark.sql.catalyst.types._ @@ -265,6 +267,23 @@ object Literal { s"Literal must have a corresponding value to ${dataType.catalogString}, " + s"but class ${Utils.getSimpleName(value.getClass)} found.") } + + def fromSQL(sql: String): Expression = { + CatalystSqlParser.parseExpression(sql).transformUp { + case u: UnresolvedFunction => + assert(u.nameParts.length == 1) + assert(!u.isDistinct) + assert(u.filter.isEmpty) + assert(!u.ignoreNulls) + assert(u.orderingWithinGroup.isEmpty) + assert(!u.isInternal) + FunctionRegistry.builtin.lookupFunction(FunctionIdentifier(u.nameParts.head), u.arguments) + } match { + case c: Cast if c.needsTimeZone => + c.withTimeZone(SQLConf.get.sessionLocalTimeZone) + case e: Expression => e + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 693ac8d94dbcf..6499e5c40049d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} +import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.AnalysisException @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal => ExprLiteral} import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, Optimizer} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical._ @@ -320,6 +319,29 @@ object ResolveDefaultColumns extends QueryErrorsBase coerceDefaultValue(analyzed, dataType, statementType, colName, defaultSQL) } + /** + * Analyze EXISTS_DEFAULT value. This skips some steps of analyze as most of the + * analysis has been done before. + */ + private def analyzeExistenceDefaultValue(field: StructField): Expression = { + val defaultSQL = field.metadata.getString(EXISTS_DEFAULT_COLUMN_METADATA_KEY) + + // Parse the expression. + val expr = Literal.fromSQL(defaultSQL) + + // Check invariants + if (expr.containsPattern(PLAN_EXPRESSION)) { + throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( + "", field.name, defaultSQL) + } + if (!expr.resolved) { + throw QueryCompilationErrors.defaultValuesUnresolvedExprError( + "", field.name, defaultSQL, null) + } + + coerceDefaultValue(expr, field.dataType, "", field.name, defaultSQL) + } + /** * If the provided default value is a literal of a wider type than the target column, * but the literal value fits within the narrower type, just coerce it for convenience. @@ -405,19 +427,9 @@ object ResolveDefaultColumns extends QueryErrorsBase def getExistenceDefaultValues(schema: StructType): Array[Any] = { schema.fields.map { field: StructField => val defaultValue: Option[String] = field.getExistenceDefaultValue() - defaultValue.map { text: String => - val expr = try { - val expr = analyze(field, "", EXISTS_DEFAULT_COLUMN_METADATA_KEY) - expr match { - case _: ExprLiteral | _: Cast => expr - } - } catch { - // AnalysisException thrown from analyze is already formatted, throw it directly. - case ae: AnalysisException => throw ae - case _: MatchError => - throw SparkException.internalError(s"parse existence default as literal err," + - s" field name: ${field.name}, value: $text") - } + defaultValue.map { _: String => + val expr = analyzeExistenceDefaultValue(field) + // The expression should be a literal value by this point, possibly wrapped in a cast // function. This is enforced by the execution of commands that assign default values. expr.eval()