diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 4feab311c76..f53e6c91d0a 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -154,6 +154,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_CATEGORICAL_MASK("getCategoricalMask", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), GLOVE("glove", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 251f773a18c..0419a0e056b 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -197,6 +197,8 @@ public enum Opcodes { TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin), + GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary), + //Ternary instruction opcodes PM("+*", InstructionType.Ternary), MINUSMULT("-*", InstructionType.Ternary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..c2832aeb8cd 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -639,6 +639,7 @@ public enum OpOp2 { MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) MINUS1_MULT(false), //1-X*Y + GET_CATEGORICAL_MASK(false), // get transformation mask QUANTIZE_COMPRESS(false), //quantization-fused compression UNION_DISTINCT(false); diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index a3ddb45ea6d..73e3c5fac86 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -763,8 +763,8 @@ protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - DataType dt1 = getInput().get(0).getDataType(); - DataType dt2 = getInput().get(1).getDataType(); + final DataType dt1 = getInput(0).getDataType(); + final DataType dt2 = getInput(1).getDataType(); if( _etypeForced != null ) { setExecType(_etypeForced); @@ -812,18 +812,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { checkAndSetInvalidCPDimsAndSize(); } - //spark-specific decision refinement (execute unary scalar w/ spark input and + // spark-specific decision refinement (execute unary scalar w/ spark input and // single parent also in spark because it's likely cheap and reduces intermediates) - if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED && - getDataType().isMatrix() // output should be a matrix - && (dt1.isScalar() || dt2.isScalar()) // one side should be scalar - && supportsMatrixScalarOperations() // scalar operations - && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint - && getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent - && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec - && getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) { - // pull unary scalar operation into spark - _etype = ExecType.SPARK; + if(transitive // we allow transitive Spark operations. continue sequences of spark operations + && _etype == ExecType.CP // The instruction is currently in CP + && _etypeForced != ExecType.CP // not forced CP + && _etypeForced != ExecType.FED // not federated + && (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame + ) { + final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize(); + final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize(); + final boolean left = v1 == true; // left side is the vector or scalar + final Hop sparkIn = getInput(left ? 1 : 0); + if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar. + && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation + && sparkIn.getParent().size() == 1 // only one parent + && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec + && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && !(sparkIn instanceof DataOp) // input is not checkpoint + ) { + // pull operation into spark + _etype = ExecType.SPARK; + } } if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE && @@ -853,7 +863,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + + if( op == OpOp2.GET_CATEGORICAL_MASK) + _etype = ExecType.CP; + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 86749d44c1c..675fbb380a1 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -1045,6 +1045,12 @@ public final String toString() { // ======================================================================================== + protected boolean isScalarOrVectorBellowBlockSize(){ + return getDataType().isScalar() || (dimsKnown() && + (( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize()) + || _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize())); + } + protected boolean isVector() { return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) ); } @@ -1629,6 +1635,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) { lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this)); } + protected boolean hasSparkOutput(){ + return (this.optFindExecType() == ExecType.SPARK + || (this instanceof DataOp && ((DataOp)this).hasOnlyRDD())); + } + /** * Set parse information. * diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 34da36dd13c..e16896b869b 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); + + if(getDataType() == DataType.FRAME) + return OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + else + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override @@ -463,6 +467,13 @@ public boolean isMetadataOperation() { || _op == OpOp1.CAST_AS_LIST; } + private boolean isDisallowedSparkOps(){ + return isCumulativeUnaryOperation() + || isCastUnaryOperation() + || _op==OpOp1.MEDIAN + || _op==OpOp1.IQM; + } + @Override protected ExecType optFindExecType(boolean transitive) { @@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto checkAndSetInvalidCPDimsAndSize(); } + //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) - if( _etype == ExecType.CP && _etypeForced != ExecType.CP - && getInput().get(0).optFindExecType() == ExecType.SPARK - && getDataType().isMatrix() - && !isCumulativeUnaryOperation() && !isCastUnaryOperation() - && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM - && !(getInput().get(0) instanceof DataOp) //input is not checkpoint - && getInput().get(0).getParent().size()==1 ) //unary is only parent - { + if(_etype == ExecType.CP // currently CP instruction + && _etype != ExecType.SPARK /// currently not SP. + && _etypeForced != ExecType.CP // not forced as CP instruction + && getInput(0).hasSparkOutput() // input is a spark instruction + && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame + && !isDisallowedSparkOps() // is invalid spark instruction + // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint + // && getInput(0).getParent().size() <= 1// unary is only parent + ) { //pull unary operation into spark _etype = ExecType.SPARK; } + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent } else { setRequiresRecompileIfNecessary(); } - + return _etype; } diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 28f6949f722..ab0c7993b4e 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -2018,6 +2018,15 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV else raiseValidateError("The compress or decompress instruction is not allowed in dml scripts"); break; + case GET_CATEGORICAL_MASK: + checkNumParameters(2); + checkFrameParam(getFirstExpr()); + checkScalarParam(getSecondExpr()); + output.setDataType(DataType.MATRIX); + output.setDimensions(1, -1); + output.setBlocksize( id.getBlocksize()); + output.setValueType(ValueType.FP64); + break; case QUANTIZE_COMPRESS: if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { checkNumParameters(2); @@ -2383,6 +2392,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS); } } + + protected void checkFrameParam(Expression e) { + if(e.getOutput().getDataType() != DataType.FRAME) { + raiseValidateError("Expecting frame parameter for function " + getOpCode(), false, + LanguageErrorCodes.UNSUPPORTED_PARAMETERS); + } + } protected void checkMatrixScalarParam(Expression e) { //always unconditional if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) { diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 092fbffe36d..949e67a62cc 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2821,6 +2821,9 @@ else if ( in.length == 2 ) DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr); break; + case GET_CATEGORICAL_MASK: + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, ValueType.FP64, OpOp2.GET_CATEGORICAL_MASK, expr, expr2); + break; default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 48637595741..ef5d2630390 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -58,12 +58,14 @@ import org.apache.sysds.runtime.compress.lib.CLALibMMChain; import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; -import org.apache.sysds.runtime.compress.lib.CLALibReplace; +import org.apache.sysds.runtime.compress.lib.CLALibRemoveEmpty; import org.apache.sysds.runtime.compress.lib.CLALibReorg; +import org.apache.sysds.runtime.compress.lib.CLALibReplace; import org.apache.sysds.runtime.compress.lib.CLALibReshape; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; import org.apache.sysds.runtime.compress.lib.CLALibSlice; +import org.apache.sysds.runtime.compress.lib.CLALibSort; import org.apache.sysds.runtime.compress.lib.CLALibSquash; import org.apache.sysds.runtime.compress.lib.CLALibTSMM; import org.apache.sysds.runtime.compress.lib.CLALibTernaryOp; @@ -101,6 +103,7 @@ import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.utils.DMLCompressionStatistics; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; +import org.apache.sysds.utils.stats.Timing; public class CompressedMatrixBlock extends MatrixBlock { private static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName()); @@ -475,16 +478,20 @@ public void readFields(DataInput in) throws IOException { } public static CompressedMatrixBlock read(DataInput in) throws IOException { + Timing t = new Timing(); int rlen = in.readInt(); int clen = in.readInt(); long nonZeros = in.readLong(); boolean overlappingColGroups = in.readBoolean(); List groups = ColGroupIO.readGroups(in, rlen); - return new CompressedMatrixBlock(rlen, clen, nonZeros, overlappingColGroups, groups); + CompressedMatrixBlock ret = new CompressedMatrixBlock(rlen, clen, nonZeros, overlappingColGroups, groups); + LOG.debug("Compressed read serialization time: " + t.stop()); + return ret; } @Override public void write(DataOutput out) throws IOException { + Timing t = new Timing(); final long estimateUncompressed = nonZeros > 0 ? MatrixBlock.estimateSizeOnDisk(rlen, clen, nonZeros) : Long.MAX_VALUE; final long estDisk = nonZeros > 0 ? getExactSizeOnDisk() : Long.MAX_VALUE; @@ -512,6 +519,7 @@ public void write(DataOutput out) throws IOException { out.writeLong(nonZeros); out.writeBoolean(overlappingColGroups); ColGroupIO.writeGroups(out, _colGroups); + LOG.debug("Compressed write serialization time: " + t.stop()); } /** @@ -611,14 +619,6 @@ public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, MatrixVal public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType tstype, int k) { // check for transpose type if(tstype == MMTSJType.LEFT) { - if(isEmpty()) - return new MatrixBlock(clen, clen, true); - // create output matrix block - if(out == null) - out = new MatrixBlock(clen, clen, false); - else - out.reset(clen, clen, false); - out.allocateDenseBlock(); CLALibTSMM.leftMultByTransposeSelf(this, out, k); return out; } @@ -846,9 +846,8 @@ public CM_COV_Object covOperations(COVOperator op, MatrixBlock that, MatrixBlock } @Override - public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { - MatrixBlock right = getUncompressed(weights); - return getUncompressed("sortOperations").sortOperations(right, result); + public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result, int k) { + return CLALibSort.sort(this, weights, result, k); } @Override @@ -871,9 +870,7 @@ public MatrixBlock groupedAggOperations(MatrixValue tgt, MatrixValue wghts, Matr @Override public MatrixBlock removeEmptyOperations(MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select) { - printDecompressWarning("removeEmptyOperations"); - MatrixBlock tmp = getUncompressed(); - return tmp.removeEmptyOperations(ret, rows, emptyReturn, select); + return CLALibRemoveEmpty.rmempty(this, ret, rows, emptyReturn, select); } @Override @@ -1202,8 +1199,8 @@ public void examSparsity(boolean allowCSR, int k) { } @Override - public void sparseToDense(int k) { - // do nothing + public MatrixBlock sparseToDense(int k) { + return this; // do nothing } @Override @@ -1236,16 +1233,6 @@ public double interQuartileMean() { return getUncompressed("interQuartileMean").interQuartileMean(); } - @Override - public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { - return getUncompressed("pickValues").pickValues(quantiles, ret); - } - - @Override - public double pickValue(double quantile, boolean average) { - return getUncompressed("pickValue").pickValue(quantile, average); - } - @Override public double sumWeightForQuantile() { return getUncompressed("sumWeightForQuantile").sumWeightForQuantile(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 4c48effb4df..f082d1ffc3d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -64,6 +64,8 @@ public class CompressedMatrixBlockFactory { private static final Log LOG = LogFactory.getLog(CompressedMatrixBlockFactory.class.getName()); + private static final Object asyncCompressLock = new Object(); + /** Timing object to measure the time of each phase in the compression */ private final Timing time = new Timing(true); /** Compression statistics gathered throughout the compression */ @@ -181,21 +183,23 @@ public static Future compressAsync(ExecutionContext ec, String varName) { } public static Future compressAsync(ExecutionContext ec, String varName, InstructionTypeCounter ins) { - LOG.debug("Compressing Async"); final ExecutorService pool = CommonThreadPool.get(); // We have to guarantee that a thread pool is allocated. return CompletableFuture.runAsync(() -> { // method call or code to be async try { CacheableData data = ec.getCacheableData(varName); - if(data instanceof MatrixObject) { - MatrixObject mo = (MatrixObject) data; - MatrixBlock mb = mo.acquireReadAndRelease(); - MatrixBlock mbc = CompressedMatrixBlockFactory.compress(mo.acquireReadAndRelease(), ins).getLeft(); - if(mbc instanceof CompressedMatrixBlock) { - ExecutionContext.createCacheableData(mb); - mo.acquireModify(mbc); - mo.release(); - mbc.sum(); // calculate sum to forcefully materialize counts + synchronized(asyncCompressLock){ // synchronize on the data object to not allow multiple compressions of the same matrix. + if(data instanceof MatrixObject) { + LOG.debug("Compressing Async"); + MatrixObject mo = (MatrixObject) data; + MatrixBlock mb = mo.acquireReadAndRelease(); + MatrixBlock mbc = CompressedMatrixBlockFactory.compress(mb, ins).getLeft(); + if(mbc instanceof CompressedMatrixBlock) { + ExecutionContext.createCacheableData(mb); + mo.acquireModify(mbc); + mo.release(); + mbc.sum(); // calculate sum to forcefully materialize counts + } } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index ec502d6d122..fd59447b3d9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -29,9 +29,9 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult; @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -401,8 +402,9 @@ public final AColGroup rightMultByMatrix(MatrixBlock right) { * @param cru The right hand side column upper * @param nRows The number of rows in this column group */ - public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru){ - throw new NotImplementedException("not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); + public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru) { + throw new NotImplementedException( + "not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); } /** @@ -806,7 +808,7 @@ public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlo else denseSelection(selection, points, ret, rl, ru); } - + /** * Get an approximate sparsity of this column group * @@ -972,6 +974,15 @@ public AColGroup[] splitReshapePushDown(final int multiplier, final int nRow, fi return splitReshape(multiplier, nRow, nColOrg); } + /** + * Sort the values of the column group according to double < > operations and return as another compressed group. + * + * This sorting assumes that the column group is sorted independently of everything else. + * + * @return The sorted group + */ + public abstract AColGroup sort(); + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -981,4 +992,70 @@ public String toString() { sb.append(_colIndexes); return sb.toString(); } + + /** + * Return a new column group containing only the selected rows in the given boolean vector. + * + * Whenever possible only modify the index structure, not the dictionary of the column groups. + * + * @param selectV The selection vector + * @param rOut The number of rows in the output + * @return The new column group + */ + public abstract AColGroup removeEmptyRows(boolean[] selectV, int rOut); + + /** + * Return a new column group containing only the selected columns in the given boolean vector. + * + * Whenever possible only modify the column index, and reduce the dictionaries of the column groups. + * + * @param selectV The selection vector + * @return The new column group + */ + public AColGroup removeEmptyCols(boolean[] selectV) { + if(!inSelection(selectV)) + return null; + + final IntArrayList selectedColumns = new IntArrayList(); + final IntArrayList newIDs = new IntArrayList(); + int idx = 0; + int idxOwn = 0; + final int end = Math.min(selectV.length, _colIndexes.get(_colIndexes.size() - 1) + 1); + for(int i = 0; i < end; i++) { + + if(i == _colIndexes.get(idxOwn)) { + if(selectV[i]) { + selectedColumns.appendValue(idxOwn); + newIDs.appendValue(idx); + } + idxOwn++; + } + if(selectV[i]) + idx++; + } + + final IColIndex newColumnIDs = ColIndexFactory.create(newIDs); + if(newColumnIDs.size() == _colIndexes.size()) + return copyAndSet(newColumnIDs); + else + return removeEmptyColsSubset(newColumnIDs, selectedColumns); + } + + /** + * Using the selection of columns, slice out those and return in a new column group with the given column indexes. + * Ideally this method should only modify the dictionaries. + * + * @param newColumnIDs the new column indexes + * @param selectedColumns The selected columns of this column group (guaranteed < current number of columns) + * @return A new Column group + */ + protected abstract AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns); + + private boolean inSelection(boolean[] selection) { + for(int i = 0; i < _colIndexes.size(); i++) { + if(selection[_colIndexes.get(i)]) + return true; + } + return false; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java index 0cde289b30f..4f53d8b912b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java @@ -59,8 +59,6 @@ public int getNumValues() { * produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct * tuples in the dictionary. * - * The returned counts always contains the number of zero tuples as well if there are some contained, even if they - * are not materialized. * * @return The count of each value in the MatrixBlock. */ @@ -212,6 +210,7 @@ public void clear() { counts = null; } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java index 8f2f0b46055..d114f029df8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java @@ -402,4 +402,5 @@ protected IDictionary combineDictionaries(int nCol, List right) { public double getSparsity() { return _dict.getSparsity(); } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java index 3de98a1c23f..30de5e120c5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java @@ -203,6 +203,22 @@ private final void leftMultByMatrixNoPreAggRowsDense(MatrixBlock mb, double[] re */ protected abstract void multiplyScalar(double v, double[] resV, int offRet, AIterator it); + public void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC, AIterator it) { + if(_dict instanceof MatrixBlockDictionary) { + final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; + final MatrixBlock mb = md.getMatrixBlock(); + // The dictionary is never empty. + if(mb.isInSparseFormat()) + // TODO make sparse decompression where the iterator is known in argument + decompressToSparseBlockSparseDictionary(sb, rl, ru, offR, offC, mb.getSparseBlock()); + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, mb.getDenseBlockValues(), + it); + } + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, _dict.getValues(), it); + } + public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC, AIterator it) { if(_dict instanceof MatrixBlockDictionary) { final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; @@ -223,6 +239,9 @@ public void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, _dict.getValues(), it); } + public abstract void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock db, int rl, int ru, + int offR, int offC, double[] values, AIterator it); + public abstract void decompressToDenseBlockDenseDictionaryWithProvidedIterator(DenseBlock db, int rl, int ru, int offR, int offC, double[] values, AIterator it); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 21c6a0e1d80..cd0f52ebb8d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -46,6 +46,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -527,7 +528,7 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) { @Override public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); - if(d == null){ + if(d == null) { if(max <= 0) return null; return ColGroupEmpty.create(max); @@ -758,4 +759,19 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) protected boolean allowShallowIdentityRightMult() { return true; } + + @Override + public AColGroup sort() { + return this; + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return this; + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupConst.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols())); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index fc82c58e16b..40b5d41e3ea 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -26,8 +26,6 @@ import java.util.List; import java.util.concurrent.ExecutorService; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -52,6 +50,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -67,6 +66,9 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.jboss.netty.handler.codec.compression.CompressionException; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + /** * Class to encapsulate information about a column group that is encoded with dense dictionary encoding (DDC). */ @@ -668,7 +670,8 @@ private void defaultRightDecompressingMult(MatrixBlock right, MatrixBlock ret, i } } - final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen, DoubleVector vVec) { + final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen, + DoubleVector vVec) { vVec = vVec.broadcast(aa); final int offj = k * jd; final int end = endT + offj; @@ -1091,6 +1094,41 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E return res; } + @Override + public AColGroup sort() { + // TODO restore support for run length encoding to exploit the runs + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDC.create(_colIndexes, _dict, m, counts); + + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return ColGroupDDC.create(_colIndexes, _dict, _data.removeEmpty(selectV, rOut), null); + } + + @Override + protected boolean allowShallowIdentityRightMult() { + return true; + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupDDC.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols()), _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -1100,9 +1138,4 @@ public String toString() { return sb.toString(); } - @Override - protected boolean allowShallowIdentityRightMult() { - return true; - } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index 70191a27936..3f30dc8f0aa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -546,6 +547,40 @@ protected boolean allowShallowIdentityRightMult() { return false; } + @Override + public AColGroup sort() { + // TODO restore support for run length encoding. + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDCFOR.create(_colIndexes, _dict, m, counts, _reference); + + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return ColGroupDDCFOR.create(_colIndexes, _dict, _data.removeEmpty(selectV, rOut), null, _reference); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _reference[selectedColumns.get(i)]; + } + return ColGroupDDCFOR.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols()), _data, null, ref); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index ba547a8d7aa..aa4d8428dd1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -476,4 +477,20 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) return new ColGroupEmpty(combinedIndex); } + + @Override + public AColGroup sort(){ + return this; + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut){ + return this; + } + + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + return new ColGroupEmpty(newColumnIDs); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java index 91442281317..1091ae36890 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java @@ -94,9 +94,7 @@ public static long getExactSizeOnDisk(List colGroups) { } ret += grp.getExactSizeOnDisk(); } - if(LOG.isWarnEnabled()) - LOG.warn(" duplicate dicts on exact Size on Disk : " + (colGroups.size() - dicts.size()) ); - + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index 45b4fbeb026..ed4a8d03030 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -740,4 +741,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index ea6d0f34c2a..dded0e9f520 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -26,15 +26,16 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.bitmap.ABitmap; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -731,5 +732,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 2b4b23792e3..560af40bcf4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.RLEScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -1190,4 +1191,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 1270823bfdc..8e4d23baaa8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -42,6 +42,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -508,10 +509,10 @@ protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, in AOffset indexes, AMapToData data, int[] counts, int def, int nVal) { if(d == null) { - if(def <= 0){ + if(def <= 0) { if(max > 0) return ColGroupEmpty.create(max); - else + else return null; } else if(def > max && max > 0) @@ -873,6 +874,69 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + final double def = _defaultTuple[0]; + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= def) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDC.create(_colIndexes, _numRows, _dict, _defaultTuple, o, m, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDC.create(_colIndexes, rOut, _dict, _defaultTuple, offsetTmp.retOffset, nm, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _defaultTuple[selectedColumns.get(i)]; + } + return ColGroupSDC.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), ref, + _indexes, _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 41fb7ac5709..15661e86ad0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -620,6 +621,68 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCFOR.create(_colIndexes, _numRows, _dict, o, m, counts, _reference); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDCFOR.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, nm, null, _reference); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _reference[selectedColumns.get(i)]; + } + return ColGroupSDCFOR.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), _indexes, _data, null, + ref); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index fa5772c0c3e..ed0cafd07b4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -469,10 +470,10 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); final int def = (int) _defaultTuple[0]; if(d == null) { - if(def <= 0){ + if(def <= 0) { if(max > 0) return ColGroupEmpty.create(max); - else + else return null; } else if(def > max && max > 0) @@ -718,6 +719,66 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + final double def = _defaultTuple[0]; + int defIdx = counts.length; + int nondefault = 0; + for(int i = 0; i < r.length; i++) { + if(defIdx == counts.length && _dict.getValue(r[i], 0, 1) >= def) { + defIdx = i; + } + nondefault += counts[i]; + } + + int defaultLength = _numRows - nondefault; + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingle.create(_colIndexes, _numRows, _dict, _defaultTuple, o, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO optimize by not constructing boolean array. + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + return ColGroupSDCSingle.create(_colIndexes, rOut, _dict, _defaultTuple, offsetTmp.retOffset, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _defaultTuple[selectedColumns.get(i)]; + } + return ColGroupSDCSingle.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), ref, + _indexes, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index 9efd0c41098..79db48492bf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; @@ -109,10 +110,8 @@ protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int return; else if(it.value() >= ru) return; - // _indexes.cacheIterator(it, ru); else { decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, values, it); - // _indexes.cacheIterator(it, ru); } } @@ -238,7 +237,8 @@ protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); + return; + // _indexes.cacheIterator(it, ru); else if(ru > last) { final int apos = sb.pos(0); final int alen = sb.size(0) + apos; @@ -277,8 +277,15 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, + int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int nCol = _colIndexes.size(); final int lastOff = _indexes.getOffsetToLast(); int row = offR + it.value(); @@ -963,7 +970,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { throw new NotImplementedException(); } - + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { throw new NotImplementedException(); } @@ -1043,6 +1050,62 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + int nondefault = 0; + for(int i = 0; i < r.length; i++) { + if(defIdx == counts.length && _dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + } + nondefault += counts[i]; + } + + int defaultLength = _numRows - nondefault; + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingleZeros.create(_colIndexes, _numRows, _dict, o, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO optimize by not constructing boolean array. + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + return ColGroupSDCSingleZeros.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + + return ColGroupSDCSingleZeros.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), + _indexes, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 69e0f776383..f3c82684f99 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -45,6 +45,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -184,8 +185,7 @@ private final void decompressToDenseBlockDenseDictionaryPostAllCols(DenseBlock d final double[] c = db.values(idx); final int off = db.pos(idx); final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); if(it.value() == lastOff) return; it.next(); @@ -301,13 +301,19 @@ private void decompressToDenseBlockDenseDictionaryPreAllCols(DenseBlock db, int final double[] c = db.values(idx); final int off = db.pos(idx) + offC; final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); it.next(); } } + private static void decompressSingleRow(double[] values, final int nCol, final double[] c, final int off, + final int offDict) { + final int end = nCol + off; + for(int j = off, k = offDict; j < end; j++, k++) + c[j] += values[k]; + } + @Override protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, SparseBlock sb) { @@ -438,8 +444,16 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, + int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int lastOff = _indexes.getOffsetToLast(); final int nCol = _colIndexes.size(); while(true) { @@ -467,7 +481,6 @@ else if(ru > _indexes.getOffsetToLast()) { } _indexes.cacheIterator(it, ru); } - } @Override @@ -899,7 +912,6 @@ public AColGroup morph(CompressionType ct, int nRow) { return super.morph(ct, nRow); } - @Override public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { final SparseBlock sr = ret.getSparseBlock(); @@ -942,14 +954,14 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret of = it.next(); } else if(points[c].o < of) - c++; + c++; else of = it.next(); - } - // increment the c pointer until it is pointing at least to last point or is done. - while(c < points.length && points[c].o < last) - c++; - c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); + } + // increment the c pointer until it is pointing at least to last point or is done. + while(c < points.length && points[c].o < last) + c++; + c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); } private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) { @@ -1078,6 +1090,64 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCZeros.create(_colIndexes, _numRows, _dict, o, m, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDCZeros.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, nm, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupSDCZeros.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), + _indexes, _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 1c3bce2e16c..8efffc4878c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -43,6 +43,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.EstimationFactors; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -82,7 +83,8 @@ public class ColGroupUncompressed extends AColGroup { /** * Do not use this constructor of column group uncompressed, instead use the create constructor. - * @param mb The contained data. + * + * @param mb The contained data. * @param colIndexes Column indexes for this Columngroup */ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { @@ -92,14 +94,15 @@ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { /** * Do not use this constructor of column group quantization-fused uncompressed, instead use the create constructor. - * @param mb The contained data. - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix - * @param colIndexes Column indexes for this Columngroup + * + * @param mb The contained data. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Column indexes for this Columngroup */ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { super(colIndexes); - // Apply scaling and flooring - // TODO: Use internal matrix prod + // Apply scaling and flooring + // TODO: Use internal matrix prod for(int r = 0; r < mb.getNumRows(); r++) { double scaleFactor = scaleFactors.length == 1 ? scaleFactors[0] : scaleFactors[r]; for(int c = 0; c < mb.getNumColumns(); c++) { @@ -108,7 +111,8 @@ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] sc } } _data = mb; - } + } + /** * Create an Uncompressed Matrix Block, where the columns are offset by col indexes. * @@ -130,9 +134,9 @@ public static AColGroup create(MatrixBlock mb, IColIndex colIndexes) { * * It is assumed that the size of the colIndexes and number of columns in mb is matching. * - * @param mb The MB / data to contain in the uncompressed column - * @param colIndexes The column indexes for the group - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param mb The MB / data to contain in the uncompressed column + * @param colIndexes The column indexes for the group + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix * @return An Uncompressed Column group */ public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { @@ -147,14 +151,15 @@ public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, do /** * Main constructor for a quantization-fused uncompressed ColGroup. * - * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. - * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is - * called - * @param transposed Says if the input matrix raw block have been transposed. - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. + * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is + * called + * @param transposed Says if the input matrix raw block have been transposed. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix * @return AColGroup. */ - public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, double[] scaleFactors) { + public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, + double[] scaleFactors) { // special cases if(rawBlock.isEmptyBlock(false)) // empty input @@ -187,22 +192,24 @@ else if(!transposed && colIndexes.size() == rawBlock.getNumColumns()) final int n = colIndexes.size(); if(transposed) { - if (scaleFactors.length == 1) { + if(scaleFactors.length == 1) { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); - } else { + } + else { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[j])); } } else { - if (scaleFactors.length == 1) { + if(scaleFactors.length == 1) { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); - } else { + } + else { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[i])); @@ -1075,7 +1082,6 @@ public AColGroup morph(CompressionType ct, int nRow) { return comp.get(0).copyAndSet(_colIndexes); } - @Override public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { if(_data.isInSparseFormat()) @@ -1092,7 +1098,6 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret denseSelectionDenseColumnGroup(selection, ret, rl, ru); } - private void sparseSelectionSparseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { final SparseBlock sb = selection.getSparseBlock(); @@ -1192,7 +1197,7 @@ public AColGroup reduceCols() { else return new ColGroupUncompressed(mb, ColIndexFactory.createI(0)); } - + @Override public void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { if(_data.isInSparseFormat()) @@ -1289,11 +1294,30 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { for(int i = 0; i < multiplier; i++) for(int j = 0; j < s; j++) newColumns[i * s + j] = _colIndexes.get(j) + nColOrg * i; - MatrixBlock newData = _data.reshape(nRow/ multiplier, s * multiplier, true); - return new AColGroup[]{create(newData,ColIndexFactory.create(newColumns))}; + MatrixBlock newData = _data.reshape(nRow / multiplier, s * multiplier, true); + return new AColGroup[] {create(newData, ColIndexFactory.create(newColumns))}; // throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + return new ColGroupUncompressed(_data.sortOperations(), _colIndexes); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + MatrixBlock tmp = new MatrixBlock(); + tmp = LibMatrixReorg.removeEmptyRows(_data, tmp, false, false, selectV, rOut); + return ColGroupUncompressed.create(_colIndexes, tmp, false); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] vals = MatrixBlockDictionary.sliceColumns(_data, selectedColumns); + MatrixBlock ret = new MatrixBlock(_data.getNumRows(), selectedColumns.size(), vals); + return ColGroupUncompressed.create(newColumnIDs, ret, false); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 31e29341645..a0a9dd46306 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -19,11 +19,13 @@ package org.apache.sysds.runtime.compress.colgroup; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -282,4 +284,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new UnsupportedOperationException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort(){ + throw new NotImplementedException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java index 17b382f06ad..a7e715b59b8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; public abstract class AIdentityDictionary extends ACachingMBDictionary { @@ -74,4 +75,9 @@ public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { ret[ret.length - 1] *= defaultTuple[i]; return ret; } + + @Override + public int[] sort(){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java index d67ab95f824..7bb6132deb4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java @@ -23,6 +23,7 @@ import java.io.IOException; import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -121,4 +122,14 @@ public boolean equals(IDictionary o) { public IDictionary clone() { throw new NotImplementedException(); } + + @Override + public int[] sort(){ + throw new NotImplementedException(); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 939b48bf424..bd7dc98fe55 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -28,10 +28,13 @@ import java.util.HashSet; import java.util.Set; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -1341,4 +1344,73 @@ public IDictionary append(double[] row) { return new Dictionary(retV); } + @Override + public int[] sort() { + return sort(_values); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + // TODO: make specialized version for this. + return getMBDict(nCol).sliceColumns(selectedColumns, nCol); + } + + protected static int[] sort(double[] values) { + int[] indices = new int[values.length]; + for(int i = 0; i < indices.length; i++) { + indices[i] = i; + } + + // quicksort with stack + int[] stack = new int[values.length]; + + int top = -1; + stack[++top] = 0; + stack[++top] = values.length - 1; + + while(top >= 0) { + int high = stack[top--]; + int low = stack[top--]; + + if(low < high) { + + int pivotIndex = partition(indices, values, low, high); + // Left side + if(pivotIndex - 1 > low) { + stack[++top] = low; + stack[++top] = pivotIndex - 1; + } + + // Right side + if(pivotIndex + 1 < high) { + stack[++top] = pivotIndex + 1; + stack[++top] = high; + } + } + } + + return indices; + } + + private static int partition(int[] indices, double[] values, int low, int high) { + double pivotValue = values[indices[high]]; + int i = low - 1; + + for(int j = low; j < high; j++) { + if(values[indices[j]] <= pivotValue) { + i++; + swap(indices, i, j); + } + } + + swap(indices, i + 1, high); + return i + 1; + } + + private static void swap(int[] arr, int i, int j) { + int tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index dddea0eec7a..28ec8ebf207 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; @@ -1051,4 +1052,23 @@ public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thi * @return The nonzero count of each column in the dictionary. */ public int[] countNNZZeroColumns(int[] counts); + + /** + * Sort the values of this dictionary via an index of how the values mapped previously. + * + * In practice this design means we can reuse the previous dictionary for the resulting column group + * + * @return The sorted index. + */ + public int[] sort(); + + /** + * Slice out the selected columns given of this encoded group. + * + * @param selectedColumns The columns to slice out and return as a new matrix. + * @param nCol The number of columns in this dictionary. + * @return The returned matrix + */ + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 40e1b065653..c2540de959a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -540,9 +541,13 @@ public String getString(int colIndexes) { return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + return getMBDict().sliceColumns(selectedColumns, nCol); + } + @Override public String toString() { return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index df702524d55..c7f642edfd0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -27,6 +27,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -310,6 +311,11 @@ public String getString(int colIndexes) { return toString(); } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + return getMBDict().sliceColumns(selectedColumns, nCol); + } + @Override public String toString() { return "IdentityMatrixSlice of size: " + nRowCol + " l " + l + " u " + u; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 24776f3adc4..ef0fc4aacd5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -27,8 +27,6 @@ import java.util.Arrays; import java.util.Set; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; @@ -36,6 +34,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; @@ -61,6 +60,9 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + public class MatrixBlockDictionary extends ADictionary { private static final long serialVersionUID = 2535887782150955098L; @@ -2801,4 +2803,41 @@ private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, } } + @Override + public int[] sort() { + if(_data.getNumColumns() > 1) + throw new RuntimeException("Not supported sort on multicolumn dictionaries"); + _data.sparseToDense(); + + return Dictionary.sort(_data.getDenseBlockValues()); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + + final double[] ret = sliceColumns(_data, selectedColumns); + + return new Dictionary(ret); + } + + public static double[] sliceColumns(MatrixBlock mb, IntArrayList selectedColumns) { + //TODO: Optimize to allow sparse outputs. and change output type to MatrixBlock. + final int outC = selectedColumns.size(); + if((long) mb.getNumRows() * outC > (long) Integer.MAX_VALUE) + throw new NotImplementedException("Not supported large output blocks for slicing dictionary columns"); + mb.sparseToDense(); + final DenseBlock db = mb.getDenseBlock(); + final double[] ret = new double[mb.getNumRows() * outC]; + + for(int i = 0; i < mb.getNumRows(); i++) { + double[] vals = db.values(i); + int offIn = db.pos(i); + int offOut = i * outC; + for(int j = 0; j < outC; j++) { + ret[offOut + j] = vals[offIn + selectedColumns.get(j)]; + } + } + return ret; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index f5746647a37..ec3ead4a68b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -23,6 +23,7 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.io.IOUtilFunctions; public class PlaceHolderDict extends ADictionary { @@ -101,4 +102,14 @@ public DictType getDictType() { throw new RuntimeException("invalid to get dictionary type for PlaceHolderDict"); } + @Override + public int[] sort() { + throw new RuntimeException("Invalid call"); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + throw new RuntimeException("Invalid call"); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 6802d920b49..9e2fa4bf1d7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -23,6 +23,8 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.utils.MemoryEstimates; @@ -277,4 +279,13 @@ public MatrixBlockDictionary createMBDict(int nCol) { return new MatrixBlockDictionary(mb); } + @Override + public int[] sort() { + throw new NotImplementedException(); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + return getMBDict().sliceColumns(selectedColumns, nCol); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index 5fc2acaea7a..79cc219f2e2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -30,6 +30,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; @@ -39,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -1041,4 +1043,40 @@ public String toString() { sb.append("]"); return sb.toString(); } + + public AMapToData removeEmpty(final boolean[] selectV, final int rOut) { + try{ + + final AMapToData ret = MapToFactory.create(rOut, getUnique()); + final int s = size(); + int t = 0; + for(int i = 0; i < s; i++) + if(selectV[i] == true) + ret.set(t++, getIndex(i)); + + return ret; + } + catch(ArrayIndexOutOfBoundsException e){ + + int trueCount = 0; + for(boolean a : selectV){ + if(a) trueCount ++; + } + throw new DMLRuntimeException("actual number of true values " + trueCount + " vs argument " + rOut,e); + } + } + + /** + * Use the offsets of the select vector to choose which values to keep. + * + * @param select The row indexes to keep + * @return A New MapToData + */ + public AMapToData removeEmpty(IntArrayList select) { + final int s = select.size(); + final AMapToData ret = MapToFactory.create(s, getUnique()); + for(int i = 0; i < s; i++) + ret.set(i, getIndex(select.get(i))); + return ret; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java index 45c78dd3abd..a809afccd3d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java @@ -71,8 +71,8 @@ public boolean isNotOver(int ub) { /** * Get the current data index associated with the index returned from value. * - * This index points to a position int the mapToData object, that then inturn can be used to lookup the dictionary - * entry in ADictionary. + * This index points to a position in the AMapToData object, that can be used to lookup the dictionary entry in + * ADictionary. * * @return The Data Index. */ diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java index 8930074eb0e..bae6ae57cad 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java @@ -569,7 +569,7 @@ public OffsetSliceInfo slice(int l, int u) { else return new OffsetSliceInfo(0, s, moveIndex(l)); } - else if (u < first) + else if(u < first) return EMPTY_SLICE; final AIterator it = getIteratorSkipCache(l); @@ -764,6 +764,43 @@ public AOffset reverse(int numRows) { return OffsetFactory.createOffset(newOff); } + public RemoveEmptyOffsetsTmp removeEmptyRows(boolean[] selectV, int rOut) { + IntArrayList newOff = new IntArrayList(); + IntArrayList selectMTmp = new IntArrayList(); + + final AIterator it = getIterator(); + final int last = getOffsetToLast(); + int t = 0; + int o = 0; + while(it.value() < last) { + while(t < it.value()) { + if(selectV[t]) + o++; + t++; + } + if(selectV[it.value()]) { + newOff.appendValue(o); + selectMTmp.appendValue(it.getDataIndex()); + o++; + t++; + } + it.next(); + } + while(t < last) { + if(selectV[t]) + o++; + t++; + } + if(selectV[last]) { + newOff.appendValue(o); + selectMTmp.appendValue(it.getDataIndex()); + } + + // throw new RuntimeException("\n\n\n" + Arrays.toString(selectV) + " \n\n " + this + "\n\n " + newOff + " \n " + + // selectMTmp + "\n\n " + "\n\n "); + return new RemoveEmptyOffsetsTmp(OffsetFactory.createOffset(newOff), selectMTmp); + } + /** * Offset slice info containing the start and end index an offset that contains the slice, and an new AOffset * containing only the sliced elements @@ -793,6 +830,16 @@ public String toString() { } + public static final class RemoveEmptyOffsetsTmp { + public final AOffset retOffset; + public final IntArrayList select; + + protected RemoveEmptyOffsetsTmp(AOffset retOffset, IntArrayList select) { + this.retOffset = retOffset; + this.select = select; + } + } + private static class OffsetCache { private final AIterator it; private final int row; @@ -824,4 +871,5 @@ public String toString() { return "r" + row + " d " + dataIndex + " o " + offIndex + "\n"; } } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java index 73264c84767..5c410026587 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java @@ -76,6 +76,10 @@ public int getOffsetToLast() { public long getInMemorySize() { return estimateInMemorySize(); } + @Override + public boolean equals(AOffset b) { + return b instanceof OffsetEmpty; + } public static long estimateInMemorySize() { return 16; // object header diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index ce52bcd23fd..26779215306 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -48,6 +48,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; +import org.apache.sysds.runtime.compress.utils.HashMapIntToInt; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.SparseBlock; @@ -55,7 +56,6 @@ import org.apache.sysds.runtime.data.SparseRow; import org.apache.sysds.runtime.data.SparseRowScalar; import org.apache.sysds.runtime.data.SparseRowVector; -import org.apache.sysds.runtime.frame.data.columns.HashMapToInt; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -77,7 +77,7 @@ public final class CLALibBinaryCellOp { private static final Log LOG = LogFactory.getLog(CLALibBinaryCellOp.class.getName()); - public static final int DECOMPRESSION_BLEN = 16384; + public static final int DECOMPRESSION_BLEN = 16384 / 2; private CLALibBinaryCellOp() { // empty private constructor. @@ -86,7 +86,7 @@ private CLALibBinaryCellOp() { public static MatrixBlock binaryOperationsRight(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that) { try { - op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); + op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); if((that.getNumRows() == 1 && that.getNumColumns() == 1) || that.isEmpty()) { ScalarOperator sop = new RightScalarOperator(op.fn, that.get(0, 0), op.getNumThreads()); @@ -113,7 +113,7 @@ public static MatrixBlock binaryOperationsLeft(BinaryOperator op, CompressedMatr return selectProcessingBasedOnAccessType(op, m1, that, atype, true); } catch(Exception e) { - throw new DMLRuntimeException("Failed Left Binary Compressed Operation", e); + throw new DMLRuntimeException("Failed Left Binary Compressed Operation: " + op, e); } } @@ -122,8 +122,8 @@ private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, Comp BinaryAccessType atype = LibMatrixBincell.getBinaryAccessTypeExtended(m1, that); if(isDoubleCompressedOpApplicable(m1, that)) return doubleCompressedBinaryOp(op, m1, (CompressedMatrixBlock) that); - if(that instanceof CompressedMatrixBlock && that.getNumColumns() == m1.getNumColumns() - && that.getInMemorySize() < m1.getInMemorySize() ) { + if(that instanceof CompressedMatrixBlock && that.getNumColumns() == m1.getNumColumns() && + that.getInMemorySize() < m1.getInMemorySize()) { MatrixBlock m1uc = CompressedMatrixBlock.getUncompressed(m1, "Decompressing left side in BinaryOps"); return selectProcessingBasedOnAccessType(op, (CompressedMatrixBlock) that, m1uc, atype, true); } @@ -135,16 +135,15 @@ private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, Comp } private static boolean isDoubleCompressedOpApplicable(CompressedMatrixBlock m1, MatrixBlock that) { - return that instanceof CompressedMatrixBlock - && !m1.isOverlapping() - && m1.getColGroups().get(0) instanceof ColGroupDDC - && !((CompressedMatrixBlock) that).isOverlapping() - && ((CompressedMatrixBlock) that).getColGroups().get(0) instanceof ColGroupDDC - && ((IMapToDataGroup) m1.getColGroups().get(0)).getMapToData() == - ((IMapToDataGroup) ((CompressedMatrixBlock) that).getColGroups().get(0)).getMapToData(); + return that instanceof CompressedMatrixBlock && !m1.isOverlapping() && + m1.getColGroups().get(0) instanceof ColGroupDDC && !((CompressedMatrixBlock) that).isOverlapping() && + ((CompressedMatrixBlock) that).getColGroups().get(0) instanceof ColGroupDDC && + ((IMapToDataGroup) m1.getColGroups().get(0)) + .getMapToData() == ((IMapToDataGroup) ((CompressedMatrixBlock) that).getColGroups().get(0)).getMapToData(); } - private static CompressedMatrixBlock doubleCompressedBinaryOp(BinaryOperator op, CompressedMatrixBlock m1, CompressedMatrixBlock m2) { + private static CompressedMatrixBlock doubleCompressedBinaryOp(BinaryOperator op, CompressedMatrixBlock m1, + CompressedMatrixBlock m2) { LOG.debug("Double Compressed BinaryOp"); AColGroup left = m1.getColGroups().get(0); AColGroup right = m2.getColGroups().get(0); @@ -201,6 +200,7 @@ private static MatrixBlock mvCol(BinaryOperator op, CompressedMatrixBlock m1, Ma // Column vector access MatrixBlock d_compressed = m1.getCachedDecompressed(); if(d_compressed != null) { + LOG.debug("Using cached decompressed for Matrix column vector compressed operation"); if(left) throw new NotImplementedException("Binary row op left is not supported for Uncompressed Matrix, " + "Implement support for VMr in MatrixBlock Binary Cell operations"); @@ -416,17 +416,24 @@ private static MatrixBlock mvColCompressed(CompressedMatrixBlock m1, MatrixBlock Pair tuple = evaluateSparsityMVCol(m1, m2, op, left); double estSparsity = tuple.getKey(); double estNnzPerRow = tuple.getValue(); - boolean shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, (long) (estSparsity * nRows * nCols)); + boolean shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, + (long) (estSparsity * nRows * nCols)); // currently also jump into that case if estNnzPerRow == 0 - if(estNnzPerRow <= 2 && nCols <= 31 && op.fn instanceof ValueComparisonFunction){ - return k <= 1 ? binaryMVComparisonColSingleThreadCompressed(m1, m2, op, left) : - binaryMVComparisonColMultiCompressed(m1, m2, op, left); + if(estNnzPerRow <= 2 && nCols <= 31 && op.fn instanceof ValueComparisonFunction) { + return k <= 1 ? binaryMVComparisonColSingleThreadCompressed(m1, m2, op, + left) : binaryMVComparisonColMultiCompressed(m1, m2, op, left); } MatrixBlock ret = new MatrixBlock(nRows, nCols, shouldBeSparseOut, -1).allocateBlock(); if(shouldBeSparseOut) { - if(k <= 1) + if(!m1.isOverlapping() && MatrixBlock.evalSparseFormatInMemory(nRows, nCols, m1.getNonZeros())) { + if(k <= 1) + nnz = binaryMVColSingleThreadSparseSparse(m1, m2, op, left, ret); + else + nnz = binaryMVColMultiThreadSparseSparse(m1, m2, op, left, ret); + } + else if(k <= 1) nnz = binaryMVColSingleThreadSparse(m1, m2, op, left, ret); else nnz = binaryMVColMultiThreadSparse(m1, m2, op, left, ret); @@ -438,7 +445,7 @@ private static MatrixBlock mvColCompressed(CompressedMatrixBlock m1, MatrixBlock nnz = binaryMVColMultiThreadDense(m1, m2, op, left, ret); } - if(op.fn instanceof ValueComparisonFunction) { + if(op.fn instanceof ValueComparisonFunction) { // potentially empty or filled. if(nnz == (long) nRows * nCols)// all was 1 return CompressedMatrixBlockFactory.createConstant(nRows, nCols, 1.0); else if(nnz == 0) // all was 0 -> return empty. @@ -452,19 +459,19 @@ else if(nnz == 0) // all was 0 -> return empty. } private static MatrixBlock binaryMVComparisonColSingleThreadCompressed(CompressedMatrixBlock m1, MatrixBlock m2, - BinaryOperator op, boolean left) { + BinaryOperator op, boolean left) { final int nRows = m1.getNumRows(); final int nCols = m1.getNumColumns(); // get indicators (one-hot-encoded comparison results) - BinaryMVColTaskCompressed task = new BinaryMVColTaskCompressed(m1, m2, 0, nRows, op, left); + BinaryMVColTaskCompressed task = new BinaryMVColTaskCompressed(m1, m2, 0, nRows, op, left); long nnz = task.call(); int[] indicators = task._ret; // map each unique indicator to an index - HashMapToInt hm = new HashMapToInt<>(nCols*3); + HashMapIntToInt hm = new HashMapIntToInt(nCols * 3); int[] colMap = new int[nRows]; - for(int i = 0; i < m1.getNumRows(); i++){ + for(int i = 0; i < m1.getNumRows(); i++) { int nextId = hm.size(); int id = hm.putIfAbsentI(indicators[i], nextId); colMap[i] = id == -1 ? nextId : id; @@ -477,37 +484,39 @@ private static MatrixBlock binaryMVComparisonColSingleThreadCompressed(Compresse return getCompressedMatrixBlock(m1, colMap, hm.size(), outMb, nRows, nCols, nnz); } - private static void fillSparseBlockFromIndicatorFromIndicatorInt(int numCol, Integer indicator, Integer rix, SparseBlockMCSR out) { + private static void fillSparseBlockFromIndicatorFromIndicatorInt(int numCol, Integer indicator, Integer rix, + SparseBlockMCSR out) { ArrayList colIndices = new ArrayList<>(8); - for (int c = numCol - 1; c >= 0; c--) { + for(int c = numCol - 1; c >= 0; c--) { if(indicator <= 0) break; - if(indicator % 2 == 1){ + if(indicator % 2 == 1) { colIndices.add(c); } indicator = indicator >> 1; } SparseRow row = null; - if(colIndices.size() > 1){ + if(colIndices.size() > 1) { double[] vals = new double[colIndices.size()]; Arrays.fill(vals, 1); int[] indices = new int[colIndices.size()]; - for (int i = 0, j = colIndices.size() - 1; i < colIndices.size(); i++, j--) + for(int i = 0, j = colIndices.size() - 1; i < colIndices.size(); i++, j--) indices[i] = colIndices.get(j); row = new SparseRowVector(vals, indices); - } else if(colIndices.size() == 1){ + } + else if(colIndices.size() == 1) { row = new SparseRowScalar(colIndices.get(0), 1.0); } out.set(rix, row, false); } private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrixBlock m1, MatrixBlock m2, - BinaryOperator op, boolean left) throws Exception { + BinaryOperator op, boolean left) throws Exception { final int nRows = m1.getNumRows(); final int nCols = m1.getNumColumns(); final int k = op.getNumThreads(); - final int blkz = nRows / k; + final int blkz = Math.max((nRows + k) / k, 1000); // get indicators (one-hot-encoded comparison results) long nnz = 0; @@ -518,14 +527,11 @@ private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrix tasks.add(new BinaryMVColTaskCompressed(m1, m2, i, Math.min(nRows, i + blkz), op, left)); } List> futures = pool.invokeAll(tasks); - HashMapToInt hm = new HashMapToInt<>(nCols*2); + HashMapIntToInt hm = new HashMapIntToInt(nCols * 2); int[] colMap = new int[nRows]; - for(Future f : futures) - nnz += f.get(); - // map each unique indicator to an index - mergeMVColTaskResults(tasks, blkz, hm, colMap); + nnz = mergeMVColTaskResults(futures, tasks, blkz, hm, colMap); // decode the unique indicator ints to SparseVectors MatrixBlock outMb = getMCSRMatrixBlock(hm, nCols); @@ -539,48 +545,53 @@ private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrix } - private static void mergeMVColTaskResults(ArrayList tasks, int blkz, HashMapToInt hm, int[] colMap) { - + private static long mergeMVColTaskResults(List> futures, ArrayList tasks, + int blkz, HashMapIntToInt hm, int[] colMap) throws InterruptedException, ExecutionException { + long nnz = 0; for(int j = 0; j < tasks.size(); j++) { + nnz += futures.get(j).get(); // ensure task was finished. int[] indicators = tasks.get(j)._ret; - int offset = j* blkz; - - final int remainders = indicators.length % 8; - final int endVecLen = indicators.length - remainders; - for (int i = 0; i < endVecLen; i+= 8) { - colMap[offset + i] = hm.putIfAbsentReturnVal(indicators[i], hm.size()); - colMap[offset + i + 1] = hm.putIfAbsentReturnVal(indicators[i + 1], hm.size()); - colMap[offset + i + 2] = hm.putIfAbsentReturnVal(indicators[i + 2], hm.size()); - colMap[offset + i + 3] = hm.putIfAbsentReturnVal(indicators[i + 3], hm.size()); - colMap[offset + i + 4] = hm.putIfAbsentReturnVal(indicators[i + 4], hm.size()); - colMap[offset + i + 5] = hm.putIfAbsentReturnVal(indicators[i + 5], hm.size()); - colMap[offset + i + 6] = hm.putIfAbsentReturnVal(indicators[i + 6], hm.size()); - colMap[offset + i + 7] = hm.putIfAbsentReturnVal(indicators[i + 7], hm.size()); + int offset = j * blkz; - } - for (int i = 0; i < remainders; i++) { - colMap[offset + endVecLen + i] = hm.putIfAbsentReturnVal(indicators[endVecLen + i], hm.size()); - } + mergeMVColUnrolled(hm, colMap, indicators, offset); } + return nnz; } + private static void mergeMVColUnrolled(HashMapIntToInt hm, int[] colMap, int[] indicators, int offset) { + final int remainders = indicators.length % 8; + final int endVecLen = indicators.length - remainders; + for(int i = 0; i < endVecLen; i += 8) { + colMap[offset + i] = hm.putIfAbsentReturnVal(indicators[i], hm.size()); + colMap[offset + i + 1] = hm.putIfAbsentReturnVal(indicators[i + 1], hm.size()); + colMap[offset + i + 2] = hm.putIfAbsentReturnVal(indicators[i + 2], hm.size()); + colMap[offset + i + 3] = hm.putIfAbsentReturnVal(indicators[i + 3], hm.size()); + colMap[offset + i + 4] = hm.putIfAbsentReturnVal(indicators[i + 4], hm.size()); + colMap[offset + i + 5] = hm.putIfAbsentReturnVal(indicators[i + 5], hm.size()); + colMap[offset + i + 6] = hm.putIfAbsentReturnVal(indicators[i + 6], hm.size()); + colMap[offset + i + 7] = hm.putIfAbsentReturnVal(indicators[i + 7], hm.size()); - private static CompressedMatrixBlock getCompressedMatrixBlock(CompressedMatrixBlock m1, int[] colMap, - int mapSize, MatrixBlock outMb, int nRows, int nCols, long nnz) { + } + for(int i = 0; i < remainders; i++) { + colMap[offset + endVecLen + i] = hm.putIfAbsentReturnVal(indicators[endVecLen + i], hm.size()); + } + } + + private static CompressedMatrixBlock getCompressedMatrixBlock(CompressedMatrixBlock m1, int[] colMap, int mapSize, + MatrixBlock outMb, int nRows, int nCols, long nnz) { final IColIndex i = ColIndexFactory.create(0, m1.getNumColumns()); final AMapToData map = MapToFactory.create(m1.getNumRows(), colMap, mapSize); final AColGroup rgroup = ColGroupDDC.create(i, MatrixBlockDictionary.create(outMb), map, null); final ArrayList groups = new ArrayList<>(1); groups.add(rgroup); - return new CompressedMatrixBlock(nRows, nCols, nnz, false, groups); + return new CompressedMatrixBlock(nRows, nCols, nnz, false, groups); } - private static MatrixBlock getMCSRMatrixBlock(HashMapToInt hm, int nCols) { + private static MatrixBlock getMCSRMatrixBlock(HashMapIntToInt hm, int nCols) { // decode the unique indicator ints to SparseVectors SparseBlockMCSR out = new SparseBlockMCSR(hm.size()); - hm.forEach((indicator, rix) -> - fillSparseBlockFromIndicatorFromIndicatorInt(nCols, indicator, rix, out)); - return new MatrixBlock(hm.size(), nCols, -1, out); + hm.forEach((indicator, rix) -> fillSparseBlockFromIndicatorFromIndicatorInt(nCols, indicator, rix, out)); + return new MatrixBlock(hm.size(), nCols, -1, out); } private static long binaryMVColSingleThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, @@ -599,6 +610,14 @@ private static long binaryMVColSingleThreadSparse(CompressedMatrixBlock m1, Matr return nnz; } + private static long binaryMVColSingleThreadSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) { + final int nRows = m1.getNumRows(); + long nnz = 0; + nnz += new BinaryMVColTaskSparseSparse(m1, m2, ret, 0, nRows, op, left).call(); + return nnz; + } + private static long binaryMVColMultiThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left, MatrixBlock ret) throws Exception { final int nRows = m1.getNumRows(); @@ -641,6 +660,27 @@ private static long binaryMVColMultiThreadSparse(CompressedMatrixBlock m1, Matri return nnz; } + private static long binaryMVColMultiThreadSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) throws Exception { + final int nRows = m1.getNumRows(); + final int k = op.getNumThreads(); + final int blkz = Math.max(nRows / k, 64); + long nnz = 0; + final ExecutorService pool = CommonThreadPool.get(op.getNumThreads()); + try { + final ArrayList> tasks = new ArrayList<>(); + for(int i = 0; i < nRows; i += blkz) { + tasks.add(new BinaryMVColTaskSparseSparse(m1, m2, ret, i, Math.min(nRows, i + blkz), op, left)); + } + for(Future f : pool.invokeAll(tasks)) + nnz += f.get(); + } + finally { + pool.shutdown(); + } + return nnz; + } + private static MatrixBlock mmCompressed(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left) throws Exception { final int nCols = m1.getNumColumns(); @@ -724,8 +764,8 @@ private static class BinaryMVColTaskCompressed implements Callable { private MatrixBlock tmp; - protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, int rl, int ru, - BinaryOperator op, boolean left) { + protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, int rl, int ru, BinaryOperator op, + boolean left) { _m1 = m1; _m2 = m2; _op = op; @@ -738,21 +778,21 @@ protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, in @Override public Long call() { - tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); - final int _blklen = tmp.getNumRows(); + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlock(_blklen, _m1.getNumColumns()); final List groups = _m1.getColGroups(); final AIterator[] its = getIterators(groups, _rl); long nnz = 0; if(!_left) - for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + for(int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen) { int ru = Math.min(rl + _blklen, _ru); decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); nnz += processDense(rl, ru, retIxOff); tmp.reset(); } else - for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + for(int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen) { int ru = Math.min(rl + _blklen, _ru); decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); nnz += processDenseLeft(rl, ru, retIxOff); @@ -770,18 +810,24 @@ private final long processDense(final int rl, final int ru, final int retIxOffse for(int row = rl, retIx = retIxOffset; row < ru; row++, retIx++) { final double vr = _m2Dense[row]; final int tmpOff = (row - rl) * nCol; - int indicatorVector = 0; - for(int col = 0; col < nCol; col++) { - indicatorVector = indicatorVector << 1; - int indicator = _compFn.compare(_tmpDense[tmpOff + col], vr) ? 1 : 0; - indicatorVector += indicator; - nnz += indicator; - } - _ret[retIx] = indicatorVector; + nnz = processRow(nCol, _tmpDense, nnz, retIx, vr, tmpOff); } return nnz; } + private final long processRow(final int nCol, final double[] _tmpDense, long nnz, int retIx, final double vr, + final int tmpOff) { + int indicatorVector = 0; + for(int col = tmpOff; col < nCol + tmpOff; col++) { + indicatorVector = indicatorVector << 1; + int indicator = _compFn.compare(_tmpDense[col], vr) ? 1 : 0; + indicatorVector += indicator; + nnz += indicator; + } + _ret[retIx] = indicatorVector; + return nnz; + } + private final long processDenseLeft(final int rl, final int ru, final int retIxOffset) { final int nCol = _m1.getNumColumns(); final double[] _tmpDense = tmp.getDenseBlockValues(); @@ -847,7 +893,8 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { // unsafe decompress, since we count nonzeros afterwards. final DenseBlock db = _ret.getDenseBlock(); decompressToSubBlock(rl, ru, db, groups, its); @@ -887,7 +934,7 @@ private void processRow(final int ncol, final double[] ret, final int posR, fina private void processRowLeft(final int ncol, final double[] ret, final int posR, final double vr) { for(int col = 0; col < ncol; col++) - ret[posR + col] = _op.fn.execute(vr,ret[posR + col]); + ret[posR + col] = _op.fn.execute(vr, ret[posR + col]); } } @@ -917,8 +964,8 @@ protected BinaryMVColTaskSparse(CompressedMatrixBlock m1, MatrixBlock m2, Matrix @Override public Long call() { - tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); - final int _blklen = tmp.getNumRows(); + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlock(_blklen, _m1.getNumColumns()); final List groups = _m1.getColGroups(); final AIterator[] its = getIterators(groups, _rl); if(!_left) @@ -936,7 +983,8 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); processDenseLeft(rl, ru); tmp.reset(); @@ -971,8 +1019,107 @@ private final void processDenseLeft(final int rl, final int ru) { } } - private static MatrixBlock allocateTempUncompressedBlock(int cols) { - MatrixBlock out = new MatrixBlock(Math.max(DECOMPRESSION_BLEN / cols, 64), cols, false); + private static class BinaryMVColTaskSparseSparse implements Callable { + private final int _rl; + private final int _ru; + private final CompressedMatrixBlock _m1; + private final MatrixBlock _m2; + private final MatrixBlock _ret; + private final BinaryOperator _op; + + private MatrixBlock tmp; + + private boolean _left; + + protected BinaryMVColTaskSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, + BinaryOperator op, boolean left) { + _m1 = m1; + _m2 = m2; + _ret = ret; + _op = op; + _rl = rl; + _ru = ru; + _left = left; + } + + @Override + public Long call() { + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlockSparse(_blklen, _m1.getNumColumns()); + final List groups = _m1.getColGroups(); + final AIterator[] its = getIterators(groups, _rl); + if(!_left) + for(int r = _rl; r < _ru; r += _blklen) + processBlock(r, Math.min(r + _blklen, _ru), groups, its); + else + for(int r = _rl; r < _ru; r += _blklen) + processBlockLeft(r, Math.min(r + _blklen, _ru), groups, its); + return _ret.recomputeNonZeros(_rl, _ru - 1); + } + + private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getSparseBlock(), groups, its); + processDense(rl, ru); + tmp.reset(); + } + + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getSparseBlock(), groups, its); + processDenseLeft(rl, ru); + tmp.reset(); + } + + private final void processDense(final int rl, final int ru) { + final SparseBlock sb = _ret.getSparseBlock(); + final SparseBlock _tmpSparse = tmp.getSparseBlock(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl); + if(!_tmpSparse.isEmpty(tmpOff)){ + int[] aoff = _tmpSparse.indexes(tmpOff); + double[] aval = _tmpSparse.values(tmpOff); + int apos = _tmpSparse.pos(tmpOff); + int alen = apos + _tmpSparse.size(tmpOff); + + for(int j = apos; j < alen; j++){ + sb.append(row, aoff[j], _op.fn.execute(aval[j], vr)); + } + } + + } + } + + private final void processDenseLeft(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final SparseBlock sb = _ret.getSparseBlock(); + final SparseBlock _tmpSparse = tmp.getSparseBlock(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl) * nCol; + if(!_tmpSparse.isEmpty(tmpOff)){ + int[] aoff = _tmpSparse.indexes(tmpOff); + double[] aval = _tmpSparse.values(tmpOff); + int apos = _tmpSparse.pos(tmpOff); + int alen = apos + _tmpSparse.size(tmpOff); + for(int j = apos; j < alen; j++){ + sb.append(row, aoff[j], _op.fn.execute(vr,aval[j])); + } + } + } + } + } + + private static MatrixBlock allocateTempUncompressedBlock(int blklen, int cols) { + MatrixBlock out = new MatrixBlock(blklen, cols, false); + out.allocateBlock(); + return out; + } + + private static MatrixBlock allocateTempUncompressedBlockSparse(int blklen, int cols) { + MatrixBlock out = new MatrixBlock(blklen, cols, true); out.allocateBlock(); return out; } @@ -1199,6 +1346,25 @@ protected static void decompressToTmpBlock(final int rl, final int ru, final Den } } + protected static void decompressToTmpBlock(final int rl, final int ru, final SparseBlock db, + final List groups, final AIterator[] its) { + Timing time = new Timing(true); + for(int i = 0; i < groups.size(); i++) { + final AColGroup g = groups.get(i); + if(g.getCompType() == CompressionType.SDC) + ((ASDCZero) g).decompressToSparseBlock(db, rl, ru, -rl, 0, its[i]); + else + g.decompressToSparseBlock(db, rl, ru, -rl, 0); + } + + if(DMLScript.STATISTICS) { + final double t = time.stop(); + DMLCompressionStatistics.addDecompressToBlockTime(t, 1); + if(LOG.isTraceEnabled()) + LOG.trace("decompressed block w/ k=" + 1 + " in " + t + "ms."); + } + } + protected static AIterator[] getIterators(final List groups, final int rl) { final AIterator[] its = new AIterator[groups.size()]; for(int i = 0; i < groups.size(); i++) { @@ -1210,8 +1376,8 @@ protected static AIterator[] getIterators(final List groups, final in return its; } - private static Pair evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, - boolean left) { + private static Pair evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, + BinaryOperator op, boolean left) { final List groups = m1.getColGroups(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); @@ -1247,7 +1413,7 @@ private static Pair evaluateSparsityMVCol(CompressedMatrixBlock for(int r = 0; r < sampleRow; r++) { final double m = m2v[r]; final int off = r * sampleCol; - for(int c = 0; c < sampleCol; c++){ + for(int c = 0; c < sampleCol; c++) { int outVal = op.fn.execute(dv[off + c], m) != 0 ? 1 : 0; nnz += outVal; nnzPerRow[r] += outVal; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 99693635a9b..948a78f96af 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -486,7 +486,7 @@ private static List> generateUnaryAggregateOverlappingFuture final ArrayList tasks = new ArrayList<>(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); - final int blklen = Math.max(64, nRow / k); + final int blklen = Math.max(64, (nRow + k) / k); final List groups = m1.getColGroups(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); if(shouldFilter) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index d82d58e323e..cc7953f8c5d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -95,6 +96,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix if(x.isEmpty()) return returnEmpty(x, out); + if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){ + MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k); + return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k)); + } + // Morph the columns to efficient types for the operation. x = filterColGroups(x); double preFilterTime = t.stop(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java new file mode 100644 index 00000000000..89f639e4ab8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java @@ -0,0 +1,107 @@ +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; + +public class CLALibRemoveEmpty { + protected static final Log LOG = LogFactory.getLog(CLALibRemoveEmpty.class.getName()); + + /** + * CP rmempty operation (single input, single output matrix) + * + * @param in The input matrix + * @param ret The output matrix + * @param rows If we are removing based on rows, or columns. + * @param emptyReturn Return row/column of zeros for empty input. + * @param select An optional selection vector, to remove based on rather than empty rows or columns + * @return The result MatrixBlock, can be a different object that the caller used. + */ + public static MatrixBlock rmempty(CompressedMatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, + MatrixBlock select) { + if(ret == null) + ret = new MatrixBlock(); + MatrixBlock ret2 = LibMatrixReorg.rmemptyEarlyAbort(in, ret, rows, emptyReturn, select); + if(ret2 != null) + return ret2; + + if(rows) + return rmEmptyRows(in, ret, emptyReturn, select); + else + return rmEmptyCols(in, ret, emptyReturn, select); + } + + private static MatrixBlock rmEmptyCols(CompressedMatrixBlock in, MatrixBlock ret, boolean emptyReturn, + MatrixBlock select) { + if(select == null) + return fallback(in, false, emptyReturn, select, ret); + + int cOut = (int) select.getNonZeros(); + if(cOut == -1) + cOut = (int) select.recomputeNonZeros(); + if(cOut == 0){ + ret.reset(in.getNumRows(), !emptyReturn ? 0 : 1); + return ret; + } + + final boolean[] selectV = DataConverter + .convertToBooleanVector(CompressedMatrixBlock.getUncompressed(select, "decompressing selection in rmempty")); + + final List inG = in.getColGroups(); + final List retG = new ArrayList<>(inG.size()); + for(int i = 0; i < inG.size(); i++) { + AColGroup tmp = inG.get(i).removeEmptyCols(selectV); + if(tmp != null) + retG.add(tmp); + } + return new CompressedMatrixBlock(in.getNumRows(), cOut, -1, in.isOverlapping(), retG); + + } + + private static MatrixBlock rmEmptyRows(CompressedMatrixBlock in, MatrixBlock ret, boolean emptyReturn, + MatrixBlock select) { + if(select == null) + return fallback(in, true, emptyReturn, select, ret); + + select = CompressedMatrixBlock.getUncompressed(select, "decompressing selection in rmempty"); + + int rOut = (int) select.getNonZeros(); + if(rOut == -1) + rOut = (int) select.recomputeNonZeros(); + if(rOut == 0){ + ret.reset(!emptyReturn ? 0 : 1, in.getNumColumns()); + return ret; + } + + // TODO: add optimization to avoid linear scan and make selectV indexes, if selection is small relative to number + // of rows + // TODO: add decompress to boolean vector. + final boolean[] selectV = DataConverter.convertToBooleanVector(select); + + + + final List inG = in.getColGroups(); + final List retG = new ArrayList<>(inG.size()); + for(int i = 0; i < inG.size(); i++) { + retG.add(inG.get(i).removeEmptyRows(selectV, rOut)); + } + + return new CompressedMatrixBlock(rOut, in.getNumColumns(), -1, in.isOverlapping(), retG); + } + + private static MatrixBlock fallback(CompressedMatrixBlock in, boolean rows, boolean emptyReturn, MatrixBlock select, + MatrixBlock ret) { + LOG.warn("Decompressing because: removeEmptyOperations with select: " + (select != null) + " rows: " + rows); + MatrixBlock tmp = CompressedMatrixBlock.getUncompressed(in); + MatrixBlock select2 = CompressedMatrixBlock.getUncompressed(select); + return LibMatrixReorg.rmemptyUnsafe(tmp, ret, rows, emptyReturn, select2); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index f14d6833d95..ce06262b9a5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -31,6 +31,8 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -71,10 +73,10 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc if(m2 instanceof CompressedMatrixBlock) m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k); - if(betterIfDecompressed(m1)) { - // perform uncompressed multiplication. - return decompressingMatrixMult(m1, m2, k); - } + // if(betterIfDecompressed(m1)) { + // // perform uncompressed multiplication. + // return decompressingMatrixMult(m1, m2, k); + // } if(!allowOverlap) { LOG.trace("Overlapping output not allowed in call to Right MM"); @@ -143,7 +145,9 @@ private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, Mat private static boolean betterIfDecompressed(CompressedMatrixBlock m) { for(AColGroup g : m.getColGroups()) { - if(!(g instanceof ColGroupUncompressed) && g.getNumValues() * 2 >= m.getNumRows()) { + // TODO add subpport for decompressing RMM to ASDC and ASDCZero + if(!(g instanceof ColGroupUncompressed || g instanceof ASDC || g instanceof ASDCZero) && + g.getNumValues() * 2 >= m.getNumRows()) { return true; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java new file mode 100644 index 00000000000..c793e84ebef --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java @@ -0,0 +1,37 @@ +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue; + +public class CLALibSort { + + public static MatrixBlock sort(CompressedMatrixBlock mb, MatrixValue weights, MatrixBlock result, int k) { + // force uncompressed weights + weights = CompressedMatrixBlock.getUncompressed(weights); + + if(mb.getNumColumns() == 1 && mb.getColGroups().size() == 1 && weights == null) { + return sortSingleCol(mb, k); + } + + // fallback to uncompressed. + return CompressedMatrixBlock// + .getUncompressed(mb, "sortOperations")// + .sortOperations(weights, result); + } + + private static MatrixBlock sortSingleCol(CompressedMatrixBlock mb, int k) { + + AColGroup g = mb.getColGroups().get(0); + + AColGroup r = g.sort(); + + List rg = new ArrayList<>(); + rg.add(r); + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, rg); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java index a1d47a9b150..d0396b63810 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -42,6 +43,10 @@ private CLALibTSMM() { // private constructor } + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) { + return leftMultByTransposeSelf(cmb, new MatrixBlock(), k); + } + /** * Self left Matrix multiplication (tsmm) * @@ -51,24 +56,32 @@ private CLALibTSMM() { * @param ret The output matrix to put the result into * @param k The parallelization degree allowed */ - public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + final int numColumns = cmb.getNumColumns(); + final int numRows = cmb.getNumRows(); + if(cmb.isEmpty()) + return new MatrixBlock(numColumns, numColumns, true); + // create output matrix block + if(ret == null) + ret = new MatrixBlock(numColumns, numColumns, false); + else + ret.reset(numColumns, numColumns, false); + ret.allocateDenseBlock(); final List groups = cmb.getColGroups(); - final int numColumns = cmb.getNumColumns(); - if(groups.size() >= numColumns) { + if(groups.size() >= numColumns || containsUncompressedColGroup(groups)) { MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k); LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k); - return; + return ret; } - final int numRows = cmb.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); final boolean overlapping = cmb.isOverlapping(); if(shouldFilter) { final double[] constV = new double[numColumns]; final List filteredGroups = CLALibUtils.filterGroups(groups, constV); tsmmColGroups(filteredGroups, ret, numRows, overlapping, k); - addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV); + addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV, k); } else { @@ -77,17 +90,23 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret)); ret.examSparsity(); + return ret; + } + + private static boolean containsUncompressedColGroup(List groups) { + for(AColGroup g : groups) + if(g instanceof ColGroupUncompressed) + return true; + return false; } private static void addCorrectionLayer(List filteredGroups, MatrixBlock result, int nRows, int nCols, - double[] constV) { + double[] constV, int k) { final double[] retV = result.getDenseBlockValues(); final double[] filteredColSum = CLALibUtils.getColSum(filteredGroups, nCols, nRows); addCorrectionLayer(constV, filteredColSum, nRows, retV); } - - private static void tsmmColGroups(List groups, MatrixBlock ret, int nRows, boolean overlapping, int k) { if(k <= 1) tsmmColGroupsSingleThread(groups, ret, nRows); @@ -136,12 +155,12 @@ private static void tsmmColGroupsMultiThread(List groups, MatrixBlock public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) { final int nColRow = constV.length; - for(int row = 0; row < nColRow; row++){ + for(int row = 0; row < nColRow; row++) { int offOut = nColRow * row; final double v1l = constV[row]; final double v2l = filteredColSum[row] + constV[row] * nRow; - for(int col = row; col < nColRow; col++){ - ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; + for(int col = row; col < nColRow; col++) { + ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java new file mode 100644 index 00000000000..29650048509 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.utils; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +public class HashMapIntToInt implements Map { + + static final int DEFAULT_INITIAL_CAPACITY = 1 << 4; + static final float DEFAULT_LOAD_FACTOR = 0.75f; + + protected Node[] buckets; + + protected int size; + + public HashMapIntToInt(int capacity) { + alloc(Math.max(capacity, DEFAULT_INITIAL_CAPACITY)); + } + + protected void alloc(int size) { + Node[] tmp = (Node[]) new Node[size]; + buckets = tmp; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public boolean containsKey(Object key) { + return getI((Integer) key) != -1; + } + + @Override + public boolean containsValue(Object value) { + if(value instanceof Integer) { + for(Entry v : this.entrySet()) { + if(v.getValue().equals(value)) + return true; + } + } + return false; + + } + + @Override + public Integer get(Object key) { + final int i = getI((Integer) key); + if(i != -1) + return i; + else + return null; + } + + public int getI(int key) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b != null) { + do { + if(key == b.key) + return b.value; + } + while((b = b.next) != null); + } + return -1; + + } + + public int hash(int key) { + return Math.abs(Integer.hashCode(key) % buckets.length); + } + + @Override + public Integer put(Integer key, Integer value) { + int i = putI(key, value); + if(i != -1) + return i; + else + return null; + } + + @Override + public Integer putIfAbsent(Integer key, Integer value) { + int i = putIfAbsentI(key, value); + if(i != -1) + return i; + else + return null; + } + + public int putIfAbsentI(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return putIfAbsentBucket(ix, key, value); + + } + + public int putIfAbsentReturnVal(int key, int value) { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + } + + public int putIfAbsentReturnValHash(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + + } + + private int putIfAbsentBucket(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(b.key == key) + return b.value; + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return -1; + } + b = b.next; + } + } + + private int putIfAbsentBucketReturnval(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(b.key == key) + return b.value; + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return value; + } + b = b.next; + } + } + + public int putI(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return addToBucket(ix, key, value); + + } + + private int createBucket(int ix, int key, int value) { + buckets[ix] = new Node(key, value, null); + size++; + return -1; + } + + private int createBucketReturnVal(int ix, int key, int value) { + buckets[ix] = new Node(key, value, null); + size++; + return value; + } + + private int addToBucket(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(key == b.key) { + int tmp = b.getValue(); + b.setValue(value); + return tmp; + } + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return -1; + } + b = b.next; + } + } + + private void resize() { + if(size > buckets.length * DEFAULT_LOAD_FACTOR) { + + Node[] tmp = (Node[]) new Node[buckets.length * 2]; + Node[] oldBuckets = buckets; + buckets = tmp; + size = 0; + for(Node n : oldBuckets) { + if(n != null) + do { + put(n.key, n.value); + } + while((n = n.next) != null); + } + + } + } + + @Override + public Integer remove(Object key) { + throw new UnsupportedOperationException("Unimplemented method 'remove'"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("Unimplemented method 'putAll'"); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("Unimplemented method 'clear'"); + } + + @Override + public Set keySet() { + throw new UnsupportedOperationException("Unimplemented method 'keySet'"); + } + + @Override + public Collection values() { + throw new UnsupportedOperationException("Unimplemented method 'values'"); + } + + @Override + public Set> entrySet() { + return new EntrySet(); + } + + @Override + public void forEach(BiConsumer action) { + + for(Node n : buckets) { + if(n != null) { + do { + action.accept(n.key, n.value); + } + while((n = n.next) != null); + } + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(size() * 3); + this.forEach((k, v) -> { + sb.append("(" + k + "→" + v + ")"); + }); + return sb.toString(); + } + + private static class Node implements Entry { + final int key; + int value; + Node next; + + Node(int key, int value, Node next) { + this.key = key; + this.value = value; + this.next = next; + } + + public final void setNext(Node n) { + next = n; + } + + @Override + public Integer getKey() { + return key; + } + + @Override + public Integer getValue() { + return value; + } + + @Override + public Integer setValue(Integer value) { + return this.value = value; + } + } + + private final class EntrySet extends AbstractSet> { + + @Override + public int size() { + return size; + } + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + } + + private final class EntryIterator implements Iterator> { + Node next; + int bucketId = 0; + + protected EntryIterator() { + + for(; bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public Entry next() { + + Node e = next; + + if(e.next != null) + next = e.next; + else { + for(; ++bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + if(bucketId >= buckets.length) + next = null; + } + + return e; + } + + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java index fc0aa3b1a29..4940dd801b3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java @@ -27,9 +27,18 @@ import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.IndexFunction; +import org.apache.sysds.runtime.functionobjects.KahanPlus; +import org.apache.sysds.runtime.functionobjects.Mean; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction; +import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; public class FederatedWorkloadAnalyzer { protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName()); @@ -55,7 +64,7 @@ public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) { } public void compressRun(ExecutionContext ec, long tid) { - if(counter >= compressRunFrequency ){ + if(counter >= compressRunFrequency) { counter = 0; get(tid).forEach((K, V) -> CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V)); } @@ -68,6 +77,7 @@ private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstr public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap mm, ComputationCPInstruction cpIns) { // TODO: Count transitive closure via lineage + // TODO: add more operations if(cpIns instanceof AggregateBinaryCPInstruction) { final String n1 = cpIns.input1.getName(); MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1); @@ -81,15 +91,48 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap mm, long id) { @@ -117,8 +160,8 @@ private static boolean validSize(int nRow, int nCol) { return nRow > 90 && nRow >= nCol; } - @Override - public String toString(){ + @Override + public String toString() { StringBuilder sb = new StringBuilder(); sb.append(this.getClass().getSimpleName()); sb.append(" Counter: "); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 99cce9f9e97..972a2893fd8 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -377,7 +377,7 @@ public static double parseDouble(String value) { return Double.POSITIVE_INFINITY; else if(len == 4 && value.compareToIgnoreCase("-Inf") == 0) return Double.NEGATIVE_INFINITY; - throw new DMLRuntimeException(e); + throw e; } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index b26695e5797..84e4e89a420 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -174,6 +174,29 @@ public int putIfAbsentReturnVal(K key, int value) { } + + public int putIfAbsentReturnValHash(K key, int value) { + + if(key == null) { + if(nullV == -1) { + size++; + nullV = value; + return -1; + } + else + return nullV; + } + else { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + } + + } + private int putIfAbsentBucket(int ix, K key, int value) { Node b = buckets[ix]; while(true) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 1fc582924e4..292fcb52bf5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -607,7 +607,6 @@ public double getAsNaNDouble(int i) { private static double getAsDouble(String s) { try { - return DoubleArray.parseDouble(s); } catch(Exception e) { @@ -617,7 +616,8 @@ private static double getAsDouble(String s) { else if(ls.equals("false") || ls.equals("f")) return 0; else - throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw e; // for efficiency + // throw new DMLRuntimeException("Unable to change to double: " + s, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java index 032afe2cd7c..987d14106ac 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java @@ -32,11 +32,17 @@ import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; -public interface MatrixBlockFromFrame { +public class MatrixBlockFromFrame { public static final Log LOG = LogFactory.getLog(MatrixBlockFromFrame.class.getName()); public static final int blocksizeIJ = 32; + public static Boolean WARNED_FOR_FAILED_CAST = false; + + private MatrixBlockFromFrame(){ + // private constructor for code coverage. + } + /** * Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type * double, we do a best effort conversion of non-double types which might result in errors for non-numerical data. @@ -94,10 +100,25 @@ else if(ret.getNumRows() != m || ret.getNumColumns() != n || ret.isInSparseForma } private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) { - if(mb.getDenseBlock().isContiguous()) - return convertContiguous(frame, mb, n, rl, ru); - else - return convertGeneric(frame, mb, n, rl, ru); + try { + + if(mb.getDenseBlock().isContiguous()) + return convertContiguous(frame, mb, n, rl, ru); + else + return convertGeneric(frame, mb, n, rl, ru); + } + catch(NumberFormatException | DMLRuntimeException e) { + synchronized(WARNED_FOR_FAILED_CAST){ + if(!WARNED_FOR_FAILED_CAST) { + LOG.error( + "Failed to convert to Matrix because of number format errors, falling back to NaN on incompatible cells", + e); + WARNED_FOR_FAILED_CAST = true; + } + } + return convertSafeCast(frame, mb, n, rl, ru); + + } } private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k) throws Exception { @@ -169,4 +190,37 @@ private static long convertBlockGeneric(final FrameBlock frame, long lnnz, final } return lnnz; } + + private static long convertSafeCast(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, + final int ru) { + final DenseBlock c = mb.getDenseBlock(); + long lnnz = 0; + for(int bi = rl; bi < ru; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + int bimin = Math.min(bi + blocksizeIJ, ru); + int bjmin = Math.min(bj + blocksizeIJ, n); + lnnz = convertBlockSafeCast(frame, lnnz, c, bi, bj, bimin, bjmin); + } + } + return lnnz; + } + + private static long convertBlockSafeCast(final FrameBlock frame, long lnnz, final DenseBlock c, final int rl, + final int cl, final int ru, final int cu) { + for(int i = rl; i < ru; i++) { + final double[] cvals = c.values(i); + final int cpos = c.pos(i); + for(int j = cl; j < cu; j++) { + try { + lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; + } + catch(NumberFormatException | DMLRuntimeException e) { + lnnz += 1; + cvals[cpos + j] = Double.NaN; + } + } + } + return lnnz; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java index 39735be62e0..eed2c58f78c 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java @@ -54,7 +54,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, - DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, + DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK, MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE} private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; @@ -120,6 +120,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, String2BuiltinCode.put( "_map", BuiltinCode.MAP); String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP); String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA); + String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK); } protected Builtin(BuiltinCode bf) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 28b8775ebd5..86184f47be6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME) return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str); + else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR) + return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX) return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str); else diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java new file mode 100644 index 00000000000..99b3c1a3b13 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.util.Arrays; + +import org.apache.sysds.common.Builtins; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; +import org.apache.sysds.runtime.transform.TfUtils.TfMethod; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONObject; + +public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction { + // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); + + protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, + String opcode, String istr) { + super(CPType.Binary, op, in1, in2, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + // get input frames + FrameBlock inBlock1 = ec.getFrameInput(input1.getName()); + ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true); + if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) { + processGetCategorical(ec, inBlock1, spec); + } + else { + throw new DMLRuntimeException("Unsupported operation"); + } + + // Release the memory occupied by input frames + ec.releaseFrameInput(input1.getName()); + } + + public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + try { + + // MatrixBlock ret = new MatrixBlock(); + int nCol = f.getNumColumns(); + + JSONObject jSpec = new JSONObject(spec.getStringValue()); + + if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) { + throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); + } + + String recode = TfMethod.RECODE.toString(); + String dummycode = TfMethod.DUMMYCODE.toString(); + + int[] lengths = new int[nCol]; + // assume all columns encode to at least one column. + Arrays.fill(lengths, 1); + boolean[] categorical = new boolean[nCol]; + + if(jSpec.containsKey(recode)) { + JSONArray a = jSpec.getJSONArray(recode); + for(Object aa : a) { + int av = (Integer) aa - 1; + categorical[av] = true; + } + } + + if(jSpec.containsKey(dummycode)) { + JSONArray a = jSpec.getJSONArray(dummycode); + for(Object aa : a) { + int av = (Integer) aa - 1; + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + lengths[av] = ndist; + categorical[av] = true; + } + } + + // get total size after mapping + + int sumLengths = 0; + for(int i : lengths) { + sumLengths += i; + } + + MatrixBlock ret = new MatrixBlock(1, sumLengths, false); + ret.allocateDenseBlock(); + int off = 0; + for(int i = 0; i < lengths.length; i++) { + for(int j = 0; j < lengths[i]; j++) { + ret.set(0, off++, categorical[i] ? 1 : 0); + } + } + + ec.setMatrixOutput(output.getName(), ret); + + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java index 2ec23037385..d76dbe0d45e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java @@ -80,8 +80,15 @@ public void processInstruction(ExecutionContext ec) { retBlock = inBlock1; } else { - if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) && !compressedLeft && !compressedRight) + if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) ){ + if(compressedLeft) + inBlock1 = CompressedMatrixBlock.getUncompressed(inBlock1, getOpcode()); + + if(compressedRight) + inBlock2 = CompressedMatrixBlock.getUncompressed(inBlock2, getOpcode()); + retBlock = LibCommonsMath.matrixMatrixOperations(inBlock1, inBlock2, getOpcode()); + } else { // Perform computation using input matrices, and produce the result matrix BinaryOperator bop = (BinaryOperator) _optr; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..e53958ac4b8 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -352,7 +352,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString())) { // compute transformdecode Decoder decoder = DecoderFactory .createDecoder(getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns()); - FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema())); + FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()), InfrastructureAnalyzer.getLocalParallelism()); fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns())); // release locks diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 5dd8e55e821..93390cc686d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.io.FileFormatProperties; @@ -918,7 +919,7 @@ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { switch( getInput1().getDataType() ) { case FRAME: { FrameBlock fin = ec.getFrameInput(getInput1().getName()); - MatrixBlock out = DataConverter.convertToMatrixBlock(fin); + MatrixBlock out = MatrixBlockFromFrame.convertToMatrixBlock(fin, k); ec.releaseFrameInput(getInput1().getName()); ec.setMatrixOutput(output.getName(), out); break; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java new file mode 100644 index 00000000000..79f08cb353a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.CorrectionLocationType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.cp.KahanObject; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; + +public class LibAggregateUnarySpecialization { + protected static final Log LOG = LogFactory.getLog(LibAggregateUnarySpecialization.class.getName()); + + public static void aggregateUnary(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.sparseSafe) + sparseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + else + denseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + } + + private static void sparseAggregateUnaryHelp(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, + int blen, MatrixIndexes indexesIn) { + // initialize result + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + + if(mb.sparse && mb.sparseBlock != null) { + SparseBlock a = mb.sparseBlock; + for(int r = 0; r < Math.min(mb.rlen, a.numRows()); r++) { + if(a.isEmpty(r)) + continue; + int apos = a.pos(r); + int alen = a.size(r); + int[] aix = a.indexes(r); + double[] aval = a.values(r); + for(int i = apos; i < apos + alen; i++) { + tempCellIndex.set(r, aix[i]); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, aval[i], + buffer); + } + } + } + else if(!mb.sparse && mb.denseBlock != null) { + DenseBlock a = mb.getDenseBlock(); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, a.get(i, j), + buffer); + } + } + } + + private static void denseAggregateUnaryHelp(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, + mb.get(i, j), buffer); + } + } + + private static void incrementalAggregateUnaryHelp(AggregateOperator aggOp, MatrixBlock result, int row, int column, + double newvalue, KahanObject buffer) { + if(aggOp.existsCorrection()) { + if(aggOp.correction == CorrectionLocationType.LASTROW || + aggOp.correction == CorrectionLocationType.LASTCOLUMN) { + int corRow = row, corCol = column; + if(aggOp.correction == CorrectionLocationType.LASTROW)// extra row + corRow++; + else if(aggOp.correction == CorrectionLocationType.LASTCOLUMN) + corCol++; + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + + buffer._sum = result.get(row, column); + buffer._correction = result.get(corRow, corCol); + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue); + result.set(row, column, buffer._sum); + result.set(corRow, corCol, buffer._correction); + } + else if(aggOp.correction == CorrectionLocationType.NONE) { + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + } + else// for mean + { + int corRow = row, corCol = column; + int countRow = row, countCol = column; + if(aggOp.correction == CorrectionLocationType.LASTTWOROWS) { + countRow++; + corRow += 2; + } + else if(aggOp.correction == CorrectionLocationType.LASTTWOCOLUMNS) { + countCol++; + corCol += 2; + } + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + buffer._sum = result.get(row, column); + buffer._correction = result.get(corRow, corCol); + double count = result.get(countRow, countCol) + 1.0; + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue, count); + result.set(row, column, buffer._sum); + result.set(corRow, corCol, buffer._correction); + result.set(countRow, countCol, count); + } + + } + else { + newvalue = aggOp.increOp.fn.execute(result.get(row, column), newvalue); + result.set(row, column, newvalue); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index af702cb7fad..3113850ec80 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -3234,6 +3234,11 @@ private static void matrixMultWDivMMDense(MatrixBlock mW, MatrixBlock mU, Matrix DenseBlock x = (mX==null) ? null : mX.getDenseBlock(); DenseBlock c = ret.getDenseBlock(); + if(c == null){ + ret.allocateDenseBlock(); + c = ret.getDenseBlock(); + } + //approach: iterate over non-zeros of w, selective mm computation //cache-conscious blocking: due to blocksize constraint (default 1000), //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) @@ -3380,6 +3385,11 @@ private static void matrixMultWDivMMGeneric(MatrixBlock mW, MatrixBlock mU, Matr //output always in dense representation DenseBlock c = ret.getDenseBlock(); + + if(c == null){ + ret.allocateDenseBlock(); + c = ret.getDenseBlock(); + } //approach: iterate over non-zeros of w, selective mm computation if( mW.sparse ) //SPARSE diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 90ea445be8d..1c0535f3b36 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -844,16 +844,34 @@ public static List reshape(IndexedMatrixValue in, DataCharac } /** - * CP rmempty operation (single input, single output matrix) + * CP rmempty operation (single input, single output matrix) * - * @param in input matrix - * @param ret output matrix - * @param rows ? - * @param emptyReturn return row/column of zeros for empty input - * @param select ? - * @return matrix block + * @param in The input matrix + * @param ret The output matrix + * @param rows If we are removing based on rows, or columns. + * @param emptyReturn Return row/column of zeros for empty input + * @param select An optional selection vector, to remove based on rather than empty rows or columns + * @return The result MatrixBlock */ public static MatrixBlock rmempty(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select) { + if(ret == null) + ret = new MatrixBlock(); + MatrixBlock ret2 = rmemptyEarlyAbort(in, ret, rows, emptyReturn, select); + if(ret2 != null ) + return ret2; + // core removeEmpty + return rmemptyUnsafe(in, ret, rows, emptyReturn, select); + } + + public static MatrixBlock rmemptyUnsafe(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, + MatrixBlock select) { + if( rows ) + return removeEmptyRows(in, ret, select, emptyReturn); + else // cols + return removeEmptyColumns(in, ret, select, emptyReturn); + } + + public static MatrixBlock rmemptyEarlyAbort(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select){ //check for empty inputs //(the semantics of removeEmpty are that for an empty m-by-n matrix, the output //is an empty 1-by-n or m-by-1 matrix because we don't allow matrices with dims 0) @@ -870,12 +888,8 @@ public static MatrixBlock rmempty(MatrixBlock in, MatrixBlock ret, boolean rows, if( select != null && (select.nonZeros == (rows?in.rlen:in.clen)) ) { return in; } - - // core removeEmpty - if( rows ) - return removeEmptyRows(in, ret, select, emptyReturn); - else //cols - return removeEmptyColumns(in, ret, select, emptyReturn); + + return null; } /** @@ -3500,6 +3514,25 @@ private static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, Matr rlen2 = (int)select.getNonZeros(); } + return removeEmptyRows(in, ret, emptyReturn, select == null, flags, rlen2); + } + + /** + * Remove selected rows, based on the boolean array given. Note this function is internal use only, and require a + * boolean vector to be constructed first. + * + * @param in Input to remove rows from + * @param ret Output to assign the result into + * @param emptyReturn If the output is allowed to be empty. + * @param selectNull If the original caller did not have a selection matrix. + * @param flags The boolean selection vector to specify which rows to keep. + * @param rlen2 The number of true values in the flags argument. + * @return Another reference to the ret matrix input argument. + */ + public static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, boolean emptyReturn, boolean selectNull, + boolean[] flags, int rlen2) { + final int m = in.rlen; + final int n = in.clen; //Step 2: reset result and copy rows //dense stays dense if correct input representation (but robust for any input), //sparse might be dense/sparse @@ -3509,7 +3542,7 @@ private static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, Matr if( in.isEmptyBlock(false) ) return ret; - if( SHALLOW_COPY_REORG && m == rlen2 && select == null ) { + if( SHALLOW_COPY_REORG && m == rlen2 && selectNull ) { // the condition m==rlen2 is not enough with non-empty 1-row input but empty // 1-row select vector because if emptyReturn should output a single empty row ret.sparse = in.sparse; @@ -3552,7 +3585,7 @@ else if( !in.sparse && !ret.sparse ) //DENSE <- DENSE } //check sparsity - ret.nonZeros = (select==null) ? + ret.nonZeros = (selectNull) ? in.nonZeros : ret.recomputeNonZeros(); ret.examSparsity(); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 3dd8b2ad3b4..56095ab3d05 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -1315,7 +1315,7 @@ public void examSparsity(boolean allowCSR, int k) { else if( !sparse && sparseDst ) denseToSparse(allowCSR, k); } - + public static boolean evalSparseFormatInMemory(DataCharacteristics dc) { return evalSparseFormatInMemory(dc.getRows(), dc.getCols(), dc.getNonZeros()); } @@ -1387,12 +1387,13 @@ public void denseToSparse(boolean allowCSR, int k){ LibMatrixDenseToSparse.denseToSparse(this, allowCSR, k); } - public final void sparseToDense() { - sparseToDense(1); + public final MatrixBlock sparseToDense() { + return sparseToDense(1); } - public void sparseToDense(int k) { + public MatrixBlock sparseToDense(int k) { LibMatrixSparseToDense.sparseToDense(this, k); + return this; } /** @@ -2954,13 +2955,14 @@ public boolean isShallowSerialize(boolean inclConvert) { boolean sparseDst = evalSparseFormatOnDisk(); return !sparse || !sparseDst || (sparse && sparseBlock instanceof SparseBlockCSR) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD - <= getExactSerializedSize()) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && nonZeros < Integer.MAX_VALUE //CSR constraint - && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE - && !isUltraSparseSerialize(sparseDst)); + || (sparse && sparseBlock instanceof SparseBlockMCSR); + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD + // <= getExactSerializedSize()) + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && nonZeros < Integer.MAX_VALUE //CSR constraint + // && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE + // && !isUltraSparseSerialize(sparseDst)); } @Override @@ -4650,7 +4652,7 @@ public final MatrixBlock sortOperations(MatrixValue weights){ return sortOperations(weights, null); } - public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { + public final MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { return sortOperations(weights, result, 1); } @@ -4754,7 +4756,17 @@ public static double computeIQMCorrection(double sum, double sum_wt, return (sum + q25Part*q25Val - q75Part*q75Val) / (sum_wt*0.5); } - public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { + /** + * Pick the quantiles out of this matrix. If this matrix contains two columns it is weighted quantile picking. + * If a single column it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantiles The quantiles to pick + * @param ret The result matrix + * @return The result matrix + */ + public final MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { return pickValues(quantiles, ret, false); } @@ -4778,17 +4790,56 @@ public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret, boolean av return output; } - + + /** + * Pick the median quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @return The quantile + */ public double median() { double sum_wt = sumWeightForQuantile(); return pickValue(0.5, sum_wt%2==0); } - + + /** + * Pick a specific quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @return The quantile + */ public final double pickValue(double quantile){ return pickValue(quantile, false); } - public double pickValue(double quantile, boolean average) { + /** + * Pick a specific quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @param average If the quantile is averaged. + * @return The quantile + */ + public final double pickValue(double quantile, boolean average) { + if(this.getNumColumns() == 1) + return pickUnweightedValue(quantile, average); + return pickWeightedValue(quantile, average); + } + + private double pickUnweightedValue(double quantile, boolean average) { + double pos = quantile * rlen; + if(average && (int) pos != pos) + return (get((int) Math.floor(pos), 0) + get(Math.min(rlen - 1, (int) Math.ceil(pos)), 0)) / 2; + else + return get(Math.min(rlen - 1, (int) Math.round(pos)), 0); + } + + private double pickWeightedValue(double quantile, boolean average) { double sum_wt = sumWeightForQuantile(); // do averaging only if it is asked for; and sum_wt is even @@ -5342,8 +5393,8 @@ public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar * (i1,j1,v2) from input2 (that) * (w) from scalar_input3 (scalarThat2) * - * @param thatMatrix matrix value - * @param thatScalar scalar double + * @param thatMatrix matrix value, the vector to encode via table + * @param thatScalar scalar double, w, that is the weight to multiply on the encoded values * @param resultBlock result matrix block * @return resultBlock */ diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 724af1be630..70834675ded 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -23,6 +23,10 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +34,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Base class for all transform decoders providing both a row and block @@ -77,8 +82,31 @@ public String[] getColnames() { * @param k Parallelization degree * @return returns the given output frame block for convenience */ - public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { - return decode(in, out); + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max((in.getNumRows() + k) / k, 1000); + + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); + } + + for(Future f : tasks) + f.get(); + return out; + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index edee095f612..c9fcc23990a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -43,15 +44,18 @@ public class DecoderBin extends Decoder { // a) column bin boundaries private int[] _numBins; + private int[] _dcCols = null; + private int[] _srcCols = null; private double[][] _binMins = null; private double[][] _binMaxs = null; - public DecoderBin() { - super(null, null); - } + // public DecoderBin() { + // super(null, null); + // } - protected DecoderBin(ValueType[] schema, int[] binCols) { + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); + _dcCols = dcCols; } @Override @@ -66,14 +70,28 @@ public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { for( int i=rl; i< ru; i++ ) { for( int j=0; j<_colList.length; j++ ) { final Array a = out.getColumn(_colList[j] - 1); - final double val = in.get(i, _colList[j] - 1); + final double val = in.get(i, _srcCols[j] - 1); if(!Double.isNaN(val)){ - final int key = (int) Math.round(val); - double bmin = _binMins[j][key - 1]; - double bmax = _binMaxs[j][key - 1]; - double oval = bmin + (bmax - bmin) / 2 // bin center - + (val - key) * (bmax - bmin); // bin fractions - a.set(i, oval); + try{ + + final int key = (int) Math.round(val); + if(key == 0){ + a.set(i, _binMins[j][key]); + } + else{ + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } + } + catch(Exception e){ + LOG.error(a); + LOG.error(in.slice(0, in.getNumRows()-1, _colList[j]-1,_colList[j]-1)); + LOG.error( val); + throw e; + } } else a.set(i, val); // NaN @@ -111,6 +129,34 @@ public void initMetaData(FrameBlock meta) { _binMaxs[j][i] = Double.parseDouble(parts[1]); } } + + + if( _dcCols.length > 0 ) { + //prepare source column id mapping w/ dummy coding + _srcCols = new int[_colList.length]; + int ix1 = 0, ix2 = 0, off = 0; + while( ix1<_colList.length ) { + if( ix2>=_dcCols.length || _colList[ix1] < _dcCols[ix2] ) { + _srcCols[ix1] = _colList[ix1] + off; + ix1 ++; + } + else { //_colList[ix1] > _dcCols[ix2] + ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; + String v = meta.getString(0, _dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } + ix2 ++; + } + } + } + else { + //prepare direct source column mapping + _srcCols = _colList; + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..dff85e72dc6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -25,13 +25,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; /** * Simple composite decoder that applies a list of decoders @@ -50,7 +47,7 @@ protected DecoderComposite(ValueType[] schema, List decoders) { _decoders = decoders; } - public DecoderComposite() { super(null, null); } + // public DecoderComposite() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -59,33 +56,6 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { return out; } - - @Override - public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { - final ExecutorService pool = CommonThreadPool.get(k); - out.ensureAllocatedColumns(in.getNumRows()); - try { - final List> tasks = new ArrayList<>(); - int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } - } - for(Future f : tasks) - f.get(); - return out; - } - catch(Exception e) { - throw new RuntimeException(e); - } - finally { - pool.shutdown(); - } - } - @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ for( Decoder decoder : _decoders ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0c4c6b42690..debce027680 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -27,31 +27,30 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; /** - * Simple atomic decoder for dummycoded columns. This decoder builds internally - * inverted column mappings from the given frame meta data. - * + * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given + * frame meta data. + * */ -public class DecoderDummycode extends Decoder -{ +public class DecoderDummycode extends Decoder { private static final long serialVersionUID = 4758831042891032129L; - + private int[] _clPos = null; private int[] _cuPos = null; - + protected DecoderDummycode(ValueType[] schema, int[] dcCols) { - //dcCols refers to column IDs in output (non-dc) + // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { - //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); decode(in, out, 0, in.getNumRows()); return out; @@ -59,59 +58,98 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { - //TODO perf (exploit sparse representation for better asymptotic behavior) - // out.ensureAllocatedColumns(in.getNumRows()); - for( int i=rl; i= low && aix[h] < high) { + int k = aix[h]; + int col = _colList[j] - 1; + out.getColumn(col).set(i, k - _clPos[j] + 1); + } + // limit the binary search. + apos = h; + } + + } + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List dcList = new ArrayList<>(); List clPosList = new ArrayList<>(); List cuPosList = new ArrayList<>(); - + // get the column IDs for the sub range of the dummycode columns and their destination positions, // where they will be decoded to - for( int j=0; j<_colList.length; j++ ) { + for(int j = 0; j < _colList.length; j++) { int colID = _colList[j]; - if (colID >= colStart && colID < colEnd) { + if(colID >= colStart && colID < colEnd) { dcList.add(colID - (colStart - 1)); clPosList.add(_clPos[j] - dummycodedOffset); cuPosList.add(_cuPos[j] - dummycodedOffset); } } - if (dcList.isEmpty()) + if(dcList.isEmpty()) return null; // create sub-range decoder int[] colList = dcList.stream().mapToInt(i -> i).toArray(); - DecoderDummycode subRangeDecoder = new DecoderDummycode( - Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + DecoderDummycode subRangeDecoder = new DecoderDummycode(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList); subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); return subRangeDecoder; } - + @Override public void updateIndexRanges(long[] beginDims, long[] endDims) { if(_colList == null) return; - + long lowerColDest = beginDims[1]; long upperColDest = endDims[1]; for(int i = 0; i < _colList.length; i++) { long numDistinct = _cuPos[i] - _clPos[i]; - + if(_cuPos[i] <= beginDims[1] + 1) if(numDistinct > 0) lowerColDest -= numDistinct - 1; - + if(_cuPos[i] <= endDims[1] + 1) if(numDistinct > 0) upperColDest -= numDistinct - 1; @@ -119,16 +157,25 @@ public void updateIndexRanges(long[] beginDims, long[] endDims) { beginDims[1] = lowerColDest; endDims[1] = upperColDest; } - + @Override public void initMetaData(FrameBlock meta) { - _clPos = new int[_colList.length]; //col lower pos - _cuPos = new int[_colList.length]; //col upper pos - for( int j=0, off=0; j<_colList.length; j++ ) { + _clPos = new int[_colList.length]; // col lower pos + _cuPos = new int[_colList.length]; // col upper pos + for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID-1]; - int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct(); - ndist = ndist < -1 ? 0: ndist; + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + String v = meta.getString(0, colID - 1); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + + ndist = ndist < -1 ? 0 : ndist; // safety if all values was null. + _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; off += ndist - 1; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 0a400e6da92..12ba2968877 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -64,34 +64,52 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] try { //parse transform specification JSONObject jSpec = new JSONObject(spec); - List ldecoders = new ArrayList<>(); - //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'hash', 'dummy', and 'pass-through' List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); + List hcIDs = Arrays.asList(ArrayUtils.toObject( + TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); + // only specially treat the columns with both recode and dictionary rcIDs = unionDistinct(rcIDs, dcIDs); + // remove hash recoded. // todo potentially wrong and remove? + rcIDs = except(rcIDs, hcIDs); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); - + + // set the remaining columns to passthrough. + List ptIDs = UtilFunctions.getSeqList(1, len, 1); + // except recoded columns + ptIDs = except(ptIDs, rcIDs); + // binned columns + ptIDs = except(ptIDs, binIDs); + // hashed columns + ptIDs = except(ptIDs, hcIDs); // remove hashed columns + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.FP64; } + + // collect all the decoders in one list. + List ldecoders = new ArrayList<>(); if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, - ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !rcIDs.isEmpty() ) { + // todo figure out if we need to handle rc columns with regards to dictionary offsets. ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 5b6bf7a093e..c2de3ec1df3 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -49,7 +49,7 @@ protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { _dcCols = dcCols; } - public DecoderPassThrough() { super(null, null); } + // public DecoderPassThrough() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -61,13 +61,12 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { int clen = Math.min(_colList.length, out.getNumColumns()); - for( int i=rl; i _dcCols[ix2] ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + String v = meta.getString( 0,_dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 33459a1c4f9..1cf0b7c4b3f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.Pair; @@ -46,12 +47,11 @@ public class DecoderRecode extends Decoder private static final long serialVersionUID = -3784249774608228805L; private HashMap[] _rcMaps = null; - private Object[][] _rcMapsDirect = null; private boolean _onOut = false; - public DecoderRecode() { - super(null, null); - } + // public DecoderRecode() { + // super(null, null); + // } protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { super(schema, rcCols); @@ -59,8 +59,7 @@ protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { } public Object getRcMapValue(int i, long key) { - return (_rcMapsDirect != null && key > 0) ? - _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key); + return _rcMaps[i].get(key); } @Override @@ -129,27 +128,33 @@ public void initMetaData(FrameBlock meta) { for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - _rcMapsDirect = new Object[_rcMaps.length][]; - for( int i=0; i<_rcMaps.length; i++ ) { - Object[] arr = new Object[(int)max[i]]; - for(Entry e1 : _rcMaps[i].entrySet()) - arr[e1.getKey().intValue()-1] = e1.getValue(); - _rcMapsDirect[i] = arr; - } - } + // if( Arrays.stream(max).allMatch(v -> v < Integer.MAX_VALUE) ) { + // _rcMapsDirect = new Object[_rcMaps.length][]; + // for( int i=0; i<_rcMaps.length; i++ ) { + // Object[] arr = new Object[(int)max[i]]; + // for(Entry e1 : _rcMaps[i].entrySet()) + // arr[e1.getKey().intValue()-1] = e1.getValue(); + // _rcMapsDirect[i] = arr; + // } + // } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 400b7f64ffc..361c9c52135 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,7 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); - meta.set(0, _colID - 1, String.valueOf(_K)); + // set metadata of hash columns to magical hash value + k + meta.set(0, _colID - 1, String.format("¿%d" , _K)); + return meta; } @@ -154,7 +156,7 @@ public FrameBlock getMetaData(FrameBlock meta) { public void initMetaData(FrameBlock meta) { if(meta == null || meta.getNumRows() <= 0) return; - _K = UtilFunctions.parseToLong(meta.get(0, _colID - 1).toString()); + _K = UtilFunctions.parseToLong(meta.getString(0, _colID - 1).substring(1)); } @Override diff --git a/src/main/java/org/apache/sysds/utils/DoubleParser.java b/src/main/java/org/apache/sysds/utils/DoubleParser.java index 9c77a3e95c8..c0122f8061f 100644 --- a/src/main/java/org/apache/sysds/utils/DoubleParser.java +++ b/src/main/java/org/apache/sysds/utils/DoubleParser.java @@ -184,7 +184,7 @@ public interface DoubleParser { 0x8e679c2f5e44ff8fL}; public static double parseFloatingPointLiteral(String str, int offset, int endIndex) { - if(endIndex > 100) + if(endIndex > 100)// long string return Double.parseDouble(str); // Skip leading whitespace int index = skipWhitespace(str, offset, endIndex); @@ -197,9 +197,10 @@ public static double parseFloatingPointLiteral(String str, int offset, int endIn } // Parse NaN or Infinity (this occurs rarely) - if(ch >= 'I') - return Double.parseDouble(str); - else if(str.charAt(endIndex - 1) >= 'a') + // : is the first character after numbers. + // 0 is the first number. + // we use the last position, since this is not allowed to be other values than a number. + if(str.charAt(endIndex - 1) > '9' || str.charAt(endIndex - 1) < '0') return Double.parseDouble(str); final double val = parseDecFloatLiteral(str, index, offset, endIndex); diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index e470dd82539..07abdbaff26 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -32,6 +32,7 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; +import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; @@ -2941,6 +2942,25 @@ public static void writeTestScalar(String file, double value) { } } + + /** + * Write scalar to file + * + * @param file File to write to + * @param value Value to write + */ + public static void writeTestScalar(String file, String value) { + try { + DataOutputStream out = new DataOutputStream(new FileOutputStream(file)); + try(PrintWriter pw = new PrintWriter(out)) { + pw.println(value); + } + } + catch(IOException e) { + fail("unable to write test scalar (" + file + "): " + e.getMessage()); + } + } + /** * Write scalar to file * diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java index d36c6167cf7..9d5976a8903 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java @@ -687,4 +687,54 @@ public void toRDDAndBack(int blen) { fail(e.getMessage()); } } + + @Test + public void removeEmptyOperationsBase1() { + removeEmptyOperations(false, false, null); + } + + @Test + public void removeEmptyOperationsBase2() { + removeEmptyOperations(true, false, null); + } + + @Test + public void removeEmptyOperationsBase3() { + removeEmptyOperations(false, true, null); + } + + @Test + public void removeEmptyOperationsBase4() { + removeEmptyOperations(true, true, null); + } + + @Test + public void removeEmptyOperationsSelect1() { + if(rows < 5000) { + MatrixBlock s = TestUtils.generateTestMatrixBlock(rows, 1, 1, 1, 0.05, 321); + removeEmptyOperations(true, false, s); + } + } + + @Test + public void removeEmptyOperationsSelect2() { + if(rows < 5000) { + MatrixBlock s = TestUtils.generateTestMatrixBlock(1, cols, 1, 1, 0.5, 321); + removeEmptyOperations(false, false, s); + } + } + + public void removeEmptyOperations(boolean rows, boolean emptyReturn, MatrixBlock select) { + try { + MatrixBlock a = cmb.removeEmptyOperations(null, rows, emptyReturn, select); + MatrixBlock b = mb.removeEmptyOperations(null, rows, emptyReturn, select); + compareResultMatrices(b, a, 0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index c3efeea4014..36dab4191ee 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -49,6 +49,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -468,6 +469,24 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); + } } private class FakeDictBasedColGroup extends ADictBasedColGroup { @@ -777,5 +796,23 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java index 194f581121a..a5bd3cebfb0 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.fail; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.junit.Test; @@ -115,6 +116,8 @@ public void testJoinWithSecondSubpartLeft() { private void partJoinVerification(IEncode er) { boolean incorrectUnique = e.getUnique() != er.getUnique(); + er.extractFacts(10000, 1.0, 1.0, new CompressionSettingsBuilder().create()); + if(incorrectUnique) { StringBuilder sb = new StringBuilder(); sb.append("\nFailed joining sub parts to recreate whole."); diff --git a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java index 182bd7fa37e..5a298f145ec 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java @@ -81,6 +81,10 @@ public static Collection data() { // Both Sparse and end dense joined tests.add(createT(1, 0.2, 10, 10, 0.1, 2, 1000, 1231521)); + + tests.add(createT(1, 1.0, 100, 1, 1.0, 10, 10000, 132)); + tests.add(createT(1, 1.0, 1000, 1, 1.0, 10, 10000, 132)); + return tests; } diff --git a/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java index 2e901eeb14d..3755365c018 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java @@ -28,13 +28,14 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.junit.Test; public class CustomOffsetTest { protected static final Log LOG = LogFactory.getLog(CustomOffsetTest.class.getName()); - static{ + static { CompressedMatrixBlock.debug = true; } @@ -96,4 +97,95 @@ public void printCache() { String s = off.toString(); assertTrue(s.contains("CacheRow")); } + + @Test + public void removeEmptyRows1() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, false, false, false, false}, 0); + assertEquals(1, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows2() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, true, false, false, false}, 0); + assertEquals(1, t.select.size()); + assertEquals(1, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows3() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, true, false, false, false}, 0); + assertEquals(2, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.select.get(1)); + assertEquals(2, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0, 1}), t.retOffset); + } + + @Test + public void removeEmptyRows4() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, true, false, false, true}, 0); + assertEquals(3, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.select.get(1)); + assertEquals(4, t.select.get(2)); + assertEquals(3, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0, 1, 2}), t.retOffset); + } + + @Test + public void removeEmptyRows5() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, false, false, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(4, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows6() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, true, true, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(2, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {2}), t.retOffset); + } + + @Test + public void removeEmptyRows7() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {true, false, false, true, true, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(2, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {3}), t.retOffset); + } + + @Test + public void removeEmptyRows8() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {true, false, false, false, false, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(4, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {1}), t.retOffset); + } + + @Test + public void removeEmptyRowsEmpty() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, false, false, false}, 0); + assertEquals(0, t.select.size()); + assertEquals(OffsetFactory.createOffset(new int[] {}), t.retOffset); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java index c6d52a70a51..872ec79c1f1 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java @@ -49,7 +49,7 @@ protected String getTestDir() { @Test public void testTranspose_CP() { - runTest(1500, 20, 1, 1, ExecType.CP, "transpose"); + runTest(1500, 20, 2, 1, ExecType.CP, "transpose"); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java index ee6a2953980..18ca2fbc454 100644 --- a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java +++ b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java @@ -270,4 +270,96 @@ protected void toStringTestHelper(ExecMode platform, String testName, String exp DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } + + @Test + public void testPrintWithDecimal(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "22"; + String expectedOutput = "22.00\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal2(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "5.244058388023880"; + String expectedOutput = "5.24\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal3(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "5.244058388023880"; + String expectedOutput = "5.2440583880\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal4(){ + String testName = "ToString12"; + + String decimalPoints = "4"; + String value = "5.244058388023880"; + String expectedOutput = "5.2441\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal5(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "0.000000008023880"; + String expectedOutput = "0.0000000080\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + protected void toStringTestHelper2(ExecMode platform, String testName, String expectedOutput, String decimalPoints, String value) { + ExecMode platformOld = rtplatform; + + rtplatform = platform; + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if (rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + try { + // Create and load test configuration + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[]{"-args", output(OUTPUT_NAME), value, decimalPoints}; + + // Run DML and R scripts + runTest(true, false, null, -1); + + // Compare output strings + String output = TestUtils.readDMLString(output(OUTPUT_NAME)); + TestUtils.compareScalars(expectedOutput, output); + } + finally { + // Reset settings + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java new file mode 100644 index 00000000000..30681f373e4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class GetCategoricalMaskTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskTest.class.getName()); + + private final static String TEST_NAME1 = "GetCategoricalMaskTest"; + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeApplyTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"})); + } + + @Test + public void testRecode() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 1, 1.0); + String spec = "{\"ids\": true, \"recode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testRecode2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8, ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 2, new double[] {0, 1}); + + String spec = "{\"ids\": true, \"recode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {0, 1, 1, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {1, 1, 1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash3() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 7, new double[] {1, 1, 1, 0, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + + @Test + public void testHybrid1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 9, new double[] {1, 1, 1, 0, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHybrid2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.BOOLEAN,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 10, new double[] {1, 1, 1, 1,1, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,2,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + private void runTransformTest(FrameBlock fb, String spec, MatrixBlock expected) throws Exception { + try { + + getAndLoadTestConfiguration(TEST_NAME1); + + String inF = input("F-In"); + String inS = input("spec"); + + TestUtils.writeTestFrame(inF, fb, fb.getSchema(), FileFormat.CSV); + TestUtils.writeTestScalar(input("spec"), spec); + + String out = output("ret"); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-args", inF, inS, out, expected.getNumColumns() + ""}; + + runTest(true, false, null, -1); + + MatrixBlock result = TestUtils.readBinary(out); + + TestUtils.compareMatrices(expected, result, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + +} diff --git a/src/test/scripts/functions/misc/ToString12.dml b/src/test/scripts/functions/misc/ToString12.dml new file mode 100644 index 00000000000..4f120630b75 --- /dev/null +++ b/src/test/scripts/functions/misc/ToString12.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix($2, rows=1, cols=1) +str = toString(X, rows=3, cols=3, decimal=$3) +write(str, $1) diff --git a/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml new file mode 100644 index 00000000000..5d7bb35a250 --- /dev/null +++ b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($1, data_type="frame", format="csv"); + +jspec = read($2, data_type="scalar", value_type="string"); + +[X, M] = transformencode(target=F1, spec=jspec); + +Cm = getCategoricalMask(M, jspec) +expectedColumns = $4 +if(ncol(Cm) != expectedColumns){ + stop("Wrong number of metadata columns in categorical mask") +} +# print mean to verify that Cm is a matrix, not a Frame according to compiler +print(mean(Cm)) + +write(Cm, $3, format="csv"); +