Skip to content

[SPARK-52582][SQL] Improve the memory usage of XML parser #51287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.{SparkIllegalArgumentException, SparkUpgradeException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream
import org.apache.spark.sql.errors.QueryExecutionErrors
Expand Down Expand Up @@ -118,23 +118,11 @@ class StaxXmlParser(
}.flatten
}

def parseColumn(xml: String, schema: StructType): InternalRow = {
// The user=specified schema from from_xml, etc will typically not include a
// "corrupted record" column. In PERMISSIVE mode, which puts bad records in
// such a column, this would cause an error. In this mode, if such a column
// is not manually specified, then fall back to DROPMALFORMED, which will return
// null column values where parsing fails.
val parseMode =
if (options.parseMode == PermissiveMode &&
!schema.fields.exists(_.name == options.columnNameOfCorruptRecord)) {
DropMalformedMode
} else {
options.parseMode
}
val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
doParseColumn(xml, parseMode, xsdSchema).orNull
}

/**
* Parse the given XML string record as an InternalRow
* @param xml The single XML record string to parse
* @param xsdSchema The xsd schema to validate the XML against, if provided.
*/
def doParseColumn(xml: String,
parseMode: ParseMode,
xsdSchema: Option[Schema]): Option[InternalRow] = {
Expand Down Expand Up @@ -186,6 +174,94 @@ class StaxXmlParser(
}
}

/**
* The optimized version of the XML stream parser that reads XML records from the input file
* stream sequentially without loading each individual XML record string into memory.
*/
def parseStreamOptimized(
inputStream: InputStream,
schema: StructType,
streamLiteral: () => UTF8String): Iterator[InternalRow] = {
// XSD validation would require converting to string first, which defeats the purpose
// For now, skip XSD validation in the optimized parsing mode to maintain memory efficiency
if (Option(options.rowValidationXSDPath).isDefined) {
logWarning("XSD validation is not supported in streaming mode and will be skipped")
}
Comment on lines +185 to +189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validator::validate function can take as input a StreamSource or StAXSource. Either one can be extended to pass individual rowTag element to the validate function.

val safeParser = new FailureSafeParser[XMLEventReader](
input => {
// The first event is guaranteed to be a StartElement, so we can read attributes from it
// without using StaxXmlParserUtils.skipUntil.
val attributes = input.nextEvent().asStartElement().getAttributes.asScala.toArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this inside doParseColumn.

doParseColumn(input, attributes, streamLiteral)
},
options.parseMode,
schema,
options.columnNameOfCorruptRecord
)

val xmlTokenizer = new OptimizedXmlTokenizer(inputStream, options)
StaxXmlParser.convertStream(xmlTokenizer) { tokens =>
safeParser.parse(tokens)
}.flatten
}

/**
* Parse the next XML record from the event stream.
* @param parser The XML event reader over the entire XML file stream. The first event has been
* advanced to the next record in the file.
* @param rootAttributes The attributes of the record root element.
* @param xmlLiteral A function that returns the entire XML file content as a UTF8String. Used
* to create a BadRecordException in case of parsing errors.
* TODO: Only include the file content starting with the current record.
*/
def doParseColumn(
parser: XMLEventReader,
rootAttributes: Array[Attribute],
xmlLiteral: () => UTF8String): Option[InternalRow] = {
try {
options.singleVariantColumn match {
case Some(_) =>
// If the singleVariantColumn is specified, parse the entire xml record as a Variant
val v = StaxXmlParser.parseVariant(parser, rootAttributes, options)
Some(InternalRow(v))
case _ =>
// Otherwise, parse the xml record as Structs
val result = Some(convertObject(parser, schema, rootAttributes))
result
}
} catch {
case e: SparkUpgradeException => throw e
case e@(_: RuntimeException | _: XMLStreamException | _: MalformedInputException
| _: SAXException) =>
// Skip rest of the content in the parser and put the whole XML file in the
// BadRecordException.
parser.close()
// XML parser currently doesn't support partial results for corrupted records.
// For such records, all fields other than the field configured by
// `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(xmlLiteral, () => Array.empty, e)
case e: CharConversionException if options.charset.isEmpty =>
val msg =
"""XML parser cannot handle a character in its input.
|Specifying encoding as an input option explicitly might help to resolve the issue.
|""".stripMargin + e.getMessage
val wrappedCharException = new CharConversionException(msg)
wrappedCharException.initCause(e)
throw BadRecordException(xmlLiteral, () => Array.empty,
wrappedCharException)
case PartialResultException(row, cause) =>
throw BadRecordException(
record = xmlLiteral,
partialResults = () => Array(row),
cause)
case PartialResultArrayException(rows, cause) =>
throw BadRecordException(
record = xmlLiteral,
partialResults = () => rows,
cause)
}
}

/**
* Parse the current token (and related children) according to a desired schema
*/
Expand Down Expand Up @@ -637,6 +713,14 @@ class StaxXmlParser(
}
}

trait XmlTokenizerBase[T] extends Logging {
/**
* Finds the next XML record in the stream.
* @return an Option containing the next XML record as a String, or None if no more records
*/
def next(): Option[T]
}

/**
* XMLRecordReader class to read through a given xml document to output xml blocks as records
* as specified by the start tag and end tag.
Expand All @@ -645,7 +729,7 @@ class StaxXmlParser(
*/
class XmlTokenizer(
inputStream: InputStream,
options: XmlOptions) extends Logging {
options: XmlOptions) extends XmlTokenizerBase[String] {
private var reader = new BufferedReader(
new InputStreamReader(inputStream, Charset.forName(options.charset)))
private var currentStartTag: String = _
Expand All @@ -665,7 +749,7 @@ class XmlTokenizer(
* @param value the object that will be written
* @return whether it reads successfully
*/
def next(): Option[String] = {
override def next(): Option[String] = {
var nextString: Option[String] = None
try {
if (readUntilStartElement()) {
Expand Down Expand Up @@ -898,6 +982,84 @@ class XmlTokenizer(
}
}

/**
* Optimized XML tokenizer that avoids loading entire XML records into memory.
* - Uses XMLEventReader to parse XML stream directly
* - Never buffers complete XML records in memory
* - Allows the parser to work directly with XML events
*/
class OptimizedXmlTokenizer(inputStream: InputStream, options: XmlOptions)
extends XmlTokenizerBase[XMLEventReader] {
private var reader = StaxXmlParserUtils.filteredReader(inputStream, options)

/**
* Returns the next XML record as a positioned XMLEventReader.
* This avoids creating intermediate string representations.
*/
override def next(): Option[XMLEventReader] = {
var nextRecord: Option[XMLEventReader] = None
try {
// Skip to the next row start element
if (skipToNextRowStart()) {
nextRecord = Some(reader)
}
} catch {
case e: FileNotFoundException if options.ignoreMissingFiles =>
logWarning("Skipping the rest of the content in the missing file", e)
case NonFatal(e) =>
ExceptionUtils.getRootCause(e) match {
case _: AccessControlException | _: BlockMissingException =>
close()
throw e
case _: RuntimeException | _: IOException if options.ignoreCorruptFiles =>
logWarning("Skipping the rest of the content in the corrupted file", e)
case _: XMLStreamException =>
logWarning("Skipping the rest of the content in the corrupted file", e)
case e: Throwable =>
close()
throw e
}
} finally {
if (nextRecord.isEmpty && reader != null) {
close()
}
}
nextRecord
}

def close(): Unit = {
if (reader != null) {
reader.close()
inputStream.close()
reader = null
}
}

/**
* Skip through the XML stream until we find the next row start element.
*/
private def skipToNextRowStart(): Boolean = {
val rowTagName = options.rowTag
while (reader.hasNext) {
val event = reader.peek()
event match {
case startElement: StartElement =>
val elementName = StaxXmlParserUtils.getName(startElement.getName, options)
if (elementName == rowTagName) {
return true
}
case _: EndDocument =>
return false
case _ =>
// Continue searching
}
// if not the event we want, advance the reader
reader.nextEvent()
}
false
}
}

object StaxXmlParser {
/**
* Parses a stream that contains CSV strings and turns it into an iterator of tokens.
Expand All @@ -907,15 +1069,22 @@ object StaxXmlParser {
convertStream(xmlTokenizer)(tokens => tokens)
}

private def convertStream[T](
xmlTokenizer: XmlTokenizer)(
convert: String => T) = new Iterator[T] {
def tokenizeStreamOptimized(
inputStream: InputStream,
options: XmlOptions): Iterator[XMLEventReader] = {
val xmlTokenizer = new OptimizedXmlTokenizer(inputStream, options)
convertStream(xmlTokenizer)(tokens => tokens)
}

private def convertStream[TokenType, ResultType](
xmlTokenizer: XmlTokenizerBase[TokenType])(
convert: TokenType => ResultType) = new Iterator[ResultType] {

private var nextRecord = xmlTokenizer.next()

override def hasNext: Boolean = nextRecord.nonEmpty

override def next(): T = {
override def next(): ResultType = {
if (!hasNext) {
throw QueryExecutionErrors.endOfStreamError()
}
Expand All @@ -931,9 +1100,17 @@ object StaxXmlParser {
def parseVariant(xml: String, options: XmlOptions): VariantVal = {
val parser = StaxXmlParserUtils.filteredReader(xml)
val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
val v = convertVariant(parser, rootAttributes, options)
val v = parseVariant(parser, rootAttributes, options)
parser.close()
v
new VariantVal(v.getValue, v.getMetadata)
}

def parseVariant(
parser: XMLEventReader,
rootAttributes: Array[Attribute],
options: XmlOptions): VariantVal = {
val v = convertVariant(parser, rootAttributes, options)
new VariantVal(v.getValue, v.getMetadata)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.catalyst.xml

import java.io.StringReader
import java.io.{InputStreamReader, StringReader}
import java.nio.charset.Charset
import javax.xml.namespace.QName
import javax.xml.stream.{EventFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants}
import javax.xml.stream.events._
Expand All @@ -35,26 +36,33 @@ object StaxXmlParserUtils {
factory
}

val filter = new EventFilter {
override def accept(event: XMLEvent): Boolean =
event.getEventType match {
// Ignore comments and processing instructions
case XMLStreamConstants.COMMENT | XMLStreamConstants.PROCESSING_INSTRUCTION => false
// unsupported events
case XMLStreamConstants.DTD |
XMLStreamConstants.ENTITY_DECLARATION |
XMLStreamConstants.ENTITY_REFERENCE |
XMLStreamConstants.NOTATION_DECLARATION => false
case _ => true
}
}

def filteredReader(xml: String): XMLEventReader = {
val filter = new EventFilter {
override def accept(event: XMLEvent): Boolean =
event.getEventType match {
// Ignore comments and processing instructions
case XMLStreamConstants.COMMENT | XMLStreamConstants.PROCESSING_INSTRUCTION => false
// unsupported events
case XMLStreamConstants.DTD |
XMLStreamConstants.ENTITY_DECLARATION |
XMLStreamConstants.ENTITY_REFERENCE |
XMLStreamConstants.NOTATION_DECLARATION => false
case _ => true
}
}
// It does not have to skip for white space, since `XmlInputFormat`
// always finds the root tag without a heading space.
val eventReader = factory.createXMLEventReader(new StringReader(xml))
factory.createFilteredReader(eventReader, filter)
}

def filteredReader(inputStream: java.io.InputStream, options: XmlOptions): XMLEventReader = {
val inputStreamReader = new InputStreamReader(inputStream, Charset.forName(options.charset))
val eventReader = factory.createXMLEventReader(inputStreamReader)
Comment on lines +61 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any issue with this?

Suggested change
val inputStreamReader = new InputStreamReader(inputStream, Charset.forName(options.charset))
val eventReader = factory.createXMLEventReader(inputStreamReader)
val eventReader = factory.createXMLEventReader(inputStream, Charset.forName(options.charset))

factory.createFilteredReader(eventReader, filter)
}

def gatherRootAttributes(parser: XMLEventReader): Array[Attribute] = {
val rootEvent =
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
Expand Down
Loading