diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java index d15764d72..1174ca543 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java @@ -1,5 +1,8 @@ package io.substrait.isthmus; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -38,6 +41,15 @@ void is_not_false() throws Exception { void is_distinct_from(String left, String right) throws Exception { String query = String.format("SELECT (%s IS DISTINCT FROM %s) FROM numbers", left, right); assertSqlSubstraitRelRoundTrip(query, CREATES); + + // Assert logical rewrite exists + io.substrait.plan.Plan plan = + toSubstraitPlan( + query, SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + String planString = plan.toString(); + assertTrue( + planString.contains("and") && planString.contains("or") && planString.contains("equal"), + "Expected Substrait plan to contain logical rewrite for IS DISTINCT FROM"); } @ParameterizedTest @@ -45,6 +57,15 @@ void is_distinct_from(String left, String right) throws Exception { void is_distinct_from_null_vs_col(String column) throws Exception { String query = String.format("SELECT (NULL IS DISTINCT FROM %s) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); + + // Assert logical rewrite exists + io.substrait.plan.Plan plan = + toSubstraitPlan( + query, SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + String planString = plan.toString(); + assertTrue( + planString.contains("is_not_null"), + "Expected Substrait plan to contain logical rewrite for NULL IS DISTINCT FROM to IS NOT NULL"); } @ParameterizedTest @@ -71,7 +92,7 @@ void least(String args) throws Exception { }) void greatest(String args) throws Exception { String join_args = String.join(", ", args); - String query = String.format("SELECT LEAST(%s) FROM numbers", join_args); + String query = String.format("SELECT GREATEST(%s) FROM numbers", join_args); assertSqlSubstraitRelRoundTrip(query, CREATES); } }