Skip to content

Commit

Permalink
[SPARK-30082][SQL] Depend on Scala type coercion when building replac…
Browse files Browse the repository at this point in the history
…e query (#628)

apache#26738
apache#26749

### What changes were proposed in this pull request?
Depend on type coercion when building the replace query. This would solve an edge case where when trying to replace `NaN`s, `0`s would get replace too.

### Why are the changes needed?
This Scala code snippet:
```
import scala.math;

println(Double.NaN.toLong)
```
returns `0` which is problematic as if you run the following Spark code, `0`s get replaced as well:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+
>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    2|
|  0.0|    3|
|  2.0|    2|
+-----+-----+ 
```

### Does this PR introduce any user-facing change?
Yes, after the PR, running the same above code snippet returns the correct expected results:
```
>>> df = spark.createDataFrame([(1.0, 0), (0.0, 3), (float('nan'), 0)], ("index", "value"))
>>> df.show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  NaN|    0|
+-----+-----+

>>> df.replace(float('nan'), 2).show()
+-----+-----+
|index|value|
+-----+-----+
|  1.0|    0|
|  0.0|    3|
|  2.0|    0|
+-----+-----+
```
And additionally, query results are changed as a result of the change in depending on scala's type coercion rules.

### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why it was difficult to add.
-->
Added unit tests to verify replacing `NaN` only affects columns of type `Float` and `Double`.
  • Loading branch information
johnhany97 authored and bulldozer-bot[bot] committed Jan 15, 2020
1 parent f722525 commit f3200ec
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
Seq(Literal(source), buildExpr(target))
}.toSeq
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
).toDF("name", "age", "height")
}

def createNaNDF(): DataFrame = {
Seq[(java.lang.Integer, java.lang.Long, java.lang.Short,
java.lang.Byte, java.lang.Float, java.lang.Double)](
(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0),
(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN)
).toDF("int", "long", "short", "byte", "float", "double")
}

test("drop") {
val input = createDF()
val rows = input.collect()
Expand Down Expand Up @@ -305,4 +313,40 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
)).na.drop("name" :: Nil).select("name"),
Row("Alice") :: Row("David") :: Nil)
}

test("replace nan with float") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Float.NaN -> 10.0f
)),
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
}

test("replace nan with double") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Double.NaN -> 10.0
)),
Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) ::
Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil)
}

test("replace float with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.0f -> Float.NaN
)),
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}

test("replace double with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.0 -> Double.NaN
)),
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}
}

0 comments on commit f3200ec

Please sign in to comment.