diff --git a/pom.xml b/pom.xml
index 56310a9d7604b..c2c5327d7f6ad 100644
--- a/pom.xml
+++ b/pom.xml
@@ -793,6 +793,13 @@
${project.version}
+
+ com.facebook.presto
+ presto-tests
+ ${project.version}
+ test-jar
+
+
com.facebook.presto
presto-benchmark
@@ -1024,6 +1031,13 @@
${project.version}
+
+ com.facebook.presto
+ presto-native-execution
+ ${project.version}
+ test-jar
+
+
com.facebook.hive
hive-dwrf
diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java
index 0c07b99aaab4e..cc2dee1bd61d4 100644
--- a/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java
+++ b/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java
@@ -23,6 +23,18 @@
public class HandleJsonModule
implements Module
{
+ private final HandleResolver handleResolver;
+
+ public HandleJsonModule()
+ {
+ this(null);
+ }
+
+ public HandleJsonModule(HandleResolver handleResolver)
+ {
+ this.handleResolver = handleResolver;
+ }
+
@Override
public void configure(Binder binder)
{
@@ -38,6 +50,11 @@ public void configure(Binder binder)
jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class);
jsonBinder(binder).addModuleBinding().to(MetadataUpdateJacksonModule.class);
- binder.bind(HandleResolver.class).in(Scopes.SINGLETON);
+ if (handleResolver == null) {
+ binder.bind(HandleResolver.class).in(Scopes.SINGLETON);
+ }
+ else {
+ binder.bind(HandleResolver.class).toInstance(handleResolver);
+ }
}
}
diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java
index 3071bd0e5d0f3..4e5075b812ba4 100644
--- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java
+++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java
@@ -148,6 +148,7 @@
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.PageIndexerFactory;
import com.facebook.presto.spi.PageSorter;
+import com.facebook.presto.spi.RowExpressionSerde;
import com.facebook.presto.spi.analyzer.ViewDefinition;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.plan.SimplePlanFragment;
@@ -155,6 +156,7 @@
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.DomainTranslator;
import com.facebook.presto.spi.relation.PredicateCompiler;
+import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.session.WorkerSessionPropertyProvider;
import com.facebook.presto.spiller.FileSingleStreamSpillerFactory;
@@ -195,6 +197,7 @@
import com.facebook.presto.sql.analyzer.QueryExplainer;
import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
+import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler;
@@ -362,6 +365,7 @@ else if (serverConfig.isCoordinator()) {
// expression manager
binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON);
+ binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON);
// schema properties
binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON);
@@ -555,6 +559,7 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon
jsonCodecBinder(binder).bindJsonCodec(SqlInvokedFunction.class);
jsonCodecBinder(binder).bindJsonCodec(TaskSource.class);
jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class);
+ jsonCodecBinder(binder).bindJsonCodec(RowExpression.class);
smileCodecBinder(binder).bindSmileCodec(TaskStatus.class);
smileCodecBinder(binder).bindSmileCodec(TaskInfo.class);
thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class);
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java
index 0d200e5555a92..2a8bd6e391ca8 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java
@@ -19,6 +19,7 @@
import com.facebook.presto.nodeManager.PluginNodeManager;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.NodeManager;
+import com.facebook.presto.spi.RowExpressionSerde;
import com.facebook.presto.spi.relation.ExpressionOptimizer;
import com.facebook.presto.spi.relation.ExpressionOptimizerProvider;
import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext;
@@ -53,6 +54,7 @@ public class ExpressionOptimizerManager
private final NodeManager nodeManager;
private final FunctionAndTypeManager functionAndTypeManager;
+ private final RowExpressionSerde rowExpressionSerde;
private final FunctionResolution functionResolution;
private final File configurationDirectory;
@@ -60,16 +62,17 @@ public class ExpressionOptimizerManager
private final Map expressionOptimizers = new ConcurrentHashMap<>();
@Inject
- public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager)
+ public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde)
{
- this(nodeManager, functionAndTypeManager, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY);
+ this(nodeManager, functionAndTypeManager, rowExpressionSerde, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY);
}
- public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, File configurationDirectory)
+ public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde, File configurationDirectory)
{
requireNonNull(nodeManager, "nodeManager is null");
this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
+ this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null");
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
this.configurationDirectory = requireNonNull(configurationDirectory, "configurationDirectory is null");
expressionOptimizers.put(DEFAULT_EXPRESSION_OPTIMIZER_NAME, new RowExpressionOptimizer(functionAndTypeManager));
@@ -89,7 +92,7 @@ public void loadExpressionOptimizerFactories()
}
}
- private void loadExpressionOptimizerFactory(File configurationFile)
+ public void loadExpressionOptimizerFactory(File configurationFile)
throws IOException
{
String name = getNameWithoutExtension(configurationFile.getName());
@@ -99,13 +102,19 @@ private void loadExpressionOptimizerFactory(File configurationFile)
Map properties = new HashMap<>(loadProperties(configurationFile));
String factoryName = properties.remove(EXPRESSION_MANAGER_FACTORY_NAME);
checkArgument(!isNullOrEmpty(factoryName), "%s does not contain %s", configurationFile, EXPRESSION_MANAGER_FACTORY_NAME);
+ loadExpressionOptimizerFactory(factoryName, name, properties);
+ }
+
+ public void loadExpressionOptimizerFactory(String factoryName, String expressionOptimizerName, Map properties)
+ {
+ requireNonNull(factoryName, "factoryName is null");
checkArgument(expressionOptimizerFactories.containsKey(factoryName),
"ExpressionOptimizerFactory %s is not registered, registered factories: ", factoryName, expressionOptimizerFactories.keySet());
ExpressionOptimizer optimizer = expressionOptimizerFactories.get(factoryName).createOptimizer(
properties,
- new ExpressionOptimizerContext(nodeManager, functionAndTypeManager, functionResolution));
- expressionOptimizers.put(name, optimizer);
+ new ExpressionOptimizerContext(nodeManager, rowExpressionSerde, functionAndTypeManager, functionResolution));
+ expressionOptimizers.put(expressionOptimizerName, optimizer);
}
public void addExpressionOptimizerFactory(ExpressionOptimizerFactory expressionOptimizerFactory)
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java
new file mode 100644
index 0000000000000..20cfcf151df6c
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.expressions;
+
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.spi.RowExpressionSerde;
+import com.facebook.presto.spi.relation.RowExpression;
+
+import javax.inject.Inject;
+
+import java.nio.charset.StandardCharsets;
+
+import static java.util.Objects.requireNonNull;
+
+public class JsonCodecRowExpressionSerde
+ implements RowExpressionSerde
+{
+ private final JsonCodec codec;
+
+ @Inject
+ public JsonCodecRowExpressionSerde(JsonCodec codec)
+ {
+ this.codec = requireNonNull(codec, "codec is null");
+ }
+
+ @Override
+ public String serialize(RowExpression expression)
+ {
+ return new String(codec.toBytes(expression), StandardCharsets.UTF_8);
+ }
+
+ @Override
+ public RowExpression deserialize(String data)
+ {
+ return codec.fromBytes(data.getBytes(StandardCharsets.UTF_8));
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
index 8d2a9ac22e5bb..0360d5b951932 100644
--- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
+++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java
@@ -149,6 +149,7 @@
import com.facebook.presto.spi.plan.SimplePlanFragment;
import com.facebook.presto.spi.plan.StageExecutionDescriptor;
import com.facebook.presto.spi.plan.TableScanNode;
+import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spiller.FileSingleStreamSpillerFactory;
import com.facebook.presto.spiller.GenericPartitioningSpillerFactory;
import com.facebook.presto.spiller.GenericSpillerFactory;
@@ -175,6 +176,7 @@
import com.facebook.presto.sql.analyzer.QueryExplainer;
import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
+import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler;
@@ -462,7 +464,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig,
this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler);
NodeInfo nodeInfo = new NodeInfo("test");
- expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager());
+ expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)));
this.statsNormalizer = new StatsNormalizer();
this.scalarStatsCalculator = new ScalarStatsCalculator(metadata, expressionOptimizerManager);
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java b/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java
index 913e0db26adc0..7340e45b42c1d 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java
@@ -34,6 +34,7 @@
import java.util.Map;
import java.util.Properties;
+import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
@@ -65,6 +66,7 @@ public void setUp()
manager = new ExpressionOptimizerManager(
pluginNodeManager,
METADATA.getFunctionAndTypeManager(),
+ new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)),
directory);
}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java
index cea4abb4e0787..a2f535e4d5c2a 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java
@@ -23,9 +23,11 @@
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
+import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.sql.Optimizer;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
+import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.RuleStatsRecorder;
import com.facebook.presto.sql.planner.TypeProvider;
@@ -49,6 +51,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
+import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan;
import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlanDoesNotMatch;
import static com.facebook.presto.transaction.TransactionBuilder.transaction;
@@ -177,7 +180,8 @@ private List getMinimalOptimizers()
metadata,
new ExpressionOptimizerManager(
new PluginNodeManager(new InMemoryNodeManager()),
- queryRunner.getFunctionAndTypeManager())).rules()));
+ queryRunner.getFunctionAndTypeManager(),
+ new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)))).rules()));
}
private void inTransaction(Function transactionSessionConsumer)
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java
index 7817b311053df..78ff684386afc 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java
@@ -25,6 +25,7 @@
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
+import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.tree.Expression;
@@ -38,6 +39,7 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;
+import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
@@ -185,7 +187,7 @@ private static void assertSimplifies(String expression, String rowExpressionExpe
Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression));
InMemoryNodeManager nodeManager = new InMemoryNodeManager();
- ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager());
+ ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)));
TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(METADATA);
RowExpression actualRowExpression = translator.translate(actualExpression, TypeProvider.viewOf(TYPES));
diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt
index c06e00edf834c..7eaf9f39177ce 100644
--- a/presto-native-execution/presto_cpp/main/CMakeLists.txt
+++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt
@@ -51,6 +51,7 @@ target_link_libraries(
presto_http
presto_operators
presto_velox_conversion
+ presto_expression_optimizer
velox_abfs
velox_aggregates
velox_caching
diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp
index 0a422c66e16bd..0fa4b04bd1dfe 100644
--- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp
+++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp
@@ -1551,6 +1551,13 @@ void PrestoServer::registerSidecarEndpoints() {
proxygen::ResponseHandler* downstream) {
http::sendOkResponse(downstream, getFunctionsMetadata());
});
+ httpServer_->registerPost(
+ "/v1/expressions",
+ [&](proxygen::HTTPMessage* message,
+ const std::vector>& body,
+ proxygen::ResponseHandler* downstream) {
+ optimizeExpressions(*message, body, downstream);
+ });
httpServer_->registerPost(
"/v1/velox/plan",
[server = this](
@@ -1599,4 +1606,25 @@ protocol::NodeStatus PrestoServer::fetchNodeStatus() {
return nodeStatus;
}
+void PrestoServer::optimizeExpressions(
+ const proxygen::HTTPMessage& message,
+ const std::vector>& body,
+ proxygen::ResponseHandler* downstream) {
+ json::array_t inputRowExpressions =
+ json::parse(util::extractMessageBody(body));
+ auto rowExpressionOptimizer =
+ std::make_unique(
+ nativeWorkerPool_.get());
+ auto result = rowExpressionOptimizer->optimize(
+ message.getHeaders(), inputRowExpressions);
+ if (result.second) {
+ VELOX_CHECK(
+ result.first.is_array(),
+ "The output json is not an array of RowExpressions");
+ http::sendOkResponse(downstream, result.first);
+ } else {
+ http::sendErrorResponse(downstream, result.first);
+ }
+}
+
} // namespace facebook::presto
diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h
index 9fa1301c1bf1d..f535f83547b85 100644
--- a/presto-native-execution/presto_cpp/main/PrestoServer.h
+++ b/presto-native-execution/presto_cpp/main/PrestoServer.h
@@ -25,6 +25,7 @@
#include "presto_cpp/main/PeriodicHeartbeatManager.h"
#include "presto_cpp/main/PrestoExchangeSource.h"
#include "presto_cpp/main/PrestoServerOperations.h"
+#include "presto_cpp/main/types/RowExpressionOptimizer.h"
#include "presto_cpp/main/types/VeloxPlanValidator.h"
#include "velox/common/caching/AsyncDataCache.h"
#include "velox/common/memory/MemoryAllocator.h"
@@ -217,6 +218,11 @@ class PrestoServer {
protocol::NodeStatus fetchNodeStatus();
+ void optimizeExpressions(
+ const proxygen::HTTPMessage& message,
+ const std::vector>& body,
+ proxygen::ResponseHandler* downstream);
+
void populateMemAndCPUInfo();
// Periodically yield tasks if there are tasks queued.
diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt
index 5841728512238..93d8d4abddccc 100644
--- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt
+++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt
@@ -34,6 +34,12 @@ add_library(presto_velox_conversion OBJECT VeloxPlanConversion.cpp)
target_link_libraries(presto_velox_conversion velox_type)
+add_library(presto_expression_optimizer RowExpressionConverter.cpp
+ RowExpressionOptimizer.cpp)
+
+target_link_libraries(presto_expression_optimizer presto_type_converter
+ presto_types presto_protocol)
+
if(PRESTO_ENABLE_TESTING)
add_subdirectory(tests)
endif()
diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp
index 83d359f6d6aad..f3e0a9c766069 100644
--- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp
+++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp
@@ -14,7 +14,6 @@
#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
#include
-#include "presto_cpp/main/common/Configs.h"
#include "presto_cpp/presto_protocol/Base64Util.h"
#include "velox/common/base/Exceptions.h"
#include "velox/functions/prestosql/types/JsonType.h"
@@ -34,59 +33,9 @@ std::string toJsonString(const T& value) {
}
std::string mapScalarFunction(const std::string& name) {
- static const std::string prestoDefaultNamespacePrefix =
- SystemConfig::instance()->prestoDefaultNamespacePrefix();
- static const std::unordered_map kFunctionNames = {
- // Operator overrides: com.facebook.presto.common.function.OperatorType
- {"presto.default.$operator$add",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "plus")},
- {"presto.default.$operator$between",
- util::addDefaultNamespacePrefix(
- prestoDefaultNamespacePrefix, "between")},
- {"presto.default.$operator$divide",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "divide")},
- {"presto.default.$operator$equal",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "eq")},
- {"presto.default.$operator$greater_than",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gt")},
- {"presto.default.$operator$greater_than_or_equal",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gte")},
- {"presto.default.$operator$is_distinct_from",
- util::addDefaultNamespacePrefix(
- prestoDefaultNamespacePrefix, "distinct_from")},
- {"presto.default.$operator$less_than",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lt")},
- {"presto.default.$operator$less_than_or_equal",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lte")},
- {"presto.default.$operator$modulus",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "mod")},
- {"presto.default.$operator$multiply",
- util::addDefaultNamespacePrefix(
- prestoDefaultNamespacePrefix, "multiply")},
- {"presto.default.$operator$negation",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "negate")},
- {"presto.default.$operator$not_equal",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "neq")},
- {"presto.default.$operator$subtract",
- util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "minus")},
- {"presto.default.$operator$subscript",
- util::addDefaultNamespacePrefix(
- prestoDefaultNamespacePrefix, "subscript")},
- {"presto.default.$operator$xx_hash_64",
- util::addDefaultNamespacePrefix(
- prestoDefaultNamespacePrefix, "xxhash64_internal")},
- {"presto.default.combine_hash",
- util::addDefaultNamespacePrefix(
- prestoDefaultNamespacePrefix, "combine_hash_internal")},
- // Special form function overrides.
- {"presto.default.in", "in"},
- };
-
std::string lowerCaseName = boost::to_lower_copy(name);
-
- auto it = kFunctionNames.find(lowerCaseName);
- if (it != kFunctionNames.end()) {
- return it->second;
+ if (prestoOperatorMap().find(lowerCaseName) != prestoOperatorMap().end()) {
+ return prestoOperatorMap().at(lowerCaseName);
}
return lowerCaseName;
@@ -385,6 +334,67 @@ std::optional tryConvertLiteralArray(
}
} // namespace
+const std::unordered_map prestoOperatorMap() {
+ static const std::string prestoDefaultNamespacePrefix =
+ SystemConfig::instance()->prestoDefaultNamespacePrefix();
+ static const std::unordered_map kPrestoOperatorMap =
+ {
+ // Operator overrides:
+ // com.facebook.presto.common.function.OperatorType
+ {"presto.default.$operator$add",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "plus")},
+ {"presto.default.$operator$between",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "between")},
+ {"presto.default.$operator$divide",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "divide")},
+ {"presto.default.$operator$equal",
+ util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "eq")},
+ {"presto.default.$operator$greater_than",
+ util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gt")},
+ {"presto.default.$operator$greater_than_or_equal",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "gte")},
+ {"presto.default.$operator$is_distinct_from",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "distinct_from")},
+ {"presto.default.$operator$less_than",
+ util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lt")},
+ {"presto.default.$operator$less_than_or_equal",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "lte")},
+ {"presto.default.$operator$modulus",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "mod")},
+ {"presto.default.$operator$multiply",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "multiply")},
+ {"presto.default.$operator$negation",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "negate")},
+ {"presto.default.$operator$not_equal",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "neq")},
+ {"presto.default.$operator$subtract",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "minus")},
+ {"presto.default.$operator$subscript",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "subscript")},
+ {"presto.default.$operator$xx_hash_64",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "xxhash64_internal")},
+ {"presto.default.combine_hash",
+ util::addDefaultNamespacePrefix(
+ prestoDefaultNamespacePrefix, "combine_hash_internal")},
+ // Special form function overrides.
+ {"presto.default.in", "in"},
+ };
+ return kPrestoOperatorMap;
+}
+
std::optional VeloxExprConverter::tryConvertDate(
const protocol::CallExpression& pexpr) const {
static const std::string prestoDefaultNamespacePrefix =
diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h
index 8526f6a8c9638..76181d4e092b1 100644
--- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h
+++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h
@@ -22,6 +22,8 @@
namespace facebook::presto {
+const std::unordered_map prestoOperatorMap();
+
class VeloxExprConverter {
public:
VeloxExprConverter(velox::memory::MemoryPool* pool, TypeParser* typeParser)
diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.cpp b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.cpp
new file mode 100644
index 0000000000000..9454d24fdfaa7
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.cpp
@@ -0,0 +1,429 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "presto_cpp/main/types/RowExpressionConverter.h"
+#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
+#include "velox/expression/FieldReference.h"
+
+using namespace facebook::presto;
+using namespace facebook::velox;
+
+namespace facebook::presto::expression {
+
+namespace {
+const std::string kConstant = "constant";
+const std::string kBoolean = "boolean";
+const std::string kVariable = "variable";
+const std::string kCall = "call";
+const std::string kStatic = "$static";
+const std::string kSpecial = "special";
+const std::string kCoalesce = "COALESCE";
+const std::string kRowConstructor = "ROW_CONSTRUCTOR";
+const std::string kSwitch = "SWITCH";
+const std::string kWhen = "WHEN";
+
+protocol::TypeSignature getTypeSignature(const TypePtr& type) {
+ std::string typeSignature;
+ if (type->isPrimitiveType()) {
+ typeSignature = type->toString();
+ boost::algorithm::to_lower(typeSignature);
+ } else {
+ std::string complexTypeString;
+ std::vector childTypes;
+ if (type->isRow()) {
+ complexTypeString = "row";
+ childTypes = asRowType(type)->children();
+ } else if (type->isArray()) {
+ complexTypeString = "array";
+ childTypes = type->asArray().children();
+ } else if (type->isMap()) {
+ complexTypeString = "map";
+ const auto mapType = type->asMap();
+ childTypes = {mapType.keyType(), mapType.valueType()};
+ } else {
+ VELOX_USER_FAIL("Invalid type {}", type->toString());
+ }
+
+ typeSignature = complexTypeString + "(";
+ if (!childTypes.empty()) {
+ auto numChildren = childTypes.size();
+ for (auto i = 0; i < numChildren - 1; i++) {
+ typeSignature += fmt::format("{},", getTypeSignature(childTypes[i]));
+ }
+ typeSignature += getTypeSignature(childTypes[numChildren - 1]);
+ }
+ typeSignature += ")";
+ }
+
+ return typeSignature;
+}
+
+json toVariableReferenceExpression(
+ const std::shared_ptr& field) {
+ protocol::VariableReferenceExpression vexpr;
+ vexpr.name = field->name();
+ vexpr._type = kVariable;
+ vexpr.type = getTypeSignature(field->type());
+ json result;
+ protocol::to_json(result, vexpr);
+
+ return result;
+}
+
+bool isPrestoSpecialForm(const std::string& name) {
+ static const std::unordered_set kPrestoSpecialForms = {
+ "if",
+ "null_if",
+ "switch",
+ "when",
+ "is_null",
+ "coalesce",
+ "in",
+ "and",
+ "or",
+ "dereference",
+ "row_constructor",
+ "bind"};
+ return kPrestoSpecialForms.count(name) != 0;
+}
+
+json getWhenSpecialForm(const TypePtr& type, const json::array_t& whenArgs) {
+ json when;
+ when["@type"] = kSpecial;
+ when["form"] = kWhen;
+ when["arguments"] = whenArgs;
+ when["returnType"] = getTypeSignature(type);
+ return when;
+}
+
+std::vector getRowExpressionArguments(
+ const RowExpressionPtr& input) {
+ std::vector arguments;
+ if (input->_type == kSpecial) {
+ auto inputSpecialForm =
+ dynamic_cast(input.get());
+ VELOX_CHECK_NOT_NULL(inputSpecialForm);
+ arguments = inputSpecialForm->arguments;
+ } else if (input->_type == kCall) {
+ auto inputCall = dynamic_cast(input.get());
+ VELOX_CHECK_NOT_NULL(inputCall);
+ arguments = inputCall->arguments;
+ } else {
+ VELOX_USER_FAIL(
+ "Input should be a SpecialFormExpression or CallExpression");
+ }
+ return arguments;
+}
+
+const std::unordered_map& veloxToPrestoOperatorMap() {
+ static std::unordered_map veloxToPrestoOperatorMap =
+ {{"cast", "presto.default.$operator$cast"}};
+ for (const auto& entry : prestoOperatorMap()) {
+ veloxToPrestoOperatorMap[entry.second] = entry.first;
+ }
+ return veloxToPrestoOperatorMap;
+}
+} // namespace
+
+std::string RowExpressionConverter::getValueBlock(
+ const VectorPtr& vector) const {
+ std::ostringstream output;
+ serde_->serializeSingleColumn(vector, nullptr, pool_, &output);
+ const auto serialized = output.str();
+ const auto serializedSize = serialized.size();
+ return encoding::Base64::encode(serialized.c_str(), serializedSize);
+}
+
+json RowExpressionConverter::getConstantRowExpression(
+ const std::shared_ptr& constantExpr) {
+ protocol::ConstantExpression cexpr;
+ cexpr.type = getTypeSignature(constantExpr->type());
+ cexpr.valueBlock.data = getValueBlock(constantExpr->value());
+ json result;
+ protocol::to_json(result, cexpr);
+ return result;
+}
+
+RowExpressionConverter::SwitchFormArguments
+RowExpressionConverter::getSimpleSwitchFormArgs(
+ const exec::SwitchExpr* switchExpr,
+ const std::vector& arguments) {
+ SwitchFormArguments result;
+ // Consider the following Presto query with simple form switch expression:
+ // SELECT CASE 1 WHEN 2 THEN 31 - 1 WHEN 1 THEN 32 + 1 WHEN orderkey THEN 34
+ // ELSE 35 END FROM orders;
+ // When this Presto expression is converted to velox switch expression, the
+ // inputs to the velox expression contain the following `equal` expressions:
+ // eq(1, 2), eq(1, 1), eq(1, orderkey). The resultant Presto simple form
+ // switch expression requires the `expression` to be the first `argument`,
+ // please refer to the syntax of Presto simple form switch expression here:
+ // https://prestodb.io/docs/current/functions/conditional.html#case). It is
+ // not possible to get the value of `expression` from velox, since these
+ // equal expressions could all evaluate to either `true` or `false` during
+ // constant folding. Hence, this `expression` is obtained from the input
+ // Presto switch expression.
+ result.arguments.emplace_back(arguments[0]);
+ const auto& switchInputs = switchExpr->inputs();
+ const auto numInputs = switchInputs.size();
+
+ for (auto i = 0; i < numInputs - 1; i += 2) {
+ json::array_t resultWhenArgs;
+ const vector_size_t argsIdx = i / 2 + 1;
+ const auto& caseValue = switchInputs[i + 1];
+ auto inputWhenArgs = getRowExpressionArguments(arguments[argsIdx]);
+
+ if (switchInputs[i]->isConstant()) {
+ auto constantExpr =
+ std::dynamic_pointer_cast(switchInputs[i]);
+ if (auto constVector =
+ constantExpr->value()->as>()) {
+ // If this is the first switch case that evaluates to true, return the
+ // expression corresponding to this case. From the aforementioned
+ // example, `eq(1, 1)` evaluates to true, so the value corresponding to
+ // the WHEN clause (`CASE 1 WHEN 1 THEN 32 + 1`), `33` is returned.
+ if (constVector->valueAt(0) && result.arguments.size() == 1) {
+ return {
+ true,
+ veloxToPrestoRowExpression(caseValue, inputWhenArgs[1]),
+ json::array()};
+ } else {
+ // Skip switch cases that evaluate to false. From the aforementioned
+ // example, `eq(1, 2)` evaluates to false, so the corresponding WHEN
+ // clause, `CASE 1 WHEN 2 THEN 31 - 1`, is not included in the output
+ // switch expression's arguments.
+ continue;
+ }
+ } else {
+ resultWhenArgs.emplace_back(getConstantRowExpression(constantExpr));
+ }
+ } else {
+ VELOX_CHECK(!switchInputs[i]->inputs().empty());
+ const auto& matchExpr = switchInputs[i]->inputs().back();
+ resultWhenArgs.emplace_back(
+ veloxToPrestoRowExpression(matchExpr, inputWhenArgs[0]));
+ }
+
+ resultWhenArgs.emplace_back(
+ veloxToPrestoRowExpression(caseValue, inputWhenArgs[1]));
+ result.arguments.emplace_back(
+ getWhenSpecialForm(switchInputs[i + 1]->type(), resultWhenArgs));
+ }
+
+ // Else clause.
+ if (numInputs % 2 != 0) {
+ result.arguments.emplace_back(veloxToPrestoRowExpression(
+ switchInputs[numInputs - 1], arguments.back()));
+ }
+ return result;
+}
+
+RowExpressionConverter::SwitchFormArguments
+RowExpressionConverter::getSpecialSwitchFormArgs(
+ const exec::SwitchExpr* switchExpr,
+ const std::vector& arguments) {
+ // The searched form of CASE conditional expression in Presto needs to be
+ // handled differently from the simple form (please refer to:
+ // https://prestodb.io/docs/current/functions/conditional.html#case). The
+ // searched form can be detected by the presence of a boolean value in the
+ // first argument. This default boolean argument is not present in the velox
+ // switch expression, so it is added to the arguments of output switch
+ // expression unchanged.
+ if (arguments[0]->_type == kConstant) {
+ if (auto constantRowExpr =
+ dynamic_cast(arguments[0].get())) {
+ if (constantRowExpr->type == kBoolean) {
+ SwitchFormArguments result;
+ const auto& switchInputs = switchExpr->inputs();
+ const auto numInputs = switchInputs.size();
+ result.arguments = {arguments[0]};
+ for (auto i = 0; i < numInputs - 1; i += 2) {
+ const vector_size_t argsIdx = i / 2 + 1;
+ std::vector inputWhenArgs =
+ getRowExpressionArguments(arguments[argsIdx]);
+ json::array_t resultWhenArgs;
+ resultWhenArgs.emplace_back(
+ veloxToPrestoRowExpression(switchInputs[i], inputWhenArgs[0]));
+ resultWhenArgs.emplace_back(veloxToPrestoRowExpression(
+ switchInputs[i + 1], inputWhenArgs[1]));
+ result.arguments.emplace_back(
+ getWhenSpecialForm(switchInputs[i + 1]->type(), resultWhenArgs));
+ }
+
+ // Else clause.
+ if (numInputs % 2 != 0) {
+ result.arguments.emplace_back(veloxToPrestoRowExpression(
+ switchInputs[numInputs - 1], arguments.back()));
+ }
+ return result;
+ }
+ }
+ }
+
+ return getSimpleSwitchFormArgs(switchExpr, arguments);
+}
+
+json RowExpressionConverter::getSpecialForm(
+ const exec::ExprPtr& expr,
+ const RowExpressionPtr& input) {
+ json result;
+ result["@type"] = kSpecial;
+ result["returnType"] = getTypeSignature(expr->type());
+ auto form = expr->name();
+ // Presto requires the field form to be in upper case.
+ std::transform(form.begin(), form.end(), form.begin(), ::toupper);
+ result["form"] = form;
+ std::vector inputArguments =
+ getRowExpressionArguments(input);
+
+ // Arguments for switch expression include 'WHEN' special form expression(s)
+ // so they are constructed separately. If the switch expression evaluation
+ // found a case that always evaluates to `true`, the field 'isSimplified' in
+ // the result is `true` and the field 'caseExpression' contains the value
+ // corresponding to the simplified switch case. Otherwise, 'isSimplified' is
+ // false and the field 'arguments' contains the 'when' clauses needed by the
+ // Presto switch SpecialFormExpression.
+ if (form == kSwitch) {
+ auto switchExpr = dynamic_cast(expr.get());
+ VELOX_CHECK_NOT_NULL(switchExpr);
+ auto switchResult = getSpecialSwitchFormArgs(switchExpr, inputArguments);
+ if (switchResult.isSimplified) {
+ return switchResult.caseExpression;
+ } else {
+ result["arguments"] = switchResult.arguments;
+ }
+ } else {
+ // Presto special form expressions that are not of type `SWITCH`, such as
+ // `IN`, `AND`, `OR` etc,. are handled in this clause. The list of Presto
+ // special form expressions can be found in `kPrestoSpecialForms` in the
+ // helper function `isPrestoSpecialForm`.
+ auto exprInputs = expr->inputs();
+ const auto numInputs = exprInputs.size();
+ if (form == kCoalesce) {
+ VELOX_CHECK_LE(numInputs, inputArguments.size());
+ } else {
+ VELOX_CHECK_EQ(numInputs, inputArguments.size());
+ }
+ result["arguments"] = json::array();
+ for (auto i = 0; i < numInputs; i++) {
+ result["arguments"].push_back(
+ veloxToPrestoRowExpression(exprInputs[i], inputArguments[i]));
+ }
+ }
+
+ return result;
+}
+
+json RowExpressionConverter::getRowConstructorSpecialForm(
+ std::shared_ptr& constantExpr) {
+ json result;
+ result["@type"] = kSpecial;
+ result["form"] = kRowConstructor;
+ result["returnType"] = getTypeSignature(constantExpr->type());
+ auto value = constantExpr->value();
+ auto* constVector = value->as>();
+ auto* rowVector = constVector->valueVector()->as();
+ auto type = asRowType(constantExpr->type());
+ auto size = rowVector->children().size();
+
+ protocol::ConstantExpression cexpr;
+ json j;
+ result["arguments"] = json::array();
+ for (auto i = 0; i < size; i++) {
+ cexpr.type = getTypeSignature(type->childAt(i));
+ cexpr.valueBlock.data = getValueBlock(rowVector->childAt(i));
+ protocol::to_json(j, cexpr);
+ result["arguments"].push_back(j);
+ }
+ return result;
+}
+
+json RowExpressionConverter::toCallRowExpression(
+ const exec::ExprPtr& expr,
+ const RowExpressionPtr& input) {
+ json result;
+ result["@type"] = kCall;
+ protocol::Signature signature;
+ std::string exprName = expr->name();
+ if (veloxToPrestoOperatorMap().find(exprName) !=
+ veloxToPrestoOperatorMap().end()) {
+ exprName = veloxToPrestoOperatorMap().at(exprName);
+ }
+ signature.name = exprName;
+ result["displayName"] = exprName;
+ signature.kind = protocol::FunctionKind::SCALAR;
+ signature.typeVariableConstraints = {};
+ signature.longVariableConstraints = {};
+ signature.returnType = getTypeSignature(expr->type());
+
+ std::vector argumentTypes;
+ auto exprInputs = expr->inputs();
+ auto numArgs = exprInputs.size();
+ argumentTypes.reserve(numArgs);
+ for (auto i = 0; i < numArgs; i++) {
+ argumentTypes.emplace_back(getTypeSignature(exprInputs[i]->type()));
+ }
+ signature.argumentTypes = argumentTypes;
+ signature.variableArity = false;
+
+ protocol::BuiltInFunctionHandle builtInFunctionHandle;
+ builtInFunctionHandle._type = kStatic;
+ builtInFunctionHandle.signature = signature;
+ result["functionHandle"] = builtInFunctionHandle;
+ result["returnType"] = getTypeSignature(expr->type());
+ result["arguments"] = json::array();
+ for (const auto& exprInput : exprInputs) {
+ result["arguments"].push_back(veloxToPrestoRowExpression(exprInput, input));
+ }
+
+ return result;
+}
+
+json RowExpressionConverter::veloxToPrestoRowExpression(
+ const exec::ExprPtr& expr,
+ const RowExpressionPtr& input) {
+ if (expr->isConstant()) {
+ if (expr->inputs().empty()) {
+ auto constantExpr =
+ std::dynamic_pointer_cast(expr);
+ VELOX_CHECK_NOT_NULL(constantExpr);
+ // Constant velox expressions of ROW type map to ROW_CONSTRUCTOR special
+ // form expression in Presto.
+ if (expr->type()->isRow()) {
+ return getRowConstructorSpecialForm(constantExpr);
+ }
+ return getConstantRowExpression(constantExpr);
+ } else {
+ // Expressions such as 'divide(0, 0)' are not constant folded during
+ // compilation in velox, since they throw an exception (Divide by zero in
+ // this example) during evaluation (see function `tryFoldIfConstant` in
+ // `velox/expression/ExprCompiler.cpp`). The input expression is returned
+ // unchanged in such cases.
+ return input;
+ }
+ }
+
+ if (auto field =
+ std::dynamic_pointer_cast(expr)) {
+ return toVariableReferenceExpression(field);
+ }
+
+ // Check if special form expression or call expression.
+ auto exprName = expr->name();
+ boost::algorithm::to_lower(exprName);
+ if (isPrestoSpecialForm(exprName)) {
+ return getSpecialForm(expr, input);
+ }
+ return toCallRowExpression(expr, input);
+}
+
+} // namespace facebook::presto::expression
diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.h b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.h
new file mode 100644
index 0000000000000..c6cd379188e17
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.h
@@ -0,0 +1,122 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "presto_cpp/external/json/nlohmann/json.hpp"
+#include "presto_cpp/presto_protocol/presto_protocol.h"
+#include "velox/expression/ConstantExpr.h"
+#include "velox/expression/Expr.h"
+#include "velox/expression/SwitchExpr.h"
+#include "velox/serializers/PrestoSerializer.h"
+
+using namespace facebook::velox;
+
+using RowExpressionPtr =
+ std::shared_ptr;
+using SpecialFormExpressionPtr =
+ std::shared_ptr;
+
+namespace facebook::presto::expression {
+
+/// Helper class to convert a velox expression to its corresponding Presto
+/// protocol RowExpression. The function `veloxToPrestoRowExpression` is used in
+/// RowExpressionOptimizer to convert the constant folded velox expression to
+/// Presto protocol RowExpression:
+/// 1. A velox constant expression of type exec::ConstantExpr without any inputs
+/// is converted to Presto protocol expression of type ConstantExpression.
+/// If the velox constant expression is of ROW type, it is converted to a
+/// Presto protocol expression of type SpecialFormExpression (with Form as
+/// ROW_CONSTRUCTOR).
+/// 2. Velox expression representing a variable is of type exec::FieldReference,
+/// it is converted to a Presto protocol expression of type
+/// VariableReferenceExpression.
+/// 3. Special form expressions and expressions with a vector function in velox
+/// map either to a Presto protocol SpecialFormExpression or to a Presto
+/// protocol CallExpression. This is because special form expressions in
+/// velox and in Presto do not have a one to one mapping. If the velox
+/// expression name belongs to the set of possible Presto protocol
+/// SpecialFormExpression names, it is converted to a Presto protocol
+/// SpecialFormExpression; else it is converted to a Presto protocol
+/// CallExpression.
+///
+/// The function `getConstantRowExpression` is used in RowExpressionOptimizer to
+/// convert a velox constant expression to Presto protocol ConstantExpression.
+class RowExpressionConverter {
+ public:
+ explicit RowExpressionConverter(memory::MemoryPool* pool) : pool_(pool) {}
+
+ /// Converts a velox constant expression `constantExpr` to a Presto protocol
+ /// ConstantExpression.
+ json getConstantRowExpression(
+ const std::shared_ptr& constantExpr);
+
+ /// Converts a velox expression `expr` to a Presto protocol RowExpression.
+ /// Argument `inputRowExpr` is the input Presto protocol RowExpression before
+ /// it is constant folded in velox.
+ json veloxToPrestoRowExpression(
+ const exec::ExprPtr& expr,
+ const RowExpressionPtr& inputRowExpr);
+
+ private:
+ /// When 'isSimplified' is true, the cases (or arguments) in switch expression
+ /// have been simplified when the expression was constant folded in velox. In
+ /// this case, the expression corresponding to the first switch case that
+ /// always evaluates to true is returned in 'caseExpression'. When
+ /// 'isSimplified' is false, the cases in switch expression have not been
+ /// simplified, and the switch expression arguments required by Presto are
+ /// returned in 'arguments'.
+ struct SwitchFormArguments {
+ bool isSimplified = false;
+ json caseExpression;
+ json::array_t arguments;
+ };
+
+ /// ValueBlock in Presto protocol ConstantExpression requires only the column
+ /// from the serialized PrestoPage without the page header. This function is
+ /// used to serialize a velox vector to ValueBlock.
+ std::string getValueBlock(const VectorPtr& vector) const;
+
+ /// Helper function to get arguments for Presto protocol SpecialFormExpression
+ /// of type SWITCH when the CASE expression is of 'simple' form (please refer
+ /// to: https://prestodb.io/docs/current/functions/conditional.html#case).
+ SwitchFormArguments getSimpleSwitchFormArgs(
+ const exec::SwitchExpr* switchExpr,
+ const std::vector& inputArgs);
+
+ /// Helper function to get arguments for Presto protocol SpecialFormExpression
+ /// of type SWITCH from a velox switch expression `switchExpr`.
+ SwitchFormArguments getSpecialSwitchFormArgs(
+ const exec::SwitchExpr* switchExpr,
+ const std::vector& inputArgs);
+
+ /// Helper function to construct a Presto protocol SpecialFormExpression from
+ /// a velox expression `expr`.
+ json getSpecialForm(const exec::ExprPtr& expr, const RowExpressionPtr& input);
+
+ /// Helper function to construct a Presto protocol SpecialFormExpression of
+ /// type ROW_CONSTRUCTOR from a velox constant expression `constantExpr`.
+ json getRowConstructorSpecialForm(
+ std::shared_ptr& constantExpr);
+
+ /// Helper function to construct a Presto protocol CallExpression from a velox
+ /// expression `expr`.
+ json toCallRowExpression(
+ const exec::ExprPtr& expr,
+ const RowExpressionPtr& input);
+
+ memory::MemoryPool* pool_;
+ const std::unique_ptr serde_ =
+ std::make_unique();
+};
+} // namespace facebook::presto::expression
diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.cpp b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.cpp
new file mode 100644
index 0000000000000..1181fa3e03e9d
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.cpp
@@ -0,0 +1,325 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "presto_cpp/main/types/RowExpressionOptimizer.h"
+#include "velox/expression/ExprCompiler.h"
+
+using namespace facebook::presto;
+using namespace facebook::velox;
+
+namespace facebook::presto::expression {
+
+namespace {
+
+const std::string kTimezoneHeader = "X-Presto-Time-Zone";
+const std::string kOptimizerLevelHeader = "X-Presto-Expression-Optimizer-Level";
+const std::string kEvaluated = "EVALUATED";
+
+template
+std::shared_ptr getConstantExpr(
+ const TypePtr& type,
+ const DecodedVector& decoded,
+ memory::MemoryPool* pool) {
+ if constexpr (
+ KIND == TypeKind::ROW || KIND == TypeKind::UNKNOWN ||
+ KIND == TypeKind::ARRAY || KIND == TypeKind::MAP) {
+ VELOX_USER_FAIL("Invalid result type {}", type->toString());
+ } else {
+ using T = typename TypeTraits::NativeType;
+ auto constVector = std::make_shared>(
+ pool, decoded.size(), decoded.isNullAt(0), type, decoded.valueAt(0));
+ return std::make_shared(constVector);
+ }
+}
+} // namespace
+
+exec::ExprPtr RowExpressionOptimizer::compileExpression(
+ const std::shared_ptr& inputRowExpr) {
+ auto typedExpr = veloxExprConverter_.toVeloxExpr(inputRowExpr);
+ exec::ExprSet exprSet{{typedExpr}, execCtx_.get()};
+ auto compiledExprs =
+ exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true);
+ return compiledExprs[0];
+}
+
+RowExpressionPtr RowExpressionOptimizer::optimizeAndSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr) {
+ auto left = specialFormExpr->arguments[0];
+ auto right = specialFormExpr->arguments[1];
+ auto leftExpr = compileExpression(left);
+ bool isLeftNull = false;
+ bool leftValue = false;
+
+ if (auto constantExpr =
+ std::dynamic_pointer_cast(leftExpr)) {
+ isLeftNull = constantExpr->value()->isNullAt(0);
+ if (!isLeftNull) {
+ if (auto constVector =
+ constantExpr->value()->as>()) {
+ if (!constVector->valueAt(0)) {
+ return rowExpressionConverter_.getConstantRowExpression(constantExpr);
+ } else {
+ leftValue = true;
+ }
+ }
+ }
+ }
+
+ auto rightExpr = compileExpression(right);
+ if (auto constantExpr =
+ std::dynamic_pointer_cast(rightExpr)) {
+ if (constantExpr->value()->isNullAt(0) && (isLeftNull || leftValue)) {
+ return rowExpressionConverter_.getConstantRowExpression(constantExpr);
+ }
+ if (auto constVector = constantExpr->value()->as>()) {
+ if (constVector->valueAt(0)) {
+ return left;
+ } else if (isLeftNull || leftValue) {
+ return rowExpressionConverter_.getConstantRowExpression(constantExpr);
+ }
+ }
+ }
+
+ return specialFormExpr;
+}
+
+RowExpressionPtr RowExpressionOptimizer::optimizeIfSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr) {
+ auto condition = specialFormExpr->arguments[0];
+ auto expr = compileExpression(condition);
+
+ if (auto constantExpr =
+ std::dynamic_pointer_cast(expr)) {
+ if (auto constVector = constantExpr->value()->as>()) {
+ if (constVector->valueAt(0)) {
+ return specialFormExpr->arguments[1];
+ }
+ return specialFormExpr->arguments[2];
+ }
+ }
+
+ return specialFormExpr;
+}
+
+RowExpressionPtr RowExpressionOptimizer::optimizeIsNullSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr) {
+ auto expr = compileExpression(specialFormExpr);
+ if (auto constantExpr =
+ std::dynamic_pointer_cast(expr)) {
+ if (constantExpr->value()->isNullAt(0)) {
+ return rowExpressionConverter_.getConstantRowExpression(constantExpr);
+ }
+ }
+
+ return specialFormExpr;
+}
+
+RowExpressionPtr RowExpressionOptimizer::optimizeOrSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr) {
+ auto left = specialFormExpr->arguments[0];
+ auto right = specialFormExpr->arguments[1];
+ auto leftExpr = compileExpression(left);
+ bool isLeftNull = false;
+ bool leftValue = true;
+
+ if (auto constantExpr =
+ std::dynamic_pointer_cast(leftExpr)) {
+ isLeftNull = constantExpr->value()->isNullAt(0);
+ if (!isLeftNull) {
+ if (auto constVector =
+ constantExpr->value()->as>()) {
+ if (constVector->valueAt(0)) {
+ return rowExpressionConverter_.getConstantRowExpression(constantExpr);
+ } else {
+ leftValue = false;
+ }
+ }
+ }
+ }
+
+ auto rightExpr = compileExpression(right);
+ if (auto constantExpr =
+ std::dynamic_pointer_cast(rightExpr)) {
+ if (constantExpr->value()->isNullAt(0) && (isLeftNull || !leftValue)) {
+ return rowExpressionConverter_.getConstantRowExpression(constantExpr);
+ }
+ if (auto constVector = constantExpr->value()->as>()) {
+ if (!constVector->valueAt(0)) {
+ return left;
+ } else if (isLeftNull || !leftValue) {
+ return rowExpressionConverter_.getConstantRowExpression(constantExpr);
+ }
+ }
+ }
+
+ return specialFormExpr;
+}
+
+RowExpressionPtr RowExpressionOptimizer::optimizeCoalesceSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr) {
+ auto argsNoNulls = specialFormExpr->arguments;
+ argsNoNulls.erase(
+ std::remove_if(
+ argsNoNulls.begin(),
+ argsNoNulls.end(),
+ [&](const auto& arg) {
+ auto compiledExpr = compileExpression(arg);
+ if (auto constantExpr =
+ std::dynamic_pointer_cast(
+ compiledExpr)) {
+ return constantExpr->value()->isNullAt(0);
+ }
+ return false;
+ }),
+ argsNoNulls.end());
+
+ if (argsNoNulls.empty()) {
+ return specialFormExpr->arguments[0];
+ }
+ specialFormExpr->arguments = argsNoNulls;
+ return specialFormExpr;
+}
+
+RowExpressionPtr RowExpressionOptimizer::optimizeSpecialForm(
+ const std::shared_ptr& specialFormExpr) {
+ switch (specialFormExpr->form) {
+ case protocol::Form::IF:
+ return optimizeIfSpecialForm(specialFormExpr);
+ case protocol::Form::NULL_IF:
+ VELOX_USER_FAIL("NULL_IF specialForm not supported");
+ break;
+ case protocol::Form::IS_NULL:
+ return optimizeIsNullSpecialForm(specialFormExpr);
+ case protocol::Form::AND:
+ return optimizeAndSpecialForm(specialFormExpr);
+ case protocol::Form::OR:
+ return optimizeOrSpecialForm(specialFormExpr);
+ case protocol::Form::COALESCE:
+ return optimizeCoalesceSpecialForm(specialFormExpr);
+ case protocol::Form::IN:
+ case protocol::Form::DEREFERENCE:
+ case protocol::Form::SWITCH:
+ case protocol::Form::WHEN:
+ case protocol::Form::ROW_CONSTRUCTOR:
+ case protocol::Form::BIND:
+ default:
+ break;
+ }
+
+ return specialFormExpr;
+}
+
+json RowExpressionOptimizer::evaluateNonDeterministicConstantExpr(
+ const exec::ExprPtr& expr,
+ exec::ExprSet& exprSet) {
+ VELOX_CHECK(!expr->isDeterministic());
+ std::vector compiledExprInputTypes;
+ std::vector compiledExprInputs;
+ for (const auto& exprInput : expr->inputs()) {
+ VELOX_CHECK(
+ exprInput->isConstant(),
+ "Inputs to non-deterministic expression to be evaluated must be constant");
+ const auto inputAsConstExpr =
+ std::dynamic_pointer_cast(exprInput);
+ compiledExprInputs.emplace_back(inputAsConstExpr->value());
+ compiledExprInputTypes.emplace_back(exprInput->type());
+ }
+
+ const auto inputVector = std::make_shared(
+ pool_,
+ ROW(std::move(compiledExprInputTypes)),
+ nullptr,
+ 1,
+ compiledExprInputs);
+ exec::EvalCtx evalCtx(execCtx_.get(), &exprSet, inputVector.get());
+ std::vector results(1);
+ SelectivityVector rows(1);
+ exprSet.eval(rows, evalCtx, results);
+ auto res = results.front();
+ DecodedVector decoded(*res, rows);
+ const auto constExpr = VELOX_DYNAMIC_TYPE_DISPATCH(
+ getConstantExpr, res->typeKind(), res->type(), decoded, pool_);
+ return rowExpressionConverter_.getConstantRowExpression(constExpr);
+}
+
+json::array_t RowExpressionOptimizer::optimizeExpressions(
+ const json::array_t& input,
+ const std::string& optimizerLevel) {
+ const auto numExpr = input.size();
+ json::array_t output = json::array();
+ for (auto i = 0; i < numExpr; i++) {
+ std::shared_ptr inputRowExpr = input[i];
+ if (const auto special =
+ std::dynamic_pointer_cast(
+ inputRowExpr)) {
+ inputRowExpr = optimizeSpecialForm(special);
+ }
+ auto typedExpr = veloxExprConverter_.toVeloxExpr(inputRowExpr);
+ exec::ExprSet exprSet{{typedExpr}, execCtx_.get()};
+ auto compiledExprs =
+ exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true);
+ auto compiledExpr = compiledExprs[0];
+ json resultJson;
+
+ if (optimizerLevel == kEvaluated) {
+ if (compiledExpr->isConstant()) {
+ resultJson = rowExpressionConverter_.veloxToPrestoRowExpression(
+ compiledExpr, input[i]);
+ } else {
+ // Velox does not evaluate expressions that are non-deterministic during
+ // compilation with constant folding enabled. Presto might require such
+ // non-deterministic expressions to be evaluated as well, this would be
+ // indicated by the header field 'X-Presto-Expression-Optimizer-Level'
+ // in the http request made to the native sidecar. When this field is
+ // set to 'EVALUATED', non-deterministic expressions with constant
+ // inputs are also evaluated.
+ resultJson =
+ evaluateNonDeterministicConstantExpr(compiledExpr, exprSet);
+ }
+ } else {
+ resultJson = rowExpressionConverter_.veloxToPrestoRowExpression(
+ compiledExpr, input[i]);
+ }
+
+ output.push_back(resultJson);
+ }
+ return output;
+}
+
+std::pair RowExpressionOptimizer::optimize(
+ const proxygen::HTTPHeaders& httpHeaders,
+ const json::array_t& input) {
+ try {
+ auto timezone = httpHeaders.getSingleOrEmpty(kTimezoneHeader);
+ auto optimizerLevel = httpHeaders.getSingleOrEmpty(kOptimizerLevelHeader);
+ std::unordered_map config(
+ {{core::QueryConfig::kSessionTimezone, timezone},
+ {core::QueryConfig::kAdjustTimestampToTimezone, "true"}});
+ auto queryCtx =
+ core::QueryCtx::create(nullptr, core::QueryConfig{std::move(config)});
+ execCtx_ = std::make_unique(pool_, queryCtx.get());
+
+ return {optimizeExpressions(input, optimizerLevel), true};
+ } catch (const VeloxUserError& e) {
+ VLOG(1) << "VeloxUserError during expression evaluation: " << e.what();
+ return {e.what(), false};
+ } catch (const VeloxException& e) {
+ VLOG(1) << "VeloxException during expression evaluation: " << e.what();
+ return {e.what(), false};
+ } catch (const std::exception& e) {
+ VLOG(1) << "std::exception during expression evaluation: " << e.what();
+ return {e.what(), false};
+ }
+}
+
+} // namespace facebook::presto::expression
diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.h b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.h
new file mode 100644
index 0000000000000..737ee35d6af87
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.h
@@ -0,0 +1,102 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include
+#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
+#include "presto_cpp/main/types/RowExpressionConverter.h"
+
+using namespace facebook::velox;
+
+namespace facebook::presto::expression {
+
+/// Helper class to optimize Presto protocol RowExpressions sent along the REST
+/// endpoint 'v1/expressions' to the sidecar.
+class RowExpressionOptimizer {
+ public:
+ explicit RowExpressionOptimizer(memory::MemoryPool* pool)
+ : pool_(pool),
+ veloxExprConverter_(pool, &typeParser_),
+ rowExpressionConverter_(RowExpressionConverter(pool)) {}
+
+ /// Optimizes all expressions from the input json array. If the expression
+ /// optimization fails for any of the input expressions, the second value in
+ /// the returned pair is set to false and the returned json contains the
+ /// exception. Otherwise, the returned json is an array of optimized Presto
+ /// protocol RowExpressions.
+ std::pair optimize(
+ const proxygen::HTTPHeaders& httpHeaders,
+ const json::array_t& input);
+
+ private:
+ /// Converts Presto protocol RowExpression into a velox expression with
+ /// constant folding enabled during velox expression compilation.
+ exec::ExprPtr compileExpression(const RowExpressionPtr& inputRowExpr);
+
+ /// Evaluates Presto protocol SpecialFormExpression `AND` when the inputs can
+ /// be constant folded to `true`, `false`, or `NULL`.
+ RowExpressionPtr optimizeAndSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr);
+
+ /// Evaluates Presto protocol SpecialFormExpression `IF` when the condition
+ /// can be constant folded to `true` or `false`.
+ RowExpressionPtr optimizeIfSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr);
+
+ /// Evaluates Presto protocol SpecialFormExpression `IS_NULL` when the input
+ /// can be constant folded.
+ RowExpressionPtr optimizeIsNullSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr);
+
+ /// Evaluates Presto protocol SpecialFormExpression `OR` when the inputs can
+ /// be constant folded to `true`, `false`, or `NULL`.
+ RowExpressionPtr optimizeOrSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr);
+
+ /// Optimizes Presto protocol SpecialFormExpression `COALESCE` by removing
+ /// `NULL`s from the argument list.
+ RowExpressionPtr optimizeCoalesceSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr);
+
+ /// Optimizes Presto protocol SpecialFormExpressions using optimization rules
+ /// borrowed from Presto function visitSpecialForm() in
+ /// RowExpressionInterpreter.java.
+ RowExpressionPtr optimizeSpecialForm(
+ const SpecialFormExpressionPtr& specialFormExpr);
+
+ /// Evaluates non-deterministic expressions with constant inputs.
+ json evaluateNonDeterministicConstantExpr(
+ const exec::ExprPtr& expr,
+ exec::ExprSet& exprSet);
+
+ /// Optimizes and constant folds each expression from input json array and
+ /// returns an array of expressions that are optimized and constant folded.
+ /// Each expression in the input array is optimized with helper functions
+ /// `optimizeSpecialForm` (applicable only for special form expressions) and
+ /// `optimizeExpression`. The optimized expression is also evaluated if the
+ /// optimization level in the header of http request made to 'v1/expressions'
+ /// is 'EVALUATED'. `optimizeExpression` uses `RowExpressionConverter` to
+ /// convert velox expression(s) to their corresponding Presto protocol
+ /// RowExpression(s).
+ json::array_t optimizeExpressions(
+ const json::array_t& input,
+ const std::string& optimizationLevel);
+
+ memory::MemoryPool* pool_;
+ std::unique_ptr execCtx_;
+ TypeParser typeParser_;
+ VeloxExprConverter veloxExprConverter_;
+ RowExpressionConverter rowExpressionConverter_;
+};
+} // namespace facebook::presto::expression
diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
index 28f73aff40b80..3d1a50e463450 100644
--- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
+++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt
@@ -131,3 +131,36 @@ target_link_libraries(
velox_exec_test_lib
GTest::gtest
GTest::gtest_main)
+
+add_executable(presto_expression_converter_test RowExpressionConverterTest.cpp)
+
+add_test(
+ NAME presto_expression_converter_test
+ COMMAND presto_expression_converter_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+
+target_link_libraries(
+ presto_expression_converter_test
+ presto_expression_optimizer
+ presto_http
+ presto_type_test_utils
+ velox_exec_test_lib
+ GTest::gtest
+ GTest::gtest_main)
+
+add_executable(presto_expression_optimizer_test RowExpressionOptimizerTest.cpp)
+
+add_test(
+ NAME presto_expression_optimizer_test
+ COMMAND presto_expression_optimizer_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+
+target_link_libraries(
+ presto_expression_optimizer_test
+ presto_expression_optimizer
+ presto_http
+ presto_type_test_utils
+ velox_exec_test_lib
+ velox_parse_utils
+ GTest::gtest
+ GTest::gtest_main)
diff --git a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionConverterTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionConverterTest.cpp
new file mode 100644
index 0000000000000..c227157c06fc0
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionConverterTest.cpp
@@ -0,0 +1,176 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "presto_cpp/main/types/RowExpressionConverter.h"
+#include
+#include "presto_cpp/main/common/tests/test_json.h"
+#include "presto_cpp/main/http/tests/HttpTestBase.h"
+#include "presto_cpp/main/types/tests/TestUtils.h"
+#include "velox/expression/CastExpr.h"
+#include "velox/expression/ConjunctExpr.h"
+#include "velox/expression/FieldReference.h"
+#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
+#include "velox/parse/TypeResolver.h"
+#include "velox/vector/VectorStream.h"
+#include "velox/vector/tests/utils/VectorTestBase.h"
+
+using namespace facebook::presto;
+using namespace facebook::velox;
+
+// RowExpressionConverter is used only by RowExpressionOptimizer so unit tests
+// are only added for simple expressions here. End-to-end tests to validate
+// velox expression to Presto RowExpression conversion for different types of
+// expressions can be found in TestDelegatingExpressionOptimizer.java in
+// presto-native-sidecar-plugin.
+class RowExpressionConverterTest
+ : public ::testing::Test,
+ public facebook::velox::test::VectorTestBase {
+ protected:
+ static void SetUpTestCase() {
+ memory::MemoryManager::testingSetInstance({});
+ }
+
+ void SetUp() override {
+ rowExpressionConverter_ =
+ std::make_unique(pool_.get());
+ }
+
+ void testFile(
+ const std::vector& inputs,
+ const std::string& fileName) {
+ json::array_t expected = json::parse(
+ slurp(facebook::presto::test::utils::getDataPath(fileName)));
+ auto numExpr = inputs.size();
+ for (auto i = 0; i < numExpr; i++) {
+ RowExpressionPtr inputRowExpr;
+ // The input row expression is provided by RowExpressionOptimizer in the
+ // sidecar. Since RowExpressionConverter is being tested here, we use the
+ // expected row expression to mock the input. The input row expression is
+ // only used for conversion to special form and call expressions.
+ protocol::from_json(expected[i], inputRowExpr);
+ EXPECT_EQ(
+ rowExpressionConverter_->veloxToPrestoRowExpression(
+ inputs[i], inputRowExpr),
+ expected[i]);
+ }
+ }
+
+ void testConstantExpressionConversion(
+ const std::vector& inputs) {
+ json::array_t expected =
+ json::parse(slurp(facebook::presto::test::utils::getDataPath(
+ "ConstantExpressionConversion.json")));
+ auto numExpr = inputs.size();
+ for (auto i = 0; i < numExpr; i++) {
+ auto constantExpr =
+ std::dynamic_pointer_cast(inputs[i]);
+ EXPECT_EQ(
+ rowExpressionConverter_->getConstantRowExpression(constantExpr),
+ expected[i]);
+ }
+ }
+
+ std::unique_ptr rowExpressionConverter_;
+};
+
+TEST_F(RowExpressionConverterTest, constant) {
+ auto inputs = std::vector{
+ std::make_shared(makeConstant(true, 1)),
+ std::make_shared(makeConstant(12, 1)),
+ std::make_shared(makeConstant(123, 1)),
+ std::make_shared(makeConstant(1234, 1)),
+ std::make_shared(makeConstant(12345, 1)),
+ std::make_shared(
+ makeConstant(std::numeric_limits::max(), 1)),
+ std::make_shared(makeConstant(123.456, 1)),
+ std::make_shared(makeConstant(1234.5678, 1)),
+ std::make_shared(
+ makeConstant(std::numeric_limits::quiet_NaN(), 1)),
+ std::make_shared(makeConstant("abcd", 1)),
+ std::make_shared(
+ makeNullConstant(TypeKind::BIGINT, 1)),
+ std::make_shared(
+ makeNullConstant(TypeKind::VARCHAR, 1))};
+ testFile(inputs, "ConstantExpressionConversion.json");
+ testConstantExpressionConversion(inputs);
+}
+
+TEST_F(RowExpressionConverterTest, rowConstructor) {
+ auto inputs = std::vector{
+ std::make_shared(makeConstantRow(
+ ROW({BIGINT(), REAL(), VARCHAR()}),
+ variant::row(
+ {static_cast(123456),
+ static_cast(12.345),
+ "abc"}),
+ 1)),
+ std::make_shared(makeConstantRow(
+ ROW({REAL(), VARCHAR(), INTEGER()}),
+ variant::row(
+ {std::numeric_limits::quiet_NaN(),
+ "",
+ std::numeric_limits::max()}),
+ 1))};
+ testFile(inputs, "RowConstructorExpressionConversion.json");
+}
+
+TEST_F(RowExpressionConverterTest, variable) {
+ auto inputs = std::vector{
+ std::make_shared(
+ VARCHAR(), std::vector{}, "c0"),
+ std::make_shared(
+ ARRAY(BIGINT()), std::vector{}, "c1"),
+ std::make_shared(
+ MAP(INTEGER(), REAL()), std::vector{}, "c2"),
+ std::make_shared(
+ ROW({SMALLINT(), VARCHAR(), DOUBLE()}),
+ std::vector{},
+ "c3"),
+ std::make_shared(
+ MAP(ARRAY(TINYINT()), BOOLEAN()), std::vector{}, "c4"),
+ };
+ testFile(inputs, "VariableExpressionConversion.json");
+}
+
+TEST_F(RowExpressionConverterTest, special) {
+ auto castExpr = std::make_shared();
+ auto orExpr = std::make_shared(false);
+ auto andExpr = std::make_shared(true);
+ auto inputs = std::vector{
+ castExpr->constructSpecialForm(
+ VARCHAR(),
+ {std::make_shared(
+ BIGINT(), std::vector{}, "c0")},
+ false,
+ core::QueryConfig({})),
+ orExpr->constructSpecialForm(
+ BOOLEAN(),
+ {std::make_shared(
+ BOOLEAN(), std::vector{}, "c1"),
+ std::make_shared(makeConstant(true, 1))},
+ false,
+ core::QueryConfig({})),
+ andExpr->constructSpecialForm(
+ BOOLEAN(),
+ {castExpr->constructSpecialForm(
+ BOOLEAN(),
+ {std::make_shared(
+ INTEGER(), std::vector{}, "c2")},
+ false,
+ core::QueryConfig({})),
+ std::make_shared(
+ BOOLEAN(), std::vector{}, "c3")},
+ false,
+ core::QueryConfig({}))};
+ testFile(inputs, "CallAndSpecialFormExpressionConversion.json");
+}
diff --git a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionOptimizerTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionOptimizerTest.cpp
new file mode 100644
index 0000000000000..ace751c552117
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionOptimizerTest.cpp
@@ -0,0 +1,143 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "presto_cpp/main/types/RowExpressionOptimizer.h"
+#include
+#include
+#include "presto_cpp/main/common/tests/test_json.h"
+#include "presto_cpp/main/http/tests/HttpTestBase.h"
+#include "presto_cpp/main/types/tests/TestUtils.h"
+#include "velox/expression/FieldReference.h"
+#include "velox/expression/RegisterSpecialForm.h"
+#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
+#include "velox/parse/TypeResolver.h"
+#include "velox/vector/VectorStream.h"
+#include "velox/vector/tests/utils/VectorTestBase.h"
+
+using namespace facebook::presto;
+
+// RowExpressionOptimizerTest only tests for basic expression optimization.
+// End-to-end tests for different types of expressions can be found in
+// TestDelegatingExpressionOptimizer.java in presto-native-sidecar-plugin.
+class RowExpressionOptimizerTest
+ : public ::testing::Test,
+ public facebook::velox::test::VectorTestBase {
+ protected:
+ static void SetUpTestCase() {
+ memory::MemoryManager::testingSetInstance({});
+ }
+
+ void SetUp() override {
+ parse::registerTypeResolver();
+ functions::prestosql::registerAllScalarFunctions("presto.default.");
+ exec::registerFunctionCallToSpecialForms();
+ rowExpressionOptimizer_ =
+ std::make_unique(pool());
+ }
+
+ void testFile(const std::string& testName) {
+ auto input = slurp(facebook::presto::test::utils::getDataPath(
+ fmt::format("{}Input.json", testName)));
+ json::array_t inputExpressions = json::parse(input);
+ proxygen::HTTPMessage httpMessage;
+ httpMessage.getHeaders().set(
+ "X-Presto-Time-Zone", "America/Bahia_Banderas");
+ httpMessage.getHeaders().set(
+ "X-Presto-Expression-Optimizer-Level", "OPTIMIZED");
+ auto result = rowExpressionOptimizer_->optimize(
+ httpMessage.getHeaders(), inputExpressions);
+
+ EXPECT_EQ(result.second, true);
+ json resultExpressions = result.first;
+ EXPECT_EQ(resultExpressions.is_array(), true);
+ auto expected = slurp(facebook::presto::test::utils::getDataPath(
+ fmt::format("{}Expected.json", testName)));
+ json::array_t expectedExpressions = json::parse(expected);
+ auto numExpressions = resultExpressions.size();
+ EXPECT_EQ(numExpressions, expectedExpressions.size());
+ for (auto i = 0; i < numExpressions; i++) {
+ EXPECT_EQ(resultExpressions.at(i), expectedExpressions.at(i));
+ }
+ }
+
+ std::unique_ptr rowExpressionOptimizer_;
+};
+
+TEST_F(RowExpressionOptimizerTest, simple) {
+ // Files SimpleExpressions{Input|Expected}.json contain the input and expected
+ // JSON representing the RowExpressions resulting from the following queries:
+ // 1. select 1 + 2;
+ // 2. select abs(-11) + ceil(cast(3.4 as double)) + floor(cast(5.6 as
+ // double));
+ // 3. select 2 between 1 and 3;
+ // Simple expression evaluation with constant folding is verified here.
+ testFile("SimpleExpressions");
+}
+
+TEST_F(RowExpressionOptimizerTest, switchSpecialForm) {
+ // Files SwitchSpecialFormExpressions{Input|Expected}.json contain the input
+ // and expected JSON representing the RowExpressions resulting from the
+ // following queries:
+ // Simple form:
+ // 1. select case 1 when 1 then 32 + 1 when 1 then 34 end;
+ // 2. select case null when true then 32 else 33 end;
+ // 3. select case 33 when 0 then 0 when 33 then 1 when unbound_long then 2
+ // else 1 end from tmp;
+ // Searched form:
+ // 1. select case when true then 32 + 1 end;
+ // 2. select case when false then 1 else 34 - 1 end;
+ // 3. select case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)]
+ // then 'matched' else 'not_matched' end;
+ // Evaluation of both simple and searched forms of `SWITCH` special form
+ // expression is verified here.
+ testFile("SwitchSpecialFormExpressions");
+}
+
+TEST_F(RowExpressionOptimizerTest, inSpecialForm) {
+ // Files InSpecialFormExpressions{Input|Expected}.json contain the input
+ // and expected JSON representing the RowExpressions resulting from the
+ // following queries:
+ // 1. select 3 in (2, null, 3, 5);
+ // 2. select 'foo' in ('bar', 'baz', 'buz', 'blah');
+ // 3. select null in (2, null, 3, 5);
+ // 4. select ROW(1) IN (ROW(2), ROW(1), ROW(2));
+ // 5. select MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2],
+ // ARRAY[2, null]), null);
+ // 6. select MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3],
+ // ARRAY[1, null]));
+ // Evaluation of `IN` special form expression for primitive and complex types
+ // is verified here.
+ testFile("InSpecialFormExpressions");
+}
+
+TEST_F(RowExpressionOptimizerTest, specialFormRewrites) {
+ // Files SpecialFormExpressionRewrites{Input|Expected}.json contain the input
+ // and expected JSON representing the RowExpressions resulting from the
+ // following queries:
+ // 1. select if(1 < 2, 2, 3);
+ // 2. select (1 < 2) and (2 < 3);
+ // 3. select (1 < 2) or (2 < 3);
+ // Special form expression rewrites are verified here.
+ testFile("SpecialFormExpressionRewrites");
+}
+
+TEST_F(RowExpressionOptimizerTest, betweenCallExpression) {
+ // Files BetweenCallExpressions{Input|Expected}.json contain the input and
+ // expected JSON representing the RowExpressions resulting from the following
+ // queries:
+ // 1. select 2 between 3 and 4;
+ // 2. select null between 2 and 4;
+ // 3. select 'cc' between 'b' and 'd';
+ // Evaluation of `BETWEEN` call expression is verified here.
+ testFile("BetweenCallExpressions");
+}
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsExpected.json
new file mode 100644
index 0000000000000..e39a9b35eaf6b
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsExpected.json
@@ -0,0 +1,17 @@
+[
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAY0="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsInput.json
new file mode 100644
index 0000000000000..a65e8488ceb36
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsInput.json
@@ -0,0 +1,191 @@
+[
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABAAAAA=="
+ }
+ ],
+ "displayName": "BETWEEN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$between",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABAAAAA=="
+ }
+ ],
+ "displayName": "BETWEEN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$between",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ }
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(2)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAIAAAAAAgAAAGNj"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(1)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAEAAAAAAQAAAGI="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(1)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar(2)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar(2)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(1)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAEAAAAAAQAAAGQ="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(1)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar(2)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar(2)"
+ }
+ ],
+ "displayName": "BETWEEN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(2)",
+ "varchar(2)",
+ "varchar(2)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$between",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/CallAndSpecialFormExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/CallAndSpecialFormExpressionConversion.json
new file mode 100644
index 0000000000000..ddf74d986ed48
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/CallAndSpecialFormExpressionConversion.json
@@ -0,0 +1,83 @@
+[
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "variable",
+ "name": "c0",
+ "type": "bigint"
+ }
+ ],
+ "displayName": "presto.default.$operator$cast",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "bigint"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "variable",
+ "name": "c1",
+ "type": "boolean"
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ }
+ ],
+ "form": "OR",
+ "returnType": "boolean"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "variable",
+ "name": "c2",
+ "type": "integer"
+ }
+ ],
+ "displayName": "presto.default.$operator$cast",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ },
+ {
+ "@type": "variable",
+ "name": "c3",
+ "type": "boolean"
+ }
+ ],
+ "form": "AND",
+ "returnType": "boolean"
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/ConstantExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/ConstantExpressionConversion.json
new file mode 100644
index 0000000000000..0e962d0153dc6
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/ConstantExpressionConversion.json
@@ -0,0 +1,62 @@
+[
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "constant",
+ "type": "tinyint",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAw="
+ },
+ {
+ "@type": "constant",
+ "type": "smallint",
+ "valueBlock": "CwAAAFNIT1JUX0FSUkFZAQAAAAB7AA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAA0gQAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "bigint",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAADkwAAAAAAAA"
+ },
+ {
+ "@type": "constant",
+ "type": "bigint",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAP////////9/"
+ },
+ {
+ "@type": "constant",
+ "type": "real",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAeen2Qg=="
+ },
+ {
+ "@type": "constant",
+ "type": "double",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAK36XG1FSpNA"
+ },
+ {
+ "@type": "constant",
+ "type": "double",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAAAAAAAAAPh/"
+ },
+ {
+ "@type": "constant",
+ "type": "varchar",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAQAAAAABAAAAGFiY2Q="
+ },
+ {
+ "@type": "constant",
+ "type": "bigint",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAaA="
+ },
+ {
+ "@type": "constant",
+ "type": "varchar",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAAAAAABhQAAAAA="
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsExpected.json
new file mode 100644
index 0000000000000..2c3009fd58d53
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsExpected.json
@@ -0,0 +1,32 @@
+[
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAYU="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAYU="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA="
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsInput.json
new file mode 100644
index 0000000000000..6ed1e7ebeab1a
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsInput.json
@@ -0,0 +1,895 @@
+[
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 10,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 10,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABQAAAA=="
+ }
+ ],
+ "form": "IN",
+ "returnType": "boolean",
+ "sourceLocation": {
+ "column": 10,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(3)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGZvbw=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(3)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar(4)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar(4)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(3)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGJhcg=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(3)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar(4)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar(4)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(3)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGJheg=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(3)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar(4)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar(4)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(3)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGJ1eg=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(3)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar(4)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar(4)"
+ },
+ {
+ "@type": "constant",
+ "type": "varchar(4)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAQAAAAABAAAAGJsYWg="
+ }
+ ],
+ "form": "IN",
+ "returnType": "boolean"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 13,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 13,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABQAAAA=="
+ }
+ ],
+ "form": "IN",
+ "returnType": "boolean",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "form": "ROW_CONSTRUCTOR",
+ "returnType": "row(integer)"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "form": "ROW_CONSTRUCTOR",
+ "returnType": "row(integer)"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "form": "ROW_CONSTRUCTOR",
+ "returnType": "row(integer)"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "form": "ROW_CONSTRUCTOR",
+ "returnType": "row(integer)"
+ }
+ ],
+ "form": "IN",
+ "returnType": "boolean"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "map",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "array(integer)",
+ "array(integer)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.map",
+ "returnType": "map(integer,integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "map(integer,integer)",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "map",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "array(integer)",
+ "array(integer)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.map",
+ "returnType": "map(integer,integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "map(integer,integer)",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ }
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 72,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "map(integer,integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "map(integer,integer)",
+ "sourceLocation": {
+ "column": 72,
+ "line": 1
+ }
+ }
+ ],
+ "form": "IN",
+ "returnType": "boolean",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "map",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "array(integer)",
+ "array(integer)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.map",
+ "returnType": "map(integer,integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "map(integer,integer)",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(integer)",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "map",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "array(integer)",
+ "array(integer)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.map",
+ "returnType": "map(integer,integer)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "map(integer,integer)",
+ "sourceLocation": {
+ "column": 64,
+ "line": 1
+ }
+ }
+ ],
+ "displayName": "=",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "map(integer,integer)",
+ "map(integer,integer)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$equal",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/RowConstructorExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/RowConstructorExpressionConversion.json
new file mode 100644
index 0000000000000..8ff38f4150f41
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/RowConstructorExpressionConversion.json
@@ -0,0 +1,46 @@
+[
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "bigint",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAEDiAQAAAAAA"
+ },
+ {
+ "@type": "constant",
+ "type": "real",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAH4VFQQ=="
+ },
+ {
+ "@type": "constant",
+ "type": "varchar",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGFiYw=="
+ }
+ ],
+ "form": "ROW_CONSTRUCTOR",
+ "returnType": "row(bigint,real,varchar)"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "real",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAADAfw=="
+ },
+ {
+ "@type": "constant",
+ "type": "varchar",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAAAAAAAAAAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAA////fw=="
+ }
+ ],
+ "form": "ROW_CONSTRUCTOR",
+ "returnType": "row(real,varchar,integer)"
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsExpected.json
new file mode 100644
index 0000000000000..e9b014fecd3b2
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsExpected.json
@@ -0,0 +1,17 @@
+[
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "double",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAAAAAAAAADRA"
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsInput.json
new file mode 100644
index 0000000000000..dd2a0497703eb
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsInput.json
@@ -0,0 +1,278 @@
+[
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "displayName": "ADD",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$add",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAACwAAAA=="
+ }
+ ],
+ "displayName": "NEGATION",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$negation",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer"
+ }
+ ],
+ "displayName": "abs",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.abs",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer"
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "double",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "double"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "decimal(2,1)",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAACIAAAAAAAAA"
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "decimal(2,1)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "double",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "double"
+ }
+ ],
+ "displayName": "ceil",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "double"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.ceil",
+ "returnType": "double",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "double"
+ }
+ ],
+ "displayName": "ADD",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "double",
+ "double"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$add",
+ "returnType": "double",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "double"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "decimal(2,1)",
+ "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAADgAAAAAAAAA"
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "decimal(2,1)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "double",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "double"
+ }
+ ],
+ "displayName": "floor",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "double"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.floor",
+ "returnType": "double",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "double"
+ }
+ ],
+ "displayName": "ADD",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "double",
+ "double"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$add",
+ "returnType": "double",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "double"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ }
+ ],
+ "displayName": "BETWEEN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$between",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesExpected.json
new file mode 100644
index 0000000000000..2ce6acb1ab46e
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesExpected.json
@@ -0,0 +1,17 @@
+[
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesInput.json
new file mode 100644
index 0000000000000..77802722541b0
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesInput.json
@@ -0,0 +1,193 @@
+[
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "displayName": "LESS_THAN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$less_than",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ }
+ ],
+ "form": "IF",
+ "returnType": "integer"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "displayName": "LESS_THAN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$less_than",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ }
+ ],
+ "displayName": "LESS_THAN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$less_than",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ }
+ ],
+ "form": "AND",
+ "returnType": "boolean"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "displayName": "LESS_THAN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$less_than",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA=="
+ }
+ ],
+ "displayName": "LESS_THAN",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$less_than",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ }
+ ],
+ "form": "OR",
+ "returnType": "boolean"
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsExpected.json
new file mode 100644
index 0000000000000..40667e2976858
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsExpected.json
@@ -0,0 +1,32 @@
+[
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "varchar",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAcAAAAABwAAAG1hdGNoZWQ="
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsInput.json
new file mode 100644
index 0000000000000..5290b996ce15c
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsInput.json
@@ -0,0 +1,596 @@
+[
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIAAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "displayName": "ADD",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$add",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer"
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 8,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIgAAAA=="
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 27,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "AwAAAFJMRQEAAAAJAAAASU5UX0FSUkFZAQAAAAGA"
+ }
+ ],
+ "form": "SWITCH",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 8,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 6,
+ "line": 1
+ },
+ "type": "unknown",
+ "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "unknown"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean",
+ "sourceLocation": {
+ "column": 6,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIAAAAA=="
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 11,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA=="
+ }
+ ],
+ "form": "SWITCH",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 6,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "bigint",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "bigint"
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAAAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAAAAAA=="
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 9,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 23,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "variable",
+ "name": "unbound_long",
+ "type": "bigint"
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA=="
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 43,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "form": "SWITCH",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 9,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ },
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIAAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "displayName": "ADD",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$add",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer"
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 6,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ },
+ "type": "integer",
+ "valueBlock": "AwAAAFJMRQEAAAAJAAAASU5UX0FSUkFZAQAAAAGA"
+ }
+ ],
+ "form": "SWITCH",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ },
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 6,
+ "line": 1
+ }
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIgAAAA=="
+ },
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "displayName": "SUBTRACT",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer",
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$subtract",
+ "returnType": "integer",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "integer"
+ }
+ ],
+ "form": "SWITCH",
+ "returnType": "integer",
+ "sourceLocation": {
+ "column": 1,
+ "line": 1
+ }
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "constant",
+ "sourceLocation": {
+ "column": 83,
+ "line": 1
+ },
+ "type": "boolean",
+ "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE="
+ },
+ {
+ "@type": "special",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "bigint",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "bigint"
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "bigint"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(bigint)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(bigint)"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "integer",
+ "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA=="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "integer"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "bigint",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "bigint"
+ }
+ ],
+ "displayName": "ARRAY",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "bigint"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.array_constructor",
+ "returnType": "array(bigint)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "array(bigint)"
+ }
+ ],
+ "displayName": "EQUAL",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "array(bigint)",
+ "array(bigint)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$equal",
+ "returnType": "boolean",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "boolean"
+ },
+ {
+ "@type": "call",
+ "arguments": [
+ {
+ "@type": "constant",
+ "type": "varchar(7)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAcAAAAABwAAAG1hdGNoZWQ="
+ }
+ ],
+ "displayName": "CAST",
+ "functionHandle": {
+ "@type": "$static",
+ "signature": {
+ "argumentTypes": [
+ "varchar(7)"
+ ],
+ "kind": "SCALAR",
+ "longVariableConstraints": [],
+ "name": "presto.default.$operator$cast",
+ "returnType": "varchar(11)",
+ "typeVariableConstraints": [],
+ "variableArity": false
+ }
+ },
+ "returnType": "varchar(11)"
+ }
+ ],
+ "form": "WHEN",
+ "returnType": "varchar(11)",
+ "sourceLocation": {
+ "column": 36,
+ "line": 1
+ }
+ },
+ {
+ "@type": "constant",
+ "type": "varchar(11)",
+ "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAsAAAAACwAAAG5vdF9tYXRjaGVk"
+ }
+ ],
+ "form": "SWITCH",
+ "returnType": "varchar(11)",
+ "sourceLocation": {
+ "column": 83,
+ "line": 1
+ }
+ }
+]
diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/VariableExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/VariableExpressionConversion.json
new file mode 100644
index 0000000000000..75f524842788e
--- /dev/null
+++ b/presto-native-execution/presto_cpp/main/types/tests/data/VariableExpressionConversion.json
@@ -0,0 +1,27 @@
+[
+ {
+ "@type": "variable",
+ "name": "c0",
+ "type": "varchar"
+ },
+ {
+ "@type": "variable",
+ "name": "c1",
+ "type": "array(bigint)"
+ },
+ {
+ "@type": "variable",
+ "name": "c2",
+ "type": "map(integer,real)"
+ },
+ {
+ "@type": "variable",
+ "name": "c3",
+ "type": "row(smallint,varchar,double)"
+ },
+ {
+ "@type": "variable",
+ "name": "c4",
+ "type": "map(array(tinyint),boolean)"
+ }
+]
diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java
index 825fbb1bc4387..ab9ca3568a404 100644
--- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java
+++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java
@@ -39,6 +39,7 @@
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
+import java.net.ServerSocket;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
@@ -496,6 +497,11 @@ public static NativeQueryRunnerParameters getNativeQueryRunnerParameters()
}
public static Optional> getExternalWorkerLauncher(String catalogName, String prestoServerPath, int cacheMaxSize, Optional remoteFunctionServerUds, Boolean failOnNestedLoopJoin, boolean isCoordinatorSidecarEnabled)
+ {
+ return getExternalWorkerLauncher(catalogName, prestoServerPath, OptionalInt.empty(), cacheMaxSize, remoteFunctionServerUds, failOnNestedLoopJoin, isCoordinatorSidecarEnabled);
+ }
+
+ public static Optional> getExternalWorkerLauncher(String catalogName, String prestoServerPath, OptionalInt port, int cacheMaxSize, Optional remoteFunctionServerUds, Boolean failOnNestedLoopJoin, boolean isCoordinatorSidecarEnabled)
{
return
Optional.of((workerIndex, discoveryUri) -> {
@@ -509,7 +515,8 @@ public static Optional> getExternalWorkerLaunc
String configProperties = format("discovery.uri=%s%n" +
"presto.version=testversion%n" +
"system-memory-gb=4%n" +
- "http-server.http.port=0%n", discoveryUri);
+ "native-sidecar=true%n" +
+ "http-server.http.port=%d", discoveryUri, port.orElse(0));
if (isCoordinatorSidecarEnabled) {
configProperties = format("%s%n" +
@@ -573,6 +580,45 @@ public static Optional> getExternalWorkerLaunc
});
}
+ public static Process getNativeSidecarProcess(URI discoveryUri, int port)
+ throws IOException
+ {
+ NativeQueryRunnerParameters nativeQueryRunnerParameters = getNativeQueryRunnerParameters();
+ return getNativeSidecarProcess(nativeQueryRunnerParameters.serverBinary.toString(), discoveryUri, port);
+ }
+
+ public static Process getNativeSidecarProcess(String prestoServerPath, URI discoveryUri, int port)
+ throws IOException
+ {
+ Path tempDirectoryPath = Files.createTempDirectory(PrestoNativeQueryRunnerUtils.class.getSimpleName());
+ log.info("Temp directory for Sidecar: %s", tempDirectoryPath.toString());
+
+ // Write config file
+ String configProperties = format("discovery.uri=%s%n" +
+ "presto.version=testversion%n" +
+ "system-memory-gb=4%n" +
+ "native-sidecar=true%n" +
+ "http-server.http.port=%d", discoveryUri, port);
+
+ Files.write(tempDirectoryPath.resolve("config.properties"), configProperties.getBytes());
+ Files.write(tempDirectoryPath.resolve("node.properties"),
+ format("node.id=%s%n" +
+ "node.internal-address=127.0.0.1%n" +
+ "node.environment=testing%n" +
+ "node.location=test-location", UUID.randomUUID()).getBytes());
+
+ // TODO: sidecars require that a catalog directory exist
+ Path catalogDirectoryPath = tempDirectoryPath.resolve("catalog");
+ Files.createDirectory(catalogDirectoryPath);
+
+ return new ProcessBuilder(prestoServerPath)
+ .directory(tempDirectoryPath.toFile())
+ .redirectErrorStream(true)
+ .redirectOutput(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile()))
+ .redirectError(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile()))
+ .start();
+ }
+
public static class NativeQueryRunnerParameters
{
public final Path serverBinary;
@@ -619,4 +665,12 @@ private static Table createHiveSymlinkTable(String databaseName, String tableNam
Optional.empty(),
Optional.empty());
}
+
+ public static int findRandomPortForWorker()
+ throws IOException
+ {
+ try (ServerSocket socket = new ServerSocket(0)) {
+ return socket.getLocalPort();
+ }
+ }
}
diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml
index 6b8f1badd289d..4a62cab3f77a5 100644
--- a/presto-native-sidecar-plugin/pom.xml
+++ b/presto-native-sidecar-plugin/pom.xml
@@ -200,6 +200,12 @@
test
+
+ javax.ws.rs
+ javax.ws.rs-api
+ test
+
+
com.facebook.presto
presto-client
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java
new file mode 100644
index 0000000000000..8bf10b20223d2
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.google.inject.BindingAnnotation;
+
+import java.lang.annotation.Retention;
+
+import static java.lang.annotation.RetentionPolicy.RUNTIME;
+
+@Retention(RUNTIME)
+@BindingAnnotation
+public @interface ForSidecarInfo
+{
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java
new file mode 100644
index 0000000000000..1c7dfd30b38e0
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java
@@ -0,0 +1,353 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.SourceLocation;
+import com.facebook.presto.spi.function.FunctionMetadataManager;
+import com.facebook.presto.spi.function.StandardFunctionResolution;
+import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.ConstantExpression;
+import com.facebook.presto.spi.relation.ExpressionOptimizer;
+import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.RowExpressionVisitor;
+import com.facebook.presto.spi.relation.SpecialFormExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayDeque;
+import java.util.HashMap;
+import java.util.IdentityHashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.Function;
+
+import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED;
+import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Collections.newSetFromMap;
+import static java.util.Objects.requireNonNull;
+import static java.util.stream.Collectors.toMap;
+
+public class NativeExpressionOptimizer
+ implements ExpressionOptimizer
+{
+ private final FunctionMetadataManager functionMetadataManager;
+ private final StandardFunctionResolution resolution;
+ private final NativeSidecarExpressionInterpreter rowExpressionInterpreterService;
+
+ public NativeExpressionOptimizer(
+ NativeSidecarExpressionInterpreter rowExpressionInterpreterService,
+ FunctionMetadataManager functionMetadataManager,
+ StandardFunctionResolution resolution)
+ {
+ this.rowExpressionInterpreterService = requireNonNull(rowExpressionInterpreterService, "rowExpressionInterpreterService is null");
+ this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
+ this.resolution = requireNonNull(resolution, "resolution is null");
+ }
+
+ @Override
+ public RowExpression optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver)
+ {
+ // Collect expressions to optimize
+ CollectingVisitor collectingVisitor = new CollectingVisitor(functionMetadataManager, level, resolution);
+ expression.accept(collectingVisitor, variableResolver);
+ List expressionsToOptimize = collectingVisitor.getExpressionsToOptimize();
+ Map expressions = expressionsToOptimize.stream()
+ .collect(toMap(
+ Function.identity(),
+ rowExpression -> rowExpression.accept(
+ new ReplacingVisitor(
+ variable -> toRowExpression(variable.getSourceLocation(), variableResolver.apply(variable), variable.getType())),
+ null),
+ (a, b) -> a));
+ if (expressions.isEmpty()) {
+ return expression;
+ }
+
+ // Constants can be trivially replaced without invoking the interpreter. Move them into a separate map.
+ Map constants = new HashMap<>();
+ Iterator> entries = expressions.entrySet().iterator();
+ while (entries.hasNext()) {
+ Map.Entry entry = entries.next();
+ if (entry.getValue() instanceof ConstantExpression) {
+ constants.put(entry.getKey(), entry.getValue());
+ entries.remove();
+ }
+ }
+
+ // Optimize the expressions using the sidecar interpreter
+ Map replacements = new HashMap<>();
+ if (!expressions.isEmpty()) {
+ replacements.putAll(rowExpressionInterpreterService.optimizeBatch(session, expressions, level));
+ }
+
+ // Add back in the constants
+ replacements.putAll(constants);
+
+ // Replace all the expressions in the original expression with the optimized expressions
+ return toRowExpression(expression.getSourceLocation(), expression.accept(new ReplacingVisitor(replacements), null), expression.getType());
+ }
+
+ /**
+ * This visitor collects expressions that can be optimized by the sidecar interpreter.
+ */
+ private static class CollectingVisitor
+ implements RowExpressionVisitor
+ {
+ private final FunctionMetadataManager functionMetadataManager;
+ private final Level optimizationLevel;
+ private final StandardFunctionResolution resolution;
+ private final Set expressionsToOptimize = newSetFromMap(new IdentityHashMap<>());
+ private final Set hasOptimizedChildren = newSetFromMap(new IdentityHashMap<>());
+
+ public CollectingVisitor(FunctionMetadataManager functionMetadataManager, Level optimizationLevel, StandardFunctionResolution resolution)
+ {
+ this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
+ this.optimizationLevel = requireNonNull(optimizationLevel, "optimizationLevel is null");
+ this.resolution = requireNonNull(resolution, "resolution is null");
+ }
+
+ @Override
+ public Void visitExpression(RowExpression node, Object context)
+ {
+ visitNode(node, false);
+ return null;
+ }
+
+ @Override
+ public Void visitConstant(ConstantExpression node, Object context)
+ {
+ visitNode(node, true);
+ return null;
+ }
+
+ @Override
+ public Void visitVariableReference(VariableReferenceExpression node, Object context)
+ {
+ Object value = ((Function) context).apply(node);
+ if (value == null || value instanceof RowExpression) {
+ visitNode(node, false);
+ return null;
+ }
+ visitNode(node, true);
+ return null;
+ }
+
+ @Override
+ public Void visitCall(CallExpression node, Object context)
+ {
+ // If the optimization level is not EVALUATED, then we cannot optimize non-deterministic functions
+ boolean isDeterministic = functionMetadataManager.getFunctionMetadata(node.getFunctionHandle()).isDeterministic();
+ boolean canBeEvaluated = (optimizationLevel.ordinal() < EVALUATED.ordinal() && isDeterministic) ||
+ optimizationLevel.ordinal() == EVALUATED.ordinal();
+
+ // All arguments must be optimizable in order to evaluate the function
+ boolean allConstantFoldable = node.getArguments().stream()
+ .peek(child -> child.accept(this, context))
+ .reduce(true, (a, b) -> canBeOptimized(b) && a, (a, b) -> a && b);
+ if (canBeEvaluated && allConstantFoldable) {
+ visitNode(node, true);
+ return null;
+ }
+
+ // If it's a cast and the type is already the same, then it's constant foldable
+ if (resolution.isCastFunction(node.getFunctionHandle())
+ && node.getArguments().size() == 1
+ && node.getType().equals(node.getArguments().get(0).getType())) {
+ visitNode(node, true);
+ return null;
+ }
+ visitNode(node, false);
+ return null;
+ }
+
+ @Override
+ public Void visitSpecialForm(SpecialFormExpression node, Object context)
+ {
+ // Most special form expressions short circuit, meaning that they potentially don't evaluate all arguments. For example, the AND expression
+ // will stop evaluating arguments as soon as it finds a false argument. Because a sub-expression could be simplified into a constant, and this
+ // constant could cause the expression to short circuit, if there is at least one argument which is optimizable, then the entire expression should
+ // be sent to the sidecar to be optimized.
+ boolean anyArgumentsOptimizable = node.getArguments().stream()
+ .peek(child -> child.accept(this, context))
+ .reduce(false, (a, b) -> canBeOptimized(b) || a, (a, b) -> a || b);
+
+ // If all arguments are constant foldable, then the whole expression is constant foldable
+ if (anyArgumentsOptimizable) {
+ visitNode(node, true);
+ return null;
+ }
+
+ // If the special form is COALESCE, then we can optimize it if there are any duplicate arguments
+ if (node.getForm() == COALESCE) {
+ ImmutableSet.Builder builder = ImmutableSet.builder();
+ // Check if there's any duplicate arguments, these can be de-duplicated
+ for (RowExpression argument : node.getArguments()) {
+ // The duplicate argument must either be a leaf (variable reference) or constant foldable
+ if (canBeOptimized(argument)) {
+ builder.add(argument);
+ }
+ }
+ // If there were any duplicates, or if there's no arguments (cancel out), or if there's only one argument (just return it),
+ // then it's also constant foldable
+ boolean canBeOptimized = builder.build().size() <= node.getArguments().size() || node.getArguments().size() <= 1;
+ if (canBeOptimized) {
+ visitNode(node, true);
+ return null;
+ }
+ }
+ visitNode(node, false);
+ return null;
+ }
+
+ @Override
+ public Void visitLambda(LambdaDefinitionExpression node, Object context)
+ {
+ node.getBody().accept(this, (Function) variable -> variable);
+ if (canBeOptimized(node.getBody())) {
+ visitNode(node, true);
+ return null;
+ }
+ visitNode(node, false);
+ return null;
+ }
+
+ public boolean canBeOptimized(RowExpression rowExpression)
+ {
+ return expressionsToOptimize.contains(rowExpression);
+ }
+
+ private void visitNode(RowExpression node, boolean canBeOptimized)
+ {
+ requireNonNull(node, "node is null");
+ // If the present node can be optimized, then we send the whole expression. Because an expression may consist of many
+ // sub-expressions, we need to ensure that we don't send the sub-expression along with its parent expression. For example,
+ // if we have the expression (a + b) + c, and we can optimize a + b, then we don't want to send a + b to the sidecar, because
+ // it will be optimized twice. Instead, we want to send (a + b) + c to the sidecar, and then remove a + b from the list of
+ // expressions to optimize.
+ // We need to traverse the entire subtree of possible expressions to optimize because some special form expressions may
+ // short circuit, and we need to ensure that we don't send the sub-expression to the sidecar if the parent expression is
+ // constant foldable. For example, consider the expression false AND (true OR a). Although the expression true OR a is
+ // constant foldable, the parent expression is also constant foldable, and we don't want to send both the parent expression
+ // and the sub-expression to the sidecar because the entire expression can be constant folded in one pass.
+ if (canBeOptimized) {
+ ArrayDeque queue = new ArrayDeque<>(node.getChildren());
+ while (!queue.isEmpty()) {
+ RowExpression expression = queue.poll();
+ if (hasOptimizedChildren.remove(expression)) {
+ expressionsToOptimize.remove(expression);
+ queue.addAll(expression.getChildren());
+ }
+ }
+ expressionsToOptimize.add(node);
+ hasOptimizedChildren.add(node);
+ }
+ else if (node.getChildren().stream().anyMatch(hasOptimizedChildren::contains)) {
+ hasOptimizedChildren.add(node);
+ }
+ }
+
+ public List getExpressionsToOptimize()
+ {
+ return ImmutableList.copyOf(expressionsToOptimize);
+ }
+ }
+
+ /**
+ * This visitor replaces expressions with their optimized versions.
+ */
+ private static class ReplacingVisitor
+ implements RowExpressionVisitor
+ {
+ private final Function resolver;
+
+ public ReplacingVisitor(Map replacements)
+ {
+ requireNonNull(replacements, "replacements is null");
+ this.resolver = i -> replacements.getOrDefault(i, i);
+ }
+
+ public ReplacingVisitor(Function variableResolver)
+ {
+ requireNonNull(variableResolver, "variableResolver is null");
+ this.resolver = i -> i instanceof VariableReferenceExpression ? variableResolver.apply((VariableReferenceExpression) i) : i;
+ }
+
+ private boolean canBeReplaced(RowExpression rowExpression)
+ {
+ return resolver.apply(rowExpression) != rowExpression;
+ }
+
+ @Override
+ public RowExpression visitExpression(RowExpression originalExpression, Void context)
+ {
+ return resolver.apply(originalExpression);
+ }
+
+ @Override
+ public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
+ {
+ if (canBeReplaced(lambda.getBody())) {
+ return new LambdaDefinitionExpression(
+ lambda.getSourceLocation(),
+ lambda.getArgumentTypes(),
+ lambda.getArguments(),
+ toRowExpression(lambda.getSourceLocation(), resolver.apply(lambda.getBody()), lambda.getBody().getType()));
+ }
+ return lambda;
+ }
+
+ @Override
+ public RowExpression visitCall(CallExpression call, Void context)
+ {
+ if (canBeReplaced(call)) {
+ return resolver.apply(call);
+ }
+ List updatedArguments = call.getArguments().stream()
+ .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType()))
+ .collect(toImmutableList());
+ return new CallExpression(call.getSourceLocation(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), updatedArguments);
+ }
+
+ @Override
+ public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context)
+ {
+ if (canBeReplaced(specialForm)) {
+ return resolver.apply(specialForm);
+ }
+ List updatedArguments = specialForm.getArguments().stream()
+ .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType()))
+ .collect(toImmutableList());
+ return new SpecialFormExpression(specialForm.getSourceLocation(), specialForm.getForm(), specialForm.getType(), updatedArguments);
+ }
+ }
+
+ private static RowExpression toRowExpression(Optional sourceLocation, Object object, Type type)
+ {
+ requireNonNull(type, "type is null");
+
+ if (object instanceof RowExpression) {
+ return (RowExpression) object;
+ }
+
+ return new ConstantExpression(sourceLocation, object, type);
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java
new file mode 100644
index 0000000000000..4157cffe2c839
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.airlift.bootstrap.Bootstrap;
+import com.facebook.airlift.json.JsonModule;
+import com.facebook.presto.spi.classloader.ThreadContextClassLoader;
+import com.facebook.presto.spi.relation.ExpressionOptimizer;
+import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext;
+import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory;
+import com.google.inject.Injector;
+
+import java.util.Map;
+
+import static java.util.Objects.requireNonNull;
+
+public class NativeExpressionOptimizerFactory
+ implements ExpressionOptimizerFactory
+{
+ private final ClassLoader classLoader;
+
+ public NativeExpressionOptimizerFactory(ClassLoader classLoader)
+ {
+ this.classLoader = requireNonNull(classLoader, "classLoader is null");
+ }
+
+ @Override
+ public ExpressionOptimizer createOptimizer(Map config, ExpressionOptimizerContext context)
+ {
+ requireNonNull(context, "context is null");
+
+ try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
+ Bootstrap app = new Bootstrap(
+ new JsonModule(),
+ new NativeExpressionsCommunicationModule(),
+ new NativeExpressionsModule(context.getNodeManager(), context.getRowExpressionSerde(), context.getFunctionMetadataManager(), context.getFunctionResolution()));
+
+ Injector injector = app
+ .noStrictConfig()
+ .doNotInitializeLogging()
+ .setRequiredConfigurationProperties(config)
+ .quiet()
+ .initialize();
+ return injector.getInstance(NativeExpressionOptimizerProvider.class).createOptimizer();
+ }
+ }
+
+ @Override
+ public String getName()
+ {
+ return "native";
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java
new file mode 100644
index 0000000000000..8efe3b8a77eb5
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.presto.spi.function.FunctionMetadataManager;
+import com.facebook.presto.spi.function.StandardFunctionResolution;
+import com.facebook.presto.spi.relation.ExpressionOptimizer;
+
+import javax.inject.Inject;
+
+import static java.util.Objects.requireNonNull;
+
+public class NativeExpressionOptimizerProvider
+{
+ private final NativeSidecarExpressionInterpreter expressionInterpreterService;
+ private final FunctionMetadataManager functionMetadataManager;
+ private final StandardFunctionResolution resolution;
+
+ @Inject
+ public NativeExpressionOptimizerProvider(NativeSidecarExpressionInterpreter expressionInterpreterService, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution resolution)
+ {
+ this.expressionInterpreterService = requireNonNull(expressionInterpreterService, "expressionInterpreterService is null");
+ this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
+ this.resolution = requireNonNull(resolution, "resolution is null");
+ }
+
+ public ExpressionOptimizer createOptimizer()
+ {
+ return new NativeExpressionOptimizer(expressionInterpreterService, functionMetadataManager, resolution);
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java
new file mode 100644
index 0000000000000..fdc6da53f91bc
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.google.inject.Binder;
+import com.google.inject.Module;
+
+import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder;
+
+public class NativeExpressionsCommunicationModule
+ implements Module
+{
+ @Override
+ public void configure(Binder binder)
+ {
+ httpClientBinder(binder).bindHttpClient("sidecar", ForSidecarInfo.class);
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java
new file mode 100644
index 0000000000000..046b4e60eac7a
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.airlift.json.JsonModule;
+import com.facebook.presto.spi.NodeManager;
+import com.facebook.presto.spi.RowExpressionSerde;
+import com.facebook.presto.spi.function.FunctionMetadataManager;
+import com.facebook.presto.spi.function.StandardFunctionResolution;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.google.inject.Binder;
+import com.google.inject.Module;
+import com.google.inject.Scopes;
+
+import static com.facebook.airlift.json.JsonBinder.jsonBinder;
+import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
+import static java.util.Objects.requireNonNull;
+
+public class NativeExpressionsModule
+ implements Module
+{
+ private final NodeManager nodeManager;
+ private final RowExpressionSerde rowExpressionSerde;
+ private final FunctionMetadataManager functionMetadataManager;
+ private final StandardFunctionResolution functionResolution;
+
+ public NativeExpressionsModule(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution)
+ {
+ this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
+ this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null");
+ this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
+ this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
+ }
+
+ @Override
+ public void configure(Binder binder)
+ {
+ // Core dependencies
+ binder.bind(NodeManager.class).toInstance(nodeManager);
+ binder.bind(RowExpressionSerde.class).toInstance(rowExpressionSerde);
+ binder.bind(FunctionMetadataManager.class).toInstance(functionMetadataManager);
+ binder.bind(StandardFunctionResolution.class).toInstance(functionResolution);
+
+ // JSON dependencies and setup
+ binder.install(new JsonModule());
+ jsonBinder(binder).addDeserializerBinding(RowExpression.class).to(RowExpressionDeserializer.class).in(Scopes.SINGLETON);
+ jsonBinder(binder).addSerializerBinding(RowExpression.class).to(RowExpressionSerializer.class).in(Scopes.SINGLETON);
+ jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class);
+
+ binder.bind(NativeSidecarExpressionInterpreter.class).in(Scopes.SINGLETON);
+
+ // The main service provider
+ binder.bind(NativeExpressionOptimizerProvider.class).in(Scopes.SINGLETON);
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java
new file mode 100644
index 0000000000000..53d3530156d23
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.airlift.http.client.HttpClient;
+import com.facebook.airlift.http.client.HttpUriBuilder;
+import com.facebook.airlift.http.client.Request;
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.Node;
+import com.facebook.presto.spi.NodeManager;
+import com.facebook.presto.spi.relation.ExpressionOptimizer;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import javax.inject.Inject;
+
+import java.net.URI;
+import java.util.List;
+import java.util.Map;
+
+import static com.facebook.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator;
+import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
+import static com.facebook.airlift.http.client.Request.Builder.preparePost;
+import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.net.HttpHeaders.ACCEPT;
+import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
+import static com.google.common.net.MediaType.JSON_UTF_8;
+import static java.util.Objects.requireNonNull;
+
+public class NativeSidecarExpressionInterpreter
+{
+ private static final String PRESTO_TIME_ZONE_HEADER = "X-Presto-Time-Zone";
+ private static final String PRESTO_USER_HEADER = "X-Presto-User";
+ private static final String PRESTO_EXPRESSION_OPTIMIZER_LEVEL_HEADER = "X-Presto-Expression-Optimizer-Level";
+ private static final String SIDECAR_ENDPOINT = "/v1/expressions";
+ private final NodeManager nodeManager;
+ private final HttpClient httpClient;
+ private final JsonCodec> rowExpressionSerde;
+
+ @Inject
+ public NativeSidecarExpressionInterpreter(NodeManager nodeManager, @ForSidecarInfo HttpClient httpClient, JsonCodec> rowExpressionSerde)
+ {
+ this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
+ this.httpClient = requireNonNull(httpClient, "httpClient is null");
+ this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null");
+ }
+
+ public Map optimizeBatch(ConnectorSession session, Map expressions, ExpressionOptimizer.Level level)
+ {
+ ImmutableList.Builder originalExpressionsBuilder = ImmutableList.builder();
+ ImmutableList.Builder resolvedExpressionsBuilder = ImmutableList.builder();
+ for (Map.Entry entry : expressions.entrySet()) {
+ originalExpressionsBuilder.add(entry.getKey());
+ resolvedExpressionsBuilder.add(entry.getValue());
+ }
+ List originalExpressions = originalExpressionsBuilder.build();
+ List resolvedExpressions = resolvedExpressionsBuilder.build();
+
+ List optimizedExpressions = httpClient.execute(
+ getSidecarRequest(session, level, resolvedExpressions),
+ createJsonResponseHandler(rowExpressionSerde));
+ checkArgument(
+ optimizedExpressions.size() == resolvedExpressions.size(),
+ "Expected %s optimized expressions, but got %s",
+ resolvedExpressions.size(),
+ optimizedExpressions.size());
+
+ ImmutableMap.Builder result = ImmutableMap.builder();
+ for (int i = 0; i < optimizedExpressions.size(); i++) {
+ result.put(originalExpressions.get(i), optimizedExpressions.get(i));
+ }
+ return result.build();
+ }
+
+ private Request getSidecarRequest(ConnectorSession session, Level level, List resolvedExpressions)
+ {
+ return preparePost()
+ .setUri(getLocation())
+ .setBodyGenerator(jsonBodyGenerator(rowExpressionSerde, resolvedExpressions))
+ .setHeader(CONTENT_TYPE, JSON_UTF_8.toString())
+ .setHeader(ACCEPT, JSON_UTF_8.toString())
+ .setHeader(PRESTO_TIME_ZONE_HEADER, session.getSqlFunctionProperties().getTimeZoneKey().getId())
+ .setHeader(PRESTO_USER_HEADER, session.getUser())
+ .setHeader(PRESTO_EXPRESSION_OPTIMIZER_LEVEL_HEADER, level.name())
+ .build();
+ }
+
+ private URI getLocation()
+ {
+ Node sidecarNode = nodeManager.getSidecarNode();
+ return HttpUriBuilder.uriBuilder()
+ .scheme("http")
+ .host(sidecarNode.getHost())
+ .port(sidecarNode.getHostAndPort().getPort())
+ .appendPath(SIDECAR_ENDPOINT)
+ .build();
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java
new file mode 100644
index 0000000000000..ffc7039d4568b
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.presto.spi.RowExpressionSerde;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonDeserializer;
+import com.fasterxml.jackson.databind.jsontype.TypeDeserializer;
+import com.google.inject.Inject;
+
+import java.io.IOException;
+
+import static java.util.Objects.requireNonNull;
+
+public final class RowExpressionDeserializer
+ extends JsonDeserializer
+{
+ private final RowExpressionSerde rowExpressionSerde;
+
+ @Inject
+ public RowExpressionDeserializer(RowExpressionSerde rowExpressionSerde)
+ {
+ this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null");
+ }
+
+ @Override
+ public RowExpression deserialize(JsonParser jsonParser, DeserializationContext context)
+ throws IOException
+ {
+ return rowExpressionSerde.deserialize(jsonParser.readValueAsTree().toString());
+ }
+
+ @Override
+ public RowExpression deserializeWithType(JsonParser jsonParser, DeserializationContext context, TypeDeserializer typeDeserializer)
+ throws IOException
+ {
+ return deserialize(jsonParser, context);
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java
new file mode 100644
index 0000000000000..6ca5dcc4354f4
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.presto.spi.RowExpressionSerde;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.databind.JsonSerializer;
+import com.fasterxml.jackson.databind.SerializerProvider;
+import com.fasterxml.jackson.databind.jsontype.TypeSerializer;
+import com.google.inject.Inject;
+
+import java.io.IOException;
+
+import static java.util.Objects.requireNonNull;
+
+public final class RowExpressionSerializer
+ extends JsonSerializer
+{
+ private final RowExpressionSerde rowExpressionSerde;
+
+ @Inject
+ public RowExpressionSerializer(RowExpressionSerde rowExpressionSerde)
+ {
+ this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null");
+ }
+
+ @Override
+ public void serialize(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
+ throws IOException
+ {
+ jsonGenerator.writeRawValue(rowExpressionSerde.serialize(rowExpression));
+ }
+
+ @Override
+ public void serializeWithType(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider, TypeSerializer typeSerializer)
+ throws IOException
+ {
+ serialize(rowExpression, jsonGenerator, serializerProvider);
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java
index 7da7ba4d4fbb9..491d93d6f812f 100644
--- a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java
+++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java
@@ -13,6 +13,7 @@
*/
package com.facebook.presto.sidecar;
+import com.facebook.presto.session.sql.expressions.NativeExpressionOptimizerFactory;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory;
import com.facebook.presto.sidecar.nativechecker.NativePlanCheckerProviderFactory;
import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory;
@@ -20,6 +21,7 @@
import com.facebook.presto.spi.function.FunctionNamespaceManagerFactory;
import com.facebook.presto.spi.plan.PlanCheckerProviderFactory;
import com.facebook.presto.spi.session.WorkerSessionPropertyProviderFactory;
+import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory;
import com.google.common.collect.ImmutableList;
public class NativeSidecarPlugin
@@ -51,4 +53,9 @@ private static ClassLoader getClassLoader()
}
return classLoader;
}
+ @Override
+ public Iterable getExpressionOptimizerFactories()
+ {
+ return ImmutableList.of(new NativeExpressionOptimizerFactory(getClassLoader()));
+ }
}
diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimizer.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimizer.java
new file mode 100644
index 0000000000000..b47ffa16a6f20
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimizer.java
@@ -0,0 +1,1706 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.BlockEncodingManager;
+import com.facebook.presto.common.block.BlockEncodingSerde;
+import com.facebook.presto.common.block.BlockSerdeUtil;
+import com.facebook.presto.common.type.ArrayType;
+import com.facebook.presto.common.type.Decimals;
+import com.facebook.presto.common.type.SqlTimestampWithTimeZone;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
+import com.facebook.presto.metadata.HandleResolver;
+import com.facebook.presto.metadata.InMemoryNodeManager;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.nodeManager.PluginNodeManager;
+import com.facebook.presto.operator.scalar.FunctionAssertions;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.ConstantExpression;
+import com.facebook.presto.spi.relation.ExpressionOptimizer;
+import com.facebook.presto.spi.relation.InputReferenceExpression;
+import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.SpecialFormExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.TestingRowExpressionTranslator;
+import com.facebook.presto.sql.parser.ParsingOptions;
+import com.facebook.presto.sql.parser.SqlParser;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.TypeProvider;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.facebook.presto.sql.tree.Expression;
+import com.google.common.base.Joiner;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import io.airlift.slice.DynamicSliceOutput;
+import io.airlift.slice.SliceOutput;
+import io.airlift.slice.Slices;
+import org.intellij.lang.annotations.Language;
+import org.joda.time.DateTime;
+import org.joda.time.DateTimeZone;
+import org.joda.time.LocalDate;
+import org.joda.time.LocalTime;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.math.BigInteger;
+import java.net.URI;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.IntStream;
+
+import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
+import static com.facebook.presto.common.type.DateType.DATE;
+import static com.facebook.presto.common.type.DecimalType.createDecimalType;
+import static com.facebook.presto.common.type.DoubleType.DOUBLE;
+import static com.facebook.presto.common.type.IntegerType.INTEGER;
+import static com.facebook.presto.common.type.TimeType.TIME;
+import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey;
+import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
+import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
+import static com.facebook.presto.common.type.VarcharType.VARCHAR;
+import static com.facebook.presto.common.type.VarcharType.createVarcharType;
+import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
+import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.findRandomPortForWorker;
+import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeSidecarProcess;
+import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION;
+import static com.facebook.presto.session.sql.expressions.TestNativeExpressions.getExpressionOptimizer;
+import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED;
+import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
+import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE;
+import static com.facebook.presto.sql.ExpressionFormatter.formatExpression;
+import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME;
+import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions;
+import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone;
+import static io.airlift.slice.Slices.utf8Slice;
+import static java.lang.String.format;
+import static java.util.Locale.ENGLISH;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.assertTrue;
+
+public class TestNativeExpressionOptimizer
+{
+ private static final int TEST_VARCHAR_TYPE_LENGTH = 17;
+ private static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder()
+ .put("bound_integer", INTEGER)
+ .put("bound_long", BIGINT)
+ .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH))
+ .put("bound_varbinary", VARBINARY)
+ .put("bound_double", DOUBLE)
+ .put("bound_boolean", BOOLEAN)
+ .put("bound_date", DATE)
+ .put("bound_time", TIME)
+ .put("bound_timestamp", TIMESTAMP)
+ .put("bound_pattern", VARCHAR)
+ .put("bound_null_string", VARCHAR)
+ .put("bound_decimal_short", createDecimalType(5, 2))
+ .put("bound_decimal_long", createDecimalType(23, 3))
+ .put("time", BIGINT) // for testing reserved identifiers
+ .put("unbound_integer", INTEGER)
+ .put("unbound_long", BIGINT)
+ .put("unbound_long2", BIGINT)
+ .put("unbound_long3", BIGINT)
+ .put("unbound_string", VARCHAR)
+ .put("unbound_double", DOUBLE)
+ .put("unbound_boolean", BOOLEAN)
+ .put("unbound_date", DATE)
+ .put("unbound_time", TIME)
+ .put("unbound_array", new ArrayType(BIGINT))
+ .put("unbound_timestamp", TIMESTAMP)
+ .put("unbound_interval", INTERVAL_DAY_TIME)
+ .put("unbound_pattern", VARCHAR)
+ .put("unbound_null_string", VARCHAR)
+ .build());
+
+ private static final SqlParser SQL_PARSER = new SqlParser();
+ private static final Metadata METADATA = createTestMetadataManager();
+ private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA);
+ private static final BlockEncodingSerde BLOCK_ENCODING_SERDE = new BlockEncodingManager();
+
+ private ExpressionOptimizer expressionOptimizer;
+ private Process sidecar;
+
+ public TestNativeExpressionOptimizer()
+ {
+ METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION));
+ setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager());
+ }
+
+ @BeforeClass
+ public void setup()
+ throws Exception
+ {
+ int port = findRandomPortForWorker();
+ URI sidecarUri = URI.create(format("http://127.0.0.1:%s", port));
+ sidecar = getNativeSidecarProcess(sidecarUri, port);
+
+ expressionOptimizer = getExpressionOptimizer(METADATA, new HandleResolver(), sidecarUri);
+ }
+
+ @AfterClass
+ public void tearDown()
+ {
+ sidecar.destroyForcibly();
+ }
+
+ @Test
+ public void testAnd()
+ {
+ assertOptimizedEquals("true and false", "false");
+ assertOptimizedEquals("false and true", "false");
+ assertOptimizedEquals("false and false", "false");
+
+ assertOptimizedEquals("true and null", "null");
+ assertOptimizedEquals("false and null", "false");
+ assertOptimizedEquals("null and true", "null");
+ assertOptimizedEquals("null and false", "false");
+ assertOptimizedEquals("null and null", "null");
+
+ assertOptimizedEquals("unbound_string='z' and true", "unbound_string='z'");
+ assertOptimizedEquals("unbound_string='z' and false", "false");
+ assertOptimizedEquals("true and unbound_string='z'", "unbound_string='z'");
+ assertOptimizedEquals("false and unbound_string='z'", "false");
+
+ assertOptimizedEquals("bound_string='z' and bound_long=1+1", "bound_string='z' and bound_long=2");
+ assertOptimizedEquals("random() > 0 and random() > 0", "random() > 0 and random() > 0");
+ }
+
+ @Test
+ public void testOr()
+ {
+ assertOptimizedEquals("true or true", "true");
+ assertOptimizedEquals("true or false", "true");
+ assertOptimizedEquals("false or true", "true");
+ assertOptimizedEquals("false or false", "false");
+
+ assertOptimizedEquals("true or null", "true");
+ assertOptimizedEquals("null or true", "true");
+ assertOptimizedEquals("null or null", "null");
+
+ assertOptimizedEquals("false or null", "null");
+ assertOptimizedEquals("null or false", "null");
+
+ assertOptimizedEquals("bound_string='z' or true", "true");
+ assertOptimizedEquals("bound_string='z' or false", "bound_string='z'");
+ assertOptimizedEquals("true or bound_string='z'", "true");
+ assertOptimizedEquals("false or bound_string='z'", "bound_string='z'");
+
+ assertOptimizedEquals("bound_string='z' or bound_long=1+1", "bound_string='z' or bound_long=2");
+ assertOptimizedEquals("random() > 0 or random() > 0", "random() > 0 or random() > 0");
+ }
+
+ @Test
+ public void testComparison()
+ {
+ assertOptimizedEquals("null = null", "null");
+
+ assertOptimizedEquals("'a' = 'b'", "false");
+ assertOptimizedEquals("'a' = 'a'", "true");
+ assertOptimizedEquals("'a' = null", "null");
+ assertOptimizedEquals("null = 'a'", "null");
+ assertOptimizedEquals("bound_integer = 1234", "true");
+ assertOptimizedEquals("bound_integer = 12340000000", "false");
+ assertOptimizedEquals("bound_long = BIGINT '1234'", "true");
+ assertOptimizedEquals("bound_long = 1234", "true");
+ assertOptimizedEquals("bound_double = 12.34", "true");
+ assertOptimizedEquals("bound_string = 'hello'", "true");
+ assertOptimizedEquals("bound_long = unbound_long", "1234 = unbound_long");
+
+ assertOptimizedEquals("10151082135029368 = 10151082135029369", "false");
+
+ assertOptimizedEquals("bound_varbinary = X'a b'", "true");
+ assertOptimizedEquals("bound_varbinary = X'a d'", "false");
+
+ assertOptimizedEquals("1.1 = 1.1", "true");
+ assertOptimizedEquals("9876543210.9874561203 = 9876543210.9874561203", "true");
+ assertOptimizedEquals("bound_decimal_short = 123.45", "true");
+ assertOptimizedEquals("bound_decimal_long = 12345678901234567890.123", "true");
+ }
+
+ @Test
+ public void testIsDistinctFrom()
+ {
+ assertOptimizedEquals("null is distinct from null", "false");
+
+ assertOptimizedEquals("3 is distinct from 4", "true");
+ assertOptimizedEquals("3 is distinct from BIGINT '4'", "true");
+ assertOptimizedEquals("3 is distinct from 4000000000", "true");
+ assertOptimizedEquals("3 is distinct from 3", "false");
+ assertOptimizedEquals("3 is distinct from null", "true");
+ assertOptimizedEquals("null is distinct from 3", "true");
+
+ assertOptimizedEquals("10151082135029368 is distinct from 10151082135029369", "true");
+
+ assertOptimizedEquals("1.1 is distinct from 1.1", "false");
+ assertOptimizedEquals("9876543210.9874561203 is distinct from NULL", "true");
+ assertOptimizedEquals("bound_decimal_short is distinct from NULL", "true");
+ assertOptimizedEquals("bound_decimal_long is distinct from 12345678901234567890.123", "false");
+ }
+
+ @Test
+ public void testIsNull()
+ {
+ assertOptimizedEquals("null is null", "true");
+ assertOptimizedEquals("1 is null", "false");
+ assertOptimizedEquals("10000000000 is null", "false");
+ assertOptimizedEquals("BIGINT '1' is null", "false");
+ assertOptimizedEquals("1.0 is null", "false");
+ assertOptimizedEquals("'a' is null", "false");
+ assertOptimizedEquals("true is null", "false");
+ assertOptimizedEquals("null+1 is null", "true");
+ assertOptimizedEquals("unbound_string is null", "unbound_string is null");
+ assertOptimizedEquals("unbound_long+(1+1) is null", "unbound_long+2 is null");
+ assertOptimizedEquals("1.1 is null", "false");
+ assertOptimizedEquals("9876543210.9874561203 is null", "false");
+ assertOptimizedEquals("bound_decimal_short is null", "false");
+ assertOptimizedEquals("bound_decimal_long is null", "false");
+ }
+
+ @Test
+ public void testIsNotNull()
+ {
+ assertOptimizedEquals("null is not null", "false");
+ assertOptimizedEquals("1 is not null", "true");
+ assertOptimizedEquals("10000000000 is not null", "true");
+ assertOptimizedEquals("BIGINT '1' is not null", "true");
+ assertOptimizedEquals("1.0 is not null", "true");
+ assertOptimizedEquals("'a' is not null", "true");
+ assertOptimizedEquals("true is not null", "true");
+ assertOptimizedEquals("null+1 is not null", "false");
+ assertOptimizedEquals("unbound_string is not null", "unbound_string is not null");
+ assertOptimizedEquals("unbound_long+(1+1) is not null", "unbound_long+2 is not null");
+ assertOptimizedEquals("1.1 is not null", "true");
+ assertOptimizedEquals("9876543210.9874561203 is not null", "true");
+ assertOptimizedEquals("bound_decimal_short is not null", "true");
+ assertOptimizedEquals("bound_decimal_long is not null", "true");
+ }
+
+ // TODO: NULL_IF special form is unsupported in Presto native.
+ @Test(enabled = false)
+ public void testNullIf()
+ {
+ assertOptimizedEquals("nullif(true, true)", "null");
+ assertOptimizedEquals("nullif(true, false)", "true");
+ assertOptimizedEquals("nullif(null, false)", "null");
+ assertOptimizedEquals("nullif(true, null)", "true");
+
+ assertOptimizedEquals("nullif('a', 'a')", "null");
+ assertOptimizedEquals("nullif('a', 'b')", "'a'");
+ assertOptimizedEquals("nullif(null, 'b')", "null");
+ assertOptimizedEquals("nullif('a', null)", "'a'");
+
+ assertOptimizedEquals("nullif(1, 1)", "null");
+ assertOptimizedEquals("nullif(1, 2)", "1");
+ assertOptimizedEquals("nullif(1, BIGINT '2')", "1");
+ assertOptimizedEquals("nullif(1, 20000000000)", "1");
+ assertOptimizedEquals("nullif(1.0E0, 1)", "null");
+ assertOptimizedEquals("nullif(10000000000.0E0, 10000000000)", "null");
+ assertOptimizedEquals("nullif(1.1E0, 1)", "1.1E0");
+ assertOptimizedEquals("nullif(1.1E0, 1.1E0)", "null");
+ assertOptimizedEquals("nullif(1, 2-1)", "null");
+ assertOptimizedEquals("nullif(null, null)", "null");
+ assertOptimizedEquals("nullif(1, null)", "1");
+ assertOptimizedEquals("nullif(unbound_long, 1)", "nullif(unbound_long, 1)");
+ assertOptimizedEquals("nullif(unbound_long, unbound_long2)", "nullif(unbound_long, unbound_long2)");
+ assertOptimizedEquals("nullif(unbound_long, unbound_long2+(1+1))", "nullif(unbound_long, unbound_long2+2)");
+
+ assertOptimizedEquals("nullif(1.1, 1.2)", "1.1");
+ assertOptimizedEquals("nullif(9876543210.9874561203, 9876543210.9874561203)", "null");
+ assertOptimizedEquals("nullif(bound_decimal_short, 123.45)", "null");
+ assertOptimizedEquals("nullif(bound_decimal_long, 12345678901234567890.123)", "null");
+ assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(1 AS BIGINT)]) IS NULL", "true");
+ assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false");
+ assertOptimizedEquals("nullif(ARRAY[CAST(NULL AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false");
+ }
+
+ @Test
+ public void testNegative()
+ {
+ assertOptimizedEquals("-(1)", "-1");
+ assertOptimizedEquals("-(BIGINT '1')", "BIGINT '-1'");
+ assertOptimizedEquals("-(unbound_long+1)", "-(unbound_long+1)");
+ assertOptimizedEquals("-(1+1)", "-2");
+ assertOptimizedEquals("-(1+ BIGINT '1')", "BIGINT '-2'");
+ assertOptimizedEquals("-(CAST(NULL AS BIGINT))", "null");
+ assertOptimizedEquals("-(unbound_long+(1+1))", "-(unbound_long+2)");
+ assertOptimizedEquals("-(1.1+1.2)", "-2.3");
+ assertOptimizedEquals("-(9876543210.9874561203-9876543210.9874561203)", "CAST(0 AS DECIMAL(20,10))");
+ assertOptimizedEquals("-(bound_decimal_short+123.45)", "-246.90");
+ assertOptimizedEquals("-(bound_decimal_long-12345678901234567890.123)", "CAST(0 AS DECIMAL(20,10))");
+ }
+
+ @Test
+ public void testNot()
+ {
+ assertOptimizedEquals("not true", "false");
+ assertOptimizedEquals("not false", "true");
+ assertOptimizedEquals("not null", "null");
+ assertOptimizedEquals("not 1=1", "false");
+ assertOptimizedEquals("not 1=BIGINT '1'", "false");
+ assertOptimizedEquals("not 1!=1", "true");
+ assertOptimizedEquals("not unbound_long=1", "not unbound_long=1");
+ assertOptimizedEquals("not unbound_long=(1+1)", "not unbound_long=2");
+ }
+
+ @Test
+ public void testFunctionCall()
+ {
+ assertOptimizedEquals("abs(-5)", "5");
+ assertOptimizedEquals("abs(-10-5)", "15");
+ assertOptimizedEquals("abs(-bound_integer + 1)", "1233");
+ assertOptimizedEquals("abs(-bound_long + 1)", "1233");
+ assertOptimizedEquals("abs(-bound_long + BIGINT '1')", "1233");
+ assertOptimizedEquals("abs(-bound_long)", "1234");
+ assertOptimizedEquals("abs(unbound_long)", "abs(unbound_long)");
+ assertOptimizedEquals("abs(unbound_long + 1)", "abs(unbound_long + 1)");
+ assertOptimizedEquals("cast(json_parse(unbound_string) as map(varchar, varchar))", "cast(json_parse(unbound_string) as map(varchar, varchar))");
+ assertOptimizedEquals("cast(json_parse(unbound_string) as array(varchar))", "cast(json_parse(unbound_string) as array(varchar))");
+ assertOptimizedEquals("cast(json_parse(unbound_string) as row(bigint, varchar))", "cast(json_parse(unbound_string) as row(bigint, varchar))");
+ }
+
+ // TODO: evaluated is not working on current sidecar implementation
+ @Test(enabled = false)
+ public void testNonDeterministicFunctionCall()
+ {
+ // optimize should do nothing
+ assertOptimizedEquals("random()", "random()");
+
+ // evaluate should execute
+ RowExpression value = evaluate("random()", false);
+ assertTrue(value instanceof ConstantExpression);
+ Object innerValue = ((ConstantExpression) value).getValue();
+ assertTrue(innerValue instanceof Double);
+ double randomValue = (double) innerValue;
+ assertTrue(0 <= randomValue && randomValue < 1);
+ }
+
+ // Run this method exactly once.
+ private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager)
+ {
+ functionAndTypeManager.addFunctionNamespaceFactory(new JsonFileBasedFunctionNamespaceManagerFactory());
+ functionAndTypeManager.loadFunctionNamespaceManager(
+ JsonFileBasedFunctionNamespaceManagerFactory.NAME,
+ "json",
+ ImmutableMap.of("supported-function-languages", "CPP", "function-implementation-type", "CPP"),
+ new PluginNodeManager(new InMemoryNodeManager()));
+ }
+
+ @Test
+ public void testBetween()
+ {
+ assertOptimizedEquals("3 between 2 and 4", "true");
+ assertOptimizedEquals("2 between 3 and 4", "false");
+ assertOptimizedEquals("null between 2 and 4", "null");
+ assertOptimizedEquals("3 between null and 4", "null");
+ assertOptimizedEquals("3 between 2 and null", "null");
+
+ assertOptimizedEquals("'cc' between 'b' and 'd'", "true");
+ assertOptimizedEquals("'b' between 'cc' and 'd'", "false");
+ assertOptimizedEquals("null between 'b' and 'd'", "null");
+ assertOptimizedEquals("'cc' between null and 'd'", "null");
+ assertOptimizedEquals("'cc' between 'b' and null", "null");
+
+ assertOptimizedEquals("bound_integer between 1000 and 2000", "true");
+ assertOptimizedEquals("bound_integer between 3 and 4", "false");
+ assertOptimizedEquals("bound_long between 1000 and 2000", "true");
+ assertOptimizedEquals("bound_long between 3 and 4", "false");
+ assertOptimizedEquals("bound_long between bound_integer and (bound_long + 1)", "true");
+ assertOptimizedEquals("bound_string between 'e' and 'i'", "true");
+ assertOptimizedEquals("bound_string between 'a' and 'b'", "false");
+
+ assertOptimizedEquals("bound_long between unbound_long and 2000 + 1", "1234 between unbound_long and 2001");
+ assertOptimizedEquals(
+ "bound_string between unbound_string and 'bar'",
+ format("CAST('hello' AS VARCHAR(%s)) between unbound_string and 'bar'", TEST_VARCHAR_TYPE_LENGTH));
+
+ assertOptimizedEquals("1.15 between 1.1 and 1.2", "true");
+ assertOptimizedEquals("9876543210.98745612035 between 9876543210.9874561203 and 9876543210.9874561204", "true");
+ assertOptimizedEquals("123.455 between bound_decimal_short and 123.46", "true");
+ assertOptimizedEquals("12345678901234567890.1235 between bound_decimal_long and 12345678901234567890.123", "false");
+ }
+
+ @Test
+ public void testExtract()
+ {
+ DateTime dateTime = new DateTime(2001, 8, 22, 3, 4, 5, 321, getDateTimeZone(TEST_SESSION.getTimeZoneKey()));
+ double seconds = dateTime.getMillis() / 1000.0;
+
+ assertOptimizedEquals("extract (YEAR from from_unixtime(" + seconds + "))", "2001");
+ assertOptimizedEquals("extract (QUARTER from from_unixtime(" + seconds + "))", "3");
+ assertOptimizedEquals("extract (MONTH from from_unixtime(" + seconds + "))", "8");
+ assertOptimizedEquals("extract (WEEK from from_unixtime(" + seconds + "))", "34");
+ assertOptimizedEquals("extract (DOW from from_unixtime(" + seconds + "))", "3");
+ assertOptimizedEquals("extract (DOY from from_unixtime(" + seconds + "))", "234");
+ assertOptimizedEquals("extract (DAY from from_unixtime(" + seconds + "))", "22");
+ assertOptimizedEquals("extract (HOUR from from_unixtime(" + seconds + "))", "3");
+ assertOptimizedEquals("extract (MINUTE from from_unixtime(" + seconds + "))", "4");
+ assertOptimizedEquals("extract (SECOND from from_unixtime(" + seconds + "))", "5");
+ assertOptimizedEquals("extract (TIMEZONE_HOUR from from_unixtime(" + seconds + ", 7, 9))", "7");
+ assertOptimizedEquals("extract (TIMEZONE_MINUTE from from_unixtime(" + seconds + ", 7, 9))", "9");
+
+ assertOptimizedEquals("extract (YEAR from bound_timestamp)", "2001");
+ assertOptimizedEquals("extract (QUARTER from bound_timestamp)", "3");
+ assertOptimizedEquals("extract (MONTH from bound_timestamp)", "8");
+ assertOptimizedEquals("extract (WEEK from bound_timestamp)", "34");
+ assertOptimizedEquals("extract (DOW from bound_timestamp)", "2");
+ assertOptimizedEquals("extract (DOY from bound_timestamp)", "233");
+ assertOptimizedEquals("extract (DAY from bound_timestamp)", "21");
+ assertOptimizedEquals("extract (HOUR from bound_timestamp)", "16");
+ assertOptimizedEquals("extract (MINUTE from bound_timestamp)", "4");
+ assertOptimizedEquals("extract (SECOND from bound_timestamp)", "5");
+ // todo reenable when cast as timestamp with time zone is implemented
+ // todo add bound timestamp with time zone
+ //assertOptimizedEquals("extract (TIMEZONE_HOUR from bound_timestamp)", "0");
+ //assertOptimizedEquals("extract (TIMEZONE_MINUTE from bound_timestamp)", "0");
+
+ assertOptimizedEquals("extract (YEAR from unbound_timestamp)", "extract (YEAR from unbound_timestamp)");
+ assertOptimizedEquals("extract (SECOND from bound_timestamp + INTERVAL '3' SECOND)", "8");
+ }
+
+ @Test
+ public void testIn()
+ {
+ assertOptimizedEquals("3 in (2, 4, 3, 5)", "true");
+ assertOptimizedEquals("3 in (2, 4, 9, 5)", "false");
+ assertOptimizedEquals("3 in (2, null, 3, 5)", "true");
+
+ assertOptimizedEquals("'foo' in ('bar', 'baz', 'foo', 'blah')", "true");
+ assertOptimizedEquals("'foo' in ('bar', 'baz', 'buz', 'blah')", "false");
+ assertOptimizedEquals("'foo' in ('bar', null, 'foo', 'blah')", "true");
+
+ assertOptimizedEquals("null in (2, null, 3, 5)", "null");
+ assertOptimizedEquals("3 in (2, null)", "null");
+
+ assertOptimizedEquals("bound_integer in (2, 1234, 3, 5)", "true");
+ assertOptimizedEquals("bound_integer in (2, 4, 3, 5)", "false");
+ assertOptimizedEquals("1234 in (2, bound_integer, 3, 5)", "true");
+ assertOptimizedEquals("99 in (2, bound_integer, 3, 5)", "false");
+ assertOptimizedEquals("bound_integer in (2, bound_integer, 3, 5)", "true");
+
+ assertOptimizedEquals("bound_long in (2, 1234, 3, 5)", "true");
+ assertOptimizedEquals("bound_long in (2, 4, 3, 5)", "false");
+ assertOptimizedEquals("1234 in (2, bound_long, 3, 5)", "true");
+ assertOptimizedEquals("99 in (2, bound_long, 3, 5)", "false");
+ assertOptimizedEquals("bound_long in (2, bound_long, 3, 5)", "true");
+
+ assertOptimizedEquals("bound_string in ('bar', 'hello', 'foo', 'blah')", "true");
+ assertOptimizedEquals("bound_string in ('bar', 'baz', 'foo', 'blah')", "false");
+ assertOptimizedEquals("'hello' in ('bar', bound_string, 'foo', 'blah')", "true");
+ assertOptimizedEquals("'baz' in ('bar', bound_string, 'foo', 'blah')", "false");
+
+ assertOptimizedEquals("bound_long in (2, 1234, unbound_long, 5)", "true");
+ assertOptimizedEquals("bound_string in ('bar', 'hello', unbound_string, 'blah')", "true");
+
+ assertOptimizedEquals("bound_long in (2, 4, unbound_long, unbound_long2, 9)", "1234 in (unbound_long, unbound_long2)");
+ assertOptimizedEquals("unbound_long in (2, 4, bound_long, unbound_long2, 5)", "unbound_long in (2, 4, 1234, unbound_long2, 5)");
+
+ assertOptimizedEquals("1.15 in (1.1, 1.2, 1.3, 1.15)", "true");
+ assertOptimizedEquals("9876543210.98745612035 in (9876543210.9874561203, 9876543210.9874561204, 9876543210.98745612035)", "true");
+ assertOptimizedEquals("bound_decimal_short in (123.455, 123.46, 123.45)", "true");
+ assertOptimizedEquals("bound_decimal_long in (12345678901234567890.123, 9876543210.9874561204, 9876543210.98745612035)", "true");
+ assertOptimizedEquals("bound_decimal_long in (9876543210.9874561204, null, 9876543210.98745612035)", "null");
+ }
+
+ @Test
+ public void testInComplexTypes()
+ {
+ assertEvaluatedEquals("ARRAY[null] IN (ARRAY[null])", "null");
+ assertEvaluatedEquals("ARRAY[1] IN (ARRAY[null])", "null");
+ assertEvaluatedEquals("ARRAY[null] IN (ARRAY[1])", "null");
+ assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null])", "null");
+ assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[2, null])", "false");
+ assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null])", "null");
+ assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null], ARRAY[1, null])", "null");
+ assertEvaluatedEquals("ARRAY[ARRAY[1, 2], ARRAY[3, 4]] in (ARRAY[ARRAY[1, 2], ARRAY[3, NULL]])", "null");
+
+ assertEvaluatedEquals("ROW(1) IN (ROW(1))", "true");
+ assertEvaluatedEquals("ROW(1) IN (ROW(2))", "false");
+ assertEvaluatedEquals("ROW(1) IN (ROW(2), ROW(1), ROW(2))", "true");
+ assertEvaluatedEquals("ROW(1) IN (null)", "null");
+ assertEvaluatedEquals("ROW(1) IN (null, ROW(1))", "true");
+ assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null), null)", "null");
+ assertEvaluatedEquals("ROW(null) IN (ROW(null))", "null");
+ assertEvaluatedEquals("ROW(1) IN (ROW(null))", "null");
+ assertEvaluatedEquals("ROW(null) IN (ROW(1))", "null");
+ assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null))", "null");
+ assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null))", "false");
+ assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null))", "null");
+ assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null), ROW(1, null))", "null");
+
+ assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[1]))", "true");
+ assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null)", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null, MAP(ARRAY[1], ARRAY[1]))", "true");
+ assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "false");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]), null)", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false");
+ assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[null]))", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[null]))", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[1]))", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]))", "false");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]))", "null");
+ assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]), MAP(ARRAY[1, 2], ARRAY[1, null]))", "null");
+ }
+
+ // TODO: current timestamp returns the session timestamp, which is not supported by this test
+ @Test(enabled = false)
+ public void testCurrentTimestamp()
+ {
+ double current = TEST_SESSION.getStartTime() / 1000.0;
+ assertOptimizedEquals("current_timestamp = from_unixtime(" + current + ")", "true");
+ double future = current + TimeUnit.MINUTES.toSeconds(1);
+ assertOptimizedEquals("current_timestamp > from_unixtime(" + future + ")", "false");
+ }
+
+ @Test
+ public void testCurrentUser()
+ {
+ assertOptimizedEquals("current_user", "'" + TEST_SESSION.getUser() + "'");
+ }
+
+ @Test
+ public void testCastToString()
+ {
+ // integer
+ assertOptimizedEquals("cast(123 as VARCHAR(20))", "'123'");
+ assertOptimizedEquals("cast(-123 as VARCHAR(20))", "'-123'");
+
+ // bigint
+ assertOptimizedEquals("cast(BIGINT '123' as VARCHAR)", "'123'");
+ assertOptimizedEquals("cast(12300000000 as VARCHAR)", "'12300000000'");
+ assertOptimizedEquals("cast(-12300000000 as VARCHAR)", "'-12300000000'");
+
+ // double
+ assertOptimizedEquals("cast(123.0E0 as VARCHAR)", "'123.0'");
+ assertOptimizedEquals("cast(-123.0E0 as VARCHAR)", "'-123.0'");
+ assertOptimizedEquals("cast(123.456E0 as VARCHAR)", "'123.456'");
+ assertOptimizedEquals("cast(-123.456E0 as VARCHAR)", "'-123.456'");
+
+ // boolean
+ assertOptimizedEquals("cast(true as VARCHAR)", "'true'");
+ assertOptimizedEquals("cast(false as VARCHAR)", "'false'");
+
+ // string
+ assertOptimizedEquals("cast('xyz' as VARCHAR)", "'xyz'");
+ assertOptimizedEquals("cast(cast('abcxyz' as VARCHAR(3)) as VARCHAR(5))", "'abc'");
+
+ // null
+ assertOptimizedEquals("cast(null as VARCHAR)", "null");
+
+ // decimal
+ assertOptimizedEquals("cast(1.1 as VARCHAR)", "'1.1'");
+ // TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'");
+ }
+
+ @Test
+ public void testCastBigintToBoundedVarchar()
+ {
+ assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'");
+ assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'");
+
+ // TODO: Velox permits this cast, but Presto does not
+// try {
+// evaluate("CAST(12300000000 AS varchar(3))", true);
+// fail("Expected to throw an INVALID_CAST_ARGUMENT exception");
+// }
+// catch (PrestoException e) {
+// try {
+// assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
+// assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)");
+// }
+// catch (Throwable failure) {
+// failure.addSuppressed(e);
+// throw failure;
+// }
+// }
+//
+// try {
+// evaluate("CAST(-12300000000 AS varchar(3))", true);
+// }
+// catch (PrestoException e) {
+// try {
+// assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
+// assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)");
+// }
+// catch (Throwable failure) {
+// failure.addSuppressed(e);
+// throw failure;
+// }
+// }
+ }
+
+ @Test
+ public void testCastToBoolean()
+ {
+ // integer
+ assertOptimizedEquals("cast(123 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(-123 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(0 as BOOLEAN)", "false");
+
+ // bigint
+ assertOptimizedEquals("cast(12300000000 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(-12300000000 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(BIGINT '0' as BOOLEAN)", "false");
+
+ // boolean
+ assertOptimizedEquals("cast(true as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(false as BOOLEAN)", "false");
+
+ // string
+ assertOptimizedEquals("cast('true' as BOOLEAN)", "true");
+ assertOptimizedEquals("cast('false' as BOOLEAN)", "false");
+ assertOptimizedEquals("cast('t' as BOOLEAN)", "true");
+ assertOptimizedEquals("cast('f' as BOOLEAN)", "false");
+ assertOptimizedEquals("cast('1' as BOOLEAN)", "true");
+ assertOptimizedEquals("cast('0' as BOOLEAN)", "false");
+
+ // null
+ assertOptimizedEquals("cast(null as BOOLEAN)", "null");
+
+ // double
+ assertOptimizedEquals("cast(123.45E0 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(-123.45E0 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(0.0E0 as BOOLEAN)", "false");
+
+ // decimal
+ assertOptimizedEquals("cast(0.00 as BOOLEAN)", "false");
+ assertOptimizedEquals("cast(7.8 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(12345678901234567890.123 as BOOLEAN)", "true");
+ assertOptimizedEquals("cast(00000000000000000000.000 as BOOLEAN)", "false");
+ }
+
+ @Test
+ public void testCastToBigint()
+ {
+ // integer
+ assertOptimizedEquals("cast(0 as BIGINT)", "0");
+ assertOptimizedEquals("cast(123 as BIGINT)", "123");
+ assertOptimizedEquals("cast(-123 as BIGINT)", "-123");
+
+ // bigint
+ assertOptimizedEquals("cast(BIGINT '0' as BIGINT)", "0");
+ assertOptimizedEquals("cast(BIGINT '123' as BIGINT)", "123");
+ assertOptimizedEquals("cast(BIGINT '-123' as BIGINT)", "-123");
+
+ // double
+ assertOptimizedEquals("cast(123.0E0 as BIGINT)", "123");
+ assertOptimizedEquals("cast(-123.0E0 as BIGINT)", "-123");
+ assertOptimizedEquals("cast(123.456E0 as BIGINT)", "123");
+ assertOptimizedEquals("cast(-123.456E0 as BIGINT)", "-123");
+
+ // boolean
+ assertOptimizedEquals("cast(true as BIGINT)", "1");
+ assertOptimizedEquals("cast(false as BIGINT)", "0");
+
+ // string
+ assertOptimizedEquals("cast('123' as BIGINT)", "123");
+ assertOptimizedEquals("cast('-123' as BIGINT)", "-123");
+
+ // null
+ assertOptimizedEquals("cast(null as BIGINT)", "null");
+
+ // decimal
+ assertOptimizedEquals("cast(DECIMAL '1.01' as BIGINT)", "1");
+ assertOptimizedEquals("cast(DECIMAL '7.8' as BIGINT)", "8");
+ assertOptimizedEquals("cast(DECIMAL '1234567890.123' as BIGINT)", "1234567890");
+ assertOptimizedEquals("cast(DECIMAL '00000000000000000000.000' as BIGINT)", "0");
+ }
+
+ @Test
+ public void testCastToInteger()
+ {
+ // integer
+ assertOptimizedEquals("cast(0 as INTEGER)", "0");
+ assertOptimizedEquals("cast(123 as INTEGER)", "123");
+ assertOptimizedEquals("cast(-123 as INTEGER)", "-123");
+
+ // bigint
+ assertOptimizedEquals("cast(BIGINT '0' as INTEGER)", "0");
+ assertOptimizedEquals("cast(BIGINT '123' as INTEGER)", "123");
+ assertOptimizedEquals("cast(BIGINT '-123' as INTEGER)", "-123");
+
+ // double
+ assertOptimizedEquals("cast(123.0E0 as INTEGER)", "123");
+ assertOptimizedEquals("cast(-123.0E0 as INTEGER)", "-123");
+ assertOptimizedEquals("cast(123.456E0 as INTEGER)", "123");
+ assertOptimizedEquals("cast(-123.456E0 as INTEGER)", "-123");
+
+ // boolean
+ assertOptimizedEquals("cast(true as INTEGER)", "1");
+ assertOptimizedEquals("cast(false as INTEGER)", "0");
+
+ // string
+ assertOptimizedEquals("cast('123' as INTEGER)", "123");
+ assertOptimizedEquals("cast('-123' as INTEGER)", "-123");
+
+ // null
+ assertOptimizedEquals("cast(null as INTEGER)", "null");
+ }
+
+ @Test
+ public void testCastToDouble()
+ {
+ // integer
+ assertOptimizedEquals("cast(0 as DOUBLE)", "0.0E0");
+ assertOptimizedEquals("cast(123 as DOUBLE)", "123.0E0");
+ assertOptimizedEquals("cast(-123 as DOUBLE)", "-123.0E0");
+
+ // bigint
+ assertOptimizedEquals("cast(BIGINT '0' as DOUBLE)", "0.0E0");
+ assertOptimizedEquals("cast(12300000000 as DOUBLE)", "12300000000.0E0");
+ assertOptimizedEquals("cast(-12300000000 as DOUBLE)", "-12300000000.0E0");
+
+ // double
+ assertOptimizedEquals("cast(123.0E0 as DOUBLE)", "123.0E0");
+ assertOptimizedEquals("cast(-123.0E0 as DOUBLE)", "-123.0E0");
+ assertOptimizedEquals("cast(123.456E0 as DOUBLE)", "123.456E0");
+ assertOptimizedEquals("cast(-123.456E0 as DOUBLE)", "-123.456E0");
+
+ // string
+ assertOptimizedEquals("cast('0' as DOUBLE)", "0.0E0");
+ assertOptimizedEquals("cast('123' as DOUBLE)", "123.0E0");
+ assertOptimizedEquals("cast('-123' as DOUBLE)", "-123.0E0");
+ assertOptimizedEquals("cast('123.0E0' as DOUBLE)", "123.0E0");
+ assertOptimizedEquals("cast('-123.0E0' as DOUBLE)", "-123.0E0");
+ assertOptimizedEquals("cast('123.456E0' as DOUBLE)", "123.456E0");
+ assertOptimizedEquals("cast('-123.456E0' as DOUBLE)", "-123.456E0");
+
+ // null
+ assertOptimizedEquals("cast(null as DOUBLE)", "null");
+
+ // boolean
+ assertOptimizedEquals("cast(true as DOUBLE)", "1.0E0");
+ assertOptimizedEquals("cast(false as DOUBLE)", "0.0E0");
+
+ // decimal
+ assertOptimizedEquals("cast(1.01 as DOUBLE)", "DOUBLE '1.01'");
+ assertOptimizedEquals("cast(7.8 as DOUBLE)", "DOUBLE '7.8'");
+ assertOptimizedEquals("cast(1234567890.123 as DOUBLE)", "DOUBLE '1234567890.123'");
+ assertOptimizedEquals("cast(00000000000000000000.000 as DOUBLE)", "DOUBLE '0.0'");
+ }
+
+ @Test
+ public void testCastToDecimal()
+ {
+ // long
+ assertOptimizedEquals("cast(0 as DECIMAL(1,0))", "DECIMAL '0'");
+ assertOptimizedEquals("cast(123 as DECIMAL(3,0))", "DECIMAL '123'");
+ assertOptimizedEquals("cast(-123 as DECIMAL(3,0))", "DECIMAL '-123'");
+ assertOptimizedEquals("cast(-123 as DECIMAL(20,10))", "cast(-123 as DECIMAL(20,10))");
+
+ // double
+ assertOptimizedEquals("cast(0E0 as DECIMAL(1,0))", "DECIMAL '0'");
+ assertOptimizedEquals("cast(123.2E0 as DECIMAL(4,1))", "DECIMAL '123.2'");
+ assertOptimizedEquals("cast(-123.0E0 as DECIMAL(3,0))", "DECIMAL '-123'");
+ assertOptimizedEquals("cast(-123.55E0 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))");
+
+ // string
+ assertOptimizedEquals("cast('0' as DECIMAL(1,0))", "DECIMAL '0'");
+ assertOptimizedEquals("cast('123.2' as DECIMAL(4,1))", "DECIMAL '123.2'");
+ assertOptimizedEquals("cast('-123.0' as DECIMAL(3,0))", "DECIMAL '-123'");
+ assertOptimizedEquals("cast('-123.55' as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))");
+
+ // null
+ assertOptimizedEquals("cast(null as DECIMAL(1,0))", "null");
+ assertOptimizedEquals("cast(null as DECIMAL(20,10))", "null");
+
+ // boolean
+ assertOptimizedEquals("cast(true as DECIMAL(1,0))", "DECIMAL '1'");
+ assertOptimizedEquals("cast(false as DECIMAL(4,1))", "DECIMAL '000.0'");
+ assertOptimizedEquals("cast(true as DECIMAL(3,0))", "DECIMAL '001'");
+ assertOptimizedEquals("cast(false as DECIMAL(20,10))", "cast(0 as DECIMAL(20,10))");
+
+ // decimal
+ assertOptimizedEquals("cast(0.0 as DECIMAL(1,0))", "DECIMAL '0'");
+ assertOptimizedEquals("cast(123.2 as DECIMAL(4,1))", "DECIMAL '123.2'");
+ assertOptimizedEquals("cast(-123.0 as DECIMAL(3,0))", "DECIMAL '-123'");
+ assertOptimizedEquals("cast(-123.55 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))");
+ }
+
+ @Test
+ public void testCastOptimization()
+ {
+ assertOptimizedEquals("cast(unbound_string as VARCHAR)", "cast(unbound_string as VARCHAR)");
+ assertOptimizedMatches("cast(unbound_string as VARCHAR)", "unbound_string");
+ assertOptimizedMatches("cast(unbound_integer as INTEGER)", "unbound_integer");
+ assertOptimizedMatches("cast(unbound_string as VARCHAR(10))", "cast(unbound_string as VARCHAR(10))");
+ }
+
+ @Test
+ public void testTryCast()
+ {
+ assertOptimizedEquals("try_cast(null as BIGINT)", "null");
+ assertOptimizedEquals("try_cast(123 as BIGINT)", "123");
+ assertOptimizedEquals("try_cast(null as INTEGER)", "null");
+ assertOptimizedEquals("try_cast(123 as INTEGER)", "123");
+ assertOptimizedEquals("try_cast('foo' as VARCHAR)", "'foo'");
+ assertOptimizedEquals("try_cast('foo' as BIGINT)", "null");
+ assertOptimizedEquals("try_cast(unbound_string as BIGINT)", "try_cast(unbound_string as BIGINT)");
+ assertOptimizedEquals("try_cast('foo' as DECIMAL(2,1))", "null");
+ }
+
+ @Test
+ public void testReservedWithDoubleQuotes()
+ {
+ assertOptimizedEquals("\"time\"", "\"time\"");
+ }
+
+ @Test
+ public void testSearchCase()
+ {
+ assertOptimizedEquals("case " +
+ "when true then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case " +
+ "when false then 1 " +
+ "else 33 " +
+ "end",
+ "33");
+
+ assertOptimizedEquals("case " +
+ "when false then 10000000000 " +
+ "else 33 " +
+ "end",
+ "33");
+
+ assertOptimizedEquals("case " +
+ "when bound_long = 1234 then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case " +
+ "when true then bound_long " +
+ "end",
+ "1234");
+ assertOptimizedEquals("case " +
+ "when false then 1 " +
+ "else bound_long " +
+ "end",
+ "1234");
+
+ assertOptimizedEquals("case " +
+ "when bound_integer = 1234 then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case " +
+ "when true then bound_integer " +
+ "end",
+ "1234");
+ assertOptimizedEquals("case " +
+ "when false then 1 " +
+ "else bound_integer " +
+ "end",
+ "1234");
+
+ assertOptimizedEquals("case " +
+ "when bound_long = 1234 then 33 " +
+ "else unbound_long " +
+ "end",
+ "33");
+ assertOptimizedEquals("case " +
+ "when true then bound_long " +
+ "else unbound_long " +
+ "end",
+ "1234");
+ assertOptimizedEquals("case " +
+ "when false then unbound_long " +
+ "else bound_long " +
+ "end",
+ "1234");
+
+ assertOptimizedEquals("case " +
+ "when bound_integer = 1234 then 33 " +
+ "else unbound_integer " +
+ "end",
+ "33");
+ assertOptimizedEquals("case " +
+ "when true then bound_integer " +
+ "else unbound_integer " +
+ "end",
+ "1234");
+ assertOptimizedEquals("case " +
+ "when false then unbound_integer " +
+ "else bound_integer " +
+ "end",
+ "1234");
+
+ assertOptimizedEquals("case " +
+ "when unbound_long = 1234 then 33 " +
+ "else 1 " +
+ "end",
+ "case " +
+ "when unbound_long = 1234 then 33 " +
+ "else 1 " +
+ "end");
+
+ assertOptimizedMatches("if(false, 1, 0 / 0)", "cast(fail(8, 'ignored failure message') as integer)");
+
+ assertOptimizedEquals("case " +
+ "when false then 2.2 " +
+ "when true then 2.2 " +
+ "end",
+ "2.2");
+
+ assertOptimizedEquals("case " +
+ "when false then 1234567890.0987654321 " +
+ "when true then 3.3 " +
+ "end",
+ "CAST(3.3 AS DECIMAL(20,10))");
+
+ assertOptimizedEquals("case " +
+ "when false then 1 " +
+ "when true then 2.2 " +
+ "end",
+ "2.2");
+
+ assertOptimizedEquals("case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'");
+ assertOptimizedEquals("case when ARRAY[CAST(2 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'");
+ assertOptimizedEquals("case when ARRAY[CAST(null AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'");
+ }
+
+ @Test
+ public void testSimpleCase()
+ {
+ assertOptimizedEquals("case 1 " +
+ "when 1 then 32 + 1 " +
+ "when 1 then 34 " +
+ "end",
+ "33");
+
+ assertOptimizedEquals("case null " +
+ "when true then 33 " +
+ "end",
+ "null");
+ assertOptimizedEquals("case null " +
+ "when true then 33 " +
+ "else 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case 33 " +
+ "when null then 1 " +
+ "else 33 " +
+ "end",
+ "33");
+
+ assertOptimizedEquals("case null " +
+ "when true then 3300000000 " +
+ "end",
+ "null");
+ assertOptimizedEquals("case null " +
+ "when true then 3300000000 " +
+ "else 3300000000 " +
+ "end",
+ "3300000000");
+ assertOptimizedEquals("case 33 " +
+ "when null then 3300000000 " +
+ "else 33 " +
+ "end",
+ "33");
+
+ assertOptimizedEquals("case true " +
+ "when true then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case true " +
+ "when false then 1 " +
+ "else 33 end",
+ "33");
+
+ assertOptimizedEquals("case bound_long " +
+ "when 1234 then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case 1234 " +
+ "when bound_long then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case true " +
+ "when true then bound_long " +
+ "end",
+ "1234");
+ assertOptimizedEquals("case true " +
+ "when false then 1 " +
+ "else bound_long " +
+ "end",
+ "1234");
+
+ assertOptimizedEquals("case bound_integer " +
+ "when 1234 then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case 1234 " +
+ "when bound_integer then 33 " +
+ "end",
+ "33");
+ assertOptimizedEquals("case true " +
+ "when true then bound_integer " +
+ "end",
+ "1234");
+ assertOptimizedEquals("case true " +
+ "when false then 1 " +
+ "else bound_integer " +
+ "end",
+ "1234");
+
+ assertOptimizedEquals("case bound_long " +
+ "when 1234 then 33 " +
+ "else unbound_long " +
+ "end",
+ "33");
+ assertOptimizedEquals("case true " +
+ "when true then bound_long " +
+ "else unbound_long " +
+ "end",
+ "1234");
+ assertOptimizedEquals("case true " +
+ "when false then unbound_long " +
+ "else bound_long " +
+ "end",
+ "1234");
+
+ assertOptimizedEquals("case unbound_long " +
+ "when 1234 then 33 " +
+ "else 1 " +
+ "end",
+ "case unbound_long " +
+ "when 1234 then 33 " +
+ "else 1 " +
+ "end");
+
+ assertOptimizedEquals("case 33 " +
+ "when 0 then 0 " +
+ "when 33 then unbound_long " +
+ "else 1 " +
+ "end",
+ "unbound_long");
+ assertOptimizedEquals("case 33 " +
+ "when 0 then 0 " +
+ "when 33 then 1 " +
+ "when unbound_long then 2 " +
+ "else 1 " +
+ "end",
+ "1");
+ assertOptimizedEquals("case 33 " +
+ "when unbound_long then 0 " +
+ "when 1 then 1 " +
+ "when 33 then 2 " +
+ "else 0 " +
+ "end",
+ "case 33 " +
+ "when unbound_long then 0 " +
+ "else 2 " +
+ "end");
+ assertOptimizedEquals("case 33 " +
+ "when 0 then 0 " +
+ "when 1 then 1 " +
+ "else unbound_long " +
+ "end",
+ "unbound_long");
+ assertOptimizedEquals("case 33 " +
+ "when unbound_long then 0 " +
+ "when 1 then 1 " +
+ "when unbound_long2 then 2 " +
+ "else 3 " +
+ "end",
+ "case 33 " +
+ "when unbound_long then 0 " +
+ "when unbound_long2 then 2 " +
+ "else 3 " +
+ "end");
+
+ assertOptimizedMatches("case true " +
+ "when unbound_long = 1 then 1 " +
+ "when 0 / 0 = 0 then 2 " +
+ "else 33 end",
+ "case true " +
+ "when unbound_long = BIGINT '1' then 1 " +
+ "when CAST(fail(8, 'ignored failure message') AS boolean) then 2 else 33 " +
+ "end");
+
+ assertOptimizedEquals("case bound_long " +
+ "when 123 * 10 + unbound_long then 1 = 1 " +
+ "else 1 = 2 " +
+ "end",
+ "case bound_long when 1230 + unbound_long then true " +
+ "else false " +
+ "end");
+
+ assertOptimizedEquals("case bound_long " +
+ "when unbound_long then 2 + 2 " +
+ "end",
+ "case bound_long " +
+ "when unbound_long then 4 " +
+ "end");
+
+ assertOptimizedEquals("case bound_long " +
+ "when unbound_long then 2 + 2 " +
+ "when 1 then null " +
+ "when 2 then null " +
+ "end",
+ "case bound_long " +
+ "when unbound_long then 4 " +
+ "end");
+
+ assertOptimizedMatches("case 1 " +
+ "when 0 / 0 then 1 " +
+ "when 0 / 0 then 2 " +
+ "else 1 " +
+ "end",
+ "case 1 " +
+ "when cast(fail(8, 'ignored failure message') as integer) then 1 " +
+ "when cast(fail(8, 'ignored failure message') as integer) then 2 " +
+ "else 1 " +
+ "end");
+
+ assertOptimizedEquals("case true " +
+ "when false then 2.2 " +
+ "when true then 2.2 " +
+ "end",
+ "2.2");
+
+ // TODO enabled when DECIMAL is default for literal:
+// assertOptimizedEquals("case true " +
+// "when false then 1234567890.0987654321 " +
+// "when true then 3.3 " +
+// "end",
+// "CAST(3.3 AS DECIMAL(20,10))");
+
+ assertOptimizedEquals("case true " +
+ "when false then 1 " +
+ "when true then 2.2 " +
+ "end",
+ "2.2");
+
+ assertOptimizedEquals("case ARRAY[CAST(1 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'");
+ assertOptimizedEquals("case ARRAY[CAST(2 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'");
+ assertOptimizedEquals("case ARRAY[CAST(null AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'");
+ }
+
+ @Test
+ public void testCoalesce()
+ {
+ assertOptimizedEquals("coalesce(null, null)", "coalesce(null, null)");
+ assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1 - 1, null)", "coalesce(6 * unbound_long, 0)");
+ assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_long, 0.5E0)");
+ assertOptimizedEquals("coalesce(unbound_long, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_long, 2.0E0, 0.5E0, 12.34E0)");
+ assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1 - 1, null)", "coalesce(6 * unbound_integer, 0)");
+ assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_integer, 0.5E0)");
+ assertOptimizedEquals("coalesce(unbound_integer, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_integer, 2.0E0, 0.5E0, 12.34E0)");
+ assertOptimizedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)",
+ "coalesce(cast(fail(8, 'ignored failure message') as boolean), unbound_boolean)");
+ assertOptimizedMatches("coalesce(unbound_long, unbound_long)", "unbound_long");
+ assertOptimizedMatches("coalesce(2 * unbound_long, 2 * unbound_long)", "BIGINT '2' * unbound_long");
+ assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long)", "coalesce(unbound_long, unbound_long2)");
+ assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long, unbound_long3)", "coalesce(unbound_long, unbound_long2, unbound_long3)");
+ assertOptimizedEquals("coalesce(6, unbound_long2, unbound_long, unbound_long3)", "6");
+ assertOptimizedEquals("coalesce(2 * 3, unbound_long2, unbound_long, unbound_long3)", "6");
+ assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')");
+ assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')");
+ assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')");
+ assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')");
+ assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())");
+ assertOptimizedMatches("coalesce(random(), random(), 5)", "coalesce(random(), random(), 5E0)");
+ assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')");
+ assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')");
+ assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')");
+ assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')");
+ assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())");
+ }
+
+ @Test
+ public void testIf()
+ {
+ assertOptimizedEquals("IF(2 = 2, 3, 4)", "3");
+ assertOptimizedEquals("IF(1 = 2, 3, 4)", "4");
+ assertOptimizedEquals("IF(1 = 2, BIGINT '3', 4)", "4");
+ assertOptimizedEquals("IF(1 = 2, 3000000000, 4)", "4");
+
+ assertOptimizedEquals("IF(true, 3, 4)", "3");
+ assertOptimizedEquals("IF(false, 3, 4)", "4");
+ assertOptimizedEquals("IF(null, 3, 4)", "4");
+
+ assertOptimizedEquals("IF(true, 3, null)", "3");
+ assertOptimizedEquals("IF(false, 3, null)", "null");
+ assertOptimizedEquals("IF(true, null, 4)", "null");
+ assertOptimizedEquals("IF(false, null, 4)", "4");
+ assertOptimizedEquals("IF(true, null, null)", "null");
+ assertOptimizedEquals("IF(false, null, null)", "null");
+
+ assertOptimizedEquals("IF(true, 3.5E0, 4.2E0)", "3.5E0");
+ assertOptimizedEquals("IF(false, 3.5E0, 4.2E0)", "4.2E0");
+
+ assertOptimizedEquals("IF(true, 'foo', 'bar')", "'foo'");
+ assertOptimizedEquals("IF(false, 'foo', 'bar')", "'bar'");
+
+ assertOptimizedEquals("IF(true, 1.01, 1.02)", "1.01");
+ assertOptimizedEquals("IF(false, 1.01, 1.02)", "1.02");
+ assertOptimizedEquals("IF(true, 1234567890.123, 1.02)", "1234567890.123");
+ assertOptimizedEquals("IF(false, 1.01, 1234567890.123)", "1234567890.123");
+ }
+
+ // TODO: Pending on native function namespace manager.
+ @Test(enabled = false)
+ public void testLike()
+ {
+ assertOptimizedEquals("'a' LIKE 'a'", "true");
+ assertOptimizedEquals("'' LIKE 'a'", "false");
+ assertOptimizedEquals("'abc' LIKE 'a'", "false");
+
+ assertOptimizedEquals("'a' LIKE '_'", "true");
+ assertOptimizedEquals("'' LIKE '_'", "false");
+ assertOptimizedEquals("'abc' LIKE '_'", "false");
+
+ assertOptimizedEquals("'a' LIKE '%'", "true");
+ assertOptimizedEquals("'' LIKE '%'", "true");
+ assertOptimizedEquals("'abc' LIKE '%'", "true");
+
+ assertOptimizedEquals("'abc' LIKE '___'", "true");
+ assertOptimizedEquals("'ab' LIKE '___'", "false");
+ assertOptimizedEquals("'abcd' LIKE '___'", "false");
+
+ assertOptimizedEquals("'abc' LIKE 'abc'", "true");
+ assertOptimizedEquals("'xyz' LIKE 'abc'", "false");
+ assertOptimizedEquals("'abc0' LIKE 'abc'", "false");
+ assertOptimizedEquals("'0abc' LIKE 'abc'", "false");
+
+ assertOptimizedEquals("'abc' LIKE 'abc%'", "true");
+ assertOptimizedEquals("'abc0' LIKE 'abc%'", "true");
+ assertOptimizedEquals("'0abc' LIKE 'abc%'", "false");
+
+ assertOptimizedEquals("'abc' LIKE '%abc'", "true");
+ assertOptimizedEquals("'0abc' LIKE '%abc'", "true");
+ assertOptimizedEquals("'abc0' LIKE '%abc'", "false");
+
+ assertOptimizedEquals("'abc' LIKE '%abc%'", "true");
+ assertOptimizedEquals("'0abc' LIKE '%abc%'", "true");
+ assertOptimizedEquals("'abc0' LIKE '%abc%'", "true");
+ assertOptimizedEquals("'0abc0' LIKE '%abc%'", "true");
+ assertOptimizedEquals("'xyzw' LIKE '%abc%'", "false");
+
+ assertOptimizedEquals("'abc' LIKE '%ab%c%'", "true");
+ assertOptimizedEquals("'0abc' LIKE '%ab%c%'", "true");
+ assertOptimizedEquals("'abc0' LIKE '%ab%c%'", "true");
+ assertOptimizedEquals("'0abc0' LIKE '%ab%c%'", "true");
+ assertOptimizedEquals("'ab01c' LIKE '%ab%c%'", "true");
+ assertOptimizedEquals("'0ab01c' LIKE '%ab%c%'", "true");
+ assertOptimizedEquals("'ab01c0' LIKE '%ab%c%'", "true");
+ assertOptimizedEquals("'0ab01c0' LIKE '%ab%c%'", "true");
+
+ assertOptimizedEquals("'xyzw' LIKE '%ab%c%'", "false");
+
+ // ensure regex chars are escaped
+ assertOptimizedEquals("'' LIKE ''", "true");
+ assertOptimizedEquals("'.*' LIKE '.*'", "true");
+ assertOptimizedEquals("'[' LIKE '['", "true");
+ assertOptimizedEquals("']' LIKE ']'", "true");
+ assertOptimizedEquals("'{' LIKE '{'", "true");
+ assertOptimizedEquals("'}' LIKE '}'", "true");
+ assertOptimizedEquals("'?' LIKE '?'", "true");
+ assertOptimizedEquals("'+' LIKE '+'", "true");
+ assertOptimizedEquals("'(' LIKE '('", "true");
+ assertOptimizedEquals("')' LIKE ')'", "true");
+ assertOptimizedEquals("'|' LIKE '|'", "true");
+ assertOptimizedEquals("'^' LIKE '^'", "true");
+ assertOptimizedEquals("'$' LIKE '$'", "true");
+
+ assertOptimizedEquals("null LIKE '%'", "null");
+ assertOptimizedEquals("'a' LIKE null", "null");
+ assertOptimizedEquals("'a' LIKE '%' ESCAPE null", "null");
+ assertOptimizedEquals("'a' LIKE unbound_string ESCAPE null", "null");
+
+ assertOptimizedEquals("'%' LIKE 'z%' ESCAPE 'z'", "true");
+
+ assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%' ESCAPE 'z'", "true");
+ assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%'", "false");
+ }
+
+ // TODO: Pending on native function namespace manager.
+ @Test
+ public void testLikeOptimization()
+ {
+ assertOptimizedEquals("unbound_string LIKE 'abc'", "unbound_string = CAST('abc' AS VARCHAR)");
+
+ assertOptimizedEquals("unbound_string LIKE '' ESCAPE '#'", "unbound_string LIKE '' ESCAPE '#'");
+ assertOptimizedEquals("unbound_string LIKE 'abc' ESCAPE '#'", "unbound_string = CAST('abc' AS VARCHAR)");
+ assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)");
+ assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)");
+ assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)");
+ assertOptimizedMatches("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'");
+ assertOptimizedMatches("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'");
+
+ assertOptimizedEquals("bound_string LIKE bound_pattern", "true");
+ assertOptimizedEquals("'abc' LIKE bound_pattern", "false");
+
+ assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE);
+
+ assertOptimizedMatches("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string");
+ }
+
+ // TODO: Pending on native function namespace manager.
+ @Test(enabled = false)
+ public void testInvalidLike()
+ {
+ assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE ''"));
+ assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE 'bc'"));
+ assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#' ESCAPE '#'"));
+ assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#abc' ESCAPE '#'"));
+ assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'ab#' ESCAPE '#'"));
+ }
+
+ @Test
+ public void testLambda()
+ {
+ assertDoNotOptimize("transform(unbound_array, x -> x + x)", OPTIMIZED);
+ assertOptimizedEquals("transform(ARRAY[1, 5], x -> x + x)", "transform(ARRAY[1, 5], x -> x + x)");
+ assertOptimizedEquals("transform(sequence(1, 5), x -> x + x)", "transform(sequence(1, 5), x -> x + x)");
+ }
+
+ @Test
+ public void testFailedExpressionOptimization()
+ {
+ assertOptimizedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END",
+ "CASE unbound_long WHEN BIGINT '1' THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 END");
+
+ assertOptimizedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END",
+ "CASE unbound_boolean WHEN true THEN 1 ELSE cast(fail(8, 'ignored failure message') as integer) END");
+
+ assertOptimizedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END",
+ "CASE BIGINT '1234' WHEN unbound_long THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 ELSE 1 END");
+
+ assertOptimizedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end",
+ "case when unbound_boolean then 1 when cast(fail(8, 'ignored failure message') as boolean) then 2 end");
+
+ assertOptimizedMatches("case when unbound_boolean then 1 else 0 / 0 end",
+ "case when unbound_boolean then 1 else cast(fail(8, 'ignored failure message') as integer) end");
+
+ assertOptimizedMatches("case when unbound_boolean then 0 / 0 else 1 end",
+ "case when unbound_boolean then cast(fail(8, 'ignored failure message') as integer) else 1 end");
+ }
+
+ @Test(expectedExceptions = PrestoException.class)
+ public void testOptimizeDivideByZero()
+ {
+ optimize("0 / 0");
+ }
+
+ @Test
+ public void testMassiveArray()
+ {
+ assertDoNotOptimize("SEQUENCE(1, 999)", SERIALIZABLE);
+ assertDoNotOptimize("SEQUENCE(1, 1000)", SERIALIZABLE);
+ optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "(bound_long + " + i + ")").iterator())));
+ optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "(bound_integer + " + i + ")").iterator())));
+ optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "'" + i + "'").iterator())));
+ optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "ARRAY['" + i + "']").iterator())));
+ }
+
+ @Test
+ public void testArrayConstructor()
+ {
+ optimize("ARRAY []");
+ assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]",
+ "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))");
+ assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]",
+ "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))");
+ assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]",
+ "array_constructor((bound_long + 0), (unbound_long + 1), NULL)");
+ }
+
+ @Test
+ public void testRowConstructor()
+ {
+ optimize("ROW(NULL)");
+ optimize("ROW(1)");
+ optimize("ROW(unbound_long + 0)");
+ optimize("ROW(unbound_long + unbound_long2, unbound_string, unbound_double)");
+ optimize("ROW(unbound_boolean, FALSE, ARRAY[unbound_long, unbound_long2], unbound_null_string, unbound_interval)");
+ optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(unbound_string, 0.0E0)]");
+ optimize("ARRAY [ROW('string', unbound_double), ROW('string', bound_double)]");
+ optimize("ROW(ROW(NULL), ROW(ROW(ROW(ROW('rowception')))))");
+ optimize("ROW(unbound_string, bound_string)");
+
+ optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(CAST(bound_string AS VARCHAR), 0.0E0)]");
+ optimize("ARRAY [ROW(CAST(bound_string AS VARCHAR), 0.0E0), ROW(unbound_string, unbound_double)]");
+
+ optimize("ARRAY [ROW(unbound_string, unbound_double), CAST(NULL AS ROW(VARCHAR, DOUBLE))]");
+ optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]");
+ }
+
+ @Test
+ public void testDereference()
+ {
+ optimize("ARRAY []");
+ assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]",
+ "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))");
+ assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]",
+ "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))");
+ assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]",
+ "array_constructor((bound_long + 0), (unbound_long + 1), NULL)");
+ }
+
+ @Test
+ public void testRowDereference()
+ {
+ optimize("CAST(null AS ROW(a VARCHAR, b BIGINT)).a");
+ }
+
+ @Test
+ public void testRowSubscript()
+ {
+ assertOptimizedEquals("ROW (1, 'a', true)[3]", "true");
+ assertOptimizedEquals("ROW (1, 'a', ROW (2, 'b', ROW (3, 'c')))[3][3][2]", "'c'");
+ }
+
+ @Test(expectedExceptions = PrestoException.class)
+ public void testArraySubscriptConstantNegativeIndex()
+ {
+ optimize("ARRAY [1, 2, 3][-1]");
+ }
+
+ @Test(expectedExceptions = PrestoException.class)
+ public void testArraySubscriptConstantZeroIndex()
+ {
+ optimize("ARRAY [1, 2, 3][0]");
+ }
+
+ @Test
+ public void testMapSubscriptConstantIndexes()
+ {
+ optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[1]");
+ optimize("MAP(ARRAY [BIGINT '1', 2], ARRAY [3, 4])[1]");
+ optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[2]");
+ optimize("MAP(ARRAY [ARRAY[1,1]], ARRAY['a'])[ARRAY[1,1]]");
+ }
+
+ @Test
+ public void testLiterals()
+ {
+ optimize("date '2013-04-03' + unbound_interval");
+ optimize("timestamp '2013-04-03 03:04:05.321' + unbound_interval");
+ optimize("timestamp '2013-04-03 03:04:05.321 UTC' + unbound_interval");
+
+ optimize("interval '3' day * unbound_long");
+ // TODO: Pending on velox PR: https://github.com/facebookincubator/velox/pull/11612.
+// optimize("interval '3' year * unbound_integer");
+ }
+
+ @Test
+ public void assertLikeOptimizations()
+ {
+ assertOptimizedMatches("unbound_string LIKE bound_pattern", "unbound_string LIKE CAST('%el%' AS varchar)");
+ }
+
+ private RowExpression evaluate(String expression, boolean deterministic)
+ {
+ assertRoundTrip(expression);
+ RowExpression rowExpression = sqlToRowExpression(expression);
+ return optimizeRowExpression(rowExpression, EVALUATED);
+ }
+
+ private RowExpression optimize(@Language("SQL") String expression)
+ {
+ assertRoundTrip(expression);
+ RowExpression parsedExpression = sqlToRowExpression(expression);
+ return optimizeRowExpression(parsedExpression, OPTIMIZED);
+ }
+
+ private void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected)
+ {
+ RowExpression optimizedActual = optimize(actual);
+ RowExpression optimizedExpected = optimize(expected);
+ assertRowExpressionEvaluationEquals(optimizedActual, optimizedExpected);
+ }
+
+ private void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected)
+ {
+ RowExpression actualOptimized = optimize(actual);
+ RowExpression expectedOptimized = optimize(expected);
+ assertRowExpressionEvaluationEquals(
+ actualOptimized,
+ expectedOptimized);
+ }
+
+ private void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel)
+ {
+ assertRoundTrip(expression);
+ RowExpression rowExpression = sqlToRowExpression(expression);
+ RowExpression rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel);
+ assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression);
+ }
+
+ private RowExpression sqlToRowExpression(String expression)
+ {
+ Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES);
+ return TRANSLATOR.translate(parsedExpression, SYMBOL_TYPES);
+ }
+
+ private Object symbolConstant(Symbol symbol)
+ {
+ switch (symbol.getName().toLowerCase(ENGLISH)) {
+ case "bound_integer":
+ case "bound_long":
+ return 1234L;
+ case "bound_string":
+ return utf8Slice("hello");
+ case "bound_double":
+ return 12.34;
+ case "bound_date":
+ return new LocalDate(2001, 8, 22).toDateMidnight(DateTimeZone.UTC).getMillis();
+ case "bound_time":
+ return new LocalTime(3, 4, 5, 321).toDateTime(new DateTime(0, DateTimeZone.UTC)).getMillis();
+ case "bound_timestamp":
+ return new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis();
+ case "bound_pattern":
+ return utf8Slice("%el%");
+ case "bound_timestamp_with_timezone":
+ return new SqlTimestampWithTimeZone(new DateTime(1970, 1, 1, 1, 0, 0, 999, DateTimeZone.UTC).getMillis(), getTimeZoneKey("Z"));
+ case "bound_varbinary":
+ return Slices.wrappedBuffer((byte) 0xab);
+ case "bound_decimal_short":
+ return 12345L;
+ case "bound_decimal_long":
+ return Decimals.encodeUnscaledValue(new BigInteger("12345678901234567890123"));
+ }
+ return null;
+ }
+
+ /**
+ * Assert the evaluation result of two row expressions equivalent
+ * no matter they are constants or remaining row expressions.
+ */
+ private void assertRowExpressionEvaluationEquals(RowExpression left, RowExpression right)
+ {
+ if (left instanceof ConstantExpression) {
+ if (isRemovableCast(right)) {
+ assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0));
+ return;
+ }
+ assertTrue(right instanceof ConstantExpression);
+ assertConstantsEqual(((ConstantExpression) left), ((ConstantExpression) left));
+ }
+ else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) {
+ assertEquals(left, right);
+ }
+ else if (left instanceof CallExpression && ((CallExpression) left).getFunctionHandle().getName().contains("fail")) {
+ assertTrue(right instanceof CallExpression && ((CallExpression) right).getFunctionHandle().getName().contains("fail"));
+ assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size());
+ for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) {
+ assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i));
+ }
+ }
+ else if (left instanceof CallExpression) {
+ assertTrue(right instanceof CallExpression);
+ assertEquals(((CallExpression) left).getFunctionHandle(), ((CallExpression) right).getFunctionHandle());
+ assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size());
+ for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) {
+ assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i));
+ }
+ }
+ else if (left instanceof SpecialFormExpression) {
+ assertTrue(right instanceof SpecialFormExpression);
+ assertEquals(((SpecialFormExpression) left).getForm(), ((SpecialFormExpression) right).getForm());
+ assertEquals(((SpecialFormExpression) left).getArguments().size(), ((SpecialFormExpression) right).getArguments().size());
+ for (int i = 0; i < ((SpecialFormExpression) left).getArguments().size(); i++) {
+ assertRowExpressionEvaluationEquals(((SpecialFormExpression) left).getArguments().get(i), ((SpecialFormExpression) right).getArguments().get(i));
+ }
+ }
+ else {
+ assertTrue(left instanceof LambdaDefinitionExpression);
+ assertTrue(right instanceof LambdaDefinitionExpression);
+ assertEquals(((LambdaDefinitionExpression) left).getArguments(), ((LambdaDefinitionExpression) right).getArguments());
+ assertEquals(((LambdaDefinitionExpression) left).getArgumentTypes(), ((LambdaDefinitionExpression) right).getArgumentTypes());
+ assertRowExpressionEvaluationEquals(((LambdaDefinitionExpression) left).getBody(), ((LambdaDefinitionExpression) right).getBody());
+ }
+ }
+
+ private void assertConstantsEqual(ConstantExpression left, ConstantExpression right)
+ {
+ if (left.getValue() instanceof Block) {
+ assertTrue(right.getValue() instanceof Block);
+ assertBlockEquals((Block) left.getValue(), (Block) right.getValue());
+ }
+ else {
+ assertEquals(left.getValue(), right.getValue());
+ }
+ }
+
+ private static void assertBlockEquals(Block left, Block right)
+ {
+ SliceOutput sliceOutput = new DynamicSliceOutput(1000);
+ BlockSerdeUtil.writeBlock(BLOCK_ENCODING_SERDE, sliceOutput, right);
+ SliceOutput sliceOutput1 = new DynamicSliceOutput(1000);
+ BlockSerdeUtil.writeBlock(BLOCK_ENCODING_SERDE, sliceOutput1, left);
+ assertEquals(sliceOutput1.slice(), sliceOutput.slice());
+ }
+
+ private boolean isRemovableCast(RowExpression value)
+ {
+ if (value instanceof CallExpression &&
+ new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) {
+ Type targetType = value.getType();
+ Type sourceType = ((CallExpression) value).getArguments().get(0).getType();
+ return METADATA.getFunctionAndTypeManager().canCoerce(sourceType, targetType);
+ }
+ return false;
+ }
+
+ private void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected)
+ {
+ assertRowExpressionEvaluationEquals(evaluate(actual, true), evaluate(expected, true));
+ }
+
+ private void assertRoundTrip(String expression)
+ {
+ ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION);
+ assertEquals(SQL_PARSER.createExpression(expression, parsingOptions),
+ SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions));
+ }
+
+ private void assertRowExpressionEquals(ExpressionOptimizer.Level level, @Language("SQL") String actual, @Language("SQL") String expected)
+ {
+ Object actualResult = optimizeRowExpression(sqlToRowExpression(actual), level);
+ Object expectedResult = optimizeRowExpression(sqlToRowExpression(expected), level);
+ if (actualResult instanceof Block && expectedResult instanceof Block) {
+ assertBlockEquals((Block) actualResult, (Block) expectedResult);
+ return;
+ }
+ assertEquals(actualResult, expectedResult);
+ }
+
+ private RowExpression optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level)
+ {
+ return expressionOptimizer.optimize(
+ expression,
+ level,
+ TEST_SESSION.toConnectorSession(),
+ variable -> {
+ Symbol symbol = new Symbol(variable.getName());
+ Object value = symbolConstant(symbol);
+ if (value == null) {
+ return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference()));
+ }
+ return value;
+ });
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressions.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressions.java
new file mode 100644
index 0000000000000..d9ecbf8d8a8eb
--- /dev/null
+++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressions.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.session.sql.expressions;
+
+import com.facebook.airlift.bootstrap.Bootstrap;
+import com.facebook.airlift.jaxrs.JsonMapper;
+import com.facebook.airlift.json.JsonModule;
+import com.facebook.presto.block.BlockJsonSerde;
+import com.facebook.presto.client.NodeVersion;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.BlockEncoding;
+import com.facebook.presto.common.block.BlockEncodingManager;
+import com.facebook.presto.common.block.BlockEncodingSerde;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.TypeManager;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
+import com.facebook.presto.metadata.HandleJsonModule;
+import com.facebook.presto.metadata.HandleResolver;
+import com.facebook.presto.metadata.InMemoryNodeManager;
+import com.facebook.presto.metadata.InternalNode;
+import com.facebook.presto.metadata.InternalNodeManager;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.metadata.MetadataManager;
+import com.facebook.presto.nodeManager.PluginNodeManager;
+import com.facebook.presto.spi.ConnectorId;
+import com.facebook.presto.spi.NodeManager;
+import com.facebook.presto.spi.RowExpressionSerde;
+import com.facebook.presto.spi.relation.ExpressionOptimizer;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.sql.TestingRowExpressionTranslator;
+import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.facebook.presto.type.TypeDeserializer;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.inject.Injector;
+import com.google.inject.Module;
+import com.google.inject.Scopes;
+import org.testng.annotations.Test;
+
+import java.net.URI;
+
+import static com.facebook.airlift.json.JsonBinder.jsonBinder;
+import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
+import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
+import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.findRandomPortForWorker;
+import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeSidecarProcess;
+import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
+import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression;
+import static com.google.inject.multibindings.Multibinder.newSetBinder;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+import static org.testng.Assert.assertEquals;
+
+public class TestNativeExpressions
+{
+ private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
+ private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA);
+ private Process sidecar;
+
+ @Test
+ public void testLoadPlugin()
+ throws Exception
+ {
+ try {
+ int port = findRandomPortForWorker();
+ URI sidecarUri = URI.create(format("http://127.0.0.1:%s", port));
+ sidecar = getNativeSidecarProcess(sidecarUri, port);
+ ExpressionOptimizer interpreterService = getExpressionOptimizer(METADATA, null, sidecarUri);
+
+ // Test the native row expression interpreter service with some simple expressions
+ RowExpression simpleAddition = compileExpression("1+1");
+ RowExpression unnecessaryCoalesce = compileExpression("coalesce(1, 2)");
+
+ // Assert simple optimizations are performed
+ assertEquals(interpreterService.optimize(simpleAddition, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(2L, simpleAddition.getType()));
+ assertEquals(interpreterService.optimize(unnecessaryCoalesce, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(1L, unnecessaryCoalesce.getType()));
+ }
+ finally {
+ if (sidecar != null) {
+ sidecar.destroy();
+ }
+ }
+ }
+
+ private static RowExpression compileExpression(String expression)
+ {
+ return TRANSLATOR.translate(expression, ImmutableMap.of());
+ }
+
+ protected static ExpressionOptimizer getExpressionOptimizer(Metadata metadata, HandleResolver handleResolver, URI sidecarUri)
+ {
+ // Set up dependencies in main for this module
+ InMemoryNodeManager nodeManager = getNodeManagerWithSidecar(sidecarUri);
+ Injector prestoMainInjector = getPrestoMainInjector(metadata, handleResolver);
+ RowExpressionSerde rowExpressionSerde = prestoMainInjector.getInstance(RowExpressionSerde.class);
+ FunctionAndTypeManager functionMetadataManager = prestoMainInjector.getInstance(FunctionAndTypeManager.class);
+
+ // Create the native row expression interpreter service
+ return createExpressionOptimizer(nodeManager, rowExpressionSerde, functionMetadataManager);
+ }
+
+ private static InMemoryNodeManager getNodeManagerWithSidecar(URI sidecarUri)
+ {
+ InMemoryNodeManager nodeManager = new InMemoryNodeManager();
+ nodeManager.addNode(new ConnectorId("test"), new InternalNode("test", sidecarUri, NodeVersion.UNKNOWN, false, false, false, true));
+ return nodeManager;
+ }
+
+ private static ExpressionOptimizer createExpressionOptimizer(InternalNodeManager internalNodeManager, RowExpressionSerde rowExpressionSerde, FunctionAndTypeManager functionMetadataManager)
+ {
+ requireNonNull(internalNodeManager, "inMemoryNodeManager is null");
+ NodeManager nodeManager = new PluginNodeManager(internalNodeManager);
+ FunctionResolution functionResolution = new FunctionResolution(functionMetadataManager.getFunctionAndTypeResolver());
+
+ Bootstrap app = new Bootstrap(
+ // Specially use a testing HTTP client instead of a real one
+ new NativeExpressionsCommunicationModule(),
+ // Otherwise use the exact same module as the native row expression interpreter service
+ new NativeExpressionsModule(nodeManager, rowExpressionSerde, functionMetadataManager, functionResolution));
+
+ Injector injector = app
+ .noStrictConfig()
+ .doNotInitializeLogging()
+ .setRequiredConfigurationProperties(ImmutableMap.of())
+ .quiet()
+ .initialize();
+ return injector.getInstance(NativeExpressionOptimizerProvider.class).createOptimizer();
+ }
+
+ private static Injector getPrestoMainInjector(Metadata metadata, HandleResolver handleResolver)
+ {
+ Module module = binder -> {
+ // Installs the JSON codec
+ binder.install(new JsonModule());
+ // Required to deserialize function handles
+ binder.install(new HandleJsonModule(handleResolver));
+ // Required for this test in the JaxrsTestingHttpProcessor because the underlying object mapper
+ // must be the same as all other object mappers
+ binder.bind(JsonMapper.class);
+
+ // These dependencies are needed to serialize and deserialize types (found in expressions)
+ FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager();
+ binder.bind(FunctionAndTypeManager.class).toInstance(functionAndTypeManager);
+ binder.bind(TypeManager.class).toInstance(functionAndTypeManager);
+ jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class);
+ newSetBinder(binder, Type.class);
+
+ // These dependencies are needed to serialize and deserialize blocks (found in constant values of expressions)
+ binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON);
+ newSetBinder(binder, BlockEncoding.class);
+ jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class);
+ jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class);
+
+ // Create the serde which is used by the plugin to serialize and deserialize expressions
+ jsonCodecBinder(binder).bindJsonCodec(RowExpression.class);
+ binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON);
+ };
+ Bootstrap app = new Bootstrap(ImmutableList.of(module));
+ Injector injector = app
+ .doNotInitializeLogging()
+ .quiet()
+ .initialize();
+ return injector;
+ }
+}
diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java
index 994d824195600..e518c40bb966a 100644
--- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java
+++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java
@@ -13,16 +13,21 @@
*/
package com.facebook.presto.sidecar;
+import com.facebook.presto.metadata.InMemoryNodeManager;
+import com.facebook.presto.nodeManager.PluginNodeManager;
import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory;
import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory;
import com.facebook.presto.testing.QueryRunner;
import com.google.common.collect.ImmutableMap;
+import java.io.IOException;
+
public class NativeSidecarPluginQueryRunnerUtils
{
private NativeSidecarPluginQueryRunnerUtils() {};
public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
+ throws IOException
{
queryRunner.installCoordinatorPlugin(new NativeSidecarPlugin());
queryRunner.loadSessionPropertyProvider(NativeSystemSessionPropertyProviderFactory.NAME);
@@ -32,5 +37,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
ImmutableMap.of(
"supported-function-languages", "CPP",
"function-implementation-type", "CPP"));
+ queryRunner.getExpressionManager().loadExpressionOptimizerFactory("native", "native", ImmutableMap.of());
+ queryRunner.getPlanCheckerProviderManager().loadPlanCheckerProviders(new PluginNodeManager(new InMemoryNodeManager()));
}
}
diff --git a/presto-openapi/src/main/resources/expressions.yaml b/presto-openapi/src/main/resources/expressions.yaml
new file mode 100644
index 0000000000000..8d1e3ef5d9379
--- /dev/null
+++ b/presto-openapi/src/main/resources/expressions.yaml
@@ -0,0 +1,173 @@
+openapi: 3.0.0
+info:
+ title: Presto Expression API
+ description: API for evaluating and simplifying row expressions in Presto
+ version: "1"
+servers:
+ - url: http://localhost:8080
+ description: Presto endpoint when running locally
+paths:
+ /v1/expressions:
+ post:
+ summary: Simplify the list of row expressions
+ description: This endpoint takes in a list of row expressions and attempts to simplify them to their simplest logical equivalent expression.
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RowExpressions'
+ required: true
+ responses:
+ '200':
+ description: Results
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RowExpressions'
+components:
+ schemas:
+ RowExpressions:
+ type: array
+ maxItems: 100
+ items:
+ $ref: "#/components/schemas/RowExpression"
+ RowExpression:
+ oneOf:
+ - $ref: "#/components/schemas/ConstantExpression"
+ - $ref: "#/components/schemas/VariableReferenceExpression"
+ - $ref: "#/components/schemas/InputReferenceExpression"
+ - $ref: "#/components/schemas/LambdaDefinitionExpression"
+ - $ref: "#/components/schemas/SpecialFormExpression"
+ - $ref: "#/components/schemas/CallExpression"
+ RowExpressionParent:
+ type: object
+ properties:
+ sourceLocation:
+ $ref: "#/components/schemas/SourceLocation"
+ SourceLocation:
+ description: The source location of the row expression in the original query, referencing the line and the column of the query.
+ type: object
+ properties:
+ line:
+ type: integer
+ column:
+ type: integer
+ ConstantExpression:
+ description: A constant expression is a row expression that represents a constant value. The value attribute is the constant value.
+ allOf:
+ - $ref: "#/components/schemas/RowExpressionParent"
+ - type: object
+ properties:
+ "@type":
+ type: string
+ enum : ["constant"]
+ typeSignature:
+ type: string
+ valueBlock:
+ type: string
+ VariableReferenceExpression:
+ description: A variable reference expression is a row expression that represents a reference to a variable. The name attribute indicates the name of the variable.
+ allOf:
+ - $ref: "#/components/schemas/RowExpressionParent"
+ - type: object
+ properties:
+ "@type":
+ type: string
+ enum : ["variable"]
+ typeSignature:
+ type: string
+ name:
+ type: string
+ InputReferenceExpression:
+ description: >
+ An input reference expression is a row expression that represents a reference to a column in the input schema. The field attribute indicates the index of the column in the
+ input schema.
+ allOf:
+ - $ref: "#/components/schemas/RowExpressionParent"
+ - type: object
+ properties:
+ "@type":
+ type: string
+ enum : ["input"]
+ typeSignature:
+ type: string
+ field:
+ type: integer
+ LambdaDefinitionExpression:
+ description: >
+ A lambda definition expression is a row expression that represents a lambda function. The lambda function is defined by a list of argument types, a list of argument names,
+ and a body expression.
+ allOf:
+ - $ref: "#/components/schemas/RowExpressionParent"
+ - type: object
+ properties:
+ "@type":
+ type: string
+ enum : ["lambda"]
+ argumentTypeSignatures:
+ type: array
+ items:
+ type: string
+ arguments:
+ type: array
+ items:
+ type: string
+ body:
+ $ref: "#/components/schemas/RowExpression"
+ SpecialFormExpression:
+ description: >
+ A special form expression is a row expression that represents a special language construct. The form attribute indicates the specific form of the special form,
+ which is a well known list, and with each having special semantics. The arguments attribute is a list of row expressions that are the arguments to the special form, with
+ each form taking in a specific number of arguments.
+ allOf:
+ - $ref: "#/components/schemas/RowExpressionParent"
+ - type: object
+ properties:
+ "@type":
+ type: string
+ enum : ["special"]
+ form:
+ type: string
+ enum: ["IF","NULL_IF","SWITCH","WHEN","IS_NULL","COALESCE","IN","AND","OR","DEREFERENCE","ROW_CONSTRUCTOR","BIND"]
+ returnTypeSignature:
+ type: string
+ arguments:
+ type: array
+ items:
+ $ref: "#/components/schemas/RowExpression"
+ CallExpression:
+ description: >
+ A call expression is a row expression that represents a call to a function. The functionHandle attribute is an opaque handle to the function that is being called.
+ The arguments attribute is a list of row expressions that are the arguments to the function.
+ allOf:
+ - $ref: "#/components/schemas/RowExpressionParent"
+ - type: object
+ properties:
+ "@type":
+ type: string
+ enum : ["call"]
+ displayName:
+ type: string
+ functionHandle:
+ $ref: "#/components/schemas/FunctionHandle"
+ returnTypeSignature:
+ type: string
+ arguments:
+ type: array
+ items:
+ $ref: "#/components/schemas/RowExpression"
+ FunctionHandle:
+ description: An opaque handle to a function that may be invoked. This is interpreted by the registered function namespace manager.
+ anyOf:
+ - $ref: "#/components/schemas/OpaqueFunctionHandle"
+ - $ref: "#/components/schemas/SqlFunctionHandle"
+ OpaqueFunctionHandle:
+ type: object
+ properties: {} # any opaque object may be passed and interpreted by a function namespace manager
+ SqlFunctionHandle:
+ type: object
+ properties:
+ functionId:
+ type: string
+ version:
+ type: string
diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java
index e9d03e52a461f..7b10fe37c78e7 100644
--- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java
+++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java
@@ -137,6 +137,7 @@
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.PageIndexerFactory;
import com.facebook.presto.spi.PageSorter;
+import com.facebook.presto.spi.RowExpressionSerde;
import com.facebook.presto.spi.analyzer.ViewDefinition;
import com.facebook.presto.spi.memory.ClusterMemoryPoolManager;
import com.facebook.presto.spi.plan.SimplePlanFragment;
@@ -144,6 +145,7 @@
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.DomainTranslator;
import com.facebook.presto.spi.relation.PredicateCompiler;
+import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spiller.GenericPartitioningSpillerFactory;
import com.facebook.presto.spiller.GenericSpillerFactory;
@@ -177,6 +179,7 @@
import com.facebook.presto.sql.analyzer.QueryExplainer;
import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
+import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler;
@@ -306,6 +309,7 @@ protected void setup(Binder binder)
jsonCodecBinder(binder).bindJsonCodec(BroadcastFileInfo.class);
jsonCodecBinder(binder).bindJsonCodec(SimplePlanFragment.class);
binder.bind(SimplePlanFragmentSerde.class).to(JsonCodecSimplePlanFragmentSerde.class).in(Scopes.SINGLETON);
+ jsonCodecBinder(binder).bindJsonCodec(RowExpression.class);
// smile codecs
smileCodecBinder(binder).bindSmileCodec(TaskSource.class);
@@ -352,6 +356,7 @@ protected void setup(Binder binder)
// expression manager
binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON);
+ binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON);
// tracer provider managers
binder.bind(TracerProviderManager.class).in(Scopes.SINGLETON);
diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java
new file mode 100644
index 0000000000000..ab5381aa2556c
--- /dev/null
+++ b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.spi;
+
+import com.facebook.presto.spi.relation.RowExpression;
+
+public interface RowExpressionSerde
+{
+ String serialize(RowExpression expression);
+
+ RowExpression deserialize(String value);
+}
diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java
index c9a6f84aaaa1d..51b7cce149d8f 100644
--- a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java
+++ b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java
@@ -14,6 +14,7 @@
package com.facebook.presto.spi.sql.planner;
import com.facebook.presto.spi.NodeManager;
+import com.facebook.presto.spi.RowExpressionSerde;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
@@ -22,12 +23,14 @@
public class ExpressionOptimizerContext
{
private final NodeManager nodeManager;
+ private final RowExpressionSerde rowExpressionSerde;
private final FunctionMetadataManager functionMetadataManager;
private final StandardFunctionResolution functionResolution;
- public ExpressionOptimizerContext(NodeManager nodeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution)
+ public ExpressionOptimizerContext(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution)
{
this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
+ this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null");
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
}
@@ -37,6 +40,11 @@ public NodeManager getNodeManager()
return nodeManager;
}
+ public RowExpressionSerde getRowExpressionSerde()
+ {
+ return rowExpressionSerde;
+ }
+
public FunctionMetadataManager getFunctionMetadataManager()
{
return functionMetadataManager;
diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java
index bedd15a5d9e10..09b03cfb0409b 100644
--- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java
+++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java
@@ -26,10 +26,12 @@
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.nodeManager.PluginNodeManager;
import com.facebook.presto.spi.WarningCollector;
+import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.security.AccessDeniedException;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.analyzer.QueryExplainer;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
+import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.Plan;
@@ -61,6 +63,7 @@
import java.util.OptionalLong;
import java.util.function.Consumer;
+import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException;
import static com.facebook.presto.sql.SqlFormatter.formatSql;
import static com.facebook.presto.transaction.TransactionBuilder.transaction;
@@ -576,7 +579,8 @@ private QueryExplainer getQueryExplainer()
featuresConfig,
new ExpressionOptimizerManager(
new PluginNodeManager(new InMemoryNodeManager()),
- queryRunner.getMetadata().getFunctionAndTypeManager()))
+ queryRunner.getMetadata().getFunctionAndTypeManager(),
+ new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))))
.getPlanningTimeOptimizers();
return new QueryExplainer(
optimizers,