|
26 | 26 | import org.apache.sysds.runtime.DMLRuntimeException; |
27 | 27 | import org.apache.sysds.runtime.DMLScriptException; |
28 | 28 |
|
| 29 | +import jdk.incubator.vector.DoubleVector; |
| 30 | +import jdk.incubator.vector.VectorSpecies; |
| 31 | + |
29 | 32 |
|
30 | 33 | /** |
31 | 34 | * Class with pre-defined set of objects. This class can not be instantiated elsewhere. |
|
46 | 49 | public class Builtin extends ValueFunction |
47 | 50 | { |
48 | 51 | private static final long serialVersionUID = 3836744687789840574L; |
49 | | - |
| 52 | + |
50 | 53 | public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN, |
51 | 54 | MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, |
52 | 55 | STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, |
53 | 56 | TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, |
54 | 57 | DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, |
55 | 58 | MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE} |
56 | 59 |
|
| 60 | + private static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED; |
| 61 | + private static final int vLen = SPECIES.length(); |
| 62 | + |
57 | 63 |
|
58 | 64 | public BuiltinCode bFunc; |
59 | 65 |
|
@@ -197,6 +203,38 @@ else if (in < 0) |
197 | 203 | throw new DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc); |
198 | 204 | } |
199 | 205 | } |
| 206 | + |
| 207 | + public long execute (double[] a, double[] c, int start, int end) { |
| 208 | + long nnz = 0; |
| 209 | + |
| 210 | + //process rest or unsupported builtin codes |
| 211 | + final int end2 = (bFunc==BuiltinCode.ABS || bFunc==BuiltinCode.SQRT)? |
| 212 | + start+((end-start)%vLen) : end; |
| 213 | + for( int i = start; i < end2; i++) { |
| 214 | + c[i] = execute(a[i]); |
| 215 | + nnz += (c[i] != 0) ? 1 : 0; |
| 216 | + } |
| 217 | + |
| 218 | + nnz += (end-end2); |
| 219 | + if( bFunc == BuiltinCode.ABS) { |
| 220 | + for( int i = end2; i < end; i+=vLen ){ |
| 221 | + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, i); |
| 222 | + DoubleVector cVec = aVec.abs(); |
| 223 | + nnz -= cVec.eq(0).trueCount(); |
| 224 | + cVec.intoArray(c, i); |
| 225 | + } |
| 226 | + } |
| 227 | + else if(bFunc == BuiltinCode.SQRT ) { |
| 228 | + for( int i = end2; i < end; i+=vLen ){ |
| 229 | + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, i); |
| 230 | + DoubleVector cVec = aVec.sqrt(); |
| 231 | + nnz -= cVec.eq(0).trueCount(); |
| 232 | + cVec.intoArray(c, i); |
| 233 | + } |
| 234 | + } |
| 235 | + return nnz; |
| 236 | + } |
| 237 | + |
200 | 238 |
|
201 | 239 | @Override |
202 | 240 | public double execute (long in) { |
|
0 commit comments