Skip to content

Commit ccfb64b

Browse files
wenghallisonwang-db
authored andcommitted
[SPARK-51439][SQL] Support SQL UDF with DEFAULT argument
Continuing allisonwang-db's work on #50373 and #49471 This PR adds support for DEFAULT arguments in SQL UDF. Examples: ```sql CREATE FUNCTION foo1d1(a INT DEFAULT 10) RETURNS INT RETURN a; SELECT foo1d1(); -- 10 SELECT foo1d1(20); -- 20 CREATE FUNCTION foo1d6(a INT, b INT DEFAULT 7) RETURNS TABLE(a INT, b INT) RETURN SELECT a, b; SELECT * FROM foo1d6(5); -- 5, 7 SELECT * FROM foo1d6(5, 2); -- 5, 2 ``` See sql-udf.sql for more valid and invalid examples. To support default arguments in SQL UDFs. Yes. Now SQL UDFs support DEFAULT arguments. A side effect of the grammar change is that some invalid function parameter definitions are now no longer rejected by the grammar, but instead rejected by the parser logic. Examples: ```sql -- multiple COMMENT or multiple NOT NULL CREATE TEMPORARY FUNCTION foo(a INT COMMENT 'hello' COMMENT 'world') RETURNS INT RETURN a; -- before: [PARSE_SYNTAX_ERROR] Syntax error at or near 'COMMENT'. SQLSTATE: 42601 == SQL (line 2, position 1) == CREATE TEMPORARY FUNCTION foo(a INT COMMENT 'hello' COMMENT 'world') RETURNS INT RETURN a; ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- after: [CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE] CREATE TABLE column a specifies descriptor "COMMENT" more than once, which is invalid. SQLSTATE: 42710 == SQL (line 1, position 1) == CREATE TEMPORARY FUNCTION foo(a INT COMMENT 'hello' COMMENT 'world')... ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``` ```sql -- GENERATED ALWAYS AS CREATE TEMPORARY FUNCTION foo(a INT GENERATED ALWAYS AS (1)) RETURNS INT RETURN a; -- before: [PARSE_SYNTAX_ERROR] Syntax error at or near 'GENERATED'. SQLSTATE: 42601 == SQL (line 2, position 1) == CREATE TEMPORARY FUNCTION foo(a INT GENERATED ALWAYS AS (1)) RETURNS INT RETURN a; ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- after: [INVALID_SQL_SYNTAX.CREATE_FUNC_WITH_GENERATED_COLUMNS_AS_PARAMETERS] Invalid SQL syntax: CREATE FUNCTION with generated columns as parameters is not allowed. SQLSTATE: 42000 == SQL (line 2, position 1) == CREATE TEMPORARY FUNCTION foo(a INT GENERATED ALWAYS AS (1)) RETURNS INT RETURN a; ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``` This doesn't change the behavior of existing valid SQL. End-to-end regression tests in `sql-udf.sql` and simple tests in `SQLFunctionSuite`. No Closes #50408 from wengh/sql-udf-default. Lead-authored-by: Haoyu Weng <[email protected]> Co-authored-by: Allison Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 1bd4e4e commit ccfb64b

File tree

20 files changed

+958
-52
lines changed

20 files changed

+958
-52
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3366,6 +3366,16 @@
33663366
"ANALYZE TABLE(S) ... COMPUTE STATISTICS ... <ctx> must be either NOSCAN or empty."
33673367
]
33683368
},
3369+
"CREATE_FUNC_WITH_COLUMN_CONSTRAINTS" : {
3370+
"message" : [
3371+
"CREATE FUNCTION with constraints on parameters is not allowed."
3372+
]
3373+
},
3374+
"CREATE_FUNC_WITH_GENERATED_COLUMNS_AS_PARAMETERS" : {
3375+
"message" : [
3376+
"CREATE FUNCTION with generated columns as parameters is not allowed."
3377+
]
3378+
},
33693379
"CREATE_ROUTINE_WITH_IF_NOT_EXISTS_AND_REPLACE" : {
33703380
"message" : [
33713381
"Cannot create a routine with both IF NOT EXISTS and REPLACE specified."

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ singleTableSchema
177177
: colTypeList EOF
178178
;
179179

180+
singleRoutineParamList
181+
: colDefinitionList EOF
182+
;
183+
180184
statement
181185
: query #statementDefault
182186
| executeImmediate #visitExecuteImmediate

sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ trait SparkParserUtils {
127127
}
128128
}
129129

130+
/** Get the code that creates the given node. */
131+
def source(ctx: ParserRuleContext): String = {
132+
// Note: `exprCtx.getText` returns a string without spaces, so we need to
133+
// get the text from the underlying char stream instead.
134+
val stream = ctx.getStart.getInputStream
135+
stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
136+
}
137+
130138
/** Convert a string token into a string. */
131139
def string(token: Token): String = unescapeSQLString(token.getText)
132140

sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,16 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
656656
ctx)
657657
}
658658

659+
def createFuncWithGeneratedColumnsError(ctx: ParserRuleContext): Throwable = {
660+
new ParseException(
661+
errorClass = "INVALID_SQL_SYNTAX.CREATE_FUNC_WITH_GENERATED_COLUMNS_AS_PARAMETERS",
662+
ctx)
663+
}
664+
665+
def createFuncWithConstraintError(ctx: ParserRuleContext): Throwable = {
666+
new ParseException(errorClass = "INVALID_SQL_SYNTAX.CREATE_FUNC_WITH_COLUMN_CONSTRAINTS", ctx)
667+
}
668+
659669
def defineTempFuncWithIfNotExistsError(ctx: ParserRuleContext): Throwable = {
660670
new ParseException(errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", ctx)
661671
}

sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,13 @@ case class StructField(
150150
/**
151151
* Return the default value of this StructField. This is used for storing the default value of a
152152
* function parameter.
153+
*
154+
* It is present when the field represents a function parameter with a default value, such as
155+
* `CREATE FUNCTION f(arg INT DEFAULT 42) RETURN ...`.
153156
*/
154157
private[sql] def getDefault(): Option[String] = {
155-
if (metadata.contains("default")) {
156-
Option(metadata.getString("default"))
158+
if (metadata.contains(StructType.SQL_FUNCTION_DEFAULT_METADATA_KEY)) {
159+
Option(metadata.getString(StructType.SQL_FUNCTION_DEFAULT_METADATA_KEY))
157160
} else {
158161
None
159162
}
@@ -183,6 +186,9 @@ case class StructField(
183186

184187
/**
185188
* Return the current default value of this StructField.
189+
*
190+
* It is present only when the field represents a table column with a default value, such as:
191+
* `ALTER TABLE t ALTER COLUMN c SET DEFAULT 42`.
186192
*/
187193
def getCurrentDefaultValue(): Option[String] = {
188194
if (metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) {
@@ -214,7 +220,8 @@ case class StructField(
214220
}
215221
}
216222

217-
private def getDDLDefault = getCurrentDefaultValue()
223+
private def getDDLDefault = getDefault()
224+
.orElse(getCurrentDefaultValue())
218225
.map(" DEFAULT " + _)
219226
.getOrElse("")
220227

sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
521521
*/
522522
@Stable
523523
object StructType extends AbstractDataType {
524+
private[sql] val SQL_FUNCTION_DEFAULT_METADATA_KEY = "default"
524525

525526
override private[sql] def defaultConcreteType: DataType = new StructType
526527

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ object SQLFunction {
207207
val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc, parser)
208208
SQLFunction(
209209
name = function.identifier,
210-
inputParam = props.get(INPUT_PARAM).map(parseTableSchema(_, parser)),
210+
inputParam = props.get(INPUT_PARAM).map(parseRoutineParam(_, parser)),
211211
returnType = returnType.get,
212212
exprText = props.get(EXPRESSION),
213213
queryText = props.get(QUERY),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ object UserDefinedFunction {
8686
// The default Hive Metastore SQL schema length for function resource uri.
8787
private val HIVE_FUNCTION_RESOURCE_URI_LENGTH_THRESHOLD: Int = 4000
8888

89+
def parseRoutineParam(text: String, parser: ParserInterface): StructType = {
90+
val parsed = parser.parseRoutineParam(text)
91+
CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
92+
}
93+
8994
def parseTableSchema(text: String, parser: ParserInterface): StructType = {
9095
val parsed = parser.parseTableSchema(text)
9196
CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CompoundPlanStatement, Logic
2525
import org.apache.spark.sql.catalyst.trees.Origin
2626
import org.apache.spark.sql.errors.QueryParsingErrors
2727
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types.StructType
2829

2930
/**
3031
* Base class for all ANTLR4 [[ParserInterface]] implementations.
@@ -102,6 +103,13 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface {
102103
}
103104
}
104105

106+
override def parseRoutineParam(sqlText: String): StructType = parse(sqlText) { parser =>
107+
val ctx = parser.singleRoutineParamList()
108+
withErrorHandling(ctx, Some(sqlText)) {
109+
astBuilder.visitSingleRoutineParamList(ctx)
110+
}
111+
}
112+
105113
def withErrorHandling[T](ctx: ParserRuleContext, sqlText: Option[String])(toResult: => T): T = {
106114
withOrigin(ctx, sqlText) {
107115
try {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import scala.jdk.CollectionConverters._
2525
import scala.util.{Left, Right}
2626

2727
import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token}
28-
import org.antlr.v4.runtime.misc.Interval
2928
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
3029

3130
import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable, SparkThrowableHelper}
@@ -117,6 +116,21 @@ class AstBuilder extends DataTypeAstBuilder
117116
}
118117
}
119118

119+
/**
120+
* Retrieves the original input text for a given parser context, preserving all whitespace and
121+
* formatting.
122+
*
123+
* ANTLR's default getText method removes whitespace because lexer rules typically skip it.
124+
* This utility method extracts the exact text from the original input stream, using token
125+
* indices.
126+
*
127+
* @param ctx The parser context to retrieve original text from.
128+
* @return The original input text, including all whitespaces and formatting.
129+
*/
130+
private def getOriginalText(ctx: ParserRuleContext): String = {
131+
SparkParserUtils.source(ctx)
132+
}
133+
120134
/**
121135
* Override the default behavior for all visit methods. This will only return a non-null result
122136
* when the context has only one child. This is done because there is no generic method to
@@ -3848,6 +3862,29 @@ class AstBuilder extends DataTypeAstBuilder
38483862
* DataType parsing
38493863
* ******************************************************************************************** */
38503864

3865+
override def visitSingleRoutineParamList(
3866+
ctx: SingleRoutineParamListContext): StructType = withOrigin(ctx) {
3867+
val cols = visitColDefinitionList(ctx.colDefinitionList())
3868+
// Generated columns should have been rejected by the parser.
3869+
for (col <- cols) {
3870+
assert(col.generationExpression.isEmpty)
3871+
assert(col.identityColumnSpec.isEmpty)
3872+
}
3873+
// Build fields from the columns, converting comments and default values
3874+
val fields = for (col <- cols) yield {
3875+
val metadataBuilder = new MetadataBuilder().withMetadata(col.metadata)
3876+
col.comment.foreach { c =>
3877+
metadataBuilder.putString("comment", c)
3878+
}
3879+
col.defaultValue.foreach { default =>
3880+
metadataBuilder.putString(
3881+
StructType.SQL_FUNCTION_DEFAULT_METADATA_KEY, default.originalSQL)
3882+
}
3883+
StructField(col.name, col.dataType, col.nullable, metadataBuilder.build())
3884+
}
3885+
StructType(fields.toArray)
3886+
}
3887+
38513888
/**
38523889
* Create top level table schema.
38533890
*/
@@ -3950,14 +3987,7 @@ class AstBuilder extends DataTypeAstBuilder
39503987
if (expr.containsPattern(PARAMETER)) {
39513988
throw QueryParsingErrors.parameterMarkerNotAllowed(place, expr.origin)
39523989
}
3953-
// Extract the raw expression text so that we can save the user provided text. We don't
3954-
// use `Expression.sql` to avoid storing incorrect text caused by bugs in any expression's
3955-
// `sql` method. Note: `exprCtx.getText` returns a string without spaces, so we need to
3956-
// get the text from the underlying char stream instead.
3957-
val start = exprCtx.getStart.getStartIndex
3958-
val end = exprCtx.getStop.getStopIndex
3959-
val originalSQL = exprCtx.getStart.getInputStream.getText(new Interval(start, end))
3960-
DefaultValueExpression(expr, originalSQL)
3990+
DefaultValueExpression(expr, getOriginalText(exprCtx))
39613991
}
39623992

39633993
/**

0 commit comments

Comments
 (0)