diff --git a/pom.xml b/pom.xml index 56310a9d7604b..c2c5327d7f6ad 100644 --- a/pom.xml +++ b/pom.xml @@ -793,6 +793,13 @@ ${project.version} + + com.facebook.presto + presto-tests + ${project.version} + test-jar + + com.facebook.presto presto-benchmark @@ -1024,6 +1031,13 @@ ${project.version} + + com.facebook.presto + presto-native-execution + ${project.version} + test-jar + + com.facebook.hive hive-dwrf diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index 0c07b99aaab4e..cc2dee1bd61d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -23,6 +23,18 @@ public class HandleJsonModule implements Module { + private final HandleResolver handleResolver; + + public HandleJsonModule() + { + this(null); + } + + public HandleJsonModule(HandleResolver handleResolver) + { + this.handleResolver = handleResolver; + } + @Override public void configure(Binder binder) { @@ -38,6 +50,11 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(MetadataUpdateJacksonModule.class); - binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + if (handleResolver == null) { + binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + } + else { + binder.bind(HandleResolver.class).toInstance(handleResolver); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index 3071bd0e5d0f3..4e5075b812ba4 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -148,6 +148,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.plan.SimplePlanFragment; @@ -155,6 +156,7 @@ import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.session.WorkerSessionPropertyProvider; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; @@ -195,6 +197,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -362,6 +365,7 @@ else if (serverConfig.isCoordinator()) { // expression manager binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); // schema properties binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON); @@ -555,6 +559,7 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon jsonCodecBinder(binder).bindJsonCodec(SqlInvokedFunction.class); jsonCodecBinder(binder).bindJsonCodec(TaskSource.class); jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); smileCodecBinder(binder).bindSmileCodec(TaskStatus.class); smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java index 0d200e5555a92..2a8bd6e391ca8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java @@ -19,6 +19,7 @@ import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.ExpressionOptimizerProvider; import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; @@ -53,6 +54,7 @@ public class ExpressionOptimizerManager private final NodeManager nodeManager; private final FunctionAndTypeManager functionAndTypeManager; + private final RowExpressionSerde rowExpressionSerde; private final FunctionResolution functionResolution; private final File configurationDirectory; @@ -60,16 +62,17 @@ public class ExpressionOptimizerManager private final Map expressionOptimizers = new ConcurrentHashMap<>(); @Inject - public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager) + public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde) { - this(nodeManager, functionAndTypeManager, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY); + this(nodeManager, functionAndTypeManager, rowExpressionSerde, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY); } - public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, File configurationDirectory) + public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde, File configurationDirectory) { requireNonNull(nodeManager, "nodeManager is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); this.configurationDirectory = requireNonNull(configurationDirectory, "configurationDirectory is null"); expressionOptimizers.put(DEFAULT_EXPRESSION_OPTIMIZER_NAME, new RowExpressionOptimizer(functionAndTypeManager)); @@ -89,7 +92,7 @@ public void loadExpressionOptimizerFactories() } } - private void loadExpressionOptimizerFactory(File configurationFile) + public void loadExpressionOptimizerFactory(File configurationFile) throws IOException { String name = getNameWithoutExtension(configurationFile.getName()); @@ -99,13 +102,19 @@ private void loadExpressionOptimizerFactory(File configurationFile) Map properties = new HashMap<>(loadProperties(configurationFile)); String factoryName = properties.remove(EXPRESSION_MANAGER_FACTORY_NAME); checkArgument(!isNullOrEmpty(factoryName), "%s does not contain %s", configurationFile, EXPRESSION_MANAGER_FACTORY_NAME); + loadExpressionOptimizerFactory(factoryName, name, properties); + } + + public void loadExpressionOptimizerFactory(String factoryName, String expressionOptimizerName, Map properties) + { + requireNonNull(factoryName, "factoryName is null"); checkArgument(expressionOptimizerFactories.containsKey(factoryName), "ExpressionOptimizerFactory %s is not registered, registered factories: ", factoryName, expressionOptimizerFactories.keySet()); ExpressionOptimizer optimizer = expressionOptimizerFactories.get(factoryName).createOptimizer( properties, - new ExpressionOptimizerContext(nodeManager, functionAndTypeManager, functionResolution)); - expressionOptimizers.put(name, optimizer); + new ExpressionOptimizerContext(nodeManager, rowExpressionSerde, functionAndTypeManager, functionResolution)); + expressionOptimizers.put(expressionOptimizerName, optimizer); } public void addExpressionOptimizerFactory(ExpressionOptimizerFactory expressionOptimizerFactory) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java new file mode 100644 index 0000000000000..20cfcf151df6c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; + +import javax.inject.Inject; + +import java.nio.charset.StandardCharsets; + +import static java.util.Objects.requireNonNull; + +public class JsonCodecRowExpressionSerde + implements RowExpressionSerde +{ + private final JsonCodec codec; + + @Inject + public JsonCodecRowExpressionSerde(JsonCodec codec) + { + this.codec = requireNonNull(codec, "codec is null"); + } + + @Override + public String serialize(RowExpression expression) + { + return new String(codec.toBytes(expression), StandardCharsets.UTF_8); + } + + @Override + public RowExpression deserialize(String data) + { + return codec.fromBytes(data.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 8d2a9ac22e5bb..0360d5b951932 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -149,6 +149,7 @@ import com.facebook.presto.spi.plan.SimplePlanFragment; import com.facebook.presto.spi.plan.StageExecutionDescriptor; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; import com.facebook.presto.spiller.GenericSpillerFactory; @@ -175,6 +176,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -462,7 +464,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); NodeInfo nodeInfo = new NodeInfo("test"); - expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager()); + expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); this.statsNormalizer = new StatsNormalizer(); this.scalarStatsCalculator = new ScalarStatsCalculator(metadata, expressionOptimizerManager); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java b/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java index 913e0db26adc0..7340e45b42c1d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.Properties; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -65,6 +66,7 @@ public void setUp() manager = new ExpressionOptimizerManager( pluginNodeManager, METADATA.getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)), directory); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java index cea4abb4e0787..a2f535e4d5c2a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java @@ -23,9 +23,11 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.Optimizer; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.TypeProvider; @@ -49,6 +51,7 @@ import java.util.function.Consumer; import java.util.function.Function; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlanDoesNotMatch; import static com.facebook.presto.transaction.TransactionBuilder.transaction; @@ -177,7 +180,8 @@ private List getMinimalOptimizers() metadata, new ExpressionOptimizerManager( new PluginNodeManager(new InMemoryNodeManager()), - queryRunner.getFunctionAndTypeManager())).rules())); + queryRunner.getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)))).rules())); } private void inTransaction(Function transactionSessionConsumer) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java index 7817b311053df..78ff684386afc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.Expression; @@ -38,6 +39,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; @@ -185,7 +187,7 @@ private static void assertSimplifies(String expression, String rowExpressionExpe Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); InMemoryNodeManager nodeManager = new InMemoryNodeManager(); - ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager()); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(METADATA); RowExpression actualRowExpression = translator.translate(actualExpression, TypeProvider.viewOf(TYPES)); diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index c06e00edf834c..7eaf9f39177ce 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -51,6 +51,7 @@ target_link_libraries( presto_http presto_operators presto_velox_conversion + presto_expression_optimizer velox_abfs velox_aggregates velox_caching diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 0a422c66e16bd..0fa4b04bd1dfe 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -1551,6 +1551,13 @@ void PrestoServer::registerSidecarEndpoints() { proxygen::ResponseHandler* downstream) { http::sendOkResponse(downstream, getFunctionsMetadata()); }); + httpServer_->registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + optimizeExpressions(*message, body, downstream); + }); httpServer_->registerPost( "/v1/velox/plan", [server = this]( @@ -1599,4 +1606,25 @@ protocol::NodeStatus PrestoServer::fetchNodeStatus() { return nodeStatus; } +void PrestoServer::optimizeExpressions( + const proxygen::HTTPMessage& message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + json::array_t inputRowExpressions = + json::parse(util::extractMessageBody(body)); + auto rowExpressionOptimizer = + std::make_unique( + nativeWorkerPool_.get()); + auto result = rowExpressionOptimizer->optimize( + message.getHeaders(), inputRowExpressions); + if (result.second) { + VELOX_CHECK( + result.first.is_array(), + "The output json is not an array of RowExpressions"); + http::sendOkResponse(downstream, result.first); + } else { + http::sendErrorResponse(downstream, result.first); + } +} + } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index 9fa1301c1bf1d..f535f83547b85 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -25,6 +25,7 @@ #include "presto_cpp/main/PeriodicHeartbeatManager.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/PrestoServerOperations.h" +#include "presto_cpp/main/types/RowExpressionOptimizer.h" #include "presto_cpp/main/types/VeloxPlanValidator.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/memory/MemoryAllocator.h" @@ -217,6 +218,11 @@ class PrestoServer { protocol::NodeStatus fetchNodeStatus(); + void optimizeExpressions( + const proxygen::HTTPMessage& message, + const std::vector>& body, + proxygen::ResponseHandler* downstream); + void populateMemAndCPUInfo(); // Periodically yield tasks if there are tasks queued. diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 5841728512238..93d8d4abddccc 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -34,6 +34,12 @@ add_library(presto_velox_conversion OBJECT VeloxPlanConversion.cpp) target_link_libraries(presto_velox_conversion velox_type) +add_library(presto_expression_optimizer RowExpressionConverter.cpp + RowExpressionOptimizer.cpp) + +target_link_libraries(presto_expression_optimizer presto_type_converter + presto_types presto_protocol) + if(PRESTO_ENABLE_TESTING) add_subdirectory(tests) endif() diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 83d359f6d6aad..f3e0a9c766069 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -14,7 +14,6 @@ #include "presto_cpp/main/types/PrestoToVeloxExpr.h" #include -#include "presto_cpp/main/common/Configs.h" #include "presto_cpp/presto_protocol/Base64Util.h" #include "velox/common/base/Exceptions.h" #include "velox/functions/prestosql/types/JsonType.h" @@ -34,59 +33,9 @@ std::string toJsonString(const T& value) { } std::string mapScalarFunction(const std::string& name) { - static const std::string prestoDefaultNamespacePrefix = - SystemConfig::instance()->prestoDefaultNamespacePrefix(); - static const std::unordered_map kFunctionNames = { - // Operator overrides: com.facebook.presto.common.function.OperatorType - {"presto.default.$operator$add", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "plus")}, - {"presto.default.$operator$between", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "between")}, - {"presto.default.$operator$divide", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "divide")}, - {"presto.default.$operator$equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "eq")}, - {"presto.default.$operator$greater_than", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gt")}, - {"presto.default.$operator$greater_than_or_equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gte")}, - {"presto.default.$operator$is_distinct_from", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "distinct_from")}, - {"presto.default.$operator$less_than", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lt")}, - {"presto.default.$operator$less_than_or_equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lte")}, - {"presto.default.$operator$modulus", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "mod")}, - {"presto.default.$operator$multiply", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "multiply")}, - {"presto.default.$operator$negation", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "negate")}, - {"presto.default.$operator$not_equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "neq")}, - {"presto.default.$operator$subtract", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "minus")}, - {"presto.default.$operator$subscript", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "subscript")}, - {"presto.default.$operator$xx_hash_64", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "xxhash64_internal")}, - {"presto.default.combine_hash", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "combine_hash_internal")}, - // Special form function overrides. - {"presto.default.in", "in"}, - }; - std::string lowerCaseName = boost::to_lower_copy(name); - - auto it = kFunctionNames.find(lowerCaseName); - if (it != kFunctionNames.end()) { - return it->second; + if (prestoOperatorMap().find(lowerCaseName) != prestoOperatorMap().end()) { + return prestoOperatorMap().at(lowerCaseName); } return lowerCaseName; @@ -385,6 +334,67 @@ std::optional tryConvertLiteralArray( } } // namespace +const std::unordered_map prestoOperatorMap() { + static const std::string prestoDefaultNamespacePrefix = + SystemConfig::instance()->prestoDefaultNamespacePrefix(); + static const std::unordered_map kPrestoOperatorMap = + { + // Operator overrides: + // com.facebook.presto.common.function.OperatorType + {"presto.default.$operator$add", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "plus")}, + {"presto.default.$operator$between", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "between")}, + {"presto.default.$operator$divide", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "divide")}, + {"presto.default.$operator$equal", + util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "eq")}, + {"presto.default.$operator$greater_than", + util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gt")}, + {"presto.default.$operator$greater_than_or_equal", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "gte")}, + {"presto.default.$operator$is_distinct_from", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "distinct_from")}, + {"presto.default.$operator$less_than", + util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lt")}, + {"presto.default.$operator$less_than_or_equal", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "lte")}, + {"presto.default.$operator$modulus", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "mod")}, + {"presto.default.$operator$multiply", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "multiply")}, + {"presto.default.$operator$negation", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "negate")}, + {"presto.default.$operator$not_equal", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "neq")}, + {"presto.default.$operator$subtract", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "minus")}, + {"presto.default.$operator$subscript", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "subscript")}, + {"presto.default.$operator$xx_hash_64", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "xxhash64_internal")}, + {"presto.default.combine_hash", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "combine_hash_internal")}, + // Special form function overrides. + {"presto.default.in", "in"}, + }; + return kPrestoOperatorMap; +} + std::optional VeloxExprConverter::tryConvertDate( const protocol::CallExpression& pexpr) const { static const std::string prestoDefaultNamespacePrefix = diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index 8526f6a8c9638..76181d4e092b1 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -22,6 +22,8 @@ namespace facebook::presto { +const std::unordered_map prestoOperatorMap(); + class VeloxExprConverter { public: VeloxExprConverter(velox::memory::MemoryPool* pool, TypeParser* typeParser) diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.cpp b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.cpp new file mode 100644 index 0000000000000..9454d24fdfaa7 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.cpp @@ -0,0 +1,429 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/types/RowExpressionConverter.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "velox/expression/FieldReference.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace facebook::presto::expression { + +namespace { +const std::string kConstant = "constant"; +const std::string kBoolean = "boolean"; +const std::string kVariable = "variable"; +const std::string kCall = "call"; +const std::string kStatic = "$static"; +const std::string kSpecial = "special"; +const std::string kCoalesce = "COALESCE"; +const std::string kRowConstructor = "ROW_CONSTRUCTOR"; +const std::string kSwitch = "SWITCH"; +const std::string kWhen = "WHEN"; + +protocol::TypeSignature getTypeSignature(const TypePtr& type) { + std::string typeSignature; + if (type->isPrimitiveType()) { + typeSignature = type->toString(); + boost::algorithm::to_lower(typeSignature); + } else { + std::string complexTypeString; + std::vector childTypes; + if (type->isRow()) { + complexTypeString = "row"; + childTypes = asRowType(type)->children(); + } else if (type->isArray()) { + complexTypeString = "array"; + childTypes = type->asArray().children(); + } else if (type->isMap()) { + complexTypeString = "map"; + const auto mapType = type->asMap(); + childTypes = {mapType.keyType(), mapType.valueType()}; + } else { + VELOX_USER_FAIL("Invalid type {}", type->toString()); + } + + typeSignature = complexTypeString + "("; + if (!childTypes.empty()) { + auto numChildren = childTypes.size(); + for (auto i = 0; i < numChildren - 1; i++) { + typeSignature += fmt::format("{},", getTypeSignature(childTypes[i])); + } + typeSignature += getTypeSignature(childTypes[numChildren - 1]); + } + typeSignature += ")"; + } + + return typeSignature; +} + +json toVariableReferenceExpression( + const std::shared_ptr& field) { + protocol::VariableReferenceExpression vexpr; + vexpr.name = field->name(); + vexpr._type = kVariable; + vexpr.type = getTypeSignature(field->type()); + json result; + protocol::to_json(result, vexpr); + + return result; +} + +bool isPrestoSpecialForm(const std::string& name) { + static const std::unordered_set kPrestoSpecialForms = { + "if", + "null_if", + "switch", + "when", + "is_null", + "coalesce", + "in", + "and", + "or", + "dereference", + "row_constructor", + "bind"}; + return kPrestoSpecialForms.count(name) != 0; +} + +json getWhenSpecialForm(const TypePtr& type, const json::array_t& whenArgs) { + json when; + when["@type"] = kSpecial; + when["form"] = kWhen; + when["arguments"] = whenArgs; + when["returnType"] = getTypeSignature(type); + return when; +} + +std::vector getRowExpressionArguments( + const RowExpressionPtr& input) { + std::vector arguments; + if (input->_type == kSpecial) { + auto inputSpecialForm = + dynamic_cast(input.get()); + VELOX_CHECK_NOT_NULL(inputSpecialForm); + arguments = inputSpecialForm->arguments; + } else if (input->_type == kCall) { + auto inputCall = dynamic_cast(input.get()); + VELOX_CHECK_NOT_NULL(inputCall); + arguments = inputCall->arguments; + } else { + VELOX_USER_FAIL( + "Input should be a SpecialFormExpression or CallExpression"); + } + return arguments; +} + +const std::unordered_map& veloxToPrestoOperatorMap() { + static std::unordered_map veloxToPrestoOperatorMap = + {{"cast", "presto.default.$operator$cast"}}; + for (const auto& entry : prestoOperatorMap()) { + veloxToPrestoOperatorMap[entry.second] = entry.first; + } + return veloxToPrestoOperatorMap; +} +} // namespace + +std::string RowExpressionConverter::getValueBlock( + const VectorPtr& vector) const { + std::ostringstream output; + serde_->serializeSingleColumn(vector, nullptr, pool_, &output); + const auto serialized = output.str(); + const auto serializedSize = serialized.size(); + return encoding::Base64::encode(serialized.c_str(), serializedSize); +} + +json RowExpressionConverter::getConstantRowExpression( + const std::shared_ptr& constantExpr) { + protocol::ConstantExpression cexpr; + cexpr.type = getTypeSignature(constantExpr->type()); + cexpr.valueBlock.data = getValueBlock(constantExpr->value()); + json result; + protocol::to_json(result, cexpr); + return result; +} + +RowExpressionConverter::SwitchFormArguments +RowExpressionConverter::getSimpleSwitchFormArgs( + const exec::SwitchExpr* switchExpr, + const std::vector& arguments) { + SwitchFormArguments result; + // Consider the following Presto query with simple form switch expression: + // SELECT CASE 1 WHEN 2 THEN 31 - 1 WHEN 1 THEN 32 + 1 WHEN orderkey THEN 34 + // ELSE 35 END FROM orders; + // When this Presto expression is converted to velox switch expression, the + // inputs to the velox expression contain the following `equal` expressions: + // eq(1, 2), eq(1, 1), eq(1, orderkey). The resultant Presto simple form + // switch expression requires the `expression` to be the first `argument`, + // please refer to the syntax of Presto simple form switch expression here: + // https://prestodb.io/docs/current/functions/conditional.html#case). It is + // not possible to get the value of `expression` from velox, since these + // equal expressions could all evaluate to either `true` or `false` during + // constant folding. Hence, this `expression` is obtained from the input + // Presto switch expression. + result.arguments.emplace_back(arguments[0]); + const auto& switchInputs = switchExpr->inputs(); + const auto numInputs = switchInputs.size(); + + for (auto i = 0; i < numInputs - 1; i += 2) { + json::array_t resultWhenArgs; + const vector_size_t argsIdx = i / 2 + 1; + const auto& caseValue = switchInputs[i + 1]; + auto inputWhenArgs = getRowExpressionArguments(arguments[argsIdx]); + + if (switchInputs[i]->isConstant()) { + auto constantExpr = + std::dynamic_pointer_cast(switchInputs[i]); + if (auto constVector = + constantExpr->value()->as>()) { + // If this is the first switch case that evaluates to true, return the + // expression corresponding to this case. From the aforementioned + // example, `eq(1, 1)` evaluates to true, so the value corresponding to + // the WHEN clause (`CASE 1 WHEN 1 THEN 32 + 1`), `33` is returned. + if (constVector->valueAt(0) && result.arguments.size() == 1) { + return { + true, + veloxToPrestoRowExpression(caseValue, inputWhenArgs[1]), + json::array()}; + } else { + // Skip switch cases that evaluate to false. From the aforementioned + // example, `eq(1, 2)` evaluates to false, so the corresponding WHEN + // clause, `CASE 1 WHEN 2 THEN 31 - 1`, is not included in the output + // switch expression's arguments. + continue; + } + } else { + resultWhenArgs.emplace_back(getConstantRowExpression(constantExpr)); + } + } else { + VELOX_CHECK(!switchInputs[i]->inputs().empty()); + const auto& matchExpr = switchInputs[i]->inputs().back(); + resultWhenArgs.emplace_back( + veloxToPrestoRowExpression(matchExpr, inputWhenArgs[0])); + } + + resultWhenArgs.emplace_back( + veloxToPrestoRowExpression(caseValue, inputWhenArgs[1])); + result.arguments.emplace_back( + getWhenSpecialForm(switchInputs[i + 1]->type(), resultWhenArgs)); + } + + // Else clause. + if (numInputs % 2 != 0) { + result.arguments.emplace_back(veloxToPrestoRowExpression( + switchInputs[numInputs - 1], arguments.back())); + } + return result; +} + +RowExpressionConverter::SwitchFormArguments +RowExpressionConverter::getSpecialSwitchFormArgs( + const exec::SwitchExpr* switchExpr, + const std::vector& arguments) { + // The searched form of CASE conditional expression in Presto needs to be + // handled differently from the simple form (please refer to: + // https://prestodb.io/docs/current/functions/conditional.html#case). The + // searched form can be detected by the presence of a boolean value in the + // first argument. This default boolean argument is not present in the velox + // switch expression, so it is added to the arguments of output switch + // expression unchanged. + if (arguments[0]->_type == kConstant) { + if (auto constantRowExpr = + dynamic_cast(arguments[0].get())) { + if (constantRowExpr->type == kBoolean) { + SwitchFormArguments result; + const auto& switchInputs = switchExpr->inputs(); + const auto numInputs = switchInputs.size(); + result.arguments = {arguments[0]}; + for (auto i = 0; i < numInputs - 1; i += 2) { + const vector_size_t argsIdx = i / 2 + 1; + std::vector inputWhenArgs = + getRowExpressionArguments(arguments[argsIdx]); + json::array_t resultWhenArgs; + resultWhenArgs.emplace_back( + veloxToPrestoRowExpression(switchInputs[i], inputWhenArgs[0])); + resultWhenArgs.emplace_back(veloxToPrestoRowExpression( + switchInputs[i + 1], inputWhenArgs[1])); + result.arguments.emplace_back( + getWhenSpecialForm(switchInputs[i + 1]->type(), resultWhenArgs)); + } + + // Else clause. + if (numInputs % 2 != 0) { + result.arguments.emplace_back(veloxToPrestoRowExpression( + switchInputs[numInputs - 1], arguments.back())); + } + return result; + } + } + } + + return getSimpleSwitchFormArgs(switchExpr, arguments); +} + +json RowExpressionConverter::getSpecialForm( + const exec::ExprPtr& expr, + const RowExpressionPtr& input) { + json result; + result["@type"] = kSpecial; + result["returnType"] = getTypeSignature(expr->type()); + auto form = expr->name(); + // Presto requires the field form to be in upper case. + std::transform(form.begin(), form.end(), form.begin(), ::toupper); + result["form"] = form; + std::vector inputArguments = + getRowExpressionArguments(input); + + // Arguments for switch expression include 'WHEN' special form expression(s) + // so they are constructed separately. If the switch expression evaluation + // found a case that always evaluates to `true`, the field 'isSimplified' in + // the result is `true` and the field 'caseExpression' contains the value + // corresponding to the simplified switch case. Otherwise, 'isSimplified' is + // false and the field 'arguments' contains the 'when' clauses needed by the + // Presto switch SpecialFormExpression. + if (form == kSwitch) { + auto switchExpr = dynamic_cast(expr.get()); + VELOX_CHECK_NOT_NULL(switchExpr); + auto switchResult = getSpecialSwitchFormArgs(switchExpr, inputArguments); + if (switchResult.isSimplified) { + return switchResult.caseExpression; + } else { + result["arguments"] = switchResult.arguments; + } + } else { + // Presto special form expressions that are not of type `SWITCH`, such as + // `IN`, `AND`, `OR` etc,. are handled in this clause. The list of Presto + // special form expressions can be found in `kPrestoSpecialForms` in the + // helper function `isPrestoSpecialForm`. + auto exprInputs = expr->inputs(); + const auto numInputs = exprInputs.size(); + if (form == kCoalesce) { + VELOX_CHECK_LE(numInputs, inputArguments.size()); + } else { + VELOX_CHECK_EQ(numInputs, inputArguments.size()); + } + result["arguments"] = json::array(); + for (auto i = 0; i < numInputs; i++) { + result["arguments"].push_back( + veloxToPrestoRowExpression(exprInputs[i], inputArguments[i])); + } + } + + return result; +} + +json RowExpressionConverter::getRowConstructorSpecialForm( + std::shared_ptr& constantExpr) { + json result; + result["@type"] = kSpecial; + result["form"] = kRowConstructor; + result["returnType"] = getTypeSignature(constantExpr->type()); + auto value = constantExpr->value(); + auto* constVector = value->as>(); + auto* rowVector = constVector->valueVector()->as(); + auto type = asRowType(constantExpr->type()); + auto size = rowVector->children().size(); + + protocol::ConstantExpression cexpr; + json j; + result["arguments"] = json::array(); + for (auto i = 0; i < size; i++) { + cexpr.type = getTypeSignature(type->childAt(i)); + cexpr.valueBlock.data = getValueBlock(rowVector->childAt(i)); + protocol::to_json(j, cexpr); + result["arguments"].push_back(j); + } + return result; +} + +json RowExpressionConverter::toCallRowExpression( + const exec::ExprPtr& expr, + const RowExpressionPtr& input) { + json result; + result["@type"] = kCall; + protocol::Signature signature; + std::string exprName = expr->name(); + if (veloxToPrestoOperatorMap().find(exprName) != + veloxToPrestoOperatorMap().end()) { + exprName = veloxToPrestoOperatorMap().at(exprName); + } + signature.name = exprName; + result["displayName"] = exprName; + signature.kind = protocol::FunctionKind::SCALAR; + signature.typeVariableConstraints = {}; + signature.longVariableConstraints = {}; + signature.returnType = getTypeSignature(expr->type()); + + std::vector argumentTypes; + auto exprInputs = expr->inputs(); + auto numArgs = exprInputs.size(); + argumentTypes.reserve(numArgs); + for (auto i = 0; i < numArgs; i++) { + argumentTypes.emplace_back(getTypeSignature(exprInputs[i]->type())); + } + signature.argumentTypes = argumentTypes; + signature.variableArity = false; + + protocol::BuiltInFunctionHandle builtInFunctionHandle; + builtInFunctionHandle._type = kStatic; + builtInFunctionHandle.signature = signature; + result["functionHandle"] = builtInFunctionHandle; + result["returnType"] = getTypeSignature(expr->type()); + result["arguments"] = json::array(); + for (const auto& exprInput : exprInputs) { + result["arguments"].push_back(veloxToPrestoRowExpression(exprInput, input)); + } + + return result; +} + +json RowExpressionConverter::veloxToPrestoRowExpression( + const exec::ExprPtr& expr, + const RowExpressionPtr& input) { + if (expr->isConstant()) { + if (expr->inputs().empty()) { + auto constantExpr = + std::dynamic_pointer_cast(expr); + VELOX_CHECK_NOT_NULL(constantExpr); + // Constant velox expressions of ROW type map to ROW_CONSTRUCTOR special + // form expression in Presto. + if (expr->type()->isRow()) { + return getRowConstructorSpecialForm(constantExpr); + } + return getConstantRowExpression(constantExpr); + } else { + // Expressions such as 'divide(0, 0)' are not constant folded during + // compilation in velox, since they throw an exception (Divide by zero in + // this example) during evaluation (see function `tryFoldIfConstant` in + // `velox/expression/ExprCompiler.cpp`). The input expression is returned + // unchanged in such cases. + return input; + } + } + + if (auto field = + std::dynamic_pointer_cast(expr)) { + return toVariableReferenceExpression(field); + } + + // Check if special form expression or call expression. + auto exprName = expr->name(); + boost::algorithm::to_lower(exprName); + if (isPrestoSpecialForm(exprName)) { + return getSpecialForm(expr, input); + } + return toCallRowExpression(expr, input); +} + +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.h b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.h new file mode 100644 index 0000000000000..c6cd379188e17 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/RowExpressionConverter.h @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/presto_protocol.h" +#include "velox/expression/ConstantExpr.h" +#include "velox/expression/Expr.h" +#include "velox/expression/SwitchExpr.h" +#include "velox/serializers/PrestoSerializer.h" + +using namespace facebook::velox; + +using RowExpressionPtr = + std::shared_ptr; +using SpecialFormExpressionPtr = + std::shared_ptr; + +namespace facebook::presto::expression { + +/// Helper class to convert a velox expression to its corresponding Presto +/// protocol RowExpression. The function `veloxToPrestoRowExpression` is used in +/// RowExpressionOptimizer to convert the constant folded velox expression to +/// Presto protocol RowExpression: +/// 1. A velox constant expression of type exec::ConstantExpr without any inputs +/// is converted to Presto protocol expression of type ConstantExpression. +/// If the velox constant expression is of ROW type, it is converted to a +/// Presto protocol expression of type SpecialFormExpression (with Form as +/// ROW_CONSTRUCTOR). +/// 2. Velox expression representing a variable is of type exec::FieldReference, +/// it is converted to a Presto protocol expression of type +/// VariableReferenceExpression. +/// 3. Special form expressions and expressions with a vector function in velox +/// map either to a Presto protocol SpecialFormExpression or to a Presto +/// protocol CallExpression. This is because special form expressions in +/// velox and in Presto do not have a one to one mapping. If the velox +/// expression name belongs to the set of possible Presto protocol +/// SpecialFormExpression names, it is converted to a Presto protocol +/// SpecialFormExpression; else it is converted to a Presto protocol +/// CallExpression. +/// +/// The function `getConstantRowExpression` is used in RowExpressionOptimizer to +/// convert a velox constant expression to Presto protocol ConstantExpression. +class RowExpressionConverter { + public: + explicit RowExpressionConverter(memory::MemoryPool* pool) : pool_(pool) {} + + /// Converts a velox constant expression `constantExpr` to a Presto protocol + /// ConstantExpression. + json getConstantRowExpression( + const std::shared_ptr& constantExpr); + + /// Converts a velox expression `expr` to a Presto protocol RowExpression. + /// Argument `inputRowExpr` is the input Presto protocol RowExpression before + /// it is constant folded in velox. + json veloxToPrestoRowExpression( + const exec::ExprPtr& expr, + const RowExpressionPtr& inputRowExpr); + + private: + /// When 'isSimplified' is true, the cases (or arguments) in switch expression + /// have been simplified when the expression was constant folded in velox. In + /// this case, the expression corresponding to the first switch case that + /// always evaluates to true is returned in 'caseExpression'. When + /// 'isSimplified' is false, the cases in switch expression have not been + /// simplified, and the switch expression arguments required by Presto are + /// returned in 'arguments'. + struct SwitchFormArguments { + bool isSimplified = false; + json caseExpression; + json::array_t arguments; + }; + + /// ValueBlock in Presto protocol ConstantExpression requires only the column + /// from the serialized PrestoPage without the page header. This function is + /// used to serialize a velox vector to ValueBlock. + std::string getValueBlock(const VectorPtr& vector) const; + + /// Helper function to get arguments for Presto protocol SpecialFormExpression + /// of type SWITCH when the CASE expression is of 'simple' form (please refer + /// to: https://prestodb.io/docs/current/functions/conditional.html#case). + SwitchFormArguments getSimpleSwitchFormArgs( + const exec::SwitchExpr* switchExpr, + const std::vector& inputArgs); + + /// Helper function to get arguments for Presto protocol SpecialFormExpression + /// of type SWITCH from a velox switch expression `switchExpr`. + SwitchFormArguments getSpecialSwitchFormArgs( + const exec::SwitchExpr* switchExpr, + const std::vector& inputArgs); + + /// Helper function to construct a Presto protocol SpecialFormExpression from + /// a velox expression `expr`. + json getSpecialForm(const exec::ExprPtr& expr, const RowExpressionPtr& input); + + /// Helper function to construct a Presto protocol SpecialFormExpression of + /// type ROW_CONSTRUCTOR from a velox constant expression `constantExpr`. + json getRowConstructorSpecialForm( + std::shared_ptr& constantExpr); + + /// Helper function to construct a Presto protocol CallExpression from a velox + /// expression `expr`. + json toCallRowExpression( + const exec::ExprPtr& expr, + const RowExpressionPtr& input); + + memory::MemoryPool* pool_; + const std::unique_ptr serde_ = + std::make_unique(); +}; +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.cpp b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.cpp new file mode 100644 index 0000000000000..1181fa3e03e9d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.cpp @@ -0,0 +1,325 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/types/RowExpressionOptimizer.h" +#include "velox/expression/ExprCompiler.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace facebook::presto::expression { + +namespace { + +const std::string kTimezoneHeader = "X-Presto-Time-Zone"; +const std::string kOptimizerLevelHeader = "X-Presto-Expression-Optimizer-Level"; +const std::string kEvaluated = "EVALUATED"; + +template +std::shared_ptr getConstantExpr( + const TypePtr& type, + const DecodedVector& decoded, + memory::MemoryPool* pool) { + if constexpr ( + KIND == TypeKind::ROW || KIND == TypeKind::UNKNOWN || + KIND == TypeKind::ARRAY || KIND == TypeKind::MAP) { + VELOX_USER_FAIL("Invalid result type {}", type->toString()); + } else { + using T = typename TypeTraits::NativeType; + auto constVector = std::make_shared>( + pool, decoded.size(), decoded.isNullAt(0), type, decoded.valueAt(0)); + return std::make_shared(constVector); + } +} +} // namespace + +exec::ExprPtr RowExpressionOptimizer::compileExpression( + const std::shared_ptr& inputRowExpr) { + auto typedExpr = veloxExprConverter_.toVeloxExpr(inputRowExpr); + exec::ExprSet exprSet{{typedExpr}, execCtx_.get()}; + auto compiledExprs = + exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true); + return compiledExprs[0]; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeAndSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto leftExpr = compileExpression(left); + bool isLeftNull = false; + bool leftValue = false; + + if (auto constantExpr = + std::dynamic_pointer_cast(leftExpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (!isLeftNull) { + if (auto constVector = + constantExpr->value()->as>()) { + if (!constVector->valueAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } else { + leftValue = true; + } + } + } + } + + auto rightExpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rightExpr)) { + if (constantExpr->value()->isNullAt(0) && (isLeftNull || leftValue)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + if (auto constVector = constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return left; + } else if (isLeftNull || leftValue) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeIfSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto condition = specialFormExpr->arguments[0]; + auto expr = compileExpression(condition); + + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (auto constVector = constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return specialFormExpr->arguments[1]; + } + return specialFormExpr->arguments[2]; + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeIsNullSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto expr = compileExpression(specialFormExpr); + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (constantExpr->value()->isNullAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeOrSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto leftExpr = compileExpression(left); + bool isLeftNull = false; + bool leftValue = true; + + if (auto constantExpr = + std::dynamic_pointer_cast(leftExpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (!isLeftNull) { + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } else { + leftValue = false; + } + } + } + } + + auto rightExpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rightExpr)) { + if (constantExpr->value()->isNullAt(0) && (isLeftNull || !leftValue)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + if (auto constVector = constantExpr->value()->as>()) { + if (!constVector->valueAt(0)) { + return left; + } else if (isLeftNull || !leftValue) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeCoalesceSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto argsNoNulls = specialFormExpr->arguments; + argsNoNulls.erase( + std::remove_if( + argsNoNulls.begin(), + argsNoNulls.end(), + [&](const auto& arg) { + auto compiledExpr = compileExpression(arg); + if (auto constantExpr = + std::dynamic_pointer_cast( + compiledExpr)) { + return constantExpr->value()->isNullAt(0); + } + return false; + }), + argsNoNulls.end()); + + if (argsNoNulls.empty()) { + return specialFormExpr->arguments[0]; + } + specialFormExpr->arguments = argsNoNulls; + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeSpecialForm( + const std::shared_ptr& specialFormExpr) { + switch (specialFormExpr->form) { + case protocol::Form::IF: + return optimizeIfSpecialForm(specialFormExpr); + case protocol::Form::NULL_IF: + VELOX_USER_FAIL("NULL_IF specialForm not supported"); + break; + case protocol::Form::IS_NULL: + return optimizeIsNullSpecialForm(specialFormExpr); + case protocol::Form::AND: + return optimizeAndSpecialForm(specialFormExpr); + case protocol::Form::OR: + return optimizeOrSpecialForm(specialFormExpr); + case protocol::Form::COALESCE: + return optimizeCoalesceSpecialForm(specialFormExpr); + case protocol::Form::IN: + case protocol::Form::DEREFERENCE: + case protocol::Form::SWITCH: + case protocol::Form::WHEN: + case protocol::Form::ROW_CONSTRUCTOR: + case protocol::Form::BIND: + default: + break; + } + + return specialFormExpr; +} + +json RowExpressionOptimizer::evaluateNonDeterministicConstantExpr( + const exec::ExprPtr& expr, + exec::ExprSet& exprSet) { + VELOX_CHECK(!expr->isDeterministic()); + std::vector compiledExprInputTypes; + std::vector compiledExprInputs; + for (const auto& exprInput : expr->inputs()) { + VELOX_CHECK( + exprInput->isConstant(), + "Inputs to non-deterministic expression to be evaluated must be constant"); + const auto inputAsConstExpr = + std::dynamic_pointer_cast(exprInput); + compiledExprInputs.emplace_back(inputAsConstExpr->value()); + compiledExprInputTypes.emplace_back(exprInput->type()); + } + + const auto inputVector = std::make_shared( + pool_, + ROW(std::move(compiledExprInputTypes)), + nullptr, + 1, + compiledExprInputs); + exec::EvalCtx evalCtx(execCtx_.get(), &exprSet, inputVector.get()); + std::vector results(1); + SelectivityVector rows(1); + exprSet.eval(rows, evalCtx, results); + auto res = results.front(); + DecodedVector decoded(*res, rows); + const auto constExpr = VELOX_DYNAMIC_TYPE_DISPATCH( + getConstantExpr, res->typeKind(), res->type(), decoded, pool_); + return rowExpressionConverter_.getConstantRowExpression(constExpr); +} + +json::array_t RowExpressionOptimizer::optimizeExpressions( + const json::array_t& input, + const std::string& optimizerLevel) { + const auto numExpr = input.size(); + json::array_t output = json::array(); + for (auto i = 0; i < numExpr; i++) { + std::shared_ptr inputRowExpr = input[i]; + if (const auto special = + std::dynamic_pointer_cast( + inputRowExpr)) { + inputRowExpr = optimizeSpecialForm(special); + } + auto typedExpr = veloxExprConverter_.toVeloxExpr(inputRowExpr); + exec::ExprSet exprSet{{typedExpr}, execCtx_.get()}; + auto compiledExprs = + exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true); + auto compiledExpr = compiledExprs[0]; + json resultJson; + + if (optimizerLevel == kEvaluated) { + if (compiledExpr->isConstant()) { + resultJson = rowExpressionConverter_.veloxToPrestoRowExpression( + compiledExpr, input[i]); + } else { + // Velox does not evaluate expressions that are non-deterministic during + // compilation with constant folding enabled. Presto might require such + // non-deterministic expressions to be evaluated as well, this would be + // indicated by the header field 'X-Presto-Expression-Optimizer-Level' + // in the http request made to the native sidecar. When this field is + // set to 'EVALUATED', non-deterministic expressions with constant + // inputs are also evaluated. + resultJson = + evaluateNonDeterministicConstantExpr(compiledExpr, exprSet); + } + } else { + resultJson = rowExpressionConverter_.veloxToPrestoRowExpression( + compiledExpr, input[i]); + } + + output.push_back(resultJson); + } + return output; +} + +std::pair RowExpressionOptimizer::optimize( + const proxygen::HTTPHeaders& httpHeaders, + const json::array_t& input) { + try { + auto timezone = httpHeaders.getSingleOrEmpty(kTimezoneHeader); + auto optimizerLevel = httpHeaders.getSingleOrEmpty(kOptimizerLevelHeader); + std::unordered_map config( + {{core::QueryConfig::kSessionTimezone, timezone}, + {core::QueryConfig::kAdjustTimestampToTimezone, "true"}}); + auto queryCtx = + core::QueryCtx::create(nullptr, core::QueryConfig{std::move(config)}); + execCtx_ = std::make_unique(pool_, queryCtx.get()); + + return {optimizeExpressions(input, optimizerLevel), true}; + } catch (const VeloxUserError& e) { + VLOG(1) << "VeloxUserError during expression evaluation: " << e.what(); + return {e.what(), false}; + } catch (const VeloxException& e) { + VLOG(1) << "VeloxException during expression evaluation: " << e.what(); + return {e.what(), false}; + } catch (const std::exception& e) { + VLOG(1) << "std::exception during expression evaluation: " << e.what(); + return {e.what(), false}; + } +} + +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.h b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.h new file mode 100644 index 0000000000000..737ee35d6af87 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/RowExpressionOptimizer.h @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/main/types/RowExpressionConverter.h" + +using namespace facebook::velox; + +namespace facebook::presto::expression { + +/// Helper class to optimize Presto protocol RowExpressions sent along the REST +/// endpoint 'v1/expressions' to the sidecar. +class RowExpressionOptimizer { + public: + explicit RowExpressionOptimizer(memory::MemoryPool* pool) + : pool_(pool), + veloxExprConverter_(pool, &typeParser_), + rowExpressionConverter_(RowExpressionConverter(pool)) {} + + /// Optimizes all expressions from the input json array. If the expression + /// optimization fails for any of the input expressions, the second value in + /// the returned pair is set to false and the returned json contains the + /// exception. Otherwise, the returned json is an array of optimized Presto + /// protocol RowExpressions. + std::pair optimize( + const proxygen::HTTPHeaders& httpHeaders, + const json::array_t& input); + + private: + /// Converts Presto protocol RowExpression into a velox expression with + /// constant folding enabled during velox expression compilation. + exec::ExprPtr compileExpression(const RowExpressionPtr& inputRowExpr); + + /// Evaluates Presto protocol SpecialFormExpression `AND` when the inputs can + /// be constant folded to `true`, `false`, or `NULL`. + RowExpressionPtr optimizeAndSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Evaluates Presto protocol SpecialFormExpression `IF` when the condition + /// can be constant folded to `true` or `false`. + RowExpressionPtr optimizeIfSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Evaluates Presto protocol SpecialFormExpression `IS_NULL` when the input + /// can be constant folded. + RowExpressionPtr optimizeIsNullSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Evaluates Presto protocol SpecialFormExpression `OR` when the inputs can + /// be constant folded to `true`, `false`, or `NULL`. + RowExpressionPtr optimizeOrSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Optimizes Presto protocol SpecialFormExpression `COALESCE` by removing + /// `NULL`s from the argument list. + RowExpressionPtr optimizeCoalesceSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Optimizes Presto protocol SpecialFormExpressions using optimization rules + /// borrowed from Presto function visitSpecialForm() in + /// RowExpressionInterpreter.java. + RowExpressionPtr optimizeSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Evaluates non-deterministic expressions with constant inputs. + json evaluateNonDeterministicConstantExpr( + const exec::ExprPtr& expr, + exec::ExprSet& exprSet); + + /// Optimizes and constant folds each expression from input json array and + /// returns an array of expressions that are optimized and constant folded. + /// Each expression in the input array is optimized with helper functions + /// `optimizeSpecialForm` (applicable only for special form expressions) and + /// `optimizeExpression`. The optimized expression is also evaluated if the + /// optimization level in the header of http request made to 'v1/expressions' + /// is 'EVALUATED'. `optimizeExpression` uses `RowExpressionConverter` to + /// convert velox expression(s) to their corresponding Presto protocol + /// RowExpression(s). + json::array_t optimizeExpressions( + const json::array_t& input, + const std::string& optimizationLevel); + + memory::MemoryPool* pool_; + std::unique_ptr execCtx_; + TypeParser typeParser_; + VeloxExprConverter veloxExprConverter_; + RowExpressionConverter rowExpressionConverter_; +}; +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt index 28f73aff40b80..3d1a50e463450 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt @@ -131,3 +131,36 @@ target_link_libraries( velox_exec_test_lib GTest::gtest GTest::gtest_main) + +add_executable(presto_expression_converter_test RowExpressionConverterTest.cpp) + +add_test( + NAME presto_expression_converter_test + COMMAND presto_expression_converter_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +target_link_libraries( + presto_expression_converter_test + presto_expression_optimizer + presto_http + presto_type_test_utils + velox_exec_test_lib + GTest::gtest + GTest::gtest_main) + +add_executable(presto_expression_optimizer_test RowExpressionOptimizerTest.cpp) + +add_test( + NAME presto_expression_optimizer_test + COMMAND presto_expression_optimizer_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +target_link_libraries( + presto_expression_optimizer_test + presto_expression_optimizer + presto_http + presto_type_test_utils + velox_exec_test_lib + velox_parse_utils + GTest::gtest + GTest::gtest_main) diff --git a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionConverterTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionConverterTest.cpp new file mode 100644 index 0000000000000..c227157c06fc0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionConverterTest.cpp @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/types/RowExpressionConverter.h" +#include +#include "presto_cpp/main/common/tests/test_json.h" +#include "presto_cpp/main/http/tests/HttpTestBase.h" +#include "presto_cpp/main/types/tests/TestUtils.h" +#include "velox/expression/CastExpr.h" +#include "velox/expression/ConjunctExpr.h" +#include "velox/expression/FieldReference.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/VectorStream.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +// RowExpressionConverter is used only by RowExpressionOptimizer so unit tests +// are only added for simple expressions here. End-to-end tests to validate +// velox expression to Presto RowExpression conversion for different types of +// expressions can be found in TestDelegatingExpressionOptimizer.java in +// presto-native-sidecar-plugin. +class RowExpressionConverterTest + : public ::testing::Test, + public facebook::velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + rowExpressionConverter_ = + std::make_unique(pool_.get()); + } + + void testFile( + const std::vector& inputs, + const std::string& fileName) { + json::array_t expected = json::parse( + slurp(facebook::presto::test::utils::getDataPath(fileName))); + auto numExpr = inputs.size(); + for (auto i = 0; i < numExpr; i++) { + RowExpressionPtr inputRowExpr; + // The input row expression is provided by RowExpressionOptimizer in the + // sidecar. Since RowExpressionConverter is being tested here, we use the + // expected row expression to mock the input. The input row expression is + // only used for conversion to special form and call expressions. + protocol::from_json(expected[i], inputRowExpr); + EXPECT_EQ( + rowExpressionConverter_->veloxToPrestoRowExpression( + inputs[i], inputRowExpr), + expected[i]); + } + } + + void testConstantExpressionConversion( + const std::vector& inputs) { + json::array_t expected = + json::parse(slurp(facebook::presto::test::utils::getDataPath( + "ConstantExpressionConversion.json"))); + auto numExpr = inputs.size(); + for (auto i = 0; i < numExpr; i++) { + auto constantExpr = + std::dynamic_pointer_cast(inputs[i]); + EXPECT_EQ( + rowExpressionConverter_->getConstantRowExpression(constantExpr), + expected[i]); + } + } + + std::unique_ptr rowExpressionConverter_; +}; + +TEST_F(RowExpressionConverterTest, constant) { + auto inputs = std::vector{ + std::make_shared(makeConstant(true, 1)), + std::make_shared(makeConstant(12, 1)), + std::make_shared(makeConstant(123, 1)), + std::make_shared(makeConstant(1234, 1)), + std::make_shared(makeConstant(12345, 1)), + std::make_shared( + makeConstant(std::numeric_limits::max(), 1)), + std::make_shared(makeConstant(123.456, 1)), + std::make_shared(makeConstant(1234.5678, 1)), + std::make_shared( + makeConstant(std::numeric_limits::quiet_NaN(), 1)), + std::make_shared(makeConstant("abcd", 1)), + std::make_shared( + makeNullConstant(TypeKind::BIGINT, 1)), + std::make_shared( + makeNullConstant(TypeKind::VARCHAR, 1))}; + testFile(inputs, "ConstantExpressionConversion.json"); + testConstantExpressionConversion(inputs); +} + +TEST_F(RowExpressionConverterTest, rowConstructor) { + auto inputs = std::vector{ + std::make_shared(makeConstantRow( + ROW({BIGINT(), REAL(), VARCHAR()}), + variant::row( + {static_cast(123456), + static_cast(12.345), + "abc"}), + 1)), + std::make_shared(makeConstantRow( + ROW({REAL(), VARCHAR(), INTEGER()}), + variant::row( + {std::numeric_limits::quiet_NaN(), + "", + std::numeric_limits::max()}), + 1))}; + testFile(inputs, "RowConstructorExpressionConversion.json"); +} + +TEST_F(RowExpressionConverterTest, variable) { + auto inputs = std::vector{ + std::make_shared( + VARCHAR(), std::vector{}, "c0"), + std::make_shared( + ARRAY(BIGINT()), std::vector{}, "c1"), + std::make_shared( + MAP(INTEGER(), REAL()), std::vector{}, "c2"), + std::make_shared( + ROW({SMALLINT(), VARCHAR(), DOUBLE()}), + std::vector{}, + "c3"), + std::make_shared( + MAP(ARRAY(TINYINT()), BOOLEAN()), std::vector{}, "c4"), + }; + testFile(inputs, "VariableExpressionConversion.json"); +} + +TEST_F(RowExpressionConverterTest, special) { + auto castExpr = std::make_shared(); + auto orExpr = std::make_shared(false); + auto andExpr = std::make_shared(true); + auto inputs = std::vector{ + castExpr->constructSpecialForm( + VARCHAR(), + {std::make_shared( + BIGINT(), std::vector{}, "c0")}, + false, + core::QueryConfig({})), + orExpr->constructSpecialForm( + BOOLEAN(), + {std::make_shared( + BOOLEAN(), std::vector{}, "c1"), + std::make_shared(makeConstant(true, 1))}, + false, + core::QueryConfig({})), + andExpr->constructSpecialForm( + BOOLEAN(), + {castExpr->constructSpecialForm( + BOOLEAN(), + {std::make_shared( + INTEGER(), std::vector{}, "c2")}, + false, + core::QueryConfig({})), + std::make_shared( + BOOLEAN(), std::vector{}, "c3")}, + false, + core::QueryConfig({}))}; + testFile(inputs, "CallAndSpecialFormExpressionConversion.json"); +} diff --git a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionOptimizerTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionOptimizerTest.cpp new file mode 100644 index 0000000000000..ace751c552117 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionOptimizerTest.cpp @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/types/RowExpressionOptimizer.h" +#include +#include +#include "presto_cpp/main/common/tests/test_json.h" +#include "presto_cpp/main/http/tests/HttpTestBase.h" +#include "presto_cpp/main/types/tests/TestUtils.h" +#include "velox/expression/FieldReference.h" +#include "velox/expression/RegisterSpecialForm.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/VectorStream.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::presto; + +// RowExpressionOptimizerTest only tests for basic expression optimization. +// End-to-end tests for different types of expressions can be found in +// TestDelegatingExpressionOptimizer.java in presto-native-sidecar-plugin. +class RowExpressionOptimizerTest + : public ::testing::Test, + public facebook::velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + parse::registerTypeResolver(); + functions::prestosql::registerAllScalarFunctions("presto.default."); + exec::registerFunctionCallToSpecialForms(); + rowExpressionOptimizer_ = + std::make_unique(pool()); + } + + void testFile(const std::string& testName) { + auto input = slurp(facebook::presto::test::utils::getDataPath( + fmt::format("{}Input.json", testName))); + json::array_t inputExpressions = json::parse(input); + proxygen::HTTPMessage httpMessage; + httpMessage.getHeaders().set( + "X-Presto-Time-Zone", "America/Bahia_Banderas"); + httpMessage.getHeaders().set( + "X-Presto-Expression-Optimizer-Level", "OPTIMIZED"); + auto result = rowExpressionOptimizer_->optimize( + httpMessage.getHeaders(), inputExpressions); + + EXPECT_EQ(result.second, true); + json resultExpressions = result.first; + EXPECT_EQ(resultExpressions.is_array(), true); + auto expected = slurp(facebook::presto::test::utils::getDataPath( + fmt::format("{}Expected.json", testName))); + json::array_t expectedExpressions = json::parse(expected); + auto numExpressions = resultExpressions.size(); + EXPECT_EQ(numExpressions, expectedExpressions.size()); + for (auto i = 0; i < numExpressions; i++) { + EXPECT_EQ(resultExpressions.at(i), expectedExpressions.at(i)); + } + } + + std::unique_ptr rowExpressionOptimizer_; +}; + +TEST_F(RowExpressionOptimizerTest, simple) { + // Files SimpleExpressions{Input|Expected}.json contain the input and expected + // JSON representing the RowExpressions resulting from the following queries: + // 1. select 1 + 2; + // 2. select abs(-11) + ceil(cast(3.4 as double)) + floor(cast(5.6 as + // double)); + // 3. select 2 between 1 and 3; + // Simple expression evaluation with constant folding is verified here. + testFile("SimpleExpressions"); +} + +TEST_F(RowExpressionOptimizerTest, switchSpecialForm) { + // Files SwitchSpecialFormExpressions{Input|Expected}.json contain the input + // and expected JSON representing the RowExpressions resulting from the + // following queries: + // Simple form: + // 1. select case 1 when 1 then 32 + 1 when 1 then 34 end; + // 2. select case null when true then 32 else 33 end; + // 3. select case 33 when 0 then 0 when 33 then 1 when unbound_long then 2 + // else 1 end from tmp; + // Searched form: + // 1. select case when true then 32 + 1 end; + // 2. select case when false then 1 else 34 - 1 end; + // 3. select case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] + // then 'matched' else 'not_matched' end; + // Evaluation of both simple and searched forms of `SWITCH` special form + // expression is verified here. + testFile("SwitchSpecialFormExpressions"); +} + +TEST_F(RowExpressionOptimizerTest, inSpecialForm) { + // Files InSpecialFormExpressions{Input|Expected}.json contain the input + // and expected JSON representing the RowExpressions resulting from the + // following queries: + // 1. select 3 in (2, null, 3, 5); + // 2. select 'foo' in ('bar', 'baz', 'buz', 'blah'); + // 3. select null in (2, null, 3, 5); + // 4. select ROW(1) IN (ROW(2), ROW(1), ROW(2)); + // 5. select MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], + // ARRAY[2, null]), null); + // 6. select MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], + // ARRAY[1, null])); + // Evaluation of `IN` special form expression for primitive and complex types + // is verified here. + testFile("InSpecialFormExpressions"); +} + +TEST_F(RowExpressionOptimizerTest, specialFormRewrites) { + // Files SpecialFormExpressionRewrites{Input|Expected}.json contain the input + // and expected JSON representing the RowExpressions resulting from the + // following queries: + // 1. select if(1 < 2, 2, 3); + // 2. select (1 < 2) and (2 < 3); + // 3. select (1 < 2) or (2 < 3); + // Special form expression rewrites are verified here. + testFile("SpecialFormExpressionRewrites"); +} + +TEST_F(RowExpressionOptimizerTest, betweenCallExpression) { + // Files BetweenCallExpressions{Input|Expected}.json contain the input and + // expected JSON representing the RowExpressions resulting from the following + // queries: + // 1. select 2 between 3 and 4; + // 2. select null between 2 and 4; + // 3. select 'cc' between 'b' and 'd'; + // Evaluation of `BETWEEN` call expression is verified here. + testFile("BetweenCallExpressions"); +} diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsExpected.json new file mode 100644 index 0000000000000..e39a9b35eaf6b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAY0=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsInput.json new file mode 100644 index 0000000000000..a65e8488ceb36 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/BetweenCallExpressionsInput.json @@ -0,0 +1,191 @@ +[ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABAAAAA==" + } + ], + "displayName": "BETWEEN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$between", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 1, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 1, + "line": 1 + } + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABAAAAA==" + } + ], + "displayName": "BETWEEN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$between", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean", + "sourceLocation": { + "column": 1, + "line": 1 + } + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(2)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAIAAAAAAgAAAGNj" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(1)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAEAAAAAAQAAAGI=" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar(2)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar(2)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(1)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAEAAAAAAQAAAGQ=" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar(2)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar(2)" + } + ], + "displayName": "BETWEEN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(2)", + "varchar(2)", + "varchar(2)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$between", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/CallAndSpecialFormExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/CallAndSpecialFormExpressionConversion.json new file mode 100644 index 0000000000000..ddf74d986ed48 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/CallAndSpecialFormExpressionConversion.json @@ -0,0 +1,83 @@ +[ + { + "@type": "call", + "arguments": [ + { + "@type": "variable", + "name": "c0", + "type": "bigint" + } + ], + "displayName": "presto.default.$operator$cast", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "bigint" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "variable", + "name": "c1", + "type": "boolean" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } + ], + "form": "OR", + "returnType": "boolean" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "variable", + "name": "c2", + "type": "integer" + } + ], + "displayName": "presto.default.$operator$cast", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "variable", + "name": "c3", + "type": "boolean" + } + ], + "form": "AND", + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/ConstantExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/ConstantExpressionConversion.json new file mode 100644 index 0000000000000..0e962d0153dc6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/ConstantExpressionConversion.json @@ -0,0 +1,62 @@ +[ + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "tinyint", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAw=" + }, + { + "@type": "constant", + "type": "smallint", + "valueBlock": "CwAAAFNIT1JUX0FSUkFZAQAAAAB7AA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAA0gQAAA==" + }, + { + "@type": "constant", + "type": "bigint", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAADkwAAAAAAAA" + }, + { + "@type": "constant", + "type": "bigint", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAP////////9/" + }, + { + "@type": "constant", + "type": "real", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAeen2Qg==" + }, + { + "@type": "constant", + "type": "double", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAK36XG1FSpNA" + }, + { + "@type": "constant", + "type": "double", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAAAAAAAAAPh/" + }, + { + "@type": "constant", + "type": "varchar", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAQAAAAABAAAAGFiY2Q=" + }, + { + "@type": "constant", + "type": "bigint", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAaA=" + }, + { + "@type": "constant", + "type": "varchar", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAAAAAABhQAAAAA=" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsExpected.json new file mode 100644 index 0000000000000..2c3009fd58d53 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsExpected.json @@ -0,0 +1,32 @@ +[ + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAYU=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAYU=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA=" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsInput.json new file mode 100644 index 0000000000000..6ed1e7ebeab1a --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/InSpecialFormExpressionsInput.json @@ -0,0 +1,895 @@ +[ + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 10, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 10, + "line": 1 + } + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABQAAAA==" + } + ], + "form": "IN", + "returnType": "boolean", + "sourceLocation": { + "column": 10, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(3)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGZvbw==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(3)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar(4)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar(4)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(3)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGJhcg==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(3)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar(4)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar(4)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(3)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGJheg==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(3)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar(4)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar(4)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(3)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGJ1eg==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(3)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar(4)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar(4)" + }, + { + "@type": "constant", + "type": "varchar(4)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAQAAAAABAAAAGJsYWg=" + } + ], + "form": "IN", + "returnType": "boolean" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 1, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 1, + "line": 1 + } + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 13, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 13, + "line": 1 + } + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAABQAAAA==" + } + ], + "form": "IN", + "returnType": "boolean", + "sourceLocation": { + "column": 1, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "form": "ROW_CONSTRUCTOR", + "returnType": "row(integer)" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "form": "ROW_CONSTRUCTOR", + "returnType": "row(integer)" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "form": "ROW_CONSTRUCTOR", + "returnType": "row(integer)" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "form": "ROW_CONSTRUCTOR", + "returnType": "row(integer)" + } + ], + "form": "IN", + "returnType": "boolean" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 27, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 27, + "line": 1 + } + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)", + "sourceLocation": { + "column": 27, + "line": 1 + } + } + ], + "displayName": "map", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "array(integer)", + "array(integer)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.map", + "returnType": "map(integer,integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "map(integer,integer)", + "sourceLocation": { + "column": 27, + "line": 1 + } + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 64, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 64, + "line": 1 + } + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)", + "sourceLocation": { + "column": 64, + "line": 1 + } + } + ], + "displayName": "map", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "array(integer)", + "array(integer)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.map", + "returnType": "map(integer,integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "map(integer,integer)", + "sourceLocation": { + "column": 64, + "line": 1 + } + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 72, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "map(integer,integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "map(integer,integer)", + "sourceLocation": { + "column": 72, + "line": 1 + } + } + ], + "form": "IN", + "returnType": "boolean", + "sourceLocation": { + "column": 27, + "line": 1 + } + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 27, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 27, + "line": 1 + } + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)", + "sourceLocation": { + "column": 27, + "line": 1 + } + } + ], + "displayName": "map", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "array(integer)", + "array(integer)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.map", + "returnType": "map(integer,integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "map(integer,integer)", + "sourceLocation": { + "column": 27, + "line": 1 + } + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 64, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer", + "sourceLocation": { + "column": 64, + "line": 1 + } + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(integer)", + "sourceLocation": { + "column": 64, + "line": 1 + } + } + ], + "displayName": "map", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "array(integer)", + "array(integer)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.map", + "returnType": "map(integer,integer)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "map(integer,integer)", + "sourceLocation": { + "column": 64, + "line": 1 + } + } + ], + "displayName": "=", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "map(integer,integer)", + "map(integer,integer)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$equal", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean", + "sourceLocation": { + "column": 27, + "line": 1 + } + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/RowConstructorExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/RowConstructorExpressionConversion.json new file mode 100644 index 0000000000000..8ff38f4150f41 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/RowConstructorExpressionConversion.json @@ -0,0 +1,46 @@ +[ + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "bigint", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAEDiAQAAAAAA" + }, + { + "@type": "constant", + "type": "real", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAH4VFQQ==" + }, + { + "@type": "constant", + "type": "varchar", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAMAAAAAAwAAAGFiYw==" + } + ], + "form": "ROW_CONSTRUCTOR", + "returnType": "row(bigint,real,varchar)" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "real", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAADAfw==" + }, + { + "@type": "constant", + "type": "varchar", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAAAAAAAAAAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAA////fw==" + } + ], + "form": "ROW_CONSTRUCTOR", + "returnType": "row(real,varchar,integer)" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsExpected.json new file mode 100644 index 0000000000000..e9b014fecd3b2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "double", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAAAAAAAAADRA" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsInput.json new file mode 100644 index 0000000000000..dd2a0497703eb --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/SimpleExpressionsInput.json @@ -0,0 +1,278 @@ +[ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAACwAAAA==" + } + ], + "displayName": "NEGATION", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$negation", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "abs", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.abs", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAACIAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ceil", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.ceil", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAADgAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "floor", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.floor", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "BETWEEN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$between", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesExpected.json new file mode 100644 index 0000000000000..2ce6acb1ab46e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesInput.json new file mode 100644 index 0000000000000..77802722541b0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/SpecialFormExpressionRewritesInput.json @@ -0,0 +1,193 @@ +[ + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "form": "IF", + "returnType": "integer" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "AND", + "returnType": "boolean" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "OR", + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsExpected.json b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsExpected.json new file mode 100644 index 0000000000000..40667e2976858 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsExpected.json @@ -0,0 +1,32 @@ +[ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA==" + }, + { + "@type": "constant", + "type": "varchar", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAcAAAAABwAAAG1hdGNoZWQ=" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsInput.json b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsInput.json new file mode 100644 index 0000000000000..5290b996ce15c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/SwitchSpecialFormExpressionsInput.json @@ -0,0 +1,596 @@ +[ + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIAAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 8, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIgAAAA==" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 27, + "line": 1 + } + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "AwAAAFJMRQEAAAAJAAAASU5UX0FSUkFZAQAAAAGA" + } + ], + "form": "SWITCH", + "returnType": "integer", + "sourceLocation": { + "column": 8, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 6, + "line": 1 + }, + "type": "unknown", + "valueBlock": "AwAAAFJMRQEAAAAKAAAAQllURV9BUlJBWQEAAAABgA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "unknown" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean", + "sourceLocation": { + "column": 6, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIAAAAA==" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 11, + "line": 1 + } + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA==" + } + ], + "form": "SWITCH", + "returnType": "integer", + "sourceLocation": { + "column": 6, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "bigint", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "bigint" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAAAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAAAAAA==" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 9, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 23, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "variable", + "name": "unbound_long", + "type": "bigint" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 43, + "line": 1 + } + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "form": "SWITCH", + "returnType": "integer", + "sourceLocation": { + "column": 9, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 1, + "line": 1 + }, + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIAAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 6, + "line": 1 + } + }, + { + "@type": "constant", + "sourceLocation": { + "column": 1, + "line": 1 + }, + "type": "integer", + "valueBlock": "AwAAAFJMRQEAAAAJAAAASU5UX0FSUkFZAQAAAAGA" + } + ], + "form": "SWITCH", + "returnType": "integer", + "sourceLocation": { + "column": 1, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 1, + "line": 1 + }, + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAA=" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "form": "WHEN", + "returnType": "integer", + "sourceLocation": { + "column": 6, + "line": 1 + } + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAIgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "displayName": "SUBTRACT", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$subtract", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "form": "SWITCH", + "returnType": "integer", + "sourceLocation": { + "column": 1, + "line": 1 + } + }, + { + "@type": "special", + "arguments": [ + { + "@type": "constant", + "sourceLocation": { + "column": 83, + "line": 1 + }, + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "bigint", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "bigint" + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "bigint" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(bigint)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(bigint)" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "bigint", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "bigint" + } + ], + "displayName": "ARRAY", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "bigint" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.array_constructor", + "returnType": "array(bigint)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "array(bigint)" + } + ], + "displayName": "EQUAL", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "array(bigint)", + "array(bigint)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$equal", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "varchar(7)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAcAAAAABwAAAG1hdGNoZWQ=" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "varchar(7)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "varchar(11)", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "varchar(11)" + } + ], + "form": "WHEN", + "returnType": "varchar(11)", + "sourceLocation": { + "column": 36, + "line": 1 + } + }, + { + "@type": "constant", + "type": "varchar(11)", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAsAAAAACwAAAG5vdF9tYXRjaGVk" + } + ], + "form": "SWITCH", + "returnType": "varchar(11)", + "sourceLocation": { + "column": 83, + "line": 1 + } + } +] diff --git a/presto-native-execution/presto_cpp/main/types/tests/data/VariableExpressionConversion.json b/presto-native-execution/presto_cpp/main/types/tests/data/VariableExpressionConversion.json new file mode 100644 index 0000000000000..75f524842788e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/data/VariableExpressionConversion.json @@ -0,0 +1,27 @@ +[ + { + "@type": "variable", + "name": "c0", + "type": "varchar" + }, + { + "@type": "variable", + "name": "c1", + "type": "array(bigint)" + }, + { + "@type": "variable", + "name": "c2", + "type": "map(integer,real)" + }, + { + "@type": "variable", + "name": "c3", + "type": "row(smallint,varchar,double)" + }, + { + "@type": "variable", + "name": "c4", + "type": "map(array(tinyint),boolean)" + } +] diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java index 825fbb1bc4387..ab9ca3568a404 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java @@ -39,6 +39,7 @@ import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; +import java.net.ServerSocket; import java.net.URI; import java.nio.file.Files; import java.nio.file.Path; @@ -496,6 +497,11 @@ public static NativeQueryRunnerParameters getNativeQueryRunnerParameters() } public static Optional> getExternalWorkerLauncher(String catalogName, String prestoServerPath, int cacheMaxSize, Optional remoteFunctionServerUds, Boolean failOnNestedLoopJoin, boolean isCoordinatorSidecarEnabled) + { + return getExternalWorkerLauncher(catalogName, prestoServerPath, OptionalInt.empty(), cacheMaxSize, remoteFunctionServerUds, failOnNestedLoopJoin, isCoordinatorSidecarEnabled); + } + + public static Optional> getExternalWorkerLauncher(String catalogName, String prestoServerPath, OptionalInt port, int cacheMaxSize, Optional remoteFunctionServerUds, Boolean failOnNestedLoopJoin, boolean isCoordinatorSidecarEnabled) { return Optional.of((workerIndex, discoveryUri) -> { @@ -509,7 +515,8 @@ public static Optional> getExternalWorkerLaunc String configProperties = format("discovery.uri=%s%n" + "presto.version=testversion%n" + "system-memory-gb=4%n" + - "http-server.http.port=0%n", discoveryUri); + "native-sidecar=true%n" + + "http-server.http.port=%d", discoveryUri, port.orElse(0)); if (isCoordinatorSidecarEnabled) { configProperties = format("%s%n" + @@ -573,6 +580,45 @@ public static Optional> getExternalWorkerLaunc }); } + public static Process getNativeSidecarProcess(URI discoveryUri, int port) + throws IOException + { + NativeQueryRunnerParameters nativeQueryRunnerParameters = getNativeQueryRunnerParameters(); + return getNativeSidecarProcess(nativeQueryRunnerParameters.serverBinary.toString(), discoveryUri, port); + } + + public static Process getNativeSidecarProcess(String prestoServerPath, URI discoveryUri, int port) + throws IOException + { + Path tempDirectoryPath = Files.createTempDirectory(PrestoNativeQueryRunnerUtils.class.getSimpleName()); + log.info("Temp directory for Sidecar: %s", tempDirectoryPath.toString()); + + // Write config file + String configProperties = format("discovery.uri=%s%n" + + "presto.version=testversion%n" + + "system-memory-gb=4%n" + + "native-sidecar=true%n" + + "http-server.http.port=%d", discoveryUri, port); + + Files.write(tempDirectoryPath.resolve("config.properties"), configProperties.getBytes()); + Files.write(tempDirectoryPath.resolve("node.properties"), + format("node.id=%s%n" + + "node.internal-address=127.0.0.1%n" + + "node.environment=testing%n" + + "node.location=test-location", UUID.randomUUID()).getBytes()); + + // TODO: sidecars require that a catalog directory exist + Path catalogDirectoryPath = tempDirectoryPath.resolve("catalog"); + Files.createDirectory(catalogDirectoryPath); + + return new ProcessBuilder(prestoServerPath) + .directory(tempDirectoryPath.toFile()) + .redirectErrorStream(true) + .redirectOutput(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile())) + .redirectError(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile())) + .start(); + } + public static class NativeQueryRunnerParameters { public final Path serverBinary; @@ -619,4 +665,12 @@ private static Table createHiveSymlinkTable(String databaseName, String tableNam Optional.empty(), Optional.empty()); } + + public static int findRandomPortForWorker() + throws IOException + { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } + } } diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml index 6b8f1badd289d..4a62cab3f77a5 100644 --- a/presto-native-sidecar-plugin/pom.xml +++ b/presto-native-sidecar-plugin/pom.xml @@ -200,6 +200,12 @@ test + + javax.ws.rs + javax.ws.rs-api + test + + com.facebook.presto presto-client diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java new file mode 100644 index 0000000000000..8bf10b20223d2 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/ForSidecarInfo.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; + +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@BindingAnnotation +public @interface ForSidecarInfo +{ +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java new file mode 100644 index 0000000000000..1c7dfd30b38e0 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizer.java @@ -0,0 +1,353 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionVisitor; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayDeque; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Collections.newSetFromMap; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; + +public class NativeExpressionOptimizer + implements ExpressionOptimizer +{ + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution resolution; + private final NativeSidecarExpressionInterpreter rowExpressionInterpreterService; + + public NativeExpressionOptimizer( + NativeSidecarExpressionInterpreter rowExpressionInterpreterService, + FunctionMetadataManager functionMetadataManager, + StandardFunctionResolution resolution) + { + this.rowExpressionInterpreterService = requireNonNull(rowExpressionInterpreterService, "rowExpressionInterpreterService is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + } + + @Override + public RowExpression optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) + { + // Collect expressions to optimize + CollectingVisitor collectingVisitor = new CollectingVisitor(functionMetadataManager, level, resolution); + expression.accept(collectingVisitor, variableResolver); + List expressionsToOptimize = collectingVisitor.getExpressionsToOptimize(); + Map expressions = expressionsToOptimize.stream() + .collect(toMap( + Function.identity(), + rowExpression -> rowExpression.accept( + new ReplacingVisitor( + variable -> toRowExpression(variable.getSourceLocation(), variableResolver.apply(variable), variable.getType())), + null), + (a, b) -> a)); + if (expressions.isEmpty()) { + return expression; + } + + // Constants can be trivially replaced without invoking the interpreter. Move them into a separate map. + Map constants = new HashMap<>(); + Iterator> entries = expressions.entrySet().iterator(); + while (entries.hasNext()) { + Map.Entry entry = entries.next(); + if (entry.getValue() instanceof ConstantExpression) { + constants.put(entry.getKey(), entry.getValue()); + entries.remove(); + } + } + + // Optimize the expressions using the sidecar interpreter + Map replacements = new HashMap<>(); + if (!expressions.isEmpty()) { + replacements.putAll(rowExpressionInterpreterService.optimizeBatch(session, expressions, level)); + } + + // Add back in the constants + replacements.putAll(constants); + + // Replace all the expressions in the original expression with the optimized expressions + return toRowExpression(expression.getSourceLocation(), expression.accept(new ReplacingVisitor(replacements), null), expression.getType()); + } + + /** + * This visitor collects expressions that can be optimized by the sidecar interpreter. + */ + private static class CollectingVisitor + implements RowExpressionVisitor + { + private final FunctionMetadataManager functionMetadataManager; + private final Level optimizationLevel; + private final StandardFunctionResolution resolution; + private final Set expressionsToOptimize = newSetFromMap(new IdentityHashMap<>()); + private final Set hasOptimizedChildren = newSetFromMap(new IdentityHashMap<>()); + + public CollectingVisitor(FunctionMetadataManager functionMetadataManager, Level optimizationLevel, StandardFunctionResolution resolution) + { + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.optimizationLevel = requireNonNull(optimizationLevel, "optimizationLevel is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + } + + @Override + public Void visitExpression(RowExpression node, Object context) + { + visitNode(node, false); + return null; + } + + @Override + public Void visitConstant(ConstantExpression node, Object context) + { + visitNode(node, true); + return null; + } + + @Override + public Void visitVariableReference(VariableReferenceExpression node, Object context) + { + Object value = ((Function) context).apply(node); + if (value == null || value instanceof RowExpression) { + visitNode(node, false); + return null; + } + visitNode(node, true); + return null; + } + + @Override + public Void visitCall(CallExpression node, Object context) + { + // If the optimization level is not EVALUATED, then we cannot optimize non-deterministic functions + boolean isDeterministic = functionMetadataManager.getFunctionMetadata(node.getFunctionHandle()).isDeterministic(); + boolean canBeEvaluated = (optimizationLevel.ordinal() < EVALUATED.ordinal() && isDeterministic) || + optimizationLevel.ordinal() == EVALUATED.ordinal(); + + // All arguments must be optimizable in order to evaluate the function + boolean allConstantFoldable = node.getArguments().stream() + .peek(child -> child.accept(this, context)) + .reduce(true, (a, b) -> canBeOptimized(b) && a, (a, b) -> a && b); + if (canBeEvaluated && allConstantFoldable) { + visitNode(node, true); + return null; + } + + // If it's a cast and the type is already the same, then it's constant foldable + if (resolution.isCastFunction(node.getFunctionHandle()) + && node.getArguments().size() == 1 + && node.getType().equals(node.getArguments().get(0).getType())) { + visitNode(node, true); + return null; + } + visitNode(node, false); + return null; + } + + @Override + public Void visitSpecialForm(SpecialFormExpression node, Object context) + { + // Most special form expressions short circuit, meaning that they potentially don't evaluate all arguments. For example, the AND expression + // will stop evaluating arguments as soon as it finds a false argument. Because a sub-expression could be simplified into a constant, and this + // constant could cause the expression to short circuit, if there is at least one argument which is optimizable, then the entire expression should + // be sent to the sidecar to be optimized. + boolean anyArgumentsOptimizable = node.getArguments().stream() + .peek(child -> child.accept(this, context)) + .reduce(false, (a, b) -> canBeOptimized(b) || a, (a, b) -> a || b); + + // If all arguments are constant foldable, then the whole expression is constant foldable + if (anyArgumentsOptimizable) { + visitNode(node, true); + return null; + } + + // If the special form is COALESCE, then we can optimize it if there are any duplicate arguments + if (node.getForm() == COALESCE) { + ImmutableSet.Builder builder = ImmutableSet.builder(); + // Check if there's any duplicate arguments, these can be de-duplicated + for (RowExpression argument : node.getArguments()) { + // The duplicate argument must either be a leaf (variable reference) or constant foldable + if (canBeOptimized(argument)) { + builder.add(argument); + } + } + // If there were any duplicates, or if there's no arguments (cancel out), or if there's only one argument (just return it), + // then it's also constant foldable + boolean canBeOptimized = builder.build().size() <= node.getArguments().size() || node.getArguments().size() <= 1; + if (canBeOptimized) { + visitNode(node, true); + return null; + } + } + visitNode(node, false); + return null; + } + + @Override + public Void visitLambda(LambdaDefinitionExpression node, Object context) + { + node.getBody().accept(this, (Function) variable -> variable); + if (canBeOptimized(node.getBody())) { + visitNode(node, true); + return null; + } + visitNode(node, false); + return null; + } + + public boolean canBeOptimized(RowExpression rowExpression) + { + return expressionsToOptimize.contains(rowExpression); + } + + private void visitNode(RowExpression node, boolean canBeOptimized) + { + requireNonNull(node, "node is null"); + // If the present node can be optimized, then we send the whole expression. Because an expression may consist of many + // sub-expressions, we need to ensure that we don't send the sub-expression along with its parent expression. For example, + // if we have the expression (a + b) + c, and we can optimize a + b, then we don't want to send a + b to the sidecar, because + // it will be optimized twice. Instead, we want to send (a + b) + c to the sidecar, and then remove a + b from the list of + // expressions to optimize. + // We need to traverse the entire subtree of possible expressions to optimize because some special form expressions may + // short circuit, and we need to ensure that we don't send the sub-expression to the sidecar if the parent expression is + // constant foldable. For example, consider the expression false AND (true OR a). Although the expression true OR a is + // constant foldable, the parent expression is also constant foldable, and we don't want to send both the parent expression + // and the sub-expression to the sidecar because the entire expression can be constant folded in one pass. + if (canBeOptimized) { + ArrayDeque queue = new ArrayDeque<>(node.getChildren()); + while (!queue.isEmpty()) { + RowExpression expression = queue.poll(); + if (hasOptimizedChildren.remove(expression)) { + expressionsToOptimize.remove(expression); + queue.addAll(expression.getChildren()); + } + } + expressionsToOptimize.add(node); + hasOptimizedChildren.add(node); + } + else if (node.getChildren().stream().anyMatch(hasOptimizedChildren::contains)) { + hasOptimizedChildren.add(node); + } + } + + public List getExpressionsToOptimize() + { + return ImmutableList.copyOf(expressionsToOptimize); + } + } + + /** + * This visitor replaces expressions with their optimized versions. + */ + private static class ReplacingVisitor + implements RowExpressionVisitor + { + private final Function resolver; + + public ReplacingVisitor(Map replacements) + { + requireNonNull(replacements, "replacements is null"); + this.resolver = i -> replacements.getOrDefault(i, i); + } + + public ReplacingVisitor(Function variableResolver) + { + requireNonNull(variableResolver, "variableResolver is null"); + this.resolver = i -> i instanceof VariableReferenceExpression ? variableResolver.apply((VariableReferenceExpression) i) : i; + } + + private boolean canBeReplaced(RowExpression rowExpression) + { + return resolver.apply(rowExpression) != rowExpression; + } + + @Override + public RowExpression visitExpression(RowExpression originalExpression, Void context) + { + return resolver.apply(originalExpression); + } + + @Override + public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context) + { + if (canBeReplaced(lambda.getBody())) { + return new LambdaDefinitionExpression( + lambda.getSourceLocation(), + lambda.getArgumentTypes(), + lambda.getArguments(), + toRowExpression(lambda.getSourceLocation(), resolver.apply(lambda.getBody()), lambda.getBody().getType())); + } + return lambda; + } + + @Override + public RowExpression visitCall(CallExpression call, Void context) + { + if (canBeReplaced(call)) { + return resolver.apply(call); + } + List updatedArguments = call.getArguments().stream() + .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType())) + .collect(toImmutableList()); + return new CallExpression(call.getSourceLocation(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), updatedArguments); + } + + @Override + public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context) + { + if (canBeReplaced(specialForm)) { + return resolver.apply(specialForm); + } + List updatedArguments = specialForm.getArguments().stream() + .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType())) + .collect(toImmutableList()); + return new SpecialFormExpression(specialForm.getSourceLocation(), specialForm.getForm(), specialForm.getType(), updatedArguments); + } + } + + private static RowExpression toRowExpression(Optional sourceLocation, Object object, Type type) + { + requireNonNull(type, "type is null"); + + if (object instanceof RowExpression) { + return (RowExpression) object; + } + + return new ConstantExpression(sourceLocation, object, type); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java new file mode 100644 index 0000000000000..4157cffe2c839 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.google.inject.Injector; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class NativeExpressionOptimizerFactory + implements ExpressionOptimizerFactory +{ + private final ClassLoader classLoader; + + public NativeExpressionOptimizerFactory(ClassLoader classLoader) + { + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public ExpressionOptimizer createOptimizer(Map config, ExpressionOptimizerContext context) + { + requireNonNull(context, "context is null"); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + Bootstrap app = new Bootstrap( + new JsonModule(), + new NativeExpressionsCommunicationModule(), + new NativeExpressionsModule(context.getNodeManager(), context.getRowExpressionSerde(), context.getFunctionMetadataManager(), context.getFunctionResolution())); + + Injector injector = app + .noStrictConfig() + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .quiet() + .initialize(); + return injector.getInstance(NativeExpressionOptimizerProvider.class).createOptimizer(); + } + } + + @Override + public String getName() + { + return "native"; + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java new file mode 100644 index 0000000000000..8efe3b8a77eb5 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionOptimizerProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.ExpressionOptimizer; + +import javax.inject.Inject; + +import static java.util.Objects.requireNonNull; + +public class NativeExpressionOptimizerProvider +{ + private final NativeSidecarExpressionInterpreter expressionInterpreterService; + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution resolution; + + @Inject + public NativeExpressionOptimizerProvider(NativeSidecarExpressionInterpreter expressionInterpreterService, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution resolution) + { + this.expressionInterpreterService = requireNonNull(expressionInterpreterService, "expressionInterpreterService is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + } + + public ExpressionOptimizer createOptimizer() + { + return new NativeExpressionOptimizer(expressionInterpreterService, functionMetadataManager, resolution); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java new file mode 100644 index 0000000000000..fdc6da53f91bc --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsCommunicationModule.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.google.inject.Binder; +import com.google.inject.Module; + +import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder; + +public class NativeExpressionsCommunicationModule + implements Module +{ + @Override + public void configure(Binder binder) + { + httpClientBinder(binder).bindHttpClient("sidecar", ForSidecarInfo.class); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java new file mode 100644 index 0000000000000..046b4e60eac7a --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeExpressionsModule.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpression; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static java.util.Objects.requireNonNull; + +public class NativeExpressionsModule + implements Module +{ + private final NodeManager nodeManager; + private final RowExpressionSerde rowExpressionSerde; + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution functionResolution; + + public NativeExpressionsModule(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); + } + + @Override + public void configure(Binder binder) + { + // Core dependencies + binder.bind(NodeManager.class).toInstance(nodeManager); + binder.bind(RowExpressionSerde.class).toInstance(rowExpressionSerde); + binder.bind(FunctionMetadataManager.class).toInstance(functionMetadataManager); + binder.bind(StandardFunctionResolution.class).toInstance(functionResolution); + + // JSON dependencies and setup + binder.install(new JsonModule()); + jsonBinder(binder).addDeserializerBinding(RowExpression.class).to(RowExpressionDeserializer.class).in(Scopes.SINGLETON); + jsonBinder(binder).addSerializerBinding(RowExpression.class).to(RowExpressionSerializer.class).in(Scopes.SINGLETON); + jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class); + + binder.bind(NativeSidecarExpressionInterpreter.class).in(Scopes.SINGLETON); + + // The main service provider + binder.bind(NativeExpressionOptimizerProvider.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java new file mode 100644 index 0000000000000..53d3530156d23 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/NativeSidecarExpressionInterpreter.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpUriBuilder; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import javax.inject.Inject; + +import java.net.URI; +import java.util.List; +import java.util.Map; + +import static com.facebook.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator; +import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static com.facebook.airlift.http.client.Request.Builder.preparePost; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.net.HttpHeaders.ACCEPT; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static java.util.Objects.requireNonNull; + +public class NativeSidecarExpressionInterpreter +{ + private static final String PRESTO_TIME_ZONE_HEADER = "X-Presto-Time-Zone"; + private static final String PRESTO_USER_HEADER = "X-Presto-User"; + private static final String PRESTO_EXPRESSION_OPTIMIZER_LEVEL_HEADER = "X-Presto-Expression-Optimizer-Level"; + private static final String SIDECAR_ENDPOINT = "/v1/expressions"; + private final NodeManager nodeManager; + private final HttpClient httpClient; + private final JsonCodec> rowExpressionSerde; + + @Inject + public NativeSidecarExpressionInterpreter(NodeManager nodeManager, @ForSidecarInfo HttpClient httpClient, JsonCodec> rowExpressionSerde) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + public Map optimizeBatch(ConnectorSession session, Map expressions, ExpressionOptimizer.Level level) + { + ImmutableList.Builder originalExpressionsBuilder = ImmutableList.builder(); + ImmutableList.Builder resolvedExpressionsBuilder = ImmutableList.builder(); + for (Map.Entry entry : expressions.entrySet()) { + originalExpressionsBuilder.add(entry.getKey()); + resolvedExpressionsBuilder.add(entry.getValue()); + } + List originalExpressions = originalExpressionsBuilder.build(); + List resolvedExpressions = resolvedExpressionsBuilder.build(); + + List optimizedExpressions = httpClient.execute( + getSidecarRequest(session, level, resolvedExpressions), + createJsonResponseHandler(rowExpressionSerde)); + checkArgument( + optimizedExpressions.size() == resolvedExpressions.size(), + "Expected %s optimized expressions, but got %s", + resolvedExpressions.size(), + optimizedExpressions.size()); + + ImmutableMap.Builder result = ImmutableMap.builder(); + for (int i = 0; i < optimizedExpressions.size(); i++) { + result.put(originalExpressions.get(i), optimizedExpressions.get(i)); + } + return result.build(); + } + + private Request getSidecarRequest(ConnectorSession session, Level level, List resolvedExpressions) + { + return preparePost() + .setUri(getLocation()) + .setBodyGenerator(jsonBodyGenerator(rowExpressionSerde, resolvedExpressions)) + .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .setHeader(ACCEPT, JSON_UTF_8.toString()) + .setHeader(PRESTO_TIME_ZONE_HEADER, session.getSqlFunctionProperties().getTimeZoneKey().getId()) + .setHeader(PRESTO_USER_HEADER, session.getUser()) + .setHeader(PRESTO_EXPRESSION_OPTIMIZER_LEVEL_HEADER, level.name()) + .build(); + } + + private URI getLocation() + { + Node sidecarNode = nodeManager.getSidecarNode(); + return HttpUriBuilder.uriBuilder() + .scheme("http") + .host(sidecarNode.getHost()) + .port(sidecarNode.getHostAndPort().getPort()) + .appendPath(SIDECAR_ENDPOINT) + .build(); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java new file mode 100644 index 0000000000000..ffc7039d4568b --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionDeserializer.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; +import com.google.inject.Inject; + +import java.io.IOException; + +import static java.util.Objects.requireNonNull; + +public final class RowExpressionDeserializer + extends JsonDeserializer +{ + private final RowExpressionSerde rowExpressionSerde; + + @Inject + public RowExpressionDeserializer(RowExpressionSerde rowExpressionSerde) + { + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + @Override + public RowExpression deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException + { + return rowExpressionSerde.deserialize(jsonParser.readValueAsTree().toString()); + } + + @Override + public RowExpression deserializeWithType(JsonParser jsonParser, DeserializationContext context, TypeDeserializer typeDeserializer) + throws IOException + { + return deserialize(jsonParser, context); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java new file mode 100644 index 0000000000000..6ca5dcc4354f4 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/session/sql/expressions/RowExpressionSerializer.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.jsontype.TypeSerializer; +import com.google.inject.Inject; + +import java.io.IOException; + +import static java.util.Objects.requireNonNull; + +public final class RowExpressionSerializer + extends JsonSerializer +{ + private final RowExpressionSerde rowExpressionSerde; + + @Inject + public RowExpressionSerializer(RowExpressionSerde rowExpressionSerde) + { + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + @Override + public void serialize(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) + throws IOException + { + jsonGenerator.writeRawValue(rowExpressionSerde.serialize(rowExpression)); + } + + @Override + public void serializeWithType(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider, TypeSerializer typeSerializer) + throws IOException + { + serialize(rowExpression, jsonGenerator, serializerProvider); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java index 7da7ba4d4fbb9..491d93d6f812f 100644 --- a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sidecar; +import com.facebook.presto.session.sql.expressions.NativeExpressionOptimizerFactory; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory; import com.facebook.presto.sidecar.nativechecker.NativePlanCheckerProviderFactory; import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory; @@ -20,6 +21,7 @@ import com.facebook.presto.spi.function.FunctionNamespaceManagerFactory; import com.facebook.presto.spi.plan.PlanCheckerProviderFactory; import com.facebook.presto.spi.session.WorkerSessionPropertyProviderFactory; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; import com.google.common.collect.ImmutableList; public class NativeSidecarPlugin @@ -51,4 +53,9 @@ private static ClassLoader getClassLoader() } return classLoader; } + @Override + public Iterable getExpressionOptimizerFactories() + { + return ImmutableList.of(new NativeExpressionOptimizerFactory(getClassLoader())); + } } diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimizer.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimizer.java new file mode 100644 index 0000000000000..b47ffa16a6f20 --- /dev/null +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressionOptimizer.java @@ -0,0 +1,1706 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.block.BlockSerdeUtil; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.SqlTimestampWithTimeZone; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.parser.ParsingOptions; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; +import org.intellij.lang.annotations.Language; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.LocalDate; +import org.joda.time.LocalTime; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.math.BigInteger; +import java.net.URI; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.findRandomPortForWorker; +import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeSidecarProcess; +import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; +import static com.facebook.presto.session.sql.expressions.TestNativeExpressions.getExpressionOptimizer; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; +import static com.facebook.presto.sql.ExpressionFormatter.formatExpression; +import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; +import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; +import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; +import static io.airlift.slice.Slices.utf8Slice; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +public class TestNativeExpressionOptimizer +{ + private static final int TEST_VARCHAR_TYPE_LENGTH = 17; + private static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() + .put("bound_integer", INTEGER) + .put("bound_long", BIGINT) + .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH)) + .put("bound_varbinary", VARBINARY) + .put("bound_double", DOUBLE) + .put("bound_boolean", BOOLEAN) + .put("bound_date", DATE) + .put("bound_time", TIME) + .put("bound_timestamp", TIMESTAMP) + .put("bound_pattern", VARCHAR) + .put("bound_null_string", VARCHAR) + .put("bound_decimal_short", createDecimalType(5, 2)) + .put("bound_decimal_long", createDecimalType(23, 3)) + .put("time", BIGINT) // for testing reserved identifiers + .put("unbound_integer", INTEGER) + .put("unbound_long", BIGINT) + .put("unbound_long2", BIGINT) + .put("unbound_long3", BIGINT) + .put("unbound_string", VARCHAR) + .put("unbound_double", DOUBLE) + .put("unbound_boolean", BOOLEAN) + .put("unbound_date", DATE) + .put("unbound_time", TIME) + .put("unbound_array", new ArrayType(BIGINT)) + .put("unbound_timestamp", TIMESTAMP) + .put("unbound_interval", INTERVAL_DAY_TIME) + .put("unbound_pattern", VARCHAR) + .put("unbound_null_string", VARCHAR) + .build()); + + private static final SqlParser SQL_PARSER = new SqlParser(); + private static final Metadata METADATA = createTestMetadataManager(); + private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); + private static final BlockEncodingSerde BLOCK_ENCODING_SERDE = new BlockEncodingManager(); + + private ExpressionOptimizer expressionOptimizer; + private Process sidecar; + + public TestNativeExpressionOptimizer() + { + METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); + setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); + } + + @BeforeClass + public void setup() + throws Exception + { + int port = findRandomPortForWorker(); + URI sidecarUri = URI.create(format("http://127.0.0.1:%s", port)); + sidecar = getNativeSidecarProcess(sidecarUri, port); + + expressionOptimizer = getExpressionOptimizer(METADATA, new HandleResolver(), sidecarUri); + } + + @AfterClass + public void tearDown() + { + sidecar.destroyForcibly(); + } + + @Test + public void testAnd() + { + assertOptimizedEquals("true and false", "false"); + assertOptimizedEquals("false and true", "false"); + assertOptimizedEquals("false and false", "false"); + + assertOptimizedEquals("true and null", "null"); + assertOptimizedEquals("false and null", "false"); + assertOptimizedEquals("null and true", "null"); + assertOptimizedEquals("null and false", "false"); + assertOptimizedEquals("null and null", "null"); + + assertOptimizedEquals("unbound_string='z' and true", "unbound_string='z'"); + assertOptimizedEquals("unbound_string='z' and false", "false"); + assertOptimizedEquals("true and unbound_string='z'", "unbound_string='z'"); + assertOptimizedEquals("false and unbound_string='z'", "false"); + + assertOptimizedEquals("bound_string='z' and bound_long=1+1", "bound_string='z' and bound_long=2"); + assertOptimizedEquals("random() > 0 and random() > 0", "random() > 0 and random() > 0"); + } + + @Test + public void testOr() + { + assertOptimizedEquals("true or true", "true"); + assertOptimizedEquals("true or false", "true"); + assertOptimizedEquals("false or true", "true"); + assertOptimizedEquals("false or false", "false"); + + assertOptimizedEquals("true or null", "true"); + assertOptimizedEquals("null or true", "true"); + assertOptimizedEquals("null or null", "null"); + + assertOptimizedEquals("false or null", "null"); + assertOptimizedEquals("null or false", "null"); + + assertOptimizedEquals("bound_string='z' or true", "true"); + assertOptimizedEquals("bound_string='z' or false", "bound_string='z'"); + assertOptimizedEquals("true or bound_string='z'", "true"); + assertOptimizedEquals("false or bound_string='z'", "bound_string='z'"); + + assertOptimizedEquals("bound_string='z' or bound_long=1+1", "bound_string='z' or bound_long=2"); + assertOptimizedEquals("random() > 0 or random() > 0", "random() > 0 or random() > 0"); + } + + @Test + public void testComparison() + { + assertOptimizedEquals("null = null", "null"); + + assertOptimizedEquals("'a' = 'b'", "false"); + assertOptimizedEquals("'a' = 'a'", "true"); + assertOptimizedEquals("'a' = null", "null"); + assertOptimizedEquals("null = 'a'", "null"); + assertOptimizedEquals("bound_integer = 1234", "true"); + assertOptimizedEquals("bound_integer = 12340000000", "false"); + assertOptimizedEquals("bound_long = BIGINT '1234'", "true"); + assertOptimizedEquals("bound_long = 1234", "true"); + assertOptimizedEquals("bound_double = 12.34", "true"); + assertOptimizedEquals("bound_string = 'hello'", "true"); + assertOptimizedEquals("bound_long = unbound_long", "1234 = unbound_long"); + + assertOptimizedEquals("10151082135029368 = 10151082135029369", "false"); + + assertOptimizedEquals("bound_varbinary = X'a b'", "true"); + assertOptimizedEquals("bound_varbinary = X'a d'", "false"); + + assertOptimizedEquals("1.1 = 1.1", "true"); + assertOptimizedEquals("9876543210.9874561203 = 9876543210.9874561203", "true"); + assertOptimizedEquals("bound_decimal_short = 123.45", "true"); + assertOptimizedEquals("bound_decimal_long = 12345678901234567890.123", "true"); + } + + @Test + public void testIsDistinctFrom() + { + assertOptimizedEquals("null is distinct from null", "false"); + + assertOptimizedEquals("3 is distinct from 4", "true"); + assertOptimizedEquals("3 is distinct from BIGINT '4'", "true"); + assertOptimizedEquals("3 is distinct from 4000000000", "true"); + assertOptimizedEquals("3 is distinct from 3", "false"); + assertOptimizedEquals("3 is distinct from null", "true"); + assertOptimizedEquals("null is distinct from 3", "true"); + + assertOptimizedEquals("10151082135029368 is distinct from 10151082135029369", "true"); + + assertOptimizedEquals("1.1 is distinct from 1.1", "false"); + assertOptimizedEquals("9876543210.9874561203 is distinct from NULL", "true"); + assertOptimizedEquals("bound_decimal_short is distinct from NULL", "true"); + assertOptimizedEquals("bound_decimal_long is distinct from 12345678901234567890.123", "false"); + } + + @Test + public void testIsNull() + { + assertOptimizedEquals("null is null", "true"); + assertOptimizedEquals("1 is null", "false"); + assertOptimizedEquals("10000000000 is null", "false"); + assertOptimizedEquals("BIGINT '1' is null", "false"); + assertOptimizedEquals("1.0 is null", "false"); + assertOptimizedEquals("'a' is null", "false"); + assertOptimizedEquals("true is null", "false"); + assertOptimizedEquals("null+1 is null", "true"); + assertOptimizedEquals("unbound_string is null", "unbound_string is null"); + assertOptimizedEquals("unbound_long+(1+1) is null", "unbound_long+2 is null"); + assertOptimizedEquals("1.1 is null", "false"); + assertOptimizedEquals("9876543210.9874561203 is null", "false"); + assertOptimizedEquals("bound_decimal_short is null", "false"); + assertOptimizedEquals("bound_decimal_long is null", "false"); + } + + @Test + public void testIsNotNull() + { + assertOptimizedEquals("null is not null", "false"); + assertOptimizedEquals("1 is not null", "true"); + assertOptimizedEquals("10000000000 is not null", "true"); + assertOptimizedEquals("BIGINT '1' is not null", "true"); + assertOptimizedEquals("1.0 is not null", "true"); + assertOptimizedEquals("'a' is not null", "true"); + assertOptimizedEquals("true is not null", "true"); + assertOptimizedEquals("null+1 is not null", "false"); + assertOptimizedEquals("unbound_string is not null", "unbound_string is not null"); + assertOptimizedEquals("unbound_long+(1+1) is not null", "unbound_long+2 is not null"); + assertOptimizedEquals("1.1 is not null", "true"); + assertOptimizedEquals("9876543210.9874561203 is not null", "true"); + assertOptimizedEquals("bound_decimal_short is not null", "true"); + assertOptimizedEquals("bound_decimal_long is not null", "true"); + } + + // TODO: NULL_IF special form is unsupported in Presto native. + @Test(enabled = false) + public void testNullIf() + { + assertOptimizedEquals("nullif(true, true)", "null"); + assertOptimizedEquals("nullif(true, false)", "true"); + assertOptimizedEquals("nullif(null, false)", "null"); + assertOptimizedEquals("nullif(true, null)", "true"); + + assertOptimizedEquals("nullif('a', 'a')", "null"); + assertOptimizedEquals("nullif('a', 'b')", "'a'"); + assertOptimizedEquals("nullif(null, 'b')", "null"); + assertOptimizedEquals("nullif('a', null)", "'a'"); + + assertOptimizedEquals("nullif(1, 1)", "null"); + assertOptimizedEquals("nullif(1, 2)", "1"); + assertOptimizedEquals("nullif(1, BIGINT '2')", "1"); + assertOptimizedEquals("nullif(1, 20000000000)", "1"); + assertOptimizedEquals("nullif(1.0E0, 1)", "null"); + assertOptimizedEquals("nullif(10000000000.0E0, 10000000000)", "null"); + assertOptimizedEquals("nullif(1.1E0, 1)", "1.1E0"); + assertOptimizedEquals("nullif(1.1E0, 1.1E0)", "null"); + assertOptimizedEquals("nullif(1, 2-1)", "null"); + assertOptimizedEquals("nullif(null, null)", "null"); + assertOptimizedEquals("nullif(1, null)", "1"); + assertOptimizedEquals("nullif(unbound_long, 1)", "nullif(unbound_long, 1)"); + assertOptimizedEquals("nullif(unbound_long, unbound_long2)", "nullif(unbound_long, unbound_long2)"); + assertOptimizedEquals("nullif(unbound_long, unbound_long2+(1+1))", "nullif(unbound_long, unbound_long2+2)"); + + assertOptimizedEquals("nullif(1.1, 1.2)", "1.1"); + assertOptimizedEquals("nullif(9876543210.9874561203, 9876543210.9874561203)", "null"); + assertOptimizedEquals("nullif(bound_decimal_short, 123.45)", "null"); + assertOptimizedEquals("nullif(bound_decimal_long, 12345678901234567890.123)", "null"); + assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(1 AS BIGINT)]) IS NULL", "true"); + assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); + assertOptimizedEquals("nullif(ARRAY[CAST(NULL AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); + } + + @Test + public void testNegative() + { + assertOptimizedEquals("-(1)", "-1"); + assertOptimizedEquals("-(BIGINT '1')", "BIGINT '-1'"); + assertOptimizedEquals("-(unbound_long+1)", "-(unbound_long+1)"); + assertOptimizedEquals("-(1+1)", "-2"); + assertOptimizedEquals("-(1+ BIGINT '1')", "BIGINT '-2'"); + assertOptimizedEquals("-(CAST(NULL AS BIGINT))", "null"); + assertOptimizedEquals("-(unbound_long+(1+1))", "-(unbound_long+2)"); + assertOptimizedEquals("-(1.1+1.2)", "-2.3"); + assertOptimizedEquals("-(9876543210.9874561203-9876543210.9874561203)", "CAST(0 AS DECIMAL(20,10))"); + assertOptimizedEquals("-(bound_decimal_short+123.45)", "-246.90"); + assertOptimizedEquals("-(bound_decimal_long-12345678901234567890.123)", "CAST(0 AS DECIMAL(20,10))"); + } + + @Test + public void testNot() + { + assertOptimizedEquals("not true", "false"); + assertOptimizedEquals("not false", "true"); + assertOptimizedEquals("not null", "null"); + assertOptimizedEquals("not 1=1", "false"); + assertOptimizedEquals("not 1=BIGINT '1'", "false"); + assertOptimizedEquals("not 1!=1", "true"); + assertOptimizedEquals("not unbound_long=1", "not unbound_long=1"); + assertOptimizedEquals("not unbound_long=(1+1)", "not unbound_long=2"); + } + + @Test + public void testFunctionCall() + { + assertOptimizedEquals("abs(-5)", "5"); + assertOptimizedEquals("abs(-10-5)", "15"); + assertOptimizedEquals("abs(-bound_integer + 1)", "1233"); + assertOptimizedEquals("abs(-bound_long + 1)", "1233"); + assertOptimizedEquals("abs(-bound_long + BIGINT '1')", "1233"); + assertOptimizedEquals("abs(-bound_long)", "1234"); + assertOptimizedEquals("abs(unbound_long)", "abs(unbound_long)"); + assertOptimizedEquals("abs(unbound_long + 1)", "abs(unbound_long + 1)"); + assertOptimizedEquals("cast(json_parse(unbound_string) as map(varchar, varchar))", "cast(json_parse(unbound_string) as map(varchar, varchar))"); + assertOptimizedEquals("cast(json_parse(unbound_string) as array(varchar))", "cast(json_parse(unbound_string) as array(varchar))"); + assertOptimizedEquals("cast(json_parse(unbound_string) as row(bigint, varchar))", "cast(json_parse(unbound_string) as row(bigint, varchar))"); + } + + // TODO: evaluated is not working on current sidecar implementation + @Test(enabled = false) + public void testNonDeterministicFunctionCall() + { + // optimize should do nothing + assertOptimizedEquals("random()", "random()"); + + // evaluate should execute + RowExpression value = evaluate("random()", false); + assertTrue(value instanceof ConstantExpression); + Object innerValue = ((ConstantExpression) value).getValue(); + assertTrue(innerValue instanceof Double); + double randomValue = (double) innerValue; + assertTrue(0 <= randomValue && randomValue < 1); + } + + // Run this method exactly once. + private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) + { + functionAndTypeManager.addFunctionNamespaceFactory(new JsonFileBasedFunctionNamespaceManagerFactory()); + functionAndTypeManager.loadFunctionNamespaceManager( + JsonFileBasedFunctionNamespaceManagerFactory.NAME, + "json", + ImmutableMap.of("supported-function-languages", "CPP", "function-implementation-type", "CPP"), + new PluginNodeManager(new InMemoryNodeManager())); + } + + @Test + public void testBetween() + { + assertOptimizedEquals("3 between 2 and 4", "true"); + assertOptimizedEquals("2 between 3 and 4", "false"); + assertOptimizedEquals("null between 2 and 4", "null"); + assertOptimizedEquals("3 between null and 4", "null"); + assertOptimizedEquals("3 between 2 and null", "null"); + + assertOptimizedEquals("'cc' between 'b' and 'd'", "true"); + assertOptimizedEquals("'b' between 'cc' and 'd'", "false"); + assertOptimizedEquals("null between 'b' and 'd'", "null"); + assertOptimizedEquals("'cc' between null and 'd'", "null"); + assertOptimizedEquals("'cc' between 'b' and null", "null"); + + assertOptimizedEquals("bound_integer between 1000 and 2000", "true"); + assertOptimizedEquals("bound_integer between 3 and 4", "false"); + assertOptimizedEquals("bound_long between 1000 and 2000", "true"); + assertOptimizedEquals("bound_long between 3 and 4", "false"); + assertOptimizedEquals("bound_long between bound_integer and (bound_long + 1)", "true"); + assertOptimizedEquals("bound_string between 'e' and 'i'", "true"); + assertOptimizedEquals("bound_string between 'a' and 'b'", "false"); + + assertOptimizedEquals("bound_long between unbound_long and 2000 + 1", "1234 between unbound_long and 2001"); + assertOptimizedEquals( + "bound_string between unbound_string and 'bar'", + format("CAST('hello' AS VARCHAR(%s)) between unbound_string and 'bar'", TEST_VARCHAR_TYPE_LENGTH)); + + assertOptimizedEquals("1.15 between 1.1 and 1.2", "true"); + assertOptimizedEquals("9876543210.98745612035 between 9876543210.9874561203 and 9876543210.9874561204", "true"); + assertOptimizedEquals("123.455 between bound_decimal_short and 123.46", "true"); + assertOptimizedEquals("12345678901234567890.1235 between bound_decimal_long and 12345678901234567890.123", "false"); + } + + @Test + public void testExtract() + { + DateTime dateTime = new DateTime(2001, 8, 22, 3, 4, 5, 321, getDateTimeZone(TEST_SESSION.getTimeZoneKey())); + double seconds = dateTime.getMillis() / 1000.0; + + assertOptimizedEquals("extract (YEAR from from_unixtime(" + seconds + "))", "2001"); + assertOptimizedEquals("extract (QUARTER from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (MONTH from from_unixtime(" + seconds + "))", "8"); + assertOptimizedEquals("extract (WEEK from from_unixtime(" + seconds + "))", "34"); + assertOptimizedEquals("extract (DOW from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (DOY from from_unixtime(" + seconds + "))", "234"); + assertOptimizedEquals("extract (DAY from from_unixtime(" + seconds + "))", "22"); + assertOptimizedEquals("extract (HOUR from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (MINUTE from from_unixtime(" + seconds + "))", "4"); + assertOptimizedEquals("extract (SECOND from from_unixtime(" + seconds + "))", "5"); + assertOptimizedEquals("extract (TIMEZONE_HOUR from from_unixtime(" + seconds + ", 7, 9))", "7"); + assertOptimizedEquals("extract (TIMEZONE_MINUTE from from_unixtime(" + seconds + ", 7, 9))", "9"); + + assertOptimizedEquals("extract (YEAR from bound_timestamp)", "2001"); + assertOptimizedEquals("extract (QUARTER from bound_timestamp)", "3"); + assertOptimizedEquals("extract (MONTH from bound_timestamp)", "8"); + assertOptimizedEquals("extract (WEEK from bound_timestamp)", "34"); + assertOptimizedEquals("extract (DOW from bound_timestamp)", "2"); + assertOptimizedEquals("extract (DOY from bound_timestamp)", "233"); + assertOptimizedEquals("extract (DAY from bound_timestamp)", "21"); + assertOptimizedEquals("extract (HOUR from bound_timestamp)", "16"); + assertOptimizedEquals("extract (MINUTE from bound_timestamp)", "4"); + assertOptimizedEquals("extract (SECOND from bound_timestamp)", "5"); + // todo reenable when cast as timestamp with time zone is implemented + // todo add bound timestamp with time zone + //assertOptimizedEquals("extract (TIMEZONE_HOUR from bound_timestamp)", "0"); + //assertOptimizedEquals("extract (TIMEZONE_MINUTE from bound_timestamp)", "0"); + + assertOptimizedEquals("extract (YEAR from unbound_timestamp)", "extract (YEAR from unbound_timestamp)"); + assertOptimizedEquals("extract (SECOND from bound_timestamp + INTERVAL '3' SECOND)", "8"); + } + + @Test + public void testIn() + { + assertOptimizedEquals("3 in (2, 4, 3, 5)", "true"); + assertOptimizedEquals("3 in (2, 4, 9, 5)", "false"); + assertOptimizedEquals("3 in (2, null, 3, 5)", "true"); + + assertOptimizedEquals("'foo' in ('bar', 'baz', 'foo', 'blah')", "true"); + assertOptimizedEquals("'foo' in ('bar', 'baz', 'buz', 'blah')", "false"); + assertOptimizedEquals("'foo' in ('bar', null, 'foo', 'blah')", "true"); + + assertOptimizedEquals("null in (2, null, 3, 5)", "null"); + assertOptimizedEquals("3 in (2, null)", "null"); + + assertOptimizedEquals("bound_integer in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_integer in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_integer, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_integer, 3, 5)", "false"); + assertOptimizedEquals("bound_integer in (2, bound_integer, 3, 5)", "true"); + + assertOptimizedEquals("bound_long in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_long in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_long, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_long, 3, 5)", "false"); + assertOptimizedEquals("bound_long in (2, bound_long, 3, 5)", "true"); + + assertOptimizedEquals("bound_string in ('bar', 'hello', 'foo', 'blah')", "true"); + assertOptimizedEquals("bound_string in ('bar', 'baz', 'foo', 'blah')", "false"); + assertOptimizedEquals("'hello' in ('bar', bound_string, 'foo', 'blah')", "true"); + assertOptimizedEquals("'baz' in ('bar', bound_string, 'foo', 'blah')", "false"); + + assertOptimizedEquals("bound_long in (2, 1234, unbound_long, 5)", "true"); + assertOptimizedEquals("bound_string in ('bar', 'hello', unbound_string, 'blah')", "true"); + + assertOptimizedEquals("bound_long in (2, 4, unbound_long, unbound_long2, 9)", "1234 in (unbound_long, unbound_long2)"); + assertOptimizedEquals("unbound_long in (2, 4, bound_long, unbound_long2, 5)", "unbound_long in (2, 4, 1234, unbound_long2, 5)"); + + assertOptimizedEquals("1.15 in (1.1, 1.2, 1.3, 1.15)", "true"); + assertOptimizedEquals("9876543210.98745612035 in (9876543210.9874561203, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_short in (123.455, 123.46, 123.45)", "true"); + assertOptimizedEquals("bound_decimal_long in (12345678901234567890.123, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_long in (9876543210.9874561204, null, 9876543210.98745612035)", "null"); + } + + @Test + public void testInComplexTypes() + { + assertEvaluatedEquals("ARRAY[null] IN (ARRAY[null])", "null"); + assertEvaluatedEquals("ARRAY[1] IN (ARRAY[null])", "null"); + assertEvaluatedEquals("ARRAY[null] IN (ARRAY[1])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[2, null])", "false"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null], ARRAY[1, null])", "null"); + assertEvaluatedEquals("ARRAY[ARRAY[1, 2], ARRAY[3, 4]] in (ARRAY[ARRAY[1, 2], ARRAY[3, NULL]])", "null"); + + assertEvaluatedEquals("ROW(1) IN (ROW(1))", "true"); + assertEvaluatedEquals("ROW(1) IN (ROW(2))", "false"); + assertEvaluatedEquals("ROW(1) IN (ROW(2), ROW(1), ROW(2))", "true"); + assertEvaluatedEquals("ROW(1) IN (null)", "null"); + assertEvaluatedEquals("ROW(1) IN (null, ROW(1))", "true"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null), null)", "null"); + assertEvaluatedEquals("ROW(null) IN (ROW(null))", "null"); + assertEvaluatedEquals("ROW(1) IN (ROW(null))", "null"); + assertEvaluatedEquals("ROW(null) IN (ROW(1))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null))", "false"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null), ROW(1, null))", "null"); + + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[1]))", "true"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null)", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null, MAP(ARRAY[1], ARRAY[1]))", "true"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]), null)", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[1]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]), MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + } + + // TODO: current timestamp returns the session timestamp, which is not supported by this test + @Test(enabled = false) + public void testCurrentTimestamp() + { + double current = TEST_SESSION.getStartTime() / 1000.0; + assertOptimizedEquals("current_timestamp = from_unixtime(" + current + ")", "true"); + double future = current + TimeUnit.MINUTES.toSeconds(1); + assertOptimizedEquals("current_timestamp > from_unixtime(" + future + ")", "false"); + } + + @Test + public void testCurrentUser() + { + assertOptimizedEquals("current_user", "'" + TEST_SESSION.getUser() + "'"); + } + + @Test + public void testCastToString() + { + // integer + assertOptimizedEquals("cast(123 as VARCHAR(20))", "'123'"); + assertOptimizedEquals("cast(-123 as VARCHAR(20))", "'-123'"); + + // bigint + assertOptimizedEquals("cast(BIGINT '123' as VARCHAR)", "'123'"); + assertOptimizedEquals("cast(12300000000 as VARCHAR)", "'12300000000'"); + assertOptimizedEquals("cast(-12300000000 as VARCHAR)", "'-12300000000'"); + + // double + assertOptimizedEquals("cast(123.0E0 as VARCHAR)", "'123.0'"); + assertOptimizedEquals("cast(-123.0E0 as VARCHAR)", "'-123.0'"); + assertOptimizedEquals("cast(123.456E0 as VARCHAR)", "'123.456'"); + assertOptimizedEquals("cast(-123.456E0 as VARCHAR)", "'-123.456'"); + + // boolean + assertOptimizedEquals("cast(true as VARCHAR)", "'true'"); + assertOptimizedEquals("cast(false as VARCHAR)", "'false'"); + + // string + assertOptimizedEquals("cast('xyz' as VARCHAR)", "'xyz'"); + assertOptimizedEquals("cast(cast('abcxyz' as VARCHAR(3)) as VARCHAR(5))", "'abc'"); + + // null + assertOptimizedEquals("cast(null as VARCHAR)", "null"); + + // decimal + assertOptimizedEquals("cast(1.1 as VARCHAR)", "'1.1'"); + // TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'"); + } + + @Test + public void testCastBigintToBoundedVarchar() + { + assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'"); + assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'"); + + // TODO: Velox permits this cast, but Presto does not +// try { +// evaluate("CAST(12300000000 AS varchar(3))", true); +// fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); +// } +// catch (PrestoException e) { +// try { +// assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); +// assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)"); +// } +// catch (Throwable failure) { +// failure.addSuppressed(e); +// throw failure; +// } +// } +// +// try { +// evaluate("CAST(-12300000000 AS varchar(3))", true); +// } +// catch (PrestoException e) { +// try { +// assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); +// assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)"); +// } +// catch (Throwable failure) { +// failure.addSuppressed(e); +// throw failure; +// } +// } + } + + @Test + public void testCastToBoolean() + { + // integer + assertOptimizedEquals("cast(123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(0 as BOOLEAN)", "false"); + + // bigint + assertOptimizedEquals("cast(12300000000 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-12300000000 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(BIGINT '0' as BOOLEAN)", "false"); + + // boolean + assertOptimizedEquals("cast(true as BOOLEAN)", "true"); + assertOptimizedEquals("cast(false as BOOLEAN)", "false"); + + // string + assertOptimizedEquals("cast('true' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('false' as BOOLEAN)", "false"); + assertOptimizedEquals("cast('t' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('f' as BOOLEAN)", "false"); + assertOptimizedEquals("cast('1' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('0' as BOOLEAN)", "false"); + + // null + assertOptimizedEquals("cast(null as BOOLEAN)", "null"); + + // double + assertOptimizedEquals("cast(123.45E0 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-123.45E0 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(0.0E0 as BOOLEAN)", "false"); + + // decimal + assertOptimizedEquals("cast(0.00 as BOOLEAN)", "false"); + assertOptimizedEquals("cast(7.8 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(12345678901234567890.123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(00000000000000000000.000 as BOOLEAN)", "false"); + } + + @Test + public void testCastToBigint() + { + // integer + assertOptimizedEquals("cast(0 as BIGINT)", "0"); + assertOptimizedEquals("cast(123 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123 as BIGINT)", "-123"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as BIGINT)", "0"); + assertOptimizedEquals("cast(BIGINT '123' as BIGINT)", "123"); + assertOptimizedEquals("cast(BIGINT '-123' as BIGINT)", "-123"); + + // double + assertOptimizedEquals("cast(123.0E0 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123.0E0 as BIGINT)", "-123"); + assertOptimizedEquals("cast(123.456E0 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123.456E0 as BIGINT)", "-123"); + + // boolean + assertOptimizedEquals("cast(true as BIGINT)", "1"); + assertOptimizedEquals("cast(false as BIGINT)", "0"); + + // string + assertOptimizedEquals("cast('123' as BIGINT)", "123"); + assertOptimizedEquals("cast('-123' as BIGINT)", "-123"); + + // null + assertOptimizedEquals("cast(null as BIGINT)", "null"); + + // decimal + assertOptimizedEquals("cast(DECIMAL '1.01' as BIGINT)", "1"); + assertOptimizedEquals("cast(DECIMAL '7.8' as BIGINT)", "8"); + assertOptimizedEquals("cast(DECIMAL '1234567890.123' as BIGINT)", "1234567890"); + assertOptimizedEquals("cast(DECIMAL '00000000000000000000.000' as BIGINT)", "0"); + } + + @Test + public void testCastToInteger() + { + // integer + assertOptimizedEquals("cast(0 as INTEGER)", "0"); + assertOptimizedEquals("cast(123 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123 as INTEGER)", "-123"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as INTEGER)", "0"); + assertOptimizedEquals("cast(BIGINT '123' as INTEGER)", "123"); + assertOptimizedEquals("cast(BIGINT '-123' as INTEGER)", "-123"); + + // double + assertOptimizedEquals("cast(123.0E0 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123.0E0 as INTEGER)", "-123"); + assertOptimizedEquals("cast(123.456E0 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123.456E0 as INTEGER)", "-123"); + + // boolean + assertOptimizedEquals("cast(true as INTEGER)", "1"); + assertOptimizedEquals("cast(false as INTEGER)", "0"); + + // string + assertOptimizedEquals("cast('123' as INTEGER)", "123"); + assertOptimizedEquals("cast('-123' as INTEGER)", "-123"); + + // null + assertOptimizedEquals("cast(null as INTEGER)", "null"); + } + + @Test + public void testCastToDouble() + { + // integer + assertOptimizedEquals("cast(0 as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast(123 as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast(-123 as DOUBLE)", "-123.0E0"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast(12300000000 as DOUBLE)", "12300000000.0E0"); + assertOptimizedEquals("cast(-12300000000 as DOUBLE)", "-12300000000.0E0"); + + // double + assertOptimizedEquals("cast(123.0E0 as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast(-123.0E0 as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast(123.456E0 as DOUBLE)", "123.456E0"); + assertOptimizedEquals("cast(-123.456E0 as DOUBLE)", "-123.456E0"); + + // string + assertOptimizedEquals("cast('0' as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast('123' as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast('-123' as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast('123.0E0' as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast('-123.0E0' as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast('123.456E0' as DOUBLE)", "123.456E0"); + assertOptimizedEquals("cast('-123.456E0' as DOUBLE)", "-123.456E0"); + + // null + assertOptimizedEquals("cast(null as DOUBLE)", "null"); + + // boolean + assertOptimizedEquals("cast(true as DOUBLE)", "1.0E0"); + assertOptimizedEquals("cast(false as DOUBLE)", "0.0E0"); + + // decimal + assertOptimizedEquals("cast(1.01 as DOUBLE)", "DOUBLE '1.01'"); + assertOptimizedEquals("cast(7.8 as DOUBLE)", "DOUBLE '7.8'"); + assertOptimizedEquals("cast(1234567890.123 as DOUBLE)", "DOUBLE '1234567890.123'"); + assertOptimizedEquals("cast(00000000000000000000.000 as DOUBLE)", "DOUBLE '0.0'"); + } + + @Test + public void testCastToDecimal() + { + // long + assertOptimizedEquals("cast(0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123 as DECIMAL(3,0))", "DECIMAL '123'"); + assertOptimizedEquals("cast(-123 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123 as DECIMAL(20,10))", "cast(-123 as DECIMAL(20,10))"); + + // double + assertOptimizedEquals("cast(0E0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123.2E0 as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast(-123.0E0 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123.55E0 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + + // string + assertOptimizedEquals("cast('0' as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast('123.2' as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast('-123.0' as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast('-123.55' as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + + // null + assertOptimizedEquals("cast(null as DECIMAL(1,0))", "null"); + assertOptimizedEquals("cast(null as DECIMAL(20,10))", "null"); + + // boolean + assertOptimizedEquals("cast(true as DECIMAL(1,0))", "DECIMAL '1'"); + assertOptimizedEquals("cast(false as DECIMAL(4,1))", "DECIMAL '000.0'"); + assertOptimizedEquals("cast(true as DECIMAL(3,0))", "DECIMAL '001'"); + assertOptimizedEquals("cast(false as DECIMAL(20,10))", "cast(0 as DECIMAL(20,10))"); + + // decimal + assertOptimizedEquals("cast(0.0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123.2 as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast(-123.0 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123.55 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + } + + @Test + public void testCastOptimization() + { + assertOptimizedEquals("cast(unbound_string as VARCHAR)", "cast(unbound_string as VARCHAR)"); + assertOptimizedMatches("cast(unbound_string as VARCHAR)", "unbound_string"); + assertOptimizedMatches("cast(unbound_integer as INTEGER)", "unbound_integer"); + assertOptimizedMatches("cast(unbound_string as VARCHAR(10))", "cast(unbound_string as VARCHAR(10))"); + } + + @Test + public void testTryCast() + { + assertOptimizedEquals("try_cast(null as BIGINT)", "null"); + assertOptimizedEquals("try_cast(123 as BIGINT)", "123"); + assertOptimizedEquals("try_cast(null as INTEGER)", "null"); + assertOptimizedEquals("try_cast(123 as INTEGER)", "123"); + assertOptimizedEquals("try_cast('foo' as VARCHAR)", "'foo'"); + assertOptimizedEquals("try_cast('foo' as BIGINT)", "null"); + assertOptimizedEquals("try_cast(unbound_string as BIGINT)", "try_cast(unbound_string as BIGINT)"); + assertOptimizedEquals("try_cast('foo' as DECIMAL(2,1))", "null"); + } + + @Test + public void testReservedWithDoubleQuotes() + { + assertOptimizedEquals("\"time\"", "\"time\""); + } + + @Test + public void testSearchCase() + { + assertOptimizedEquals("case " + + "when true then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case " + + "when false then 10000000000 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case " + + "when bound_long = 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_long " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_integer = 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_integer " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_long = 1234 then 33 " + + "else unbound_long " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_long " + + "else unbound_long " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then unbound_long " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_integer = 1234 then 33 " + + "else unbound_integer " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_integer " + + "else unbound_integer " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then unbound_integer " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when unbound_long = 1234 then 33 " + + "else 1 " + + "end", + "case " + + "when unbound_long = 1234 then 33 " + + "else 1 " + + "end"); + + assertOptimizedMatches("if(false, 1, 0 / 0)", "cast(fail(8, 'ignored failure message') as integer)"); + + assertOptimizedEquals("case " + + "when false then 2.2 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case " + + "when false then 1234567890.0987654321 " + + "when true then 3.3 " + + "end", + "CAST(3.3 AS DECIMAL(20,10))"); + + assertOptimizedEquals("case " + + "when false then 1 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); + assertOptimizedEquals("case when ARRAY[CAST(2 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + assertOptimizedEquals("case when ARRAY[CAST(null AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + } + + @Test + public void testSimpleCase() + { + assertOptimizedEquals("case 1 " + + "when 1 then 32 + 1 " + + "when 1 then 34 " + + "end", + "33"); + + assertOptimizedEquals("case null " + + "when true then 33 " + + "end", + "null"); + assertOptimizedEquals("case null " + + "when true then 33 " + + "else 33 " + + "end", + "33"); + assertOptimizedEquals("case 33 " + + "when null then 1 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case null " + + "when true then 3300000000 " + + "end", + "null"); + assertOptimizedEquals("case null " + + "when true then 3300000000 " + + "else 3300000000 " + + "end", + "3300000000"); + assertOptimizedEquals("case 33 " + + "when null then 3300000000 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case true " + + "when true then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else 33 end", + "33"); + + assertOptimizedEquals("case bound_long " + + "when 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case 1234 " + + "when bound_long then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_long " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case bound_integer " + + "when 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case 1234 " + + "when bound_integer then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_integer " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case bound_long " + + "when 1234 then 33 " + + "else unbound_long " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_long " + + "else unbound_long " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then unbound_long " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case unbound_long " + + "when 1234 then 33 " + + "else 1 " + + "end", + "case unbound_long " + + "when 1234 then 33 " + + "else 1 " + + "end"); + + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 33 then unbound_long " + + "else 1 " + + "end", + "unbound_long"); + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 33 then 1 " + + "when unbound_long then 2 " + + "else 1 " + + "end", + "1"); + assertOptimizedEquals("case 33 " + + "when unbound_long then 0 " + + "when 1 then 1 " + + "when 33 then 2 " + + "else 0 " + + "end", + "case 33 " + + "when unbound_long then 0 " + + "else 2 " + + "end"); + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 1 then 1 " + + "else unbound_long " + + "end", + "unbound_long"); + assertOptimizedEquals("case 33 " + + "when unbound_long then 0 " + + "when 1 then 1 " + + "when unbound_long2 then 2 " + + "else 3 " + + "end", + "case 33 " + + "when unbound_long then 0 " + + "when unbound_long2 then 2 " + + "else 3 " + + "end"); + + assertOptimizedMatches("case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 " + + "else 33 end", + "case true " + + "when unbound_long = BIGINT '1' then 1 " + + "when CAST(fail(8, 'ignored failure message') AS boolean) then 2 else 33 " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when 123 * 10 + unbound_long then 1 = 1 " + + "else 1 = 2 " + + "end", + "case bound_long when 1230 + unbound_long then true " + + "else false " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when unbound_long then 2 + 2 " + + "end", + "case bound_long " + + "when unbound_long then 4 " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when unbound_long then 2 + 2 " + + "when 1 then null " + + "when 2 then null " + + "end", + "case bound_long " + + "when unbound_long then 4 " + + "end"); + + assertOptimizedMatches("case 1 " + + "when 0 / 0 then 1 " + + "when 0 / 0 then 2 " + + "else 1 " + + "end", + "case 1 " + + "when cast(fail(8, 'ignored failure message') as integer) then 1 " + + "when cast(fail(8, 'ignored failure message') as integer) then 2 " + + "else 1 " + + "end"); + + assertOptimizedEquals("case true " + + "when false then 2.2 " + + "when true then 2.2 " + + "end", + "2.2"); + + // TODO enabled when DECIMAL is default for literal: +// assertOptimizedEquals("case true " + +// "when false then 1234567890.0987654321 " + +// "when true then 3.3 " + +// "end", +// "CAST(3.3 AS DECIMAL(20,10))"); + + assertOptimizedEquals("case true " + + "when false then 1 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case ARRAY[CAST(1 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); + assertOptimizedEquals("case ARRAY[CAST(2 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + assertOptimizedEquals("case ARRAY[CAST(null AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + } + + @Test + public void testCoalesce() + { + assertOptimizedEquals("coalesce(null, null)", "coalesce(null, null)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1 - 1, null)", "coalesce(6 * unbound_long, 0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_long, 0.5E0)"); + assertOptimizedEquals("coalesce(unbound_long, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_long, 2.0E0, 0.5E0, 12.34E0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1 - 1, null)", "coalesce(6 * unbound_integer, 0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_integer, 0.5E0)"); + assertOptimizedEquals("coalesce(unbound_integer, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_integer, 2.0E0, 0.5E0, 12.34E0)"); + assertOptimizedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)", + "coalesce(cast(fail(8, 'ignored failure message') as boolean), unbound_boolean)"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long)", "unbound_long"); + assertOptimizedMatches("coalesce(2 * unbound_long, 2 * unbound_long)", "BIGINT '2' * unbound_long"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long)", "coalesce(unbound_long, unbound_long2)"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long, unbound_long3)", "coalesce(unbound_long, unbound_long2, unbound_long3)"); + assertOptimizedEquals("coalesce(6, unbound_long2, unbound_long, unbound_long3)", "6"); + assertOptimizedEquals("coalesce(2 * 3, unbound_long2, unbound_long, unbound_long3)", "6"); + assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); + assertOptimizedMatches("coalesce(random(), random(), 5)", "coalesce(random(), random(), 5E0)"); + assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); + } + + @Test + public void testIf() + { + assertOptimizedEquals("IF(2 = 2, 3, 4)", "3"); + assertOptimizedEquals("IF(1 = 2, 3, 4)", "4"); + assertOptimizedEquals("IF(1 = 2, BIGINT '3', 4)", "4"); + assertOptimizedEquals("IF(1 = 2, 3000000000, 4)", "4"); + + assertOptimizedEquals("IF(true, 3, 4)", "3"); + assertOptimizedEquals("IF(false, 3, 4)", "4"); + assertOptimizedEquals("IF(null, 3, 4)", "4"); + + assertOptimizedEquals("IF(true, 3, null)", "3"); + assertOptimizedEquals("IF(false, 3, null)", "null"); + assertOptimizedEquals("IF(true, null, 4)", "null"); + assertOptimizedEquals("IF(false, null, 4)", "4"); + assertOptimizedEquals("IF(true, null, null)", "null"); + assertOptimizedEquals("IF(false, null, null)", "null"); + + assertOptimizedEquals("IF(true, 3.5E0, 4.2E0)", "3.5E0"); + assertOptimizedEquals("IF(false, 3.5E0, 4.2E0)", "4.2E0"); + + assertOptimizedEquals("IF(true, 'foo', 'bar')", "'foo'"); + assertOptimizedEquals("IF(false, 'foo', 'bar')", "'bar'"); + + assertOptimizedEquals("IF(true, 1.01, 1.02)", "1.01"); + assertOptimizedEquals("IF(false, 1.01, 1.02)", "1.02"); + assertOptimizedEquals("IF(true, 1234567890.123, 1.02)", "1234567890.123"); + assertOptimizedEquals("IF(false, 1.01, 1234567890.123)", "1234567890.123"); + } + + // TODO: Pending on native function namespace manager. + @Test(enabled = false) + public void testLike() + { + assertOptimizedEquals("'a' LIKE 'a'", "true"); + assertOptimizedEquals("'' LIKE 'a'", "false"); + assertOptimizedEquals("'abc' LIKE 'a'", "false"); + + assertOptimizedEquals("'a' LIKE '_'", "true"); + assertOptimizedEquals("'' LIKE '_'", "false"); + assertOptimizedEquals("'abc' LIKE '_'", "false"); + + assertOptimizedEquals("'a' LIKE '%'", "true"); + assertOptimizedEquals("'' LIKE '%'", "true"); + assertOptimizedEquals("'abc' LIKE '%'", "true"); + + assertOptimizedEquals("'abc' LIKE '___'", "true"); + assertOptimizedEquals("'ab' LIKE '___'", "false"); + assertOptimizedEquals("'abcd' LIKE '___'", "false"); + + assertOptimizedEquals("'abc' LIKE 'abc'", "true"); + assertOptimizedEquals("'xyz' LIKE 'abc'", "false"); + assertOptimizedEquals("'abc0' LIKE 'abc'", "false"); + assertOptimizedEquals("'0abc' LIKE 'abc'", "false"); + + assertOptimizedEquals("'abc' LIKE 'abc%'", "true"); + assertOptimizedEquals("'abc0' LIKE 'abc%'", "true"); + assertOptimizedEquals("'0abc' LIKE 'abc%'", "false"); + + assertOptimizedEquals("'abc' LIKE '%abc'", "true"); + assertOptimizedEquals("'0abc' LIKE '%abc'", "true"); + assertOptimizedEquals("'abc0' LIKE '%abc'", "false"); + + assertOptimizedEquals("'abc' LIKE '%abc%'", "true"); + assertOptimizedEquals("'0abc' LIKE '%abc%'", "true"); + assertOptimizedEquals("'abc0' LIKE '%abc%'", "true"); + assertOptimizedEquals("'0abc0' LIKE '%abc%'", "true"); + assertOptimizedEquals("'xyzw' LIKE '%abc%'", "false"); + + assertOptimizedEquals("'abc' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0abc' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'abc0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0abc0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'ab01c' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0ab01c' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'ab01c0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0ab01c0' LIKE '%ab%c%'", "true"); + + assertOptimizedEquals("'xyzw' LIKE '%ab%c%'", "false"); + + // ensure regex chars are escaped + assertOptimizedEquals("'' LIKE ''", "true"); + assertOptimizedEquals("'.*' LIKE '.*'", "true"); + assertOptimizedEquals("'[' LIKE '['", "true"); + assertOptimizedEquals("']' LIKE ']'", "true"); + assertOptimizedEquals("'{' LIKE '{'", "true"); + assertOptimizedEquals("'}' LIKE '}'", "true"); + assertOptimizedEquals("'?' LIKE '?'", "true"); + assertOptimizedEquals("'+' LIKE '+'", "true"); + assertOptimizedEquals("'(' LIKE '('", "true"); + assertOptimizedEquals("')' LIKE ')'", "true"); + assertOptimizedEquals("'|' LIKE '|'", "true"); + assertOptimizedEquals("'^' LIKE '^'", "true"); + assertOptimizedEquals("'$' LIKE '$'", "true"); + + assertOptimizedEquals("null LIKE '%'", "null"); + assertOptimizedEquals("'a' LIKE null", "null"); + assertOptimizedEquals("'a' LIKE '%' ESCAPE null", "null"); + assertOptimizedEquals("'a' LIKE unbound_string ESCAPE null", "null"); + + assertOptimizedEquals("'%' LIKE 'z%' ESCAPE 'z'", "true"); + + assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%' ESCAPE 'z'", "true"); + assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%'", "false"); + } + + // TODO: Pending on native function namespace manager. + @Test + public void testLikeOptimization() + { + assertOptimizedEquals("unbound_string LIKE 'abc'", "unbound_string = CAST('abc' AS VARCHAR)"); + + assertOptimizedEquals("unbound_string LIKE '' ESCAPE '#'", "unbound_string LIKE '' ESCAPE '#'"); + assertOptimizedEquals("unbound_string LIKE 'abc' ESCAPE '#'", "unbound_string = CAST('abc' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)"); + assertOptimizedMatches("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); + assertOptimizedMatches("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); + + assertOptimizedEquals("bound_string LIKE bound_pattern", "true"); + assertOptimizedEquals("'abc' LIKE bound_pattern", "false"); + + assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE); + + assertOptimizedMatches("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); + } + + // TODO: Pending on native function namespace manager. + @Test(enabled = false) + public void testInvalidLike() + { + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE ''")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE 'bc'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#' ESCAPE '#'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#abc' ESCAPE '#'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'ab#' ESCAPE '#'")); + } + + @Test + public void testLambda() + { + assertDoNotOptimize("transform(unbound_array, x -> x + x)", OPTIMIZED); + assertOptimizedEquals("transform(ARRAY[1, 5], x -> x + x)", "transform(ARRAY[1, 5], x -> x + x)"); + assertOptimizedEquals("transform(sequence(1, 5), x -> x + x)", "transform(sequence(1, 5), x -> x + x)"); + } + + @Test + public void testFailedExpressionOptimization() + { + assertOptimizedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END", + "CASE unbound_long WHEN BIGINT '1' THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 END"); + + assertOptimizedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END", + "CASE unbound_boolean WHEN true THEN 1 ELSE cast(fail(8, 'ignored failure message') as integer) END"); + + assertOptimizedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END", + "CASE BIGINT '1234' WHEN unbound_long THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 ELSE 1 END"); + + assertOptimizedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end", + "case when unbound_boolean then 1 when cast(fail(8, 'ignored failure message') as boolean) then 2 end"); + + assertOptimizedMatches("case when unbound_boolean then 1 else 0 / 0 end", + "case when unbound_boolean then 1 else cast(fail(8, 'ignored failure message') as integer) end"); + + assertOptimizedMatches("case when unbound_boolean then 0 / 0 else 1 end", + "case when unbound_boolean then cast(fail(8, 'ignored failure message') as integer) else 1 end"); + } + + @Test(expectedExceptions = PrestoException.class) + public void testOptimizeDivideByZero() + { + optimize("0 / 0"); + } + + @Test + public void testMassiveArray() + { + assertDoNotOptimize("SEQUENCE(1, 999)", SERIALIZABLE); + assertDoNotOptimize("SEQUENCE(1, 1000)", SERIALIZABLE); + optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "(bound_long + " + i + ")").iterator()))); + optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "(bound_integer + " + i + ")").iterator()))); + optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "'" + i + "'").iterator()))); + optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "ARRAY['" + i + "']").iterator()))); + } + + @Test + public void testArrayConstructor() + { + optimize("ARRAY []"); + assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", + "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", + "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", + "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); + } + + @Test + public void testRowConstructor() + { + optimize("ROW(NULL)"); + optimize("ROW(1)"); + optimize("ROW(unbound_long + 0)"); + optimize("ROW(unbound_long + unbound_long2, unbound_string, unbound_double)"); + optimize("ROW(unbound_boolean, FALSE, ARRAY[unbound_long, unbound_long2], unbound_null_string, unbound_interval)"); + optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(unbound_string, 0.0E0)]"); + optimize("ARRAY [ROW('string', unbound_double), ROW('string', bound_double)]"); + optimize("ROW(ROW(NULL), ROW(ROW(ROW(ROW('rowception')))))"); + optimize("ROW(unbound_string, bound_string)"); + + optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(CAST(bound_string AS VARCHAR), 0.0E0)]"); + optimize("ARRAY [ROW(CAST(bound_string AS VARCHAR), 0.0E0), ROW(unbound_string, unbound_double)]"); + + optimize("ARRAY [ROW(unbound_string, unbound_double), CAST(NULL AS ROW(VARCHAR, DOUBLE))]"); + optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]"); + } + + @Test + public void testDereference() + { + optimize("ARRAY []"); + assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", + "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", + "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", + "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); + } + + @Test + public void testRowDereference() + { + optimize("CAST(null AS ROW(a VARCHAR, b BIGINT)).a"); + } + + @Test + public void testRowSubscript() + { + assertOptimizedEquals("ROW (1, 'a', true)[3]", "true"); + assertOptimizedEquals("ROW (1, 'a', ROW (2, 'b', ROW (3, 'c')))[3][3][2]", "'c'"); + } + + @Test(expectedExceptions = PrestoException.class) + public void testArraySubscriptConstantNegativeIndex() + { + optimize("ARRAY [1, 2, 3][-1]"); + } + + @Test(expectedExceptions = PrestoException.class) + public void testArraySubscriptConstantZeroIndex() + { + optimize("ARRAY [1, 2, 3][0]"); + } + + @Test + public void testMapSubscriptConstantIndexes() + { + optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[1]"); + optimize("MAP(ARRAY [BIGINT '1', 2], ARRAY [3, 4])[1]"); + optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[2]"); + optimize("MAP(ARRAY [ARRAY[1,1]], ARRAY['a'])[ARRAY[1,1]]"); + } + + @Test + public void testLiterals() + { + optimize("date '2013-04-03' + unbound_interval"); + optimize("timestamp '2013-04-03 03:04:05.321' + unbound_interval"); + optimize("timestamp '2013-04-03 03:04:05.321 UTC' + unbound_interval"); + + optimize("interval '3' day * unbound_long"); + // TODO: Pending on velox PR: https://github.com/facebookincubator/velox/pull/11612. +// optimize("interval '3' year * unbound_integer"); + } + + @Test + public void assertLikeOptimizations() + { + assertOptimizedMatches("unbound_string LIKE bound_pattern", "unbound_string LIKE CAST('%el%' AS varchar)"); + } + + private RowExpression evaluate(String expression, boolean deterministic) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + return optimizeRowExpression(rowExpression, EVALUATED); + } + + private RowExpression optimize(@Language("SQL") String expression) + { + assertRoundTrip(expression); + RowExpression parsedExpression = sqlToRowExpression(expression); + return optimizeRowExpression(parsedExpression, OPTIMIZED); + } + + private void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + RowExpression optimizedActual = optimize(actual); + RowExpression optimizedExpected = optimize(expected); + assertRowExpressionEvaluationEquals(optimizedActual, optimizedExpected); + } + + private void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + RowExpression actualOptimized = optimize(actual); + RowExpression expectedOptimized = optimize(expected); + assertRowExpressionEvaluationEquals( + actualOptimized, + expectedOptimized); + } + + private void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + RowExpression rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } + + private RowExpression sqlToRowExpression(String expression) + { + Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); + return TRANSLATOR.translate(parsedExpression, SYMBOL_TYPES); + } + + private Object symbolConstant(Symbol symbol) + { + switch (symbol.getName().toLowerCase(ENGLISH)) { + case "bound_integer": + case "bound_long": + return 1234L; + case "bound_string": + return utf8Slice("hello"); + case "bound_double": + return 12.34; + case "bound_date": + return new LocalDate(2001, 8, 22).toDateMidnight(DateTimeZone.UTC).getMillis(); + case "bound_time": + return new LocalTime(3, 4, 5, 321).toDateTime(new DateTime(0, DateTimeZone.UTC)).getMillis(); + case "bound_timestamp": + return new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(); + case "bound_pattern": + return utf8Slice("%el%"); + case "bound_timestamp_with_timezone": + return new SqlTimestampWithTimeZone(new DateTime(1970, 1, 1, 1, 0, 0, 999, DateTimeZone.UTC).getMillis(), getTimeZoneKey("Z")); + case "bound_varbinary": + return Slices.wrappedBuffer((byte) 0xab); + case "bound_decimal_short": + return 12345L; + case "bound_decimal_long": + return Decimals.encodeUnscaledValue(new BigInteger("12345678901234567890123")); + } + return null; + } + + /** + * Assert the evaluation result of two row expressions equivalent + * no matter they are constants or remaining row expressions. + */ + private void assertRowExpressionEvaluationEquals(RowExpression left, RowExpression right) + { + if (left instanceof ConstantExpression) { + if (isRemovableCast(right)) { + assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0)); + return; + } + assertTrue(right instanceof ConstantExpression); + assertConstantsEqual(((ConstantExpression) left), ((ConstantExpression) left)); + } + else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) { + assertEquals(left, right); + } + else if (left instanceof CallExpression && ((CallExpression) left).getFunctionHandle().getName().contains("fail")) { + assertTrue(right instanceof CallExpression && ((CallExpression) right).getFunctionHandle().getName().contains("fail")); + assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size()); + for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i)); + } + } + else if (left instanceof CallExpression) { + assertTrue(right instanceof CallExpression); + assertEquals(((CallExpression) left).getFunctionHandle(), ((CallExpression) right).getFunctionHandle()); + assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size()); + for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i)); + } + } + else if (left instanceof SpecialFormExpression) { + assertTrue(right instanceof SpecialFormExpression); + assertEquals(((SpecialFormExpression) left).getForm(), ((SpecialFormExpression) right).getForm()); + assertEquals(((SpecialFormExpression) left).getArguments().size(), ((SpecialFormExpression) right).getArguments().size()); + for (int i = 0; i < ((SpecialFormExpression) left).getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(((SpecialFormExpression) left).getArguments().get(i), ((SpecialFormExpression) right).getArguments().get(i)); + } + } + else { + assertTrue(left instanceof LambdaDefinitionExpression); + assertTrue(right instanceof LambdaDefinitionExpression); + assertEquals(((LambdaDefinitionExpression) left).getArguments(), ((LambdaDefinitionExpression) right).getArguments()); + assertEquals(((LambdaDefinitionExpression) left).getArgumentTypes(), ((LambdaDefinitionExpression) right).getArgumentTypes()); + assertRowExpressionEvaluationEquals(((LambdaDefinitionExpression) left).getBody(), ((LambdaDefinitionExpression) right).getBody()); + } + } + + private void assertConstantsEqual(ConstantExpression left, ConstantExpression right) + { + if (left.getValue() instanceof Block) { + assertTrue(right.getValue() instanceof Block); + assertBlockEquals((Block) left.getValue(), (Block) right.getValue()); + } + else { + assertEquals(left.getValue(), right.getValue()); + } + } + + private static void assertBlockEquals(Block left, Block right) + { + SliceOutput sliceOutput = new DynamicSliceOutput(1000); + BlockSerdeUtil.writeBlock(BLOCK_ENCODING_SERDE, sliceOutput, right); + SliceOutput sliceOutput1 = new DynamicSliceOutput(1000); + BlockSerdeUtil.writeBlock(BLOCK_ENCODING_SERDE, sliceOutput1, left); + assertEquals(sliceOutput1.slice(), sliceOutput.slice()); + } + + private boolean isRemovableCast(RowExpression value) + { + if (value instanceof CallExpression && + new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { + Type targetType = value.getType(); + Type sourceType = ((CallExpression) value).getArguments().get(0).getType(); + return METADATA.getFunctionAndTypeManager().canCoerce(sourceType, targetType); + } + return false; + } + + private void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertRowExpressionEvaluationEquals(evaluate(actual, true), evaluate(expected, true)); + } + + private void assertRoundTrip(String expression) + { + ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION); + assertEquals(SQL_PARSER.createExpression(expression, parsingOptions), + SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions)); + } + + private void assertRowExpressionEquals(ExpressionOptimizer.Level level, @Language("SQL") String actual, @Language("SQL") String expected) + { + Object actualResult = optimizeRowExpression(sqlToRowExpression(actual), level); + Object expectedResult = optimizeRowExpression(sqlToRowExpression(expected), level); + if (actualResult instanceof Block && expectedResult instanceof Block) { + assertBlockEquals((Block) actualResult, (Block) expectedResult); + return; + } + assertEquals(actualResult, expectedResult); + } + + private RowExpression optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level) + { + return expressionOptimizer.optimize( + expression, + level, + TEST_SESSION.toConnectorSession(), + variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + return value; + }); + } +} diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressions.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressions.java new file mode 100644 index 0000000000000..d9ecbf8d8a8eb --- /dev/null +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/session/sql/expressions/TestNativeExpressions.java @@ -0,0 +1,177 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.sql.expressions; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.jaxrs.JsonMapper; +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.block.BlockJsonSerde; +import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncoding; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.type.TypeDeserializer; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Injector; +import com.google.inject.Module; +import com.google.inject.Scopes; +import org.testng.annotations.Test; + +import java.net.URI; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.findRandomPortForWorker; +import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeSidecarProcess; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertEquals; + +public class TestNativeExpressions +{ + private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); + private Process sidecar; + + @Test + public void testLoadPlugin() + throws Exception + { + try { + int port = findRandomPortForWorker(); + URI sidecarUri = URI.create(format("http://127.0.0.1:%s", port)); + sidecar = getNativeSidecarProcess(sidecarUri, port); + ExpressionOptimizer interpreterService = getExpressionOptimizer(METADATA, null, sidecarUri); + + // Test the native row expression interpreter service with some simple expressions + RowExpression simpleAddition = compileExpression("1+1"); + RowExpression unnecessaryCoalesce = compileExpression("coalesce(1, 2)"); + + // Assert simple optimizations are performed + assertEquals(interpreterService.optimize(simpleAddition, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(2L, simpleAddition.getType())); + assertEquals(interpreterService.optimize(unnecessaryCoalesce, OPTIMIZED, TEST_SESSION.toConnectorSession()), toRowExpression(1L, unnecessaryCoalesce.getType())); + } + finally { + if (sidecar != null) { + sidecar.destroy(); + } + } + } + + private static RowExpression compileExpression(String expression) + { + return TRANSLATOR.translate(expression, ImmutableMap.of()); + } + + protected static ExpressionOptimizer getExpressionOptimizer(Metadata metadata, HandleResolver handleResolver, URI sidecarUri) + { + // Set up dependencies in main for this module + InMemoryNodeManager nodeManager = getNodeManagerWithSidecar(sidecarUri); + Injector prestoMainInjector = getPrestoMainInjector(metadata, handleResolver); + RowExpressionSerde rowExpressionSerde = prestoMainInjector.getInstance(RowExpressionSerde.class); + FunctionAndTypeManager functionMetadataManager = prestoMainInjector.getInstance(FunctionAndTypeManager.class); + + // Create the native row expression interpreter service + return createExpressionOptimizer(nodeManager, rowExpressionSerde, functionMetadataManager); + } + + private static InMemoryNodeManager getNodeManagerWithSidecar(URI sidecarUri) + { + InMemoryNodeManager nodeManager = new InMemoryNodeManager(); + nodeManager.addNode(new ConnectorId("test"), new InternalNode("test", sidecarUri, NodeVersion.UNKNOWN, false, false, false, true)); + return nodeManager; + } + + private static ExpressionOptimizer createExpressionOptimizer(InternalNodeManager internalNodeManager, RowExpressionSerde rowExpressionSerde, FunctionAndTypeManager functionMetadataManager) + { + requireNonNull(internalNodeManager, "inMemoryNodeManager is null"); + NodeManager nodeManager = new PluginNodeManager(internalNodeManager); + FunctionResolution functionResolution = new FunctionResolution(functionMetadataManager.getFunctionAndTypeResolver()); + + Bootstrap app = new Bootstrap( + // Specially use a testing HTTP client instead of a real one + new NativeExpressionsCommunicationModule(), + // Otherwise use the exact same module as the native row expression interpreter service + new NativeExpressionsModule(nodeManager, rowExpressionSerde, functionMetadataManager, functionResolution)); + + Injector injector = app + .noStrictConfig() + .doNotInitializeLogging() + .setRequiredConfigurationProperties(ImmutableMap.of()) + .quiet() + .initialize(); + return injector.getInstance(NativeExpressionOptimizerProvider.class).createOptimizer(); + } + + private static Injector getPrestoMainInjector(Metadata metadata, HandleResolver handleResolver) + { + Module module = binder -> { + // Installs the JSON codec + binder.install(new JsonModule()); + // Required to deserialize function handles + binder.install(new HandleJsonModule(handleResolver)); + // Required for this test in the JaxrsTestingHttpProcessor because the underlying object mapper + // must be the same as all other object mappers + binder.bind(JsonMapper.class); + + // These dependencies are needed to serialize and deserialize types (found in expressions) + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + binder.bind(FunctionAndTypeManager.class).toInstance(functionAndTypeManager); + binder.bind(TypeManager.class).toInstance(functionAndTypeManager); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + newSetBinder(binder, Type.class); + + // These dependencies are needed to serialize and deserialize blocks (found in constant values of expressions) + binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON); + newSetBinder(binder, BlockEncoding.class); + jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); + jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); + + // Create the serde which is used by the plugin to serialize and deserialize expressions + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); + }; + Bootstrap app = new Bootstrap(ImmutableList.of(module)); + Injector injector = app + .doNotInitializeLogging() + .quiet() + .initialize(); + return injector; + } +} diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java index 994d824195600..e518c40bb966a 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java @@ -13,16 +13,21 @@ */ package com.facebook.presto.sidecar; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory; import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory; import com.facebook.presto.testing.QueryRunner; import com.google.common.collect.ImmutableMap; +import java.io.IOException; + public class NativeSidecarPluginQueryRunnerUtils { private NativeSidecarPluginQueryRunnerUtils() {}; public static void setupNativeSidecarPlugin(QueryRunner queryRunner) + throws IOException { queryRunner.installCoordinatorPlugin(new NativeSidecarPlugin()); queryRunner.loadSessionPropertyProvider(NativeSystemSessionPropertyProviderFactory.NAME); @@ -32,5 +37,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) ImmutableMap.of( "supported-function-languages", "CPP", "function-implementation-type", "CPP")); + queryRunner.getExpressionManager().loadExpressionOptimizerFactory("native", "native", ImmutableMap.of()); + queryRunner.getPlanCheckerProviderManager().loadPlanCheckerProviders(new PluginNodeManager(new InMemoryNodeManager())); } } diff --git a/presto-openapi/src/main/resources/expressions.yaml b/presto-openapi/src/main/resources/expressions.yaml new file mode 100644 index 0000000000000..8d1e3ef5d9379 --- /dev/null +++ b/presto-openapi/src/main/resources/expressions.yaml @@ -0,0 +1,173 @@ +openapi: 3.0.0 +info: + title: Presto Expression API + description: API for evaluating and simplifying row expressions in Presto + version: "1" +servers: + - url: http://localhost:8080 + description: Presto endpoint when running locally +paths: + /v1/expressions: + post: + summary: Simplify the list of row expressions + description: This endpoint takes in a list of row expressions and attempts to simplify them to their simplest logical equivalent expression. + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RowExpressions' + required: true + responses: + '200': + description: Results + content: + application/json: + schema: + $ref: '#/components/schemas/RowExpressions' +components: + schemas: + RowExpressions: + type: array + maxItems: 100 + items: + $ref: "#/components/schemas/RowExpression" + RowExpression: + oneOf: + - $ref: "#/components/schemas/ConstantExpression" + - $ref: "#/components/schemas/VariableReferenceExpression" + - $ref: "#/components/schemas/InputReferenceExpression" + - $ref: "#/components/schemas/LambdaDefinitionExpression" + - $ref: "#/components/schemas/SpecialFormExpression" + - $ref: "#/components/schemas/CallExpression" + RowExpressionParent: + type: object + properties: + sourceLocation: + $ref: "#/components/schemas/SourceLocation" + SourceLocation: + description: The source location of the row expression in the original query, referencing the line and the column of the query. + type: object + properties: + line: + type: integer + column: + type: integer + ConstantExpression: + description: A constant expression is a row expression that represents a constant value. The value attribute is the constant value. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["constant"] + typeSignature: + type: string + valueBlock: + type: string + VariableReferenceExpression: + description: A variable reference expression is a row expression that represents a reference to a variable. The name attribute indicates the name of the variable. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["variable"] + typeSignature: + type: string + name: + type: string + InputReferenceExpression: + description: > + An input reference expression is a row expression that represents a reference to a column in the input schema. The field attribute indicates the index of the column in the + input schema. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["input"] + typeSignature: + type: string + field: + type: integer + LambdaDefinitionExpression: + description: > + A lambda definition expression is a row expression that represents a lambda function. The lambda function is defined by a list of argument types, a list of argument names, + and a body expression. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["lambda"] + argumentTypeSignatures: + type: array + items: + type: string + arguments: + type: array + items: + type: string + body: + $ref: "#/components/schemas/RowExpression" + SpecialFormExpression: + description: > + A special form expression is a row expression that represents a special language construct. The form attribute indicates the specific form of the special form, + which is a well known list, and with each having special semantics. The arguments attribute is a list of row expressions that are the arguments to the special form, with + each form taking in a specific number of arguments. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["special"] + form: + type: string + enum: ["IF","NULL_IF","SWITCH","WHEN","IS_NULL","COALESCE","IN","AND","OR","DEREFERENCE","ROW_CONSTRUCTOR","BIND"] + returnTypeSignature: + type: string + arguments: + type: array + items: + $ref: "#/components/schemas/RowExpression" + CallExpression: + description: > + A call expression is a row expression that represents a call to a function. The functionHandle attribute is an opaque handle to the function that is being called. + The arguments attribute is a list of row expressions that are the arguments to the function. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["call"] + displayName: + type: string + functionHandle: + $ref: "#/components/schemas/FunctionHandle" + returnTypeSignature: + type: string + arguments: + type: array + items: + $ref: "#/components/schemas/RowExpression" + FunctionHandle: + description: An opaque handle to a function that may be invoked. This is interpreted by the registered function namespace manager. + anyOf: + - $ref: "#/components/schemas/OpaqueFunctionHandle" + - $ref: "#/components/schemas/SqlFunctionHandle" + OpaqueFunctionHandle: + type: object + properties: {} # any opaque object may be passed and interpreted by a function namespace manager + SqlFunctionHandle: + type: object + properties: + functionId: + type: string + version: + type: string diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index e9d03e52a461f..7b10fe37c78e7 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -137,6 +137,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.plan.SimplePlanFragment; @@ -144,6 +145,7 @@ import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; import com.facebook.presto.spiller.GenericSpillerFactory; @@ -177,6 +179,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -306,6 +309,7 @@ protected void setup(Binder binder) jsonCodecBinder(binder).bindJsonCodec(BroadcastFileInfo.class); jsonCodecBinder(binder).bindJsonCodec(SimplePlanFragment.class); binder.bind(SimplePlanFragmentSerde.class).to(JsonCodecSimplePlanFragmentSerde.class).in(Scopes.SINGLETON); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); // smile codecs smileCodecBinder(binder).bindSmileCodec(TaskSource.class); @@ -352,6 +356,7 @@ protected void setup(Binder binder) // expression manager binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); // tracer provider managers binder.bind(TracerProviderManager.class).in(Scopes.SINGLETON); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java new file mode 100644 index 0000000000000..ab5381aa2556c --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +import com.facebook.presto.spi.relation.RowExpression; + +public interface RowExpressionSerde +{ + String serialize(RowExpression expression); + + RowExpression deserialize(String value); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java index c9a6f84aaaa1d..51b7cce149d8f 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.sql.planner; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; @@ -22,12 +23,14 @@ public class ExpressionOptimizerContext { private final NodeManager nodeManager; + private final RowExpressionSerde rowExpressionSerde; private final FunctionMetadataManager functionMetadataManager; private final StandardFunctionResolution functionResolution; - public ExpressionOptimizerContext(NodeManager nodeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) + public ExpressionOptimizerContext(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); } @@ -37,6 +40,11 @@ public NodeManager getNodeManager() return nodeManager; } + public RowExpressionSerde getRowExpressionSerde() + { + return rowExpressionSerde; + } + public FunctionMetadataManager getFunctionMetadataManager() { return functionMetadataManager; diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index bedd15a5d9e10..09b03cfb0409b 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -26,10 +26,12 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessDeniedException; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.Plan; @@ -61,6 +63,7 @@ import java.util.OptionalLong; import java.util.function.Consumer; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.sql.SqlFormatter.formatSql; import static com.facebook.presto.transaction.TransactionBuilder.transaction; @@ -576,7 +579,8 @@ private QueryExplainer getQueryExplainer() featuresConfig, new ExpressionOptimizerManager( new PluginNodeManager(new InMemoryNodeManager()), - queryRunner.getMetadata().getFunctionAndTypeManager())) + queryRunner.getMetadata().getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)))) .getPlanningTimeOptimizers(); return new QueryExplainer( optimizers,