diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d681f00cc2..fb1deb823f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,5 +21,5 @@ repos: - id: spotless-apply name: spotless (apply) language: system - entry: mvn --batch-mode spotless:apply + entry: mvn -q --batch-mode spotless:apply pass_filenames: false diff --git a/src/main/cpp/main/CMakeLists.txt b/src/main/cpp/main/CMakeLists.txt index 7e4dce4816..b9f6de2614 100644 --- a/src/main/cpp/main/CMakeLists.txt +++ b/src/main/cpp/main/CMakeLists.txt @@ -23,8 +23,7 @@ set(VELOX4J_SOURCES velox4j/eval/Evaluation.cc velox4j/eval/Evaluator.cc velox4j/iterator/BlockingQueue.cc - velox4j/vector/Vectors.cc - velox4j/shuffle/HashPartitioner.cc) + velox4j/vector/Vectors.cc) set(VELOX4J_INCLUDES ${CMAKE_CURRENT_LIST_DIR} ${velox_SOURCE_DIR} ${JniHelpersLib_SOURCE_DIR}) set(VELOX4J_DEPENDENCIES velox JniHelpers JNI::JNI) diff --git a/src/main/cpp/main/velox4j/jni/JniWrapper.cc b/src/main/cpp/main/velox4j/jni/JniWrapper.cc index 4845fe3456..4b71b7feec 100644 --- a/src/main/cpp/main/velox4j/jni/JniWrapper.cc +++ b/src/main/cpp/main/velox4j/jni/JniWrapper.cc @@ -31,7 +31,6 @@ #include "velox4j/jni/JniError.h" #include "velox4j/lifecycle/Session.h" #include "velox4j/query/QueryExecutor.h" -#include "velox4j/shuffle/HashPartitioner.h" #include "velox4j/vector/Vectors.h" namespace velox4j { @@ -340,32 +339,68 @@ jlongArray rowVectorPartitionByKeys( JNI_METHOD_END(nullptr) } -jlongArray rowVectorPartitionByKeyHashes( +jlongArray baseVectorWrapPartitions( JNIEnv* env, jobject javaThis, - jlong vid, - jintArray jKeyChannels, + jlong vectorId, + jintArray jPartitions, jint numPartitions) { JNI_METHOD_START auto session = sessionOf(env, javaThis); auto pool = session->memoryManager()->getVeloxPool( - "Hash Partition Memory Pool", memory::MemoryPool::Kind::kLeaf); - const auto inputRowVector = ObjectStore::retrieve(vid); + "Wrap Partitions Memory Pool", memory::MemoryPool::Kind::kLeaf); + VectorPtr vector = ObjectStore::retrieve(vectorId); + flattenVector(vector, vector->size()); + const auto inputNumRows = vector->size(); + auto safeArray = getIntArrayElementsSafe(env, jPartitions); - auto safeArray = getIntArrayElementsSafe(env, jKeyChannels); - std::vector keyChannels(safeArray.length()); - for (jsize i = 0; i < safeArray.length(); ++i) { - keyChannels[i] = safeArray.elems()[i]; + std::vector outVector(numPartitions, 0); + VELOX_USER_CHECK_EQ( + safeArray.length(), + inputNumRows, + "Expected one partition id per input row"); + + std::vector partitionSizes(numPartitions); + std::vector partitionRows(numPartitions); + std::vector rawPartitionRows(numPartitions); + std::fill(partitionSizes.begin(), partitionSizes.end(), 0); + + for (int row = 0; row < inputNumRows; ++row) { + const auto partitionId = static_cast(safeArray.elems()[row]); + VELOX_USER_CHECK_GE(partitionId, 0, "partition id must be non-negative"); + VELOX_USER_CHECK_LT( + partitionId, + numPartitions, + "partition id {} is out of range for {} partitions", + partitionId, + numPartitions); + ++partitionSizes[partitionId]; + } + + for (int partitionId = 0; partitionId < numPartitions; ++partitionId) { + partitionRows[partitionId] = + allocateIndices(partitionSizes[partitionId], pool); + rawPartitionRows[partitionId] = + partitionRows[partitionId]->asMutable(); } - HashPartitioner partitioner(std::move(keyChannels), numPartitions, pool); - auto partitions = partitioner.partition(inputRowVector); + std::vector partitionNextRowOffset(numPartitions); + std::fill(partitionNextRowOffset.begin(), partitionNextRowOffset.end(), 0); + for (int row = 0; row < inputNumRows; ++row) { + const auto partitionId = static_cast(safeArray.elems()[row]); + rawPartitionRows[partitionId][partitionNextRowOffset[partitionId]] = row; + ++partitionNextRowOffset[partitionId]; + } - std::vector outVector(numPartitions, 0); - for (int pid = 0; pid < numPartitions; ++pid) { - if (partitions[pid] != nullptr) { - outVector[pid] = session->objectStore()->save(partitions[pid]); + for (int partitionId = 0; partitionId < numPartitions; ++partitionId) { + const vector_size_t partitionSize = partitionSizes[partitionId]; + if (partitionSize == 0) { + continue; } + VectorPtr partitionVector = partitionSize == inputNumRows + ? vector + : wrapInDictionary(partitionSize, partitionRows[partitionId], vector); + outVector[partitionId] = session->objectStore()->save(partitionVector); } const jlongArray out = env->NewLongArray(outVector.size()); @@ -375,6 +410,54 @@ jlongArray rowVectorPartitionByKeyHashes( JNI_METHOD_END(nullptr) } +jlong createPartitionFunction( + JNIEnv* env, + jobject javaThis, + jstring specJson, + jint numPartitions, + jboolean localExchange) { + JNI_METHOD_START + auto session = sessionOf(env, javaThis); + auto serdePool = session->memoryManager()->getVeloxPool( + "Partition Function Serde Memory Pool", memory::MemoryPool::Kind::kLeaf); + spotify::jni::JavaString jSpecJson{env, specJson}; + auto dynamic = folly::parseJson(jSpecJson.get()); + auto spec = ISerializable::deserialize( + dynamic, serdePool); + auto function = std::shared_ptr( + spec->create(numPartitions, static_cast(localExchange)).release()); + return session->objectStore()->save(function); + JNI_METHOD_END(-1) +} + +jintArray partitionFunctionPartition( + JNIEnv* env, + jobject javaThis, + jlong partitionFunctionId, + jlong rowVectorId) { + JNI_METHOD_START + auto function = + ObjectStore::retrieve(partitionFunctionId); + const auto inputRowVector = ObjectStore::retrieve(rowVectorId); + + std::vector partitions; + auto singlePartition = function->partition(*inputRowVector, partitions); + std::vector outVector; + if (singlePartition.has_value()) { + outVector.assign( + inputRowVector->size(), static_cast(singlePartition.value())); + } else { + outVector.reserve(partitions.size()); + for (const auto partition : partitions) { + outVector.push_back(static_cast(partition)); + } + } + const jintArray out = env->NewIntArray(outVector.size()); + env->SetIntArrayRegion(out, 0, outVector.size(), outVector.data()); + return out; + JNI_METHOD_END(nullptr) +} + jlong createSelectivityVector(JNIEnv* env, jobject javaThis, jint length) { JNI_METHOD_START auto vector = @@ -606,6 +689,29 @@ void JniWrapper::initialize(JNIEnv* env) { kTypeArray(kTypeInt), kTypeInt, nullptr); + addNativeMethod( + "baseVectorWrapPartitions", + (void*)baseVectorWrapPartitions, + kTypeArray(kTypeLong), + kTypeLong, + kTypeArray(kTypeInt), + kTypeInt, + nullptr); + addNativeMethod( + "createPartitionFunction", + (void*)createPartitionFunction, + kTypeLong, + kTypeString, + kTypeInt, + kTypeBool, + nullptr); + addNativeMethod( + "partitionFunctionPartition", + (void*)partitionFunctionPartition, + kTypeArray(kTypeInt), + kTypeLong, + kTypeLong, + nullptr); addNativeMethod( "createSelectivityVector", (void*)createSelectivityVector, @@ -633,14 +739,6 @@ void JniWrapper::initialize(JNIEnv* env) { kTypeString, kTypeString, nullptr); - addNativeMethod( - "rowVectorPartitionByKeyHashes", - (void*)rowVectorPartitionByKeyHashes, - kTypeArray(kTypeLong), - kTypeLong, - kTypeArray(kTypeInt), - kTypeInt, - nullptr); addNativeMethod( "createUpIteratorWithExternalStream", (void*)createUpIteratorWithExternalStream, diff --git a/src/main/cpp/main/velox4j/vector/Vectors.cc b/src/main/cpp/main/velox4j/vector/Vectors.cc index 32fffbd368..c2179e8f3e 100644 --- a/src/main/cpp/main/velox4j/vector/Vectors.cc +++ b/src/main/cpp/main/velox4j/vector/Vectors.cc @@ -14,6 +14,7 @@ #include "velox4j/vector/Vectors.h" +#include #include #include @@ -75,4 +76,15 @@ void flattenVector(VectorPtr& vector, vector_size_t targetSize) { vector = vector->slice(0, targetSize); } } + +VectorPtr wrapInDictionary( + vector_size_t size, + const BufferPtr& indices, + const VectorPtr& vector) { + VELOX_CHECK_NOT_NULL(vector); + if (auto rowVector = std::dynamic_pointer_cast(vector)) { + return exec::wrap(size, indices, rowVector); + } + return BaseVector::wrapInDictionary(nullptr, indices, size, vector); +} } // namespace velox4j diff --git a/src/main/cpp/main/velox4j/vector/Vectors.h b/src/main/cpp/main/velox4j/vector/Vectors.h index a916cdce1f..3402f56686 100644 --- a/src/main/cpp/main/velox4j/vector/Vectors.h +++ b/src/main/cpp/main/velox4j/vector/Vectors.h @@ -27,4 +27,9 @@ namespace velox4j { void flattenVector( facebook::velox::VectorPtr& vector, facebook::velox::vector_size_t targetSize); + +facebook::velox::VectorPtr wrapInDictionary( + facebook::velox::vector_size_t size, + const facebook::velox::BufferPtr& indices, + const facebook::velox::VectorPtr& vector); } // namespace velox4j diff --git a/src/main/java/org/boostscale/velox4j/data/RowVectors.java b/src/main/java/org/boostscale/velox4j/data/RowVectors.java index d7e9cde920..9225915bc6 100644 --- a/src/main/java/org/boostscale/velox4j/data/RowVectors.java +++ b/src/main/java/org/boostscale/velox4j/data/RowVectors.java @@ -14,12 +14,13 @@ package org.boostscale.velox4j.data; import java.util.List; +import java.util.stream.Collectors; import com.google.common.base.Preconditions; import org.boostscale.velox4j.jni.JniApi; -import org.boostscale.velox4j.plan.partition.HashPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.PartitionFunctionSpec; +import org.boostscale.velox4j.partition.PartitionFunction; +import org.boostscale.velox4j.partition.PartitionFunctionSpec; public class RowVectors { private final JniApi jniApi; @@ -36,28 +37,27 @@ public List partitionByKeys(RowVector rowVector, List keyCha return jniApi.rowVectorPartitionByKeys(rowVector, keyChannels, 128); } - /** - * Partitions the input RowVector into a list of RowVectors where each one has the same keys - * defined by the key indices of `keyChannels`, with a configurable maximum number of partitions. - */ - public List partitionByKeys( - RowVector rowVector, List keyChannels, int maxPartitions) { - Preconditions.checkArgument( - maxPartitions > 0, "maxPartitions must be positive, got %s", maxPartitions); - return jniApi.rowVectorPartitionByKeys(rowVector, keyChannels, maxPartitions); - } - /** * Partitions a RowVector into numPartitions groups using the given partition function spec. * Returns a list of size numPartitions where index i contains rows for partition i (null if * empty). - * - *

Currently only {@link HashPartitionFunctionSpec} is supported. */ public List partitionBySpec( RowVector rowVector, PartitionFunctionSpec spec, int numPartitions) { Preconditions.checkArgument( numPartitions > 0, "numPartitions must be positive, got %s", numPartitions); - return jniApi.rowVectorPartitionBySpec(rowVector, spec, numPartitions); + try (PartitionFunction partitionFunction = + jniApi.createPartitionFunction(spec, numPartitions, false)) { + final int[] partitions = jniApi.partitionFunctionPartition(partitionFunction, rowVector); + return jniApi.baseVectorWrapPartitions(rowVector, partitions, numPartitions).stream() + .map( + vector -> { + if (vector == null) { + return null; + } + return vector.asRowVector(); + }) + .collect(Collectors.toList()); + } } } diff --git a/src/main/java/org/boostscale/velox4j/jni/JniApi.java b/src/main/java/org/boostscale/velox4j/jni/JniApi.java index 6fc89a7782..4a71ef08c9 100644 --- a/src/main/java/org/boostscale/velox4j/jni/JniApi.java +++ b/src/main/java/org/boostscale/velox4j/jni/JniApi.java @@ -18,7 +18,6 @@ import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowSchema; @@ -30,8 +29,8 @@ import org.boostscale.velox4j.iterator.DownIterator; import org.boostscale.velox4j.iterator.GenericUpIterator; import org.boostscale.velox4j.iterator.UpIterator; -import org.boostscale.velox4j.plan.partition.HashPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.PartitionFunctionSpec; +import org.boostscale.velox4j.partition.PartitionFunction; +import org.boostscale.velox4j.partition.PartitionFunctionSpec; import org.boostscale.velox4j.query.Query; import org.boostscale.velox4j.query.QueryExecutor; import org.boostscale.velox4j.query.SerialTask; @@ -130,10 +129,6 @@ public BaseVector baseVectorSlice(BaseVector vector, int offset, int length) { return baseVectorWrap(jni.baseVectorSlice(vector.id(), offset, length)); } - public List rowVectorPartitionByKeys(RowVector vector, List keyChannels) { - return rowVectorPartitionByKeys(vector, keyChannels, 128); - } - public List rowVectorPartitionByKeys( RowVector vector, List keyChannels, int maxPartitions) { final int[] keyChannelArray = keyChannels.stream().mapToInt(i -> i).toArray(); @@ -144,23 +139,28 @@ public List rowVectorPartitionByKeys( .collect(Collectors.toList()); } - public List rowVectorPartitionBySpec( - RowVector vector, PartitionFunctionSpec spec, int numPartitions) { - Preconditions.checkArgument( - spec instanceof HashPartitionFunctionSpec, - "Only HashPartitionFunctionSpec is supported, got %s", - spec.getClass().getSimpleName()); - HashPartitionFunctionSpec hashSpec = (HashPartitionFunctionSpec) spec; - final int[] keyChannelArray = hashSpec.getKeyChannels().stream().mapToInt(i -> i).toArray(); - final long[] vids = - jni.rowVectorPartitionByKeyHashes(vector.id(), keyChannelArray, numPartitions); + public PartitionFunction createPartitionFunction( + PartitionFunctionSpec spec, int numPartitions, boolean localExchange) { + final String specJson = Serde.toPrettyJson(spec); + return new PartitionFunction( + this, jni.createPartitionFunction(specJson, numPartitions, localExchange)); + } + + public int[] partitionFunctionPartition( + PartitionFunction partitionFunction, RowVector rowVector) { + return jni.partitionFunctionPartition(partitionFunction.id(), rowVector.id()); + } + + public List baseVectorWrapPartitions( + BaseVector vector, int[] partitions, int numPartitions) { + final long[] vids = jni.baseVectorWrapPartitions(vector.id(), partitions, numPartitions); return Arrays.stream(vids) .mapToObj( vid -> { if (vid == 0) { return null; } - return baseVectorWrap(vid).asRowVector(); + return baseVectorWrap(vid); }) .collect(Collectors.toList()); } diff --git a/src/main/java/org/boostscale/velox4j/jni/JniWrapper.java b/src/main/java/org/boostscale/velox4j/jni/JniWrapper.java index 4290e07f08..7ede093af3 100644 --- a/src/main/java/org/boostscale/velox4j/jni/JniWrapper.java +++ b/src/main/java/org/boostscale/velox4j/jni/JniWrapper.java @@ -75,7 +75,11 @@ public long sessionId() { native long[] rowVectorPartitionByKeys(long id, int[] keyChannels, int maxPartitions); - native long[] rowVectorPartitionByKeyHashes(long id, int[] keyChannels, int numPartitions); + native long createPartitionFunction(String specJson, int numPartitions, boolean localExchange); + + native int[] partitionFunctionPartition(long partitionFunctionId, long rowVectorId); + + native long[] baseVectorWrapPartitions(long vectorId, int[] partitions, int numPartitions); native long createSelectivityVector(int length); diff --git a/src/main/java/org/boostscale/velox4j/jni/LocalSession.java b/src/main/java/org/boostscale/velox4j/jni/LocalSession.java index a8063e44b1..3c599d76f9 100644 --- a/src/main/java/org/boostscale/velox4j/jni/LocalSession.java +++ b/src/main/java/org/boostscale/velox4j/jni/LocalSession.java @@ -19,6 +19,7 @@ import org.boostscale.velox4j.data.RowVectors; import org.boostscale.velox4j.data.SelectivityVectors; import org.boostscale.velox4j.eval.Evaluations; +import org.boostscale.velox4j.partition.PartitionFunctions; import org.boostscale.velox4j.query.Queries; import org.boostscale.velox4j.serializable.ISerializables; import org.boostscale.velox4j.session.Session; @@ -91,4 +92,9 @@ public ISerializables iSerializableOps() { public Variants variantOps() { return new Variants(jniApi()); } + + @Override + public PartitionFunctions partitionFunctionOps() { + return new PartitionFunctions(jniApi()); + } } diff --git a/src/main/java/org/boostscale/velox4j/plan/partition/GatherPartitionFunctionSpec.java b/src/main/java/org/boostscale/velox4j/partition/GatherPartitionFunctionSpec.java similarity index 94% rename from src/main/java/org/boostscale/velox4j/plan/partition/GatherPartitionFunctionSpec.java rename to src/main/java/org/boostscale/velox4j/partition/GatherPartitionFunctionSpec.java index c48f469b43..6d1beb21f2 100644 --- a/src/main/java/org/boostscale/velox4j/plan/partition/GatherPartitionFunctionSpec.java +++ b/src/main/java/org/boostscale/velox4j/partition/GatherPartitionFunctionSpec.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.boostscale.velox4j.plan.partition; +package org.boostscale.velox4j.partition; import com.fasterxml.jackson.annotation.JsonCreator; diff --git a/src/main/java/org/boostscale/velox4j/plan/partition/HashPartitionFunctionSpec.java b/src/main/java/org/boostscale/velox4j/partition/HashPartitionFunctionSpec.java similarity index 93% rename from src/main/java/org/boostscale/velox4j/plan/partition/HashPartitionFunctionSpec.java rename to src/main/java/org/boostscale/velox4j/partition/HashPartitionFunctionSpec.java index 96ce11636e..876a9b3dba 100644 --- a/src/main/java/org/boostscale/velox4j/plan/partition/HashPartitionFunctionSpec.java +++ b/src/main/java/org/boostscale/velox4j/partition/HashPartitionFunctionSpec.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.boostscale.velox4j.plan.partition; +package org.boostscale.velox4j.partition; import java.util.Collections; import java.util.List; @@ -19,6 +19,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonGetter; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; import org.boostscale.velox4j.expression.ConstantTypedExpr; import org.boostscale.velox4j.type.RowType; @@ -39,7 +40,7 @@ public HashPartitionFunctionSpec( @JsonProperty("constants") List constants) { this.inputType = inputType; this.keyChannels = keyChannels; - this.constants = constants; + this.constants = Preconditions.checkNotNull(constants); } public HashPartitionFunctionSpec(RowType inputType, List keyChannels) { diff --git a/src/main/java/org/boostscale/velox4j/partition/PartitionFunction.java b/src/main/java/org/boostscale/velox4j/partition/PartitionFunction.java new file mode 100644 index 0000000000..fcb1cefd76 --- /dev/null +++ b/src/main/java/org/boostscale/velox4j/partition/PartitionFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.boostscale.velox4j.partition; + +import org.boostscale.velox4j.data.RowVector; +import org.boostscale.velox4j.jni.CppObject; +import org.boostscale.velox4j.jni.JniApi; + +/** Runtime Velox partition function created from a {@link PartitionFunctionSpec}. */ +public class PartitionFunction implements CppObject { + private final JniApi jniApi; + private final long id; + + public PartitionFunction(JniApi jniApi, long id) { + this.jniApi = jniApi; + this.id = id; + } + + @Override + public long id() { + return id; + } + + public int[] partition(RowVector rowVector) { + return jniApi.partitionFunctionPartition(this, rowVector); + } +} diff --git a/src/main/java/org/boostscale/velox4j/plan/partition/PartitionFunctionSpec.java b/src/main/java/org/boostscale/velox4j/partition/PartitionFunctionSpec.java similarity index 94% rename from src/main/java/org/boostscale/velox4j/plan/partition/PartitionFunctionSpec.java rename to src/main/java/org/boostscale/velox4j/partition/PartitionFunctionSpec.java index 5abd3bccb9..200c58e8ee 100644 --- a/src/main/java/org/boostscale/velox4j/plan/partition/PartitionFunctionSpec.java +++ b/src/main/java/org/boostscale/velox4j/partition/PartitionFunctionSpec.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.boostscale.velox4j.plan.partition; +package org.boostscale.velox4j.partition; import org.boostscale.velox4j.plan.LocalPartitionNode; import org.boostscale.velox4j.serializable.ISerializable; diff --git a/src/main/java/org/boostscale/velox4j/partition/PartitionFunctions.java b/src/main/java/org/boostscale/velox4j/partition/PartitionFunctions.java new file mode 100644 index 0000000000..0a7c22e1a7 --- /dev/null +++ b/src/main/java/org/boostscale/velox4j/partition/PartitionFunctions.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.boostscale.velox4j.partition; + +import org.boostscale.velox4j.jni.JniApi; + +public class PartitionFunctions { + private final JniApi jniApi; + + public PartitionFunctions(JniApi jniApi) { + this.jniApi = jniApi; + } + + public PartitionFunction create( + PartitionFunctionSpec spec, int numPartitions, boolean localExchange) { + return jniApi.createPartitionFunction(spec, numPartitions, localExchange); + } +} diff --git a/src/main/java/org/boostscale/velox4j/plan/partition/RoundRobinPartitionFunctionSpec.java b/src/main/java/org/boostscale/velox4j/partition/RoundRobinPartitionFunctionSpec.java similarity index 94% rename from src/main/java/org/boostscale/velox4j/plan/partition/RoundRobinPartitionFunctionSpec.java rename to src/main/java/org/boostscale/velox4j/partition/RoundRobinPartitionFunctionSpec.java index 9ead7cb7f5..f2465804fd 100644 --- a/src/main/java/org/boostscale/velox4j/plan/partition/RoundRobinPartitionFunctionSpec.java +++ b/src/main/java/org/boostscale/velox4j/partition/RoundRobinPartitionFunctionSpec.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.boostscale.velox4j.plan.partition; +package org.boostscale.velox4j.partition; import com.fasterxml.jackson.annotation.JsonCreator; diff --git a/src/main/java/org/boostscale/velox4j/plan/LocalPartitionNode.java b/src/main/java/org/boostscale/velox4j/plan/LocalPartitionNode.java index e1dc2e8cab..e1d49906b1 100644 --- a/src/main/java/org/boostscale/velox4j/plan/LocalPartitionNode.java +++ b/src/main/java/org/boostscale/velox4j/plan/LocalPartitionNode.java @@ -21,7 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; -import org.boostscale.velox4j.plan.partition.PartitionFunctionSpec; +import org.boostscale.velox4j.partition.PartitionFunctionSpec; /** * In-process repartitioning plan node. Partitions data using the specified partition function. Can diff --git a/src/main/java/org/boostscale/velox4j/serializable/ISerializableRegistry.java b/src/main/java/org/boostscale/velox4j/serializable/ISerializableRegistry.java index a4ff421800..01924f3825 100644 --- a/src/main/java/org/boostscale/velox4j/serializable/ISerializableRegistry.java +++ b/src/main/java/org/boostscale/velox4j/serializable/ISerializableRegistry.java @@ -26,6 +26,9 @@ import org.boostscale.velox4j.expression.InputTypedExpr; import org.boostscale.velox4j.expression.LambdaTypedExpr; import org.boostscale.velox4j.filter.AlwaysTrue; +import org.boostscale.velox4j.partition.GatherPartitionFunctionSpec; +import org.boostscale.velox4j.partition.HashPartitionFunctionSpec; +import org.boostscale.velox4j.partition.RoundRobinPartitionFunctionSpec; import org.boostscale.velox4j.plan.AggregationNode; import org.boostscale.velox4j.plan.FilterNode; import org.boostscale.velox4j.plan.HashJoinNode; @@ -37,9 +40,6 @@ import org.boostscale.velox4j.plan.TableWriteNode; import org.boostscale.velox4j.plan.ValuesNode; import org.boostscale.velox4j.plan.WindowNode; -import org.boostscale.velox4j.plan.partition.GatherPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.HashPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.RoundRobinPartitionFunctionSpec; import org.boostscale.velox4j.query.Query; import org.boostscale.velox4j.serde.Serde; import org.boostscale.velox4j.serde.SerdeRegistry; diff --git a/src/main/java/org/boostscale/velox4j/session/Session.java b/src/main/java/org/boostscale/velox4j/session/Session.java index 85fcc3f13f..3920fe6c4d 100644 --- a/src/main/java/org/boostscale/velox4j/session/Session.java +++ b/src/main/java/org/boostscale/velox4j/session/Session.java @@ -20,6 +20,7 @@ import org.boostscale.velox4j.data.SelectivityVectors; import org.boostscale.velox4j.eval.Evaluations; import org.boostscale.velox4j.jni.CppObject; +import org.boostscale.velox4j.partition.PartitionFunctions; import org.boostscale.velox4j.query.Queries; import org.boostscale.velox4j.serializable.ISerializables; import org.boostscale.velox4j.variant.Variants; @@ -68,4 +69,7 @@ public interface Session extends CppObject { /** APIs in relation to {@link org.boostscale.velox4j.variant.Variant}. */ Variants variantOps(); + + /** APIs in relation to {@link org.boostscale.velox4j.partition.PartitionFunction}. */ + PartitionFunctions partitionFunctionOps(); } diff --git a/src/test/java/org/boostscale/velox4j/data/HashPartitionTest.java b/src/test/java/org/boostscale/velox4j/data/RowVectorPartitionTest.java similarity index 98% rename from src/test/java/org/boostscale/velox4j/data/HashPartitionTest.java rename to src/test/java/org/boostscale/velox4j/data/RowVectorPartitionTest.java index bd3522fd6a..f758224717 100644 --- a/src/test/java/org/boostscale/velox4j/data/HashPartitionTest.java +++ b/src/test/java/org/boostscale/velox4j/data/RowVectorPartitionTest.java @@ -21,12 +21,12 @@ import org.boostscale.velox4j.Velox4j; import org.boostscale.velox4j.memory.BytesAllocationListener; import org.boostscale.velox4j.memory.MemoryManager; -import org.boostscale.velox4j.plan.partition.HashPartitionFunctionSpec; +import org.boostscale.velox4j.partition.HashPartitionFunctionSpec; import org.boostscale.velox4j.session.Session; import org.boostscale.velox4j.test.Velox4jTests; import org.boostscale.velox4j.type.RowType; -public class HashPartitionTest { +public class RowVectorPartitionTest { private static BytesAllocationListener allocationListener; private static MemoryManager memoryManager; private static Session session; diff --git a/src/test/java/org/boostscale/velox4j/partition/PartitionFunctionTest.java b/src/test/java/org/boostscale/velox4j/partition/PartitionFunctionTest.java new file mode 100644 index 0000000000..3e8ddefc01 --- /dev/null +++ b/src/test/java/org/boostscale/velox4j/partition/PartitionFunctionTest.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.boostscale.velox4j.partition; + +import com.google.common.collect.ImmutableList; +import org.junit.*; + +import org.boostscale.velox4j.Velox4j; +import org.boostscale.velox4j.data.BaseVectorTests; +import org.boostscale.velox4j.data.RowVector; +import org.boostscale.velox4j.memory.BytesAllocationListener; +import org.boostscale.velox4j.memory.MemoryManager; +import org.boostscale.velox4j.session.Session; +import org.boostscale.velox4j.test.Velox4jTests; +import org.boostscale.velox4j.type.RowType; + +public class PartitionFunctionTest { + private static BytesAllocationListener allocationListener; + private static MemoryManager memoryManager; + private static Session session; + + @BeforeClass + public static void beforeClass() throws Exception { + Velox4jTests.ensureInitialized(); + allocationListener = new BytesAllocationListener(); + memoryManager = Velox4j.newMemoryManager(allocationListener); + } + + @AfterClass + public static void afterClass() throws Exception { + memoryManager.close(); + Assert.assertEquals(0, allocationListener.currentBytes()); + } + + @Before + public void setUp() throws Exception { + session = Velox4j.newSession(memoryManager); + } + + @After + public void tearDown() throws Exception { + session.close(); + } + + @Test + public void testHashPartitionFunctionCreateAndPartition() { + final RowVector input = BaseVectorTests.newSampleRowVector(session); + final PartitionFunctionSpec spec = + new HashPartitionFunctionSpec((RowType) input.getType(), ImmutableList.of(0)); + + final int[] result1; + final int[] result2; + try (PartitionFunction function1 = session.partitionFunctionOps().create(spec, 4, false); + PartitionFunction function2 = session.partitionFunctionOps().create(spec, 4, false)) { + result1 = function1.partition(input); + result2 = function2.partition(input); + } + + Assert.assertEquals(input.getSize(), result1.length); + Assert.assertArrayEquals(result1, result2); + } + + @Test + public void testRoundRobinPartitionFunctionCreateAndPartition() { + final RowVector input = BaseVectorTests.newSampleRowVector(session); + + final int[] result; + try (PartitionFunction function = + session.partitionFunctionOps().create(new RoundRobinPartitionFunctionSpec(), 4, true)) { + result = function.partition(input); + } + + Assert.assertTrue(result.length == 1 || result.length == input.getSize()); + for (int partition : result) { + Assert.assertTrue(partition >= 0); + Assert.assertTrue(partition < 4); + } + } +} diff --git a/src/test/java/org/boostscale/velox4j/query/QueryTest.java b/src/test/java/org/boostscale/velox4j/query/QueryTest.java index 5ad41b30ac..512564b4f6 100644 --- a/src/test/java/org/boostscale/velox4j/query/QueryTest.java +++ b/src/test/java/org/boostscale/velox4j/query/QueryTest.java @@ -45,8 +45,8 @@ import org.boostscale.velox4j.join.JoinType; import org.boostscale.velox4j.memory.BytesAllocationListener; import org.boostscale.velox4j.memory.MemoryManager; +import org.boostscale.velox4j.partition.GatherPartitionFunctionSpec; import org.boostscale.velox4j.plan.*; -import org.boostscale.velox4j.plan.partition.GatherPartitionFunctionSpec; import org.boostscale.velox4j.serde.Serde; import org.boostscale.velox4j.session.Session; import org.boostscale.velox4j.sort.SortOrder; diff --git a/src/test/java/org/boostscale/velox4j/query/ShuffleJoinTest.java b/src/test/java/org/boostscale/velox4j/query/ShuffleJoinTest.java index 78bbd7012b..e1a487adb9 100644 --- a/src/test/java/org/boostscale/velox4j/query/ShuffleJoinTest.java +++ b/src/test/java/org/boostscale/velox4j/query/ShuffleJoinTest.java @@ -34,9 +34,9 @@ import org.boostscale.velox4j.join.JoinType; import org.boostscale.velox4j.memory.BytesAllocationListener; import org.boostscale.velox4j.memory.MemoryManager; +import org.boostscale.velox4j.partition.HashPartitionFunctionSpec; import org.boostscale.velox4j.plan.HashJoinNode; import org.boostscale.velox4j.plan.TableScanNode; -import org.boostscale.velox4j.plan.partition.HashPartitionFunctionSpec; import org.boostscale.velox4j.session.Session; import org.boostscale.velox4j.test.Velox4jTests; import org.boostscale.velox4j.test.dataset.TestDataFile; diff --git a/src/test/java/org/boostscale/velox4j/serde/PartitionFunctionSpecTest.java b/src/test/java/org/boostscale/velox4j/serde/PartitionFunctionSpecTest.java index e7c31506ca..a5c5a61441 100644 --- a/src/test/java/org/boostscale/velox4j/serde/PartitionFunctionSpecTest.java +++ b/src/test/java/org/boostscale/velox4j/serde/PartitionFunctionSpecTest.java @@ -17,9 +17,9 @@ import org.junit.BeforeClass; import org.junit.Test; -import org.boostscale.velox4j.plan.partition.GatherPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.HashPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.RoundRobinPartitionFunctionSpec; +import org.boostscale.velox4j.partition.GatherPartitionFunctionSpec; +import org.boostscale.velox4j.partition.HashPartitionFunctionSpec; +import org.boostscale.velox4j.partition.RoundRobinPartitionFunctionSpec; import org.boostscale.velox4j.test.Velox4jTests; import org.boostscale.velox4j.type.IntegerType; import org.boostscale.velox4j.type.RowType; diff --git a/src/test/java/org/boostscale/velox4j/serde/PlanNodeSerdeTest.java b/src/test/java/org/boostscale/velox4j/serde/PlanNodeSerdeTest.java index 935e247c2a..157ba47476 100644 --- a/src/test/java/org/boostscale/velox4j/serde/PlanNodeSerdeTest.java +++ b/src/test/java/org/boostscale/velox4j/serde/PlanNodeSerdeTest.java @@ -26,6 +26,9 @@ import org.boostscale.velox4j.join.JoinType; import org.boostscale.velox4j.memory.BytesAllocationListener; import org.boostscale.velox4j.memory.MemoryManager; +import org.boostscale.velox4j.partition.GatherPartitionFunctionSpec; +import org.boostscale.velox4j.partition.HashPartitionFunctionSpec; +import org.boostscale.velox4j.partition.RoundRobinPartitionFunctionSpec; import org.boostscale.velox4j.plan.AggregationNode; import org.boostscale.velox4j.plan.FilterNode; import org.boostscale.velox4j.plan.HashJoinNode; @@ -37,9 +40,6 @@ import org.boostscale.velox4j.plan.TableWriteNode; import org.boostscale.velox4j.plan.ValuesNode; import org.boostscale.velox4j.plan.WindowNode; -import org.boostscale.velox4j.plan.partition.GatherPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.HashPartitionFunctionSpec; -import org.boostscale.velox4j.plan.partition.RoundRobinPartitionFunctionSpec; import org.boostscale.velox4j.session.Session; import org.boostscale.velox4j.sort.SortOrder; import org.boostscale.velox4j.test.Velox4jTests;