diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 513b788b605..737922397e5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -31,9 +31,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Function; @@ -56,6 +58,7 @@ import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.frame.data.lib.FrameFromMatrixBlock; import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend; +import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema; import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; import org.apache.sysds.runtime.frame.data.lib.FrameLibRemoveEmpty; import org.apache.sysds.runtime.frame.data.lib.FrameUtil; @@ -214,8 +217,8 @@ public FrameBlock(Array[] data) { if(debug) { for(int i = 0; i < data.length; i++) { if(data[i].size() != getNumRows()) - throw new DMLRuntimeException( - "Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows()); + throw new DMLRuntimeException("Invalid Frame allocation with different size arrays " + + data[i].size() + " vs " + getNumRows()); } } } @@ -239,8 +242,8 @@ public FrameBlock(Array[] data, String[] colnames) { if(debug) { for(int i = 0; i < data.length; i++) { if(data[i].size() != getNumRows()) - throw new DMLRuntimeException( - "Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows()); + throw new DMLRuntimeException("Invalid Frame allocation with different size arrays " + + data[i].size() + " vs " + getNumRows()); } } } @@ -400,8 +403,8 @@ public Map getColumnNameIDMap() { } /** - * Allocate column data structures if necessary, i.e., if schema specified but not all column data structures created - * yet. + * Allocate column data structures if necessary, i.e., if schema specified but not all column data structures + * created yet. * * @param numRows number of rows */ @@ -640,8 +643,8 @@ public void appendColumn(int[] col) { } /** - * Append a column of value type LONG as the last column of the data frame. The given array is wrapped but not copied - * and hence might be updated in the future. + * Append a column of value type LONG as the last column of the data frame. The given array is wrapped but not + * copied and hence might be updated in the future. * * @param col array of longs */ @@ -701,7 +704,9 @@ public void appendColumns(double[][] cols) { Array[] tmpData = new Array[ncol]; for(int j = 0; j < ncol; j++) tmpData[j] = ArrayFactory.create(cols[j]); - _colnames = empty ? null : ArrayUtils.addAll(getColumnNames(), createColNames(getNumColumns(), ncol)); // before schema modification + _colnames = empty ? null : ArrayUtils.addAll(getColumnNames(), createColNames(getNumColumns(), ncol)); // before + // schema + // modification _schema = empty ? tmpSchema : ArrayUtils.addAll(_schema, tmpSchema); _coldata = empty ? tmpData : ArrayUtils.addAll(_coldata, tmpData); _nRow = cols[0].length; @@ -859,17 +864,22 @@ private double arraysSizeInMemory() { if(rlen > 1000 && clen > 10 && ConfigurationManager.isParallelIOEnabled()) { final ExecutorService pool = CommonThreadPool.get(); try { - size += pool.submit(() -> { - return Arrays.stream(_coldata).parallel() // parallel columns - .map(x ->x.getInMemorySize()).reduce(0L, (a,x) -> a + x); - }).get(); + List> f = new ArrayList<>(clen); + for(int i = 0; i < clen; i++) { + final int j = i; + f.add(pool.submit(() -> _coldata[j].getInMemorySize())); + } + + for(Future e : f) { + size += e.get(); + } } catch(InterruptedException | ExecutionException e) { LOG.error(e); for(Array aa : _coldata) size += aa.getInMemorySize(); } - finally{ + finally { pool.shutdown(); } } @@ -1012,11 +1022,11 @@ public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, IndexRange ixrange public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, int rl, int ru, int cl, int cu, FrameBlock ret) { // check the validity of bounds - if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl || - cu >= getNumColumns()) { + if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || + cu < cl || cu >= getNumColumns()) { throw new DMLRuntimeException( - "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] " - + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]."); + "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + + "] " + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]."); } if((ru - rl + 1) < rhsFrame.getNumRows() || (cu - cl + 1) < rhsFrame.getNumColumns()) { @@ -1132,11 +1142,11 @@ public FrameBlock slice(int rl, int ru, int cl, int cu, boolean deep, FrameBlock } protected void validateSliceArgument(int rl, int ru, int cl, int cu) { - if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl || - cu >= getNumColumns()) { + if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || + cu < cl || cu >= getNumColumns()) { throw new DMLRuntimeException( - "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] " - + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]"); + "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + + "] " + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]"); } } @@ -1338,6 +1348,14 @@ public final FrameBlock detectSchema(double sampleFraction, int k) { return FrameLibDetectSchema.detectSchema(this, sampleFraction, k); } + public final FrameBlock applySchema(FrameBlock schema) { + return FrameLibApplySchema.applySchema(this, schema); + } + + public final FrameBlock applySchema(FrameBlock schema, int k) { + return FrameLibApplySchema.applySchema(this, schema, k); + } + /** * Drop the cell value which does not confirms to the data type of its column * @@ -1347,8 +1365,8 @@ public final FrameBlock detectSchema(double sampleFraction, int k) { public FrameBlock dropInvalidType(FrameBlock schema) { // sanity checks if(this.getNumColumns() != schema.getNumColumns()) - throw new DMLException("mismatch in number of columns in frame and its schema " + this.getNumColumns() + " != " - + schema.getNumColumns()); + throw new DMLException("mismatch in number of columns in frame and its schema " + this.getNumColumns() + + " != " + schema.getNumColumns()); // extract the schema in String array String[] schemaString = IteratorFactory.getStringRowIterator(schema).next(); @@ -1375,8 +1393,8 @@ else if(schemaCol.contains("STRING")) if(!dataType.toString().contains(type) && !(dataType == ValueType.BOOLEAN && type.equals("INT")) && !(dataType == ValueType.BOOLEAN && type.equals("FP"))) { - LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: " + (i + 1) - + ", row:" + (j + 1)); + LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: " + + (i + 1) + ", row:" + (j + 1)); this.set(j, i, null); } @@ -1554,12 +1572,9 @@ public FrameBlock map(FrameMapFunction lambdaExpr, long margin) { else if(margin == 2) { // Execute map function on columns for(int j = 0; j < getNumColumns(); j++) { - String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); // since more rows - // can be - // allocated, - // mutable array + // since more rows can be allocated, mutable array + String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); String[] outColumn = lambdaExpr.apply(actualColumn); - for(int i = 0; i < getNumRows(); i++) output[i][j] = outColumn[i]; } @@ -1615,7 +1630,8 @@ public static FrameMapFunction getCompiledFunction(String lambdaExpr, long margi sb.append(" return String.valueOf(" + expr + "); }}\n"); } else if(varname.length == 2) { - sb.append("public String apply(String " + varname[0].trim() + ", String " + varname[1].trim() + ") {\n"); + sb.append( + "public String apply(String " + varname[0].trim() + ", String " + varname[1].trim() + ") {\n"); sb.append(" return String.valueOf(" + expr + "); }}\n"); } } @@ -1651,11 +1667,12 @@ public FrameBlock replaceOperations(String pattern, String replacement) { boolean NaNp = "NaN".equals(pattern); boolean NaNr = "NaN".equals(replacement); - ValueType patternType = UtilFunctions.isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) | - NaNp ? (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); - ValueType replacementType = UtilFunctions - .isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(replacement) | - NaNr ? (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); + ValueType patternType = UtilFunctions + .isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) | + NaNp ? (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); + ValueType replacementType = UtilFunctions.isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils + .isCreatable(replacement) | + NaNr ? (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); if(patternType != replacementType || !ValueType.isSameTypeString(patternType, replacementType)) throw new DMLRuntimeException( diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java index 5b7c84605dd..206a0722d7b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java @@ -38,4 +38,9 @@ public ABooleanArray(int size) { @Override public abstract ABooleanArray select(boolean[] select, int nTrue); + + @Override + public boolean possiblyContainsNaN(){ + return false; + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index ff6c5d3d5f5..7f6698ef18c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -394,6 +394,8 @@ public boolean containsNull() { return false; } + public abstract boolean possiblyContainsNaN(); + public Array changeTypeWithNulls(ValueType t) { final ABooleanArray nulls = getNulls(); if(nulls == null) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java index d87cf39666e..47ec83c7884 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java @@ -335,6 +335,11 @@ public boolean equals(Array other){ return false; } + @Override + public boolean possiblyContainsNaN(){ + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 2 + 15); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index c362c6e8007..9b49839721a 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -315,6 +315,12 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return dict.possiblyContainsNaN(); + } + + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 575313a450d..7abee26fdc7 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -387,6 +387,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index 5f463bb0ce6..6eb7885a4d5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -339,6 +339,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index 2c6d3e80f4f..6ebd3d9a844 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -344,6 +344,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java index 02fa0386f6f..b46d86da1e6 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java @@ -346,6 +346,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index 772b07af8b0..99444015d43 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -456,6 +456,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java index a2745df32a9..450a9efa456 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java @@ -395,6 +395,11 @@ public boolean containsNull() { return (_a.size() < super._size) || _a.containsNull(); } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index e24815aebaf..9f0a3596440 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -37,6 +37,9 @@ import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode; import org.apache.sysds.utils.MemoryEstimates; +import ch.randelshofer.fastdoubleparser.JavaDoubleParser; +import ch.randelshofer.fastdoubleparser.JavaFloatParser; + public class StringArray extends Array { private String[] _data; @@ -466,7 +469,7 @@ protected Array changeTypeDouble() { for(int i = 0; i < size(); i++) { final String s = _data[i]; if(s != null) - ret[i] = Double.parseDouble(s); + ret[i] = JavaDoubleParser.parseDouble(s); } return new DoubleArray(ret); } @@ -482,7 +485,7 @@ protected Array changeTypeFloat() { for(int i = 0; i < size(); i++) { final String s = _data[i]; if(s != null) - ret[i] = Float.parseFloat(s); + ret[i] = JavaFloatParser.parseFloat(s); } return new FloatArray(ret); } @@ -678,6 +681,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java index cddaa2313dd..92372ecab23 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java @@ -19,9 +19,11 @@ package org.apache.sysds.runtime.frame.data.lib; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; -import java.util.stream.IntStream; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -45,6 +47,10 @@ public class FrameLibApplySchema { private final int k; + public static FrameBlock applySchema(FrameBlock fb, FrameBlock schema) { + return applySchema(fb, schema, 1); + } + /** * Method to create a new FrameBlock where the given schema is applied, k is the parallelization degree. * @@ -136,18 +142,22 @@ private void apply(int i) { private void applyMultiThread() { final ExecutorService pool = CommonThreadPool.get(k); try { - - pool.submit(() -> { - IntStream.rangeClosed(0, nCol - 1).parallel() // parallel columns - .forEach(x -> apply(x)); - }).get(); - - pool.shutdown(); + List> f = new ArrayList<>(nCol ); + for(int i = 0; i < nCol ; i ++){ + final int j = i; + f.add(pool.submit(() -> apply(j))); + } + + for( Future e : f) + e.get(); } catch(InterruptedException | ExecutionException e) { - pool.shutdown(); throw new DMLRuntimeException("Failed to combine column groups", e); } + finally{ + pool.shutdown(); + + } } private void verifySize() { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java index 7f66ecfe6d3..d9b1f739ba1 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java @@ -98,19 +98,44 @@ else if(integerFloatPattern.matcher(val).matches()) { } public static ValueType isFloatType(final String val, final int len) { + if(len <= 25 && (simpleFloatMatch(val, len) || floatPattern.matcher(val).matches())) { + if(len <= 7 || (len == 8 && val.charAt(0) == '-')) + return ValueType.FP32; + else if(len >= 13) + return ValueType.FP64; - if(len <= 25 && floatPattern.matcher(val).matches()) { final double d = Double.parseDouble(val); - if(same(d, (float) d)) + if(d >= 10000 || d < 0.00001) + return ValueType.FP64; // just to be safe. + else if(same(d, (float) d)) return ValueType.FP32; else return ValueType.FP64; } else if(val.equals("infinity") || val.equals("-infinity") || val.equals("nan")) - return ValueType.FP64; + return ValueType.FP32; return null; } + private static boolean simpleFloatMatch(final String val, final int len) { + // a simple float matcher to avoid using the Regex. + boolean encounteredDot = false; + int start = val.charAt(0) == '-' && len > 1 ? 1 : 0; + for(int i = start; i < len; i++) { + final char c = val.charAt(i); + if(c >= '0' && c <= '9') + continue; + else if(c == '.' || c == ',') + if(encounteredDot == true) + return false; + else + encounteredDot = true; + else + return false; + } + return true; + } + private static boolean same(double d, float f) { // parse float and double, // and make back to string if equivalent use float. @@ -208,7 +233,8 @@ public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); if(rowTemp1.length != rowTemp2.length) - throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); + throw new DMLRuntimeException( + "Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); for(int i = 0; i < rowTemp1.length; i++) { // modify schema1 if necessary (different schema2) @@ -234,7 +260,6 @@ else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER")) return mergedFrame; } - public static boolean isDefault(String v, ValueType t) { if(v == null) return true; diff --git a/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java b/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java new file mode 100644 index 00000000000..f9fdf1b9547 --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.performance.simple; + +import java.util.Random; + +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; + +public class DetectTypeArray { + + public static void main(String[] args) { + Array a = ArrayFactory.create(generateRandomFloatString(1000, 134)); + + Timing t = new Timing(); + t.start(); + int N = 10000; + for(int i = 0; i < N; i++) + a.analyzeValueType(); + + System.out.println(t.stop() / N); + + } + + public static String[] generateRandomFloatString(int size, int seed) { + Random r = new Random(seed); + String[] ret = new String[size]; + for(int i = 0; i < size; i++) { + int e = r.nextInt(999); + int a = r.nextInt(999); + + ret[i] = String.format("%d.%03d", e, a); + } + + return ret; + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java index 562b164a6f7..340f385d88f 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java @@ -129,17 +129,17 @@ public void testIsIntLongString() { @Test public void testInfinite() { - assertEquals(ValueType.FP64, FrameUtil.isType("infinity")); + assertEquals(ValueType.FP32, FrameUtil.isType("infinity")); } @Test public void testMinusInfinite() { - assertEquals(ValueType.FP64, FrameUtil.isType("-infinity")); + assertEquals(ValueType.FP32, FrameUtil.isType("-infinity")); } @Test public void testNan() { - assertEquals(ValueType.FP64, FrameUtil.isType("nan")); + assertEquals(ValueType.FP32, FrameUtil.isType("nan")); } @Test diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index c712789ab71..91a17d2d2cb 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -189,7 +189,9 @@ public void testGet() { for(int i = 0; i < size; i++) { Object av = a.get(i); Object sv = s.get(i); - if(!(av == null && sv == null)) + if((av == null && sv != null) || (sv == null && av != null)) + fail("not both null"); + else if(av != null && sv != null) assertTrue(av.toString().equals(sv)); } } @@ -1720,7 +1722,9 @@ protected static void compare(Array a, Array b) { for(int i = 0; i < size; i++) { final Object av = a.get(i); final Object bv = b.get(i); - if(!(av == null && bv == null)) + if((av == null && bv != null) || (bv == null && av != null)) + fail("not both null"); + else if(av != null && bv != null) assertTrue(err, av.toString().equals(bv.toString())); } } @@ -1730,7 +1734,9 @@ protected static void compare(Array sub, Array b, int off) { for(int i = 0; i < size; i++) { final Object av = sub.get(i); final Object bv = b.get(i + off); - if(!(av == null && bv == null)) + if((av == null && bv != null) || (bv == null && av != null)) + fail("not both null"); + else if(av != null && bv != null) assertTrue(av.toString().equals(bv.toString())); } } @@ -1739,7 +1745,9 @@ protected static void compareSetSubRange(Array out, Array in, int rl, int for(int i = rl; i <= ru; i++, off++) { Object av = out.get(i); Object bv = in.get(off); - if(!(av == null && bv == null)) { + if((av == null && bv != null) || (bv == null && av != null)) + fail("not both null"); + else if(av != null && bv != null){ String v1 = av.toString(); String v2 = bv.toString(); assertEquals("i: " + i + " args: " + rl + " " + ru + " " + (off - i) + " " + out.size(), v1, v2);