Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/scala/org/apache/spark/sql/OapExtensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ class OapExtensions extends (SparkSessionExtensions => Unit) {
extensions.injectPlannerStrategy(_ => OapGroupAggregateStrategy)
extensions.injectPlannerStrategy(_ => OapFileSourceStrategy)
// Oap Custom SqlParser.
extensions.injectParser((session, _) => new OapSparkSqlParser(session))
extensions.injectParser((session, delegate) => new OapSparkSqlParser(session, delegate))
}
}
237 changes: 221 additions & 16 deletions src/main/scala/org/apache/spark/sql/execution/OapSparkSqlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,109 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans.logical._
import scala.collection.JavaConverters._

import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
import org.antlr.v4.runtime.tree._

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.{OapSqlBaseBaseListener, OapSqlBaseBaseVisitor, OapSqlBaseLexer, OapSqlBaseParser, ParseErrorListener, ParseException, ParserInterface}
import org.apache.spark.sql.catalyst.parser.OapSqlBaseParser._
import org.apache.spark.sql.catalyst.parser.ParserUtils._
import org.apache.spark.sql.catalyst.parser.ParserUtils.{string, withOrigin}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.execution.datasources.oap.index._
import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
* Concrete parser for Spark SQL statements.
* A SQL parser that tries to parse OAP commands. If failing to parse the SQL text, it will
* forward the call to `delegate`.
*/
class OapSparkSqlParser(session: SparkSession) extends AbstractSqlParser {
lazy val conf = session.sessionState.conf
lazy val astBuilder = new OapSparkSqlAstBuilder(conf)
class OapSparkSqlParser(session: SparkSession, delegate: ParserInterface) extends ParserInterface {

lazy val conf: SQLConf = session.sessionState.conf
lazy val builder = new OapSqlBaseAstBuilder(conf)


override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
builder.visitSingleStatement(parser.singleStatement()) match {
case plan: LogicalPlan => plan
case _ => delegate.parsePlan(sqlText)
}
}
// scalastyle:off line.size.limit
/**
* Fork from `org.apache.spark.sql.catalyst.parser.AbstractSqlParser#parse(java.lang.String, scala.Function1)`.
*
* @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L81
*/
// scalastyle:on
protected def parse[T](command: String)(toResult: OapSqlBaseParser => T): T = {
val lexer = new OapSqlBaseLexer(
new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)

private lazy val substitutor = new VariableSubstitution(conf)
val tokenStream = new CommonTokenStream(lexer)
val parser = new OapSqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)

protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
super.parse(substitutor.substitute(command))(toResult)
try {
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
toResult(parser)
} catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.seek(0) // rewind input stream
parser.reset()

// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
} catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
throw e.withCommand(command)
case e: AnalysisException =>
val position = Origin(e.line, e.startPosition)
throw new ParseException(Option(command), e.message, position, position)
}
}
}

override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)

override def parseTableIdentifier(sqlText: String): TableIdentifier =
delegate.parseTableIdentifier(sqlText)

override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
delegate.parseFunctionIdentifier(sqlText)

override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText)

override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText)

}

/**
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
* Define how to convert an AST generated from `OapSqlBase.g4` to a `LogicalPlan`. The parent
* class `OapSqlBaseBaseVisitor` defines all visitXXX methods generated from `#` instructions in
* `OapSqlBase.g4` (such as `#oapCreateIndex`).
*/
class OapSparkSqlAstBuilder(conf: SQLConf) extends SparkSqlAstBuilder(conf) {
import org.apache.spark.sql.catalyst.parser.ParserUtils._
class OapSqlBaseAstBuilder(conf: SQLConf) extends OapSqlBaseBaseVisitor[AnyRef] {

override def visitPassThrough(ctx: PassThroughContext): LogicalPlan = null

/**
* Create an index. Create a [[CreateIndexCommand]] command.
Expand Down Expand Up @@ -124,4 +201,132 @@ class OapSparkSqlAstBuilder(conf: SQLConf) extends SparkSqlAstBuilder(conf) {
override def visitOapEnableIndex(ctx: OapEnableIndexContext): LogicalPlan =
OapEnableIndexCommand(ctx.IDENTIFIER.getText)

override def visitTableIdentifier(ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
}

protected def visitNonOptionalPartitionSpec(
ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
visitPartitionSpec(ctx).map {
case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx)
case (key, Some(value)) => key -> value
}
}

override def visitPartitionSpec(
ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
val parts = ctx.partitionVal.asScala.map { pVal =>
val name = pVal.identifier.getText
val value = Option(pVal.constant).map(visitStringConstant)
name -> value
}
// Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values
// in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for
// partition columns will be done in analyzer.
checkDuplicateKeys(parts, ctx)
parts.toMap
}

override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
visit(ctx.statement).asInstanceOf[LogicalPlan]
}

protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) {
ctx match {
case s: StringLiteralContext => createString(s)
case o => o.getText
}
}

private def createString(ctx: StringLiteralContext): String = {
if (conf.escapedStringLiterals) {
ctx.STRING().asScala.map(stringWithoutUnescape).mkString
} else {
ctx.STRING().asScala.map(string).mkString
}
}


protected def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T]
}
}

// scalastyle:off line.size.limit
/**
* Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`.
*
* @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L157
*/
// scalastyle:on
class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
override def consume(): Unit = wrapped.consume
override def getSourceName(): String = wrapped.getSourceName
override def index(): Int = wrapped.index
override def mark(): Int = wrapped.mark
override def release(marker: Int): Unit = wrapped.release(marker)
override def seek(where: Int): Unit = wrapped.seek(where)
override def size(): Int = wrapped.size

override def getText(interval: Interval): String = {
// ANTLR 4.7's CodePointCharStream implementations have bugs when
// getText() is called with an empty stream, or intervals where
// the start > end. See
// https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
// that is not yet in a released ANTLR artifact.
if (size() > 0 && (interval.b - interval.a >= 0)) {
wrapped.getText(interval)
} else {
""
}
}

override def LA(i: Int): Int = {
val la = wrapped.LA(i)
if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}

// scalastyle:off line.size.limit
/**
* Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`.
*
* @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L248
*/
// scalastyle:on
case object PostProcessor extends OapSqlBaseBaseListener {

/** Remove the back ticks from an Identifier. */
override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
replaceTokenByIdentifier(ctx, 1) { token =>
// Remove the double back ticks in the string.
token.setText(token.getText.replace("``", "`"))
token
}
}

/** Treat non-reserved keywords as Identifiers. */
override def exitNonReserved(ctx: NonReservedContext): Unit = {
replaceTokenByIdentifier(ctx, 0)(identity)
}

private def replaceTokenByIdentifier(
ctx: ParserRuleContext,
stripMargins: Int)(
f: CommonToken => CommonToken = identity): Unit = {
val parent = ctx.getParent
parent.removeLastChild()
val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
val newToken = new CommonToken(
new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
OapSqlBaseParser.IDENTIFIER,
token.getChannel,
token.getStartIndex + stripMargins,
token.getStopIndex - stripMargins)
parent.addChild(new TerminalNodeImpl(f(newToken)))
}
}



Loading