diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala index 438110d7acc4c..da0d5d6521bd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala @@ -155,7 +155,7 @@ case class XmlToStructsEvaluator( val xsdSchema = Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema) new FailureSafeParser[String]( - input => rawParser.doParseColumn(input, mode, xsdSchema), + input => rawParser.doParseColumn(input, xsdSchema), mode, schema, parsedOptions.columnNameOfCorruptRecord) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXMLRecordReader.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXMLRecordReader.scala new file mode 100644 index 0000000000000..951da0959499e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXMLRecordReader.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.xml + +import java.io.InputStream +import javax.xml.stream.{XMLEventReader, XMLStreamConstants} +import javax.xml.stream.events.{EndDocument, StartElement, XMLEvent} +import javax.xml.transform.stax.StAXSource +import javax.xml.validation.Schema + +import scala.util.control.NonFatal + +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.hadoop.shaded.com.ctc.wstx.exc.WstxEOFException + +import org.apache.spark.internal.Logging + +/** + * XML record reader that reads the next XML record in the underlying XML stream. It can support XSD + * schema validation by maintaining a separate XML reader and keep it in sync with the primary XML + * reader. + */ +case class StaxXMLRecordReader(inputStream: () => InputStream, options: XmlOptions) + extends XMLEventReader + with Logging { + // Reader for the XML record parsing. + private val in1 = inputStream() + private val primaryEventReader = StaxXmlParserUtils.filteredEventReader(in1, options) + // Reader for the XSD validation, if an XSD schema is provided. + private val in2 = Option(options.rowValidationXSDPath).map(_ => inputStream()) + private val streamReaderForXSDValidation = + in2.map(in => StaxXmlParserUtils.filteredStreamReader(in, options)) + private val eventReaderForXSDValidation = + streamReaderForXSDValidation.map(StaxXmlParserUtils.filteredEventReader) + + final var hasMoreRecord: Boolean = true + + /** + * Skip through the XML stream until we find the next row start element. + * Returns true if a row start element is found, false if end of stream is reached. + */ + def skipToNextRecord(): Boolean = { + hasMoreRecord = skipToNextRowStart(primaryEventReader) && eventReaderForXSDValidation.forall( + skipToNextRowStart + ) + if (!hasMoreRecord) { + closeAllReaders() + } + hasMoreRecord + } + + /** + * Skip through the XML stream until we find the next row start element. + */ + private def skipToNextRowStart(reader: XMLEventReader): Boolean = { + val rowTagName = options.rowTag + try { + while (reader.hasNext) { + val event = reader.peek() + event match { + case startElement: StartElement => + val elementName = StaxXmlParserUtils.getName(startElement.getName, options) + if (elementName == rowTagName) { + return true + } + case _: EndDocument => + return false + case _ => + // Continue searching + } + // if not the event we want, advance the reader + reader.nextEvent() + } + false + } catch { + case NonFatal(e) if ExceptionUtils.getRootCause(e).isInstanceOf[WstxEOFException] => + logWarning("Reached end of file while looking for next row start element.") + false + } + } + + def validateXSDSchema(schema: Schema): Unit = { + streamReaderForXSDValidation match { + case Some(p) => + try { + // StAXSource requires the stream reader to start with the START_DOCUMENT OR START_ELEMENT + // events. + def rowTagStarted: Boolean = + p.getEventType == XMLStreamConstants.START_ELEMENT && + StaxXmlParserUtils.getName(p.getName, options) == options.rowTag + while (!rowTagStarted && p.hasNext) { + p.next() + } + schema.newValidator().validate(new StAXSource(p)) + } catch { + case NonFatal(e) => + try { + // If the validation fails, we need to skip the current record in the primary reader + // advancing the primary event reader so that the parser will continue to the + // next record. + primaryEventReader.next() + } finally { + throw e + } + } + case None => throw new IllegalStateException("XSD validation parser is not initialized") + } + } + + def closeAllReaders(): Unit = { + primaryEventReader.close() + streamReaderForXSDValidation.foreach(_.close()) + eventReaderForXSDValidation.foreach(_.close()) + in1.close() + in2.foreach(_.close()) + hasMoreRecord = false + } + + override def nextEvent(): XMLEvent = primaryEventReader.nextEvent() + override def hasNext: Boolean = primaryEventReader.hasNext + override def peek(): XMLEvent = primaryEventReader.peek() + override def getElementText: String = primaryEventReader.getElementText + override def nextTag(): XMLEvent = primaryEventReader.nextTag() + override def getProperty(name: String): AnyRef = primaryEventReader.getProperty(name) + override def close(): Unit = {} + override def next(): AnyRef = primaryEventReader.next() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 00497c1c31f35..34e126f689ed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.sql.catalyst.xml -import java.io.{BufferedReader, CharConversionException, FileNotFoundException, InputStream, InputStreamReader, IOException, StringReader} -import java.nio.charset.{Charset, MalformedInputException} +import java.io.{CharConversionException, FileNotFoundException, InputStream, IOException} +import java.nio.charset.MalformedInputException import java.text.NumberFormat import java.util import java.util.Locale import javax.xml.stream.{XMLEventReader, XMLStreamException} import javax.xml.stream.events._ -import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema import scala.collection.mutable.ArrayBuffer @@ -33,15 +32,14 @@ import scala.util.control.Exception.allCatch import scala.util.control.NonFatal import scala.xml.SAXException -import org.apache.commons.lang3.exception.ExceptionUtils -import org.apache.hadoop.hdfs.BlockMissingException -import org.apache.hadoop.security.AccessControlException +import com.google.common.io.ByteStreams +import org.apache.hadoop.shaded.org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.{SparkIllegalArgumentException, SparkUpgradeException} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, FailureSafeParser, GenericArrayData, MapData, PartialResultArrayException, PartialResultException, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors @@ -51,6 +49,7 @@ import org.apache.spark.types.variant.{Variant, VariantBuilder} import org.apache.spark.types.variant.VariantBuilder.FieldEntry import org.apache.spark.types.variant.VariantUtil import org.apache.spark.unsafe.types.{UTF8String, VariantVal} +import org.apache.spark.util.Utils class StaxXmlParser( schema: StructType, @@ -90,7 +89,7 @@ class StaxXmlParser( (_: String) => Some(InternalRow.empty) } else { val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) - (input: String) => doParseColumn(input, options.parseMode, xsdSchema) + (input: String) => doParseColumn(input, xsdSchema) } } @@ -102,68 +101,81 @@ class StaxXmlParser( } } - def parseStream( - inputStream: InputStream, - schema: StructType): Iterator[InternalRow] = { + /** + * XML stream parser that reads XML records from the input file stream sequentially without + * loading each individual XML record string into memory. + */ + def parseStream(inputStream: () => InputStream, schema: StructType): Iterator[InternalRow] = { val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) - val safeParser = new FailureSafeParser[String]( - input => doParseColumn(input, options.parseMode, xsdSchema), + val streamLiteral = () => + Utils.tryWithResource( + inputStream() + ) { is => + UTF8String.fromBytes(ByteStreams.toByteArray(is)) + } + val safeParser = new FailureSafeParser[StaxXMLRecordReader]( + input => doParseColumn(input, xsdSchema, shouldSkipToNextRecord = true, streamLiteral), options.parseMode, schema, - options.columnNameOfCorruptRecord) + options.columnNameOfCorruptRecord + ) - val xmlTokenizer = new XmlTokenizer(inputStream, options) - convertStream(xmlTokenizer) { tokens => - safeParser.parse(tokens) + convertStream(inputStream, options) { reader => + safeParser.parse(reader) }.flatten } - def parseColumn(xml: String, schema: StructType): InternalRow = { - // The user=specified schema from from_xml, etc will typically not include a - // "corrupted record" column. In PERMISSIVE mode, which puts bad records in - // such a column, this would cause an error. In this mode, if such a column - // is not manually specified, then fall back to DROPMALFORMED, which will return - // null column values where parsing fails. - val parseMode = - if (options.parseMode == PermissiveMode && - !schema.fields.exists(_.name == options.columnNameOfCorruptRecord)) { - DropMalformedMode - } else { - options.parseMode + /** + * Parse a single XML record string and return an InternalRow. + */ + def doParseColumn(xml: String, xsdSchema: Option[Schema]): Option[InternalRow] = { + val parser = StaxXmlParserUtils.staxXMLRecordReader(xml, options) + try { + doParseColumn( + parser, xsdSchema, shouldSkipToNextRecord = false, () => UTF8String.fromString(xml) + ) + } finally { + parser.close() } - val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) - doParseColumn(xml, parseMode, xsdSchema).orNull } - def doParseColumn(xml: String, - parseMode: ParseMode, - xsdSchema: Option[Schema]): Option[InternalRow] = { - lazy val xmlRecord = UTF8String.fromString(xml) + /** + * Parse the next XML record from the XML event stream. + * Note that the method will **NOT** close the XML event stream as there could have more XML + * records to parse. It's the caller's responsibility to close the stream. + * + * @param parser The XML event reader. + * @param shouldSkipToNextRecord If true, we will skip to the next XML record in the parser + * @param xmlLiteral A function that returns the entire XML file content as a UTF8String. Used + * to create a BadRecordException in case of parsing errors. + * TODO: Only include the file content starting with the current record. + */ + def doParseColumn( + parser: StaxXMLRecordReader, + xsdSchema: Option[Schema] = None, + shouldSkipToNextRecord: Boolean, + xmlLiteral: () => UTF8String): Option[InternalRow] = { try { + if (shouldSkipToNextRecord && !parser.skipToNextRecord()) { + return None + } + xsdSchema.foreach { schema => - schema.newValidator().validate(new StreamSource(new StringReader(xml))) + parser.validateXSDSchema(schema) } options.singleVariantColumn match { case Some(_) => - // If the singleVariantColumn is specified, parse the entire xml string as a Variant - val v = StaxXmlParser.parseVariant(xml, options) + // If the singleVariantColumn is specified, parse the entire xml record as a Variant + val v = StaxXmlParser.parseVariant(parser, options) Some(InternalRow(v)) case _ => - // Otherwise, parse the xml string as Structs - val parser = StaxXmlParserUtils.filteredReader(xml) + // Otherwise, parse the xml record as Structs val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) val result = Some(convertObject(parser, schema, rootAttributes)) - parser.close() result } } catch { case e: SparkUpgradeException => throw e - case e@(_: RuntimeException | _: XMLStreamException | _: MalformedInputException - | _: SAXException) => - // XML parser currently doesn't support partial results for corrupted records. - // For such records, all fields other than the field configured by - // `columnNameOfCorruptRecord` are set to `null`. - throw BadRecordException(() => xmlRecord, () => Array.empty, e) case e: CharConversionException if options.charset.isEmpty => val msg = """XML parser cannot handle a character in its input. @@ -171,18 +183,34 @@ class StaxXmlParser( |""".stripMargin + e.getMessage val wrappedCharException = new CharConversionException(msg) wrappedCharException.initCause(e) - throw BadRecordException(() => xmlRecord, () => Array.empty, + throw BadRecordException(xmlLiteral, () => Array.empty, wrappedCharException) case PartialResultException(row, cause) => throw BadRecordException( - record = () => xmlRecord, + record = xmlLiteral, partialResults = () => Array(row), cause) case PartialResultArrayException(rows, cause) => - throw BadRecordException( - record = () => xmlRecord, - partialResults = () => rows, - cause) + throw BadRecordException(record = xmlLiteral, partialResults = () => rows, cause) + case e: Throwable => + ExceptionUtils.getRootCause(e) match { + case _: FileNotFoundException if options.ignoreMissingFiles => + logWarning("Skipped missing file", e) + parser.closeAllReaders() + None + case _: IOException | _: RuntimeException if options.ignoreCorruptFiles => + logWarning("Skipped the rest of the content in the corrupted file", e) + parser.closeAllReaders() + None + case _: XMLStreamException | _: MalformedInputException | _: SAXException => + // Skip rest of the content in the parser and put the whole XML file in the + // BadRecordException. + parser.closeAllReaders() + // XML parser currently doesn't support partial results for corrupted records. + // For such records, all fields other than the field configured by + // `columnNameOfCorruptRecord` are set to `null`. + throw BadRecordException(xmlLiteral, () => Array.empty, e) + } } } @@ -637,291 +665,18 @@ class StaxXmlParser( } } -/** - * XMLRecordReader class to read through a given xml document to output xml blocks as records - * as specified by the start tag and end tag. - * - * This implementation is ultimately loosely based on LineRecordReader in Hadoop. - */ -class XmlTokenizer( - inputStream: InputStream, - options: XmlOptions) extends Logging { - private var reader = new BufferedReader( - new InputStreamReader(inputStream, Charset.forName(options.charset))) - private var currentStartTag: String = _ - private var buffer = new StringBuilder() - private val startTag = s"<${options.rowTag}>" - private val endTag = s"" - private val commentStart = s"" - private val cdataStart = s"" - - /** - * Finds the start of the next record. - * It treats data from `startTag` and `endTag` as a record. - * - * @param key the current key that will be written - * @param value the object that will be written - * @return whether it reads successfully - */ - def next(): Option[String] = { - var nextString: Option[String] = None - try { - if (readUntilStartElement()) { - buffer.append(currentStartTag) - // Don't check whether the end element was found. Even if not, return everything - // that was read, which will invariably cause a parse error later - readUntilEndElement(currentStartTag.endsWith(">")) - nextString = Some(buffer.toString()) - buffer = new StringBuilder() - } - } catch { - case e: FileNotFoundException if options.ignoreMissingFiles => - logWarning( - "Skipping the rest of" + - " the content in the missing file during schema inference", - e) - case NonFatal(e) => - ExceptionUtils.getRootCause(e) match { - case _: AccessControlException | _: BlockMissingException => - reader.close() - reader = null - throw e - case _: RuntimeException | _: IOException if options.ignoreCorruptFiles => - logWarning( - "Skipping the rest of" + - " the content in the corrupted file during schema inference", - e) - case e: Throwable => - reader.close() - reader = null - throw e - } - } finally { - if (nextString.isEmpty && reader != null) { - reader.close() - reader = null - } - } - nextString - } - - private def readUntilMatch(end: String): Boolean = { - var i = 0 - while (true) { - val cOrEOF = reader.read() - if (cOrEOF == -1) { - // End of file. - return false - } - val c = cOrEOF.toChar - if (c == end(i)) { - i += 1 - if (i >= end.length) { - // Found the end string. - return true - } - } else { - i = 0 - } - } - // Unreachable. - false - } - - private def readUntilStartElement(): Boolean = { - currentStartTag = startTag - var i = 0 - var commentIdx = 0 - var cdataIdx = 0 - - while (true) { - val cOrEOF = reader.read() - if (cOrEOF == -1) { // || (i == 0 && getFilePosition() > end)) { - // End of file or end of split. - return false - } - val c = cOrEOF.toChar - - if (c == commentStart(commentIdx)) { - if (commentIdx >= commentStart.length - 1) { - // If a comment beigns we must ignore all character until its end - commentIdx = 0 - readUntilMatch(commentEnd) - } else { - commentIdx += 1 - } - } else { - commentIdx = 0 - } - - if (c == cdataStart(cdataIdx)) { - if (cdataIdx >= cdataStart.length - 1) { - // If a CDATA beigns we must ignore all character until its end - cdataIdx = 0 - readUntilMatch(cdataEnd) - } else { - cdataIdx += 1 - } - } else { - cdataIdx = 0 - } - - if (c == startTag(i)) { - if (i >= startTag.length - 1) { - // Found start tag. - return true - } - // else in start tag - i += 1 - } else { - // if doesn't match the closing angle bracket, check if followed by attributes - if (i == (startTag.length - 1) && Character.isWhitespace(c)) { - // Found start tag with attributes. Remember to write with following whitespace - // char, not angle bracket - currentStartTag = startTag.dropRight(1) + c - return true - } - // else not in start tag - i = 0 - } - } - // Unreachable. - false - } - - private def readUntilEndElement(startTagClosed: Boolean): Boolean = { - // Index into the start or end tag that has matched so far - var si = 0 - var ei = 0 - // Index into the start of a comment tag that matched so far - var commentIdx = 0 - // Index into the start of a CDATA tag that matched so far - var cdataIdx = 0 - // How many other start tags enclose the one that's started already? - var depth = 0 - // Previously read character - var prevC = '\u0000' - - // The current start tag already found may or may not have terminated with - // a '>' as it may have attributes we read here. If not, we search for - // a self-close tag, but only until a non-self-closing end to the start - // tag is found - var canSelfClose = !startTagClosed - - while (true) { - - val cOrEOF = reader.read() - if (cOrEOF == -1) { - // End of file (ignore end of split). - return false - } - - val c = cOrEOF.toChar - buffer.append(c) - - if (c == commentStart(commentIdx)) { - if (commentIdx >= commentStart.length - 1) { - // If a comment beigns we must ignore everything until its end - buffer.setLength(buffer.length - commentStart.length) - commentIdx = 0 - readUntilMatch(commentEnd) - } else { - commentIdx += 1 - } - } else { - commentIdx = 0 - } - - if (c == '>' && prevC != '/') { - canSelfClose = false - } - - // Still matching a start tag? - if (c == startTag(si)) { - // Still also matching an end tag? - if (c == endTag(ei)) { - // In start tag or end tag. - si += 1 - ei += 1 - } else { - if (si >= startTag.length - 1) { - // Found start tag. - si = 0 - ei = 0 - depth += 1 - } else { - // In start tag. - si += 1 - ei = 0 - } - } - } else if (c == endTag(ei)) { - if (ei >= endTag.length - 1) { - if (depth == 0) { - // Found closing end tag. - return true - } - // else found nested end tag. - si = 0 - ei = 0 - depth -= 1 - } else { - // In end tag. - si = 0 - ei += 1 - } - } else if (c == '>' && prevC == '/' && canSelfClose) { - if (depth == 0) { - // found a self-closing tag (end tag) - return true - } - // else found self-closing nested tag (end tag) - si = 0 - ei = 0 - depth -= 1 - } else if (si == (startTag.length - 1) && Character.isWhitespace(c)) { - // found a start tag with attributes - si = 0 - ei = 0 - depth += 1 - } else { - // Not in start tag or end tag. - si = 0 - ei = 0 - } - prevC = c - } - // Unreachable. - false - } -} - object StaxXmlParser { - /** - * Parses a stream that contains CSV strings and turns it into an iterator of tokens. - */ - def tokenizeStream(inputStream: InputStream, options: XmlOptions): Iterator[String] = { - val xmlTokenizer = new XmlTokenizer(inputStream, options) - convertStream(xmlTokenizer)(tokens => tokens) - } - - private def convertStream[T]( - xmlTokenizer: XmlTokenizer)( - convert: String => T) = new Iterator[T] { - - private var nextRecord = xmlTokenizer.next() + def convertStream[T](inputStream: () => InputStream, options: XmlOptions)( + convert: StaxXMLRecordReader => T): Iterator[T] = new Iterator[T] { + private val reader = StaxXMLRecordReader(inputStream, options) - override def hasNext: Boolean = nextRecord.nonEmpty + override def hasNext: Boolean = reader.hasMoreRecord override def next(): T = { if (!hasNext) { throw QueryExecutionErrors.endOfStreamError() } - val curRecord = convert(nextRecord.get) - nextRecord = xmlTokenizer.next() - curRecord + convert(reader) } } @@ -929,11 +684,18 @@ object StaxXmlParser { * Parse the input XML string as a Variant value */ def parseVariant(xml: String, options: XmlOptions): VariantVal = { - val parser = StaxXmlParserUtils.filteredReader(xml) + val parser = StaxXmlParserUtils.staxXMLRecordReader(xml, options) + try { + parseVariant(parser, options) + } finally { + parser.close() + } + } + + def parseVariant(parser: XMLEventReader, options: XmlOptions): VariantVal = { val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) val v = convertVariant(parser, rootAttributes, options) - parser.close() - v + new VariantVal(v.getValue, v.getMetadata) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala index 5d267143b06c9..455fd8dd7918b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala @@ -16,14 +16,15 @@ */ package org.apache.spark.sql.catalyst.xml -import java.io.StringReader import javax.xml.namespace.QName -import javax.xml.stream.{EventFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants} +import javax.xml.stream.{EventFilter, StreamFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants, XMLStreamReader} import javax.xml.stream.events._ import scala.annotation.tailrec import scala.jdk.CollectionConverters._ +import org.apache.commons.io.input.BOMInputStream + object StaxXmlParserUtils { private[sql] val factory: XMLInputFactory = { @@ -35,29 +36,55 @@ object StaxXmlParserUtils { factory } - def filteredReader(xml: String): XMLEventReader = { + private val eventTypeFilter: Int => Boolean = { + // Ignore comments and processing instructions + case XMLStreamConstants.COMMENT | + XMLStreamConstants.PROCESSING_INSTRUCTION => false + // unsupported events + case XMLStreamConstants.DTD | + XMLStreamConstants.ENTITY_DECLARATION | + XMLStreamConstants.ENTITY_REFERENCE | + XMLStreamConstants.NOTATION_DECLARATION => false + case _ => true + } + + def filteredStreamReader( + inputStream: java.io.InputStream, + options: XmlOptions): XMLStreamReader = { + val filter = new StreamFilter { + override def accept(event: XMLStreamReader): Boolean = eventTypeFilter(event.getEventType) + } + val bomInputStreamBuilder = new BOMInputStream.Builder + bomInputStreamBuilder.setInputStream(inputStream) + val streamReader = factory.createXMLStreamReader(bomInputStreamBuilder.get(), options.charset) + factory.createFilteredReader(streamReader, filter) + } + + def filteredEventReader(inputStream: java.io.InputStream, options: XmlOptions): XMLEventReader = { + val streamReader = filteredStreamReader(inputStream, options) + filteredEventReader(streamReader) + } + + def filteredEventReader(streamReader: XMLStreamReader): XMLEventReader = { val filter = new EventFilter { - override def accept(event: XMLEvent): Boolean = - event.getEventType match { - // Ignore comments and processing instructions - case XMLStreamConstants.COMMENT | XMLStreamConstants.PROCESSING_INSTRUCTION => false - // unsupported events - case XMLStreamConstants.DTD | - XMLStreamConstants.ENTITY_DECLARATION | - XMLStreamConstants.ENTITY_REFERENCE | - XMLStreamConstants.NOTATION_DECLARATION => false - case _ => true - } + override def accept(event: XMLEvent): Boolean = eventTypeFilter(event.getEventType) } - // It does not have to skip for white space, since `XmlInputFormat` - // always finds the root tag without a heading space. - val eventReader = factory.createXMLEventReader(new StringReader(xml)) + val eventReader = factory.createXMLEventReader(streamReader) factory.createFilteredReader(eventReader, filter) } + def staxXMLRecordReader(xml: String, options: XmlOptions): StaxXMLRecordReader = { + val inputStream = () => new java.io.ByteArrayInputStream(xml.getBytes(options.charset)) + StaxXMLRecordReader(inputStream, options) + } + def gatherRootAttributes(parser: XMLEventReader): Array[Attribute] = { - val rootEvent = - StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT) + val rootEvent = parser.peek() match { + case _: StartElement => + parser.nextEvent() + case _ => + StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT) + } rootEvent.asStartElement.getAttributes.asScala.toArray } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index ecde7c1715bd5..056bf0accdb0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -16,13 +16,11 @@ */ package org.apache.spark.sql.catalyst.xml -import java.io.{CharConversionException, FileNotFoundException, IOException, StringReader} +import java.io.{CharConversionException, FileNotFoundException, IOException} import java.nio.charset.MalformedInputException import java.util.Locale import javax.xml.stream.{XMLEventReader, XMLStreamException} import javax.xml.stream.events._ -import javax.xml.transform.stream.StreamSource -import javax.xml.validation.Schema import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -32,6 +30,7 @@ import scala.xml.SAXException import org.apache.hadoop.hdfs.BlockMissingException import org.apache.hadoop.security.AccessControlException +import org.apache.hadoop.shaded.org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.internal.Logging @@ -94,18 +93,27 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) * 3. Replace any remaining null fields with string, the top type */ def infer(xml: RDD[String]): StructType = { - val schemaData = if (options.samplingRatio < 1.0) { - xml.sample(withReplacement = false, options.samplingRatio, 1) + val inferredTypesRdd = xml.mapPartitions { iter => + iter.flatMap { xml => + val parser = StaxXmlParserUtils.staxXMLRecordReader(xml, options) + val inferredType = infer(parser, shouldSkipToNextReader = false) + parser.close() + inferredType + } + } + + mergeType(inferredTypesRdd) + } + + def mergeType(inferredTypes: RDD[DataType]): StructType = { + val sampledRdd = if (options.samplingRatio < 1.0) { + inferredTypes.sample(withReplacement = false, options.samplingRatio, 1) } else { - xml + inferredTypes } // perform schema inference on each row and merge afterwards - val mergedTypesFromPartitions = schemaData.mapPartitions { iter => - val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) - - iter.flatMap { xml => - infer(xml, xsdSchema) - }.reduceOption(compatibleType(caseSensitive, options.valueTag)).iterator + val mergedTypesFromPartitions = sampledRdd.mapPartitions { iter => + iter.reduceOption(compatibleType(caseSensitive, options.valueTag)).iterator } // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running @@ -119,7 +127,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) compatibleType(caseSensitive, options.valueTag)(rootType, taskResult) } } - xml.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) + sampledRdd.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -130,21 +138,37 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } - def infer(xml: String, xsdSchema: Option[Schema] = None): Option[DataType] = { - var parser: XMLEventReader = null + /** + * Infer the schema of the single XML record string + */ + def infer(xml: String): Option[DataType] = { + val parser = StaxXmlParserUtils.staxXMLRecordReader(xml, options) try { - val xsd = xsdSchema.orElse(Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)) + infer(parser, shouldSkipToNextReader = false) + } finally { + parser.close() + } + } + + /** + * Infer the schema of the next XML record in the XML event stream. + * Note that the method will **NOT** close the XML event stream as there could have more XML + * records to parse. It's the caller's responsibility to close the stream. + */ + def infer(parser: StaxXMLRecordReader, shouldSkipToNextReader: Boolean): Option[DataType] = { + try { + if (shouldSkipToNextReader && !parser.skipToNextRecord()) { + return None + } + + val xsd = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) xsd.foreach { schema => - schema.newValidator().validate(new StreamSource(new StringReader(xml))) + parser.validateXSDSchema(schema) } - parser = StaxXmlParserUtils.filteredReader(xml) val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) val schema = Some(inferObject(parser, rootAttributes)) - parser.close() schema } catch { - case e @ (_: XMLStreamException | _: MalformedInputException | _: SAXException) => - handleXmlErrorsByParseMode(options.parseMode, options.columnNameOfCorruptRecord, e) case e: CharConversionException if options.charset.isEmpty => val msg = """XML parser cannot handle a character in its input. @@ -160,16 +184,22 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) logWarning("Skipped missing file", e) Some(StructType(Nil)) case e: FileNotFoundException if !options.ignoreMissingFiles => throw e - case e @ (_ : AccessControlException | _ : BlockMissingException) => throw e - case e @ (_: IOException | _: RuntimeException) if options.ignoreCorruptFiles => - logWarning("Skipped the rest of the content in the corrupted file", e) - Some(StructType(Nil)) case NonFatal(e) => - handleXmlErrorsByParseMode(options.parseMode, options.columnNameOfCorruptRecord, e) - } finally { - if (parser != null) { - parser.close() - } + ExceptionUtils.getRootCause(e) match { + case _: XMLStreamException | _: MalformedInputException | _: SAXException => + logWarning("Malformed XML record found", e) + // Close the XML event stream from the first malformed XML record + parser.closeAllReaders() + handleXmlErrorsByParseMode(options.parseMode, options.columnNameOfCorruptRecord, e) + case _: AccessControlException | _: BlockMissingException => throw e + case _: IOException | _: RuntimeException if options.ignoreCorruptFiles => + logWarning("Skipped the rest of the content in the corrupted file", e) + parser.closeAllReaders() + Some(StructType(Nil)) + case _ => + logWarning("Failed to infer schema from XML record", e) + handleXmlErrorsByParseMode(options.parseMode, options.columnNameOfCorruptRecord, e) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index 23bca35725397..9328dd471f755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat -import org.apache.spark.sql.types.{StructField, StructType, VariantType} +import org.apache.spark.sql.types.{DataType, StructField, StructType, VariantType} import org.apache.spark.util.Utils /** @@ -176,8 +176,9 @@ object MultiLineXmlDataSource extends XmlDataSource { parser: StaxXmlParser, requiredSchema: StructType): Iterator[InternalRow] = { parser.parseStream( - CodecStreams.createInputStreamWithCloseResource(conf, file.toPath), - requiredSchema) + () => CodecStreams.createInputStreamWithCloseResource(conf, file.toPath), + requiredSchema + ) } override def infer( @@ -186,33 +187,39 @@ object MultiLineXmlDataSource extends XmlDataSource { parsedOptions: XmlOptions): StructType = { val xml = createBaseRdd(sparkSession, inputPaths, parsedOptions) - val tokenRDD: RDD[String] = - xml.flatMap { portableDataStream => + val xmlInferSchema = + new XmlInferSchema(parsedOptions, sparkSession.sessionState.conf.caseSensitiveAnalysis) + + SQLExecution.withSQLConfPropagated(sparkSession) { + val inferredTypeRdd: RDD[DataType] = xml.flatMap { portableDataStream => try { - StaxXmlParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource( - portableDataStream.getConfiguration, - new Path(portableDataStream.getPath())), - parsedOptions) + val inputStream = () => CodecStreams.createInputStreamWithCloseResource( + portableDataStream.getConfiguration, + new Path(portableDataStream.getPath()) + ) + + // XML tokenizer for parsing XML records + StaxXmlParser + .convertStream(inputStream, parsedOptions) { reader => + xmlInferSchema.infer(reader, shouldSkipToNextReader = true) + } + .flatten } catch { case e: FileNotFoundException if parsedOptions.ignoreMissingFiles => logWarning("Skipped missing file", e) - Iterator.empty[String] + Iterator.empty[DataType] case NonFatal(e) => ExceptionUtils.getRootCause(e) match { case e @ (_ : AccessControlException | _ : BlockMissingException) => throw e case _: RuntimeException | _: IOException if parsedOptions.ignoreCorruptFiles => logWarning("Skipped the rest of the content in the corrupted file", e) - Iterator.empty[String] + Iterator.empty[DataType] case o => throw o } } } - SQLExecution.withSQLConfPropagated(sparkSession) { - val schema = - new XmlInferSchema(parsedOptions, sparkSession.sessionState.conf.caseSensitiveAnalysis) - .infer(tokenRDD) - schema + + xmlInferSchema.mergeType(inferredTypeRdd) } } diff --git a/sql/core/src/test/resources/test-data/xml-resources/books-malformed-attributes.xml b/sql/core/src/test/resources/test-data/xml-resources/books-malformed-attributes.xml index e9830d55d3da7..10b94f956eb32 100644 --- a/sql/core/src/test/resources/test-data/xml-resources/books-malformed-attributes.xml +++ b/sql/core/src/test/resources/test-data/xml-resources/books-malformed-attributes.xml @@ -1,5 +1,15 @@ + + O'Brien, Tim + MSXML3: A Comprehensive Guide + Computer + 36.95 + 2000-12-01 + The Microsoft MSXML3 parser is covered in + detail, with attention to XML DOM interfaces, XSLT processing, + SAX and more. + Kress, Peter Paradox Lost @@ -19,16 +29,6 @@ Microsoft's .NET initiative is explored in detail in this deep programmer's reference. - - O'Brien, Tim - MSXML3: A Comprehensive Guide - Computer - 36.95 - 2000-12-01 - The Microsoft MSXML3 parser is covered in - detail, with attention to XML DOM interfaces, XSLT processing, - SAX and more. - Galos, Mike Visual Studio 7: A Comprehensive Guide diff --git a/sql/core/src/test/resources/test-data/xml-resources/cars-malformed.xml b/sql/core/src/test/resources/test-data/xml-resources/cars-malformed.xml index 3859f04fbe199..18b782c0603b4 100644 --- a/sql/core/src/test/resources/test-data/xml-resources/cars-malformed.xml +++ b/sql/core/src/test/resources/test-data/xml-resources/cars-malformed.xml @@ -1,5 +1,10 @@ + + 2015 + Chevy + Volt + 2012 Tesla @@ -12,9 +17,4 @@ E350model> Go get one now they are going fast - - 2015 - Chevy - Volt - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala index 618127fb6e615..be159492197f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala @@ -20,7 +20,6 @@ import java.io.File import java.nio.file.Files import java.util.UUID -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row} import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf @@ -407,9 +406,13 @@ class XmlInferSchemaSuite } test("XML with partitions") { - def makePartition(rdd: RDD[String], parent: File, partName: String, partValue: Any): File = { + def makePartition(rows: Seq[String], parent: File, partName: String, partValue: Any): File = { + val spark = this.spark + import spark.implicits._ val p = new File(parent, s"$partName=${partValue.toString}") - rdd.saveAsTextFile(p.getCanonicalPath) + (Seq("") ++ rows ++ Seq("")).toDF("data") + .coalesce(1) + .write.text(p.getAbsolutePath) p } @@ -418,7 +421,7 @@ class XmlInferSchemaSuite val d1 = new File(root, "d1=1") // root/d1=1/col1=abc makePartition( - sparkContext.parallelize(2 to 5).map(i => s"""1str$i"""), + (2 to 5).map(i => s"""1str$i"""), d1, "col1", "abc" @@ -426,7 +429,7 @@ class XmlInferSchemaSuite // root/d1=1/col1=abd makePartition( - sparkContext.parallelize(6 to 10).map(i => s"""1str$i"""), + (6 to 10).map(i => s"""1str$i"""), d1, "col1", "abd" @@ -649,7 +652,14 @@ trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { dir: File, multiline: Boolean = true, fileName: String = UUID.randomUUID().toString): String = { - val bytes = if (multiline) xmlString.getBytes() else xmlString.filter(_ >= ' ').getBytes + val xmlStringWithRootTag = + s""" + | + |$xmlString + |""".stripMargin + val bytes = + if (multiline) xmlStringWithRootTag.getBytes() + else xmlStringWithRootTag.filter(_ >= ' ').getBytes Files.write(new File(dir, fileName).toPath, bytes) dir.getCanonicalPath + s"/$fileName" } @@ -657,6 +667,7 @@ trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { private val valueTagCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { val caseSensitiveValueTag = """ + | | | | 1 @@ -668,6 +679,7 @@ trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { | 3 | | + | |""".stripMargin XmlSchemaInferenceCaseSensitiveTestCase( "value tag", @@ -692,6 +704,7 @@ trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { private val arrayComplexCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { val caseSensitiveArrayType = """ + | | | | 1 @@ -705,6 +718,7 @@ trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { | | 5 | + | |""".stripMargin XmlSchemaInferenceCaseSensitiveTestCase( "array type - simple", @@ -742,6 +756,7 @@ trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { private val arraySimpleCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { val caseSensitiveArrayType = """ + | | | | 1 @@ -752,6 +767,7 @@ trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { | 4 | | + | |""".stripMargin XmlSchemaInferenceCaseSensitiveTestCase( "array type - complex", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlPartitioningSuite.scala index c08f2d6c329bb..0fc7c84230576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlPartitioningSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.datasources.xml import org.scalatest.BeforeAndAfterAll import org.scalatest.matchers.should.Matchers - -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.SparkSession /** @@ -36,9 +35,31 @@ final class XmlPartitioningSuite extends SparkFunSuite with Matchers with Before try { val fileName = s"test-data/xml-resources/fias_house${if (large) ".large" else ""}.xml$suffix" val xmlFile = getClass.getClassLoader.getResource(fileName).getFile - val results = spark.read.option("rowTag", "House").option("mode", "FAILFAST").xml(xmlFile) - // Test file has 37 records; large file is 20x the records - assert(results.count() === (if (large) 740 else 37)) + if (large) { + // The large file is invalid because it concatenates several XML files together and thus + // there are more one root tags, and each one has a BOM character at the beginning. + + // In FAILFAST mode, we should throw an exception + val error = intercept[SparkException] { + spark.read.option("rowTag", "House").option("mode", "FAILFAST").xml(xmlFile) + } + checkError( + exception = error, + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map("badRecord" -> "_corrupt_record", "failFastMode" -> "FAILFAST") + ) + + // In PERMISSIVE mode, we should read the records in the first root tag and ignore the rest + // of the content + val results = spark.read.option("rowTag", "House").option("mode", "PERMISSIVE").xml(xmlFile) + // There should be 38 records: 37 valid records in the first root tag and the rest of the + // content in the _corrupt_record column + assert(results.count() === 38) + } else { + val results = spark.read.option("rowTag", "House").option("mode", "FAILFAST").xml(xmlFile) + // Test file has 37 records; large file is 20x the records + assert(results.count() === 37) + } } finally { spark.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index be08ce5bd7db9..12183d0e113d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -59,7 +59,7 @@ class XmlSuite with TestXmlData { import testImplicits._ - private val resDir = "test-data/xml-resources/" + protected val resDir = "test-data/xml-resources/" private var tempDir: Path = _ @@ -238,15 +238,6 @@ class XmlSuite assert(spark.sql("SELECT year FROM carsTable2").collect().length === 3) } - test("DSL test for parsing a malformed XML file") { - val results = spark.read - .option("rowTag", "ROW") - .option("mode", DropMalformedMode.name) - .xml(getTestResourcePath(resDir + "cars-malformed.xml")) - - assert(results.count() === 1) - } - test("DSL test for dropping malformed rows") { val cars = spark.read .option("rowTag", "ROW") @@ -333,21 +324,11 @@ class XmlSuite .option("columnNameOfCorruptRecord", "_malformed_records") .xml(getTestResourcePath(resDir + "cars-malformed.xml")) val cars = carsDf.collect() - assert(cars.length === 3) - val malformedRowOne = carsDf.cache().select("_malformed_records").first().get(0).toString - val malformedRowTwo = carsDf.cache().select("_malformed_records").take(2).last.get(0).toString - val expectedMalformedRowOne = "2012Tesla>S" + - "No comment" - val expectedMalformedRowTwo = "FordE350model>" + - "Go get one now they are going fast" - - assert(malformedRowOne.replaceAll("\\s", "") === expectedMalformedRowOne.replaceAll("\\s", "")) - assert(malformedRowTwo.replaceAll("\\s", "") === expectedMalformedRowTwo.replaceAll("\\s", "")) - assert(cars(2)(0) === null) - assert(cars(0).toSeq.takeRight(3) === Seq(null, null, null)) - assert(cars(1).toSeq.takeRight(3) === Seq(null, null, null)) - assert(cars(2).toSeq.takeRight(3) === Seq("Chevy", "Volt", 2015)) + // There are two records and the second one is malformed. + assert(cars.length === 2) + assert(carsDf.cache().filter("_malformed_records is not null").count() === 1) + assert(cars(0).toSeq.takeRight(3) === Seq("Chevy", "Volt", 2015)) } test("DSL test with empty file and known schema") { @@ -1073,9 +1054,8 @@ class XmlSuite .xml(getTestResourcePath(resDir + "books-malformed-attributes.xml")) .collect() - assert(results.length === 2) + assert(results.length === 1) assert(results(0)(0) === "bk111") - assert(results(1)(0) === "bk112") } test("read utf-8 encoded file with empty tag") { @@ -1256,11 +1236,12 @@ class XmlSuite .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", "_malformed_records") .xml(getTestResourcePath(resDir + "basket_invalid.xml")).cache() + // The first record is valid and the second is invalid, the whole document is put in the + // _malformed_records column for the second record. assert(basketDF.filter($"_malformed_records".isNotNull).count() == 1) assert(basketDF.filter($"_malformed_records".isNull).count() == 1) val rec = basketDF.select("_malformed_records").collect()(1).getString(0) - assert(rec.startsWith("") && rec.indexOf("123") != -1 && - rec.endsWith("")) + assert(rec.startsWith("") && rec.endsWith("")) } test("test XSD validation with addFile() with validation error") { @@ -1272,11 +1253,12 @@ class XmlSuite .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", "_malformed_records") .xml(getTestResourcePath(resDir + "basket_invalid.xml")).cache() + // The first record is valid and the second is invalid, the whole document is put in the + // _malformed_records column for the second record. assert(basketDF.filter($"_malformed_records".isNotNull).count() == 1) assert(basketDF.filter($"_malformed_records".isNull).count() == 1) val rec = basketDF.select("_malformed_records").collect()(1).getString(0) - assert(rec.startsWith("") && rec.indexOf("123") != -1 && - rec.endsWith("")) + assert(rec.startsWith("") && rec.endsWith("")) } test("test xmlDataset") { @@ -2445,10 +2427,12 @@ class XmlSuite test("Timestamp type inference for a mix of TIMESTAMP_NTZ and TIMESTAMP_LTZ") { withTempPath { path => Seq( + "", "2020-12-12T12:12:12.000", "2020-12-12T17:12:12.000Z", "2020-12-12T17:12:12.000+05:00", - "2020-12-12T12:12:12.000" + "2020-12-12T12:12:12.000", + "" ).toDF("data") .coalesce(1) .write.text(path.getAbsolutePath) @@ -2480,10 +2464,12 @@ class XmlSuite test("Malformed records when reading TIMESTAMP_LTZ as TIMESTAMP_NTZ") { withTempPath { path => Seq( + "", "2020-12-12T12:12:12.000", "2020-12-12T12:12:12.000Z", "2020-12-12T12:12:12.000+05:00", - "2020-12-12T12:12:12.000" + "2020-12-12T12:12:12.000", + "" ).toDF("data") .coalesce(1) .write.text(path.getAbsolutePath) @@ -2690,15 +2676,20 @@ class XmlSuite .option("rowTag", "ROW") .load(getTestResourcePath(resDir + "cdata-ending-eof.xml")) - val expectedResults2 = Seq.range(1, 18).map(Row(_)) - checkAnswer(results2, expectedResults2) + assert( + results2.schema == new StructType().add("_corrupt_record", StringType).add("a", LongType)) + + // The last row is null because the last CDATA at eof is invalid + val expectedResults2 = Seq.range(1, 18).map(Row(_)) ++ Seq(Row(null)) + checkAnswer(results2.selectExpr("a"), expectedResults2) val results3 = spark.read.format("xml") .option("rowTag", "ROW") .load(getTestResourcePath(resDir + "cdata-no-close.xml")) - val expectedResults3 = Seq.range(1, 18).map(Row(_)) - checkAnswer(results3, expectedResults3) + // Similar to the previous test, the last row is null because the last CDATA section is invalid + val expectedResults3 = Seq.range(1, 18).map(Row(_)) ++ Seq(Row(null)) + checkAnswer(results3.select("a"), expectedResults3) val results4 = spark.read.format("xml") .option("rowTag", "ROW") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala index 63b816ad6b53a..003462e7b101f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala @@ -507,7 +507,7 @@ class XmlVariantSuite extends QueryTest with SharedSparkSession with TestXmlData ) checkAnswer( df.select(variant_get(col("var"), "$.year", "int")), - Seq(Row(2015), Row(null), Row(null)) + Seq(Row(2015), Row(null)) ) // DROPMALFORMED mode