@@ -29,13 +29,12 @@ import java.io.File
2929
3030import org .apache .spark .sql .catalyst .InternalRow
3131import org .apache .spark .sql .{DataFrame , Row , SaveMode , SparkSession , functions => fn }
32- import org .scalatest .BeforeAndAfterAll
3332import org .scalatest .matchers .should .Matchers
3433import org .scalatest .wordspec .AnyWordSpec
3534
3635import scala .util .Random
3736
38- class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll {
37+ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with SparkSessionBeforeAfterAll {
3938
4039 def withTemporalFolder (testCode : File => Any ): Unit =
4140 testCode(
@@ -44,42 +43,28 @@ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll
4443 )
4544 )
4645
47- val cores = 4
4846 val madridPath = " core/src/test/resources/com/acervera/osm4scala/Madrid.bbbike.osm.pbf"
4947 val monacoPath = " core/src/test/resources/com/acervera/osm4scala/monaco-latest.osm.pbf"
5048
51- val sparkSession = SparkSession
52- .builder()
53- .master(s " local[ $cores] " )
54- .getOrCreate()
55-
56- import sparkSession .implicits ._
57-
58- val sqlContext = sparkSession.sqlContext
59-
60- def loadOsmPbf (path : String , tableName : Option [String ] = None ): DataFrame = {
61- val df = sqlContext.read
49+ def loadOsmPbf (spark : SparkSession , path : String , tableName : Option [String ] = None ): DataFrame = {
50+ val df = spark.sqlContext.read
6251 .format(" osm.pbf" )
6352 .load(path)
6453 .repartition(cores * 2 )
65- tableName.foreach(df.createTempView )
54+ tableName.foreach(df.createOrReplaceTempView )
6655 df
6756 }
6857
69- override protected def afterAll (): Unit = {
70- sparkSession.close()
71- }
72-
7358 " OsmPbfFormat" should {
7459
7560 " parsing all only one time" in {
76- val entitiesCount = loadOsmPbf(madridPath).count()
61+ val entitiesCount = loadOsmPbf(spark, madridPath).count()
7762 entitiesCount shouldBe 2677227
7863 }
7964
8065 " parser correctly" when {
8166 " is parsing nodes" in {
82- val node171946 = loadOsmPbf(madridPath).filter(" id == 171946" ).collect()(0 )
67+ val node171946 = loadOsmPbf(spark, madridPath).filter(" id == 171946" ).collect()(0 )
8368 node171946.getAs[Long ](" id" ) shouldBe 171946L
8469 node171946.getAs[Byte ](" type" ) shouldBe 0
8570 node171946.getAs[Double ](" latitude" ) shouldBe (40.42125 +- 0.001 )
@@ -92,7 +77,7 @@ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll
9277 }
9378
9479 " is parsing ways" in {
95- val way3996192 = loadOsmPbf(madridPath).filter(" id == 3996192" ).collect()(0 )
80+ val way3996192 = loadOsmPbf(spark, madridPath).filter(" id == 3996192" ).collect()(0 )
9681 way3996192.getAs[Long ](" id" ) shouldBe 3996192L
9782 way3996192.getAs[Byte ](" type" ) shouldBe 1
9883 way3996192.getAs[AnyRef ](" latitude" ) should be(null )
@@ -109,7 +94,7 @@ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll
10994 }
11095
11196 " is parsing relations" in {
112- val relation55799 = loadOsmPbf(madridPath).filter(" id == 55799" ).collect()(0 )
97+ val relation55799 = loadOsmPbf(spark, madridPath).filter(" id == 55799" ).collect()(0 )
11398 relation55799.getAs[Long ](" id" ) shouldBe 55799
11499 relation55799.getAs[Byte ](" type" ) shouldBe 2
115100 relation55799.getAs[AnyRef ](" latitude" ) should be(null )
@@ -127,7 +112,7 @@ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll
127112 }
128113
129114 " export to other formats" in withTemporalFolder { tmpFolder =>
130- val threeExamples = loadOsmPbf(madridPath)
115+ val threeExamples = loadOsmPbf(spark, madridPath)
131116 .filter(" id == 55799 || id == 3996192 || id == 171946" )
132117 .orderBy(" id" )
133118
@@ -136,7 +121,7 @@ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll
136121 .format(" orc" )
137122 .save(s " ${tmpFolder}/madrid/three " )
138123
139- val readFromOrc = sqlContext.read
124+ val readFromOrc = spark. sqlContext.read
140125 .format(" orc" )
141126 .load(s " ${tmpFolder}/madrid/three " )
142127 .orderBy(" id" )
@@ -151,7 +136,10 @@ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll
151136 " execute complex queries" when {
152137 " using dsl" should {
153138 " count arrays and filter" in {
154- loadOsmPbf(madridPath)
139+ val sparkStable = spark
140+ import sparkStable .implicits ._
141+
142+ loadOsmPbf(spark, madridPath)
155143 .withColumn(" no_of_nodes" , fn.size($" nodes" ))
156144 .withColumn(" no_of_relations" , fn.size($" relations" ))
157145 .withColumn(" no_of_tags" , fn.size($" tags" ))
@@ -161,33 +149,37 @@ class OsmPbfFormatSpec extends AnyWordSpec with Matchers with BeforeAndAfterAll
161149 }
162150 }
163151 " using SQL" should {
164- loadOsmPbf(madridPath, Some (" madrid_shows" ))
165- loadOsmPbf(monacoPath, Some (" monaco_shows" ))
152+
166153 " count all zebras" in {
167- sqlContext
154+ loadOsmPbf(spark, madridPath, Some (" madrid_shows" ))
155+ spark.sqlContext
168156 .sql(" select count(*) from madrid_shows where array_contains(map_values(tags), 'zebra')" )
169157 .show()
170158 }
171159 " extract all keys used in tags" in {
172- sqlContext
160+ loadOsmPbf(spark, madridPath, Some (" madrid_shows" ))
161+ spark.sqlContext
173162 .sql(" select distinct explode(map_keys(tags)) as tag from madrid_shows where size(tags) > 0 order by tag" )
174163 .show()
175164 }
176165
177166 " extract unique list of types" in {
178- sqlContext
167+ loadOsmPbf(spark, monacoPath, Some (" monaco_shows" ))
168+ spark.sqlContext
179169 .sql(" select distinct(type) as unique_types from monaco_shows order by unique_types" )
180170 .show()
181171 }
182172
183173 " extract ways with more nodes" in {
184- sqlContext
174+ loadOsmPbf(spark, monacoPath, Some (" monaco_shows" ))
175+ spark.sqlContext
185176 .sql(" select id, size(nodes) as size_nodes from monaco_shows where type == 1 order by size_nodes desc" )
186177 .show()
187178 }
188179
189180 " extract relations" in {
190- sqlContext
181+ loadOsmPbf(spark, monacoPath, Some (" monaco_shows" ))
182+ spark.sqlContext
191183 .sql(" select id, relations from monaco_shows where type == 2" )
192184 .show()
193185 }
0 commit comments