Skip to content

Commit f667aef

Browse files
Pull request #4: Spark 3.5 support
Merge in PDS-GTB/elasticsearch-hadoop from spark-35 to db-feature/spark-35 * commit 'fc4f33b6c3d609f18b820b25cc1435a2c4c5ead8': Spark 3.5 support
2 parents b68e3f4 + fc4f33b commit f667aef

File tree

10 files changed

+267
-125
lines changed

10 files changed

+267
-125
lines changed

buildSrc/src/main/groovy/org/elasticsearch/hadoop/gradle/BuildPlugin.groovy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,6 @@ class BuildPlugin implements Plugin<Project> {
594594

595595
// Main variant needs the least configuration on its own, since it is the default publication created above.
596596
sparkVariants.defaultVariant { SparkVariant variant ->
597-
project.publishing.publications.main.setAlias(true)
598597
updateVariantArtifactId(project, project.publishing.publications.main, variant)
599598
}
600599

@@ -651,6 +650,7 @@ class BuildPlugin implements Plugin<Project> {
651650
}
652651
configurePom(project, variantPublication)
653652
updateVariantArtifactId(project, variantPublication, variant)
653+
variantPublication.setAlias(true)
654654
}
655655
}
656656
if (signingKey.isPresent()) {

buildSrc/src/main/java/org/elasticsearch/hadoop/gradle/scala/SparkVariantPlugin.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272

7373
public class SparkVariantPlugin implements Plugin<Project> {
7474

75+
public static final String ITEST_SOURCE_SET_NAME = "itest";
76+
7577
public static class SparkVariant {
7678

7779
private final CharSequence name;
@@ -311,6 +313,23 @@ private static void configureDefaultVariant(Project project, SparkVariant sparkV
311313
runtimeElements.getOutgoing().capability(capability);
312314

313315
configureScalaJarClassifiers(project, sparkVariant);
316+
// Extend main and test source set for the main variant - this enables the possibility of having diverging code between variants
317+
SourceSetContainer sourceSets = javaPluginExtension.getSourceSets();
318+
ScalaSourceDirectorySet scalaSourceSet = getScalaSourceSet(sourceSets.getByName(MAIN_SOURCE_SET_NAME));
319+
scalaSourceSet.setSrcDirs(Arrays.asList(
320+
"src/" + MAIN_SOURCE_SET_NAME + "/scala",
321+
"src/" + MAIN_SOURCE_SET_NAME + "/" + sparkVariant.getName()
322+
));
323+
ScalaSourceDirectorySet scalaTestSourceSet = getScalaSourceSet(sourceSets.getByName(TEST_SOURCE_SET_NAME));
324+
scalaTestSourceSet.setSrcDirs(Arrays.asList(
325+
"src/" + TEST_SOURCE_SET_NAME + "/scala",
326+
"src/" + TEST_SOURCE_SET_NAME + "/" + sparkVariant.getName()
327+
));
328+
ScalaSourceDirectorySet scalaITestSourceSet = getScalaSourceSet(sourceSets.getByName(ITEST_SOURCE_SET_NAME));
329+
scalaITestSourceSet.setSrcDirs(Arrays.asList(
330+
"src/" + ITEST_SOURCE_SET_NAME + "/scala",
331+
"src/" + ITEST_SOURCE_SET_NAME + "/" + sparkVariant.getName()
332+
));
314333
}
315334

316335
private static void configureVariant(Project project, SparkVariant sparkVariant, JavaPluginExtension javaPluginExtension) {

gradle.properties

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ spark20Version = 2.3.0
3030
spark22Version = 2.2.3
3131
spark24Version = 2.4.4
3232
spark30Version = 3.4.3
33+
spark35Version = 3.5.6
3334

3435
# same as Spark's
3536
scala210Version = 2.10.7

spark/core/build.gradle

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ sparkVariants {
1414
// We should maybe move these to a separate config file that can be read from both this file and the pipeline script in the future if it creates issues
1515
setCoreDefaultVariant "spark30scala213", spark30Version, scala213Version
1616
addCoreFeatureVariant "spark30scala212", spark30Version, scala212Version
17+
addCoreFeatureVariant "spark35scala212", spark35Version, scala212Version
18+
addCoreFeatureVariant "spark35scala213", spark35Version, scala213Version
1719

1820
all { SparkVariantPlugin.SparkVariant variant ->
1921

@@ -44,6 +46,11 @@ sparkVariants {
4446
add(variant.configuration('api'), "org.apache.spark:spark-core_${variant.scalaMajorVersion}:${variant.sparkVersion}") {
4547
exclude group: 'org.apache.hadoop'
4648
}
49+
if (variant.sparkVersion >= "3.5.0") {
50+
add(variant.configuration('implementation'), "org.apache.spark:spark-common-utils_${variant.scalaMajorVersion}:$variant.sparkVersion") {
51+
exclude group: 'org.apache.hadoop'
52+
}
53+
}
4754

4855
add(variant.configuration('implementation'), project(":elasticsearch-hadoop-mr"))
4956
add(variant.configuration('implementation'), "commons-logging:commons-logging:1.1.1")
@@ -126,3 +133,23 @@ tasks.withType(ScalaCompile) { ScalaCompile task ->
126133
task.targetCompatibility = project.ext.minimumRuntimeVersion
127134
task.options.forkOptions.executable = new File(project.ext.runtimeJavaHome, 'bin/java').absolutePath
128135
}
136+
137+
tasks.register('copyPoms', Copy) {
138+
from(tasks.named('generatePomFileForMainPublication')) {
139+
rename 'pom-default.xml', "elasticsearch-spark-30_2.13-${project.getVersion()}.pom"
140+
}
141+
from(tasks.named('generatePomFileForSpark30scala212Publication')) {
142+
rename 'pom-default.xml', "elasticsearch-spark-30_2.12-${project.getVersion()}.pom"
143+
}
144+
from(tasks.named('generatePomFileForSpark35scala212Publication')) {
145+
rename 'pom-default.xml', "elasticsearch-spark-35_2.12-${project.getVersion()}.pom"
146+
}
147+
from(tasks.named('generatePomFileForSpark35scala213Publication')) {
148+
rename 'pom-default.xml', "elasticsearch-spark-35_2.13-${project.getVersion()}.pom"
149+
}
150+
into(new File(project.buildDir, 'distributions'))
151+
}
152+
153+
tasks.named('distribution').configure {
154+
dependsOn 'copyPoms'
155+
}

spark/sql-30/build.gradle

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ apply plugin: 'spark.variants'
99

1010
sparkVariants {
1111
capabilityGroup 'org.elasticsearch.spark.sql.variant'
12-
setDefaultVariant "spark30scala213", spark30Version, scala213Version
13-
addFeatureVariant "spark30scala212", spark30Version, scala212Version
12+
setCoreDefaultVariant "spark30scala213", spark30Version, scala213Version
13+
addCoreFeatureVariant "spark30scala212", spark30Version, scala212Version
14+
addCoreFeatureVariant "spark35scala213", spark35Version, scala213Version
15+
addCoreFeatureVariant "spark35scala212", spark35Version, scala212Version
1416

1517
all { SparkVariantPlugin.SparkVariant variant ->
1618
String scalaCompileTaskName = project.sourceSets
@@ -58,6 +60,14 @@ sparkVariants {
5860
exclude group: 'javax.servlet'
5961
exclude group: 'org.apache.hadoop'
6062
}
63+
if (variant.sparkVersion >= "3.5.0") {
64+
add(variant.configuration('implementation'), "org.apache.spark:spark-common-utils_${variant.scalaMajorVersion}:$variant.sparkVersion") {
65+
exclude group: 'org.apache.hadoop'
66+
}
67+
add(variant.configuration('implementation'), "org.apache.spark:spark-sql-api_${variant.scalaMajorVersion}:$variant.sparkVersion") {
68+
exclude group: 'org.apache.hadoop'
69+
}
70+
}
6171

6272
add(variant.configuration('implementation'), "org.apache.spark:spark-sql_${variant.scalaMajorVersion}:$variant.sparkVersion") {
6373
exclude group: 'org.apache.hadoop'
@@ -198,6 +208,12 @@ tasks.register('copyPoms', Copy) {
198208
from(tasks.named('generatePomFileForSpark30scala212Publication')) {
199209
rename 'pom-default.xml', "elasticsearch-spark-30_2.12-${project.getVersion()}.pom"
200210
}
211+
from(tasks.named('generatePomFileForSpark35scala212Publication')) {
212+
rename 'pom-default.xml', "elasticsearch-spark-35_2.12-${project.getVersion()}.pom"
213+
}
214+
from(tasks.named('generatePomFileForSpark35scala213Publication')) {
215+
rename 'pom-default.xml', "elasticsearch-spark-35_2.13-${project.getVersion()}.pom"
216+
}
201217
into(new File(project.buildDir, 'distributions'))
202218
}
203219

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package org.elasticsearch.spark.sql.streaming
2+
3+
import org.apache.spark.sql.streaming.StreamingQueryListener
4+
import org.elasticsearch.hadoop.util.unit.TimeValue
5+
import org.junit.Assert
6+
7+
import java.util.UUID
8+
import java.util.concurrent.{CountDownLatch, TimeUnit}
9+
10+
// Listener to 1) ensure no more than a single stream is running at a time, 2) know when we're done processing inputs
11+
// and 3) to capture any Exceptions encountered during the execution of the stream.
12+
class StreamingQueryLifecycleListener extends StreamingQueryListener {
13+
14+
private var uuid: Option[UUID] = None
15+
16+
private var inputsRequired = 0L
17+
private var inputsSeen = 0L
18+
19+
private var expectingToThrow: Option[Class[_]] = None
20+
private var foundExpectedException: Boolean = false
21+
private var encounteredException: Option[String] = None
22+
23+
private var latch = new CountDownLatch(1) // expects just a single batch
24+
25+
def incrementExpected(): Unit = inputsRequired = inputsRequired + 1
26+
27+
def setExpectedException(clazz: Class[_]): Unit = {
28+
expectingToThrow match {
29+
case Some(cls) => throw new IllegalArgumentException(s"Already expecting exception of type [$cls]!")
30+
case None => expectingToThrow = Some(clazz)
31+
}
32+
}
33+
34+
// Make sure we only ever watch one query at a time.
35+
private def captureQueryID(eventId: UUID): Unit = {
36+
uuid match {
37+
case Some(id) if eventId != id => throw new IllegalStateException("Multiple queries are not supported")
38+
case None => uuid = Some(eventId)
39+
case _ => // No problem for now
40+
}
41+
}
42+
43+
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
44+
captureQueryID(event.id)
45+
}
46+
47+
override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
48+
captureQueryID(event.progress.id)
49+
50+
// keep track of the number of input rows seen. When we reach the number of expected inputs,
51+
// wait for two 0 values to pass before signaling to close
52+
53+
val rows = event.progress.numInputRows
54+
inputsSeen = inputsSeen + rows
55+
56+
if (inputsSeen == inputsRequired) {
57+
if (rows == 0) {
58+
// Don't close after meeting the first input level. Wait to make sure we get
59+
// one last pass with no new rows processed before signalling.
60+
latch.countDown()
61+
}
62+
} else if (inputsSeen > inputsRequired) {
63+
throw new IllegalStateException("Too many inputs encountered. Expected [" + inputsRequired +
64+
"] but found [" + inputsSeen + "]")
65+
}
66+
}
67+
68+
protected def onQueryIdle(eventId: UUID): Unit = {
69+
captureQueryID(eventId)
70+
if (inputsSeen == inputsRequired) {
71+
latch.countDown()
72+
}
73+
}
74+
75+
override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
76+
try {
77+
captureQueryID(event.id)
78+
79+
encounteredException = event.exception match {
80+
case Some(value) =>
81+
// This is a whole trace, get everything after the enclosing SparkException (i.e. the original exception + trace)
82+
val messageParts = value.split("\\): ")
83+
if (messageParts.size > 1) {
84+
val nonSparkMessage = messageParts(1)
85+
// Ditch the full trace after the first newline
86+
val removedNewLine = nonSparkMessage.substring(0, nonSparkMessage.indexOf("\n"))
87+
// Split the class name from the exception message and take the class name
88+
Some(removedNewLine.substring(0, removedNewLine.indexOf(":")))
89+
} else {
90+
// Return the original framework error
91+
Some(value.substring(0, value.indexOf(":")))
92+
}
93+
case None => None
94+
}
95+
96+
val expectedExceptionName = expectingToThrow.map(_.getCanonicalName).getOrElse("None")
97+
98+
foundExpectedException = encounteredException.exists(_.equals(expectedExceptionName))
99+
} finally {
100+
// signal no matter what to avoid deadlock
101+
latch.countDown()
102+
}
103+
}
104+
105+
def waitOnComplete(timeValue: TimeValue): Boolean = latch.await(timeValue.millis, TimeUnit.MILLISECONDS)
106+
107+
def expectAnotherBatch(): Unit = {
108+
latch = new CountDownLatch(1)
109+
}
110+
111+
def assertExpectedExceptions(message: Option[String]): Unit = {
112+
expectingToThrow match {
113+
case Some(exceptionClass) =>
114+
if (!foundExpectedException) {
115+
encounteredException match {
116+
case Some(s) => Assert.fail(s"Expected ${exceptionClass.getCanonicalName} but got $s")
117+
case None => Assert.fail(message.getOrElse(s"Expected ${exceptionClass.getCanonicalName} but no Exceptions were thrown"))
118+
}
119+
}
120+
case None =>
121+
encounteredException match {
122+
case Some(exception) => Assert.fail(s"Expected no exceptions but got $exception")
123+
case None => ()
124+
}
125+
}
126+
}
127+
}

0 commit comments

Comments
 (0)