From 4851f34bd58777a7b4605ed2997b3302eddd169a Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 6 Sep 2023 10:41:28 +0200 Subject: [PATCH] [MINOR] Frame Float detection refinement This commit refine the detection and selection of float values in the schema detection algorithm located in: src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java To improve the performance a custom matcher have been made to avoid using regexes if possible, and we try to avoid parsing the float value from the string representation if at all possible. The implementation is not completely fool proof and does not consider many adversarial inputs. on a small test case of 1000 string values that all are fp32 the implementation improve performance from 0.5 ms to 0.02 ms. --- .../sysds/runtime/frame/data/FrameBlock.java | 91 +++++++++++-------- .../frame/data/columns/ABooleanArray.java | 5 + .../runtime/frame/data/columns/Array.java | 2 + .../runtime/frame/data/columns/CharArray.java | 5 + .../runtime/frame/data/columns/DDCArray.java | 6 ++ .../frame/data/columns/DoubleArray.java | 5 + .../frame/data/columns/FloatArray.java | 5 + .../frame/data/columns/IntegerArray.java | 5 + .../runtime/frame/data/columns/LongArray.java | 5 + .../frame/data/columns/OptionalArray.java | 5 + .../frame/data/columns/RaggedArray.java | 5 + .../frame/data/columns/StringArray.java | 12 ++- .../frame/data/lib/FrameLibApplySchema.java | 28 ++++-- .../runtime/frame/data/lib/FrameUtil.java | 35 ++++++- .../performance/simple/DetectTypeArray.java | 56 ++++++++++++ .../test/component/frame/FrameUtilTest.java | 6 +- .../frame/array/FrameArrayTests.java | 16 +++- 17 files changed, 232 insertions(+), 60 deletions(-) create mode 100644 src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 513b788b605..737922397e5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -31,9 +31,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Function; @@ -56,6 +58,7 @@ import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.frame.data.lib.FrameFromMatrixBlock; import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend; +import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema; import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; import org.apache.sysds.runtime.frame.data.lib.FrameLibRemoveEmpty; import org.apache.sysds.runtime.frame.data.lib.FrameUtil; @@ -214,8 +217,8 @@ public FrameBlock(Array[] data) { if(debug) { for(int i = 0; i < data.length; i++) { if(data[i].size() != getNumRows()) - throw new DMLRuntimeException( - "Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows()); + throw new DMLRuntimeException("Invalid Frame allocation with different size arrays " + + data[i].size() + " vs " + getNumRows()); } } } @@ -239,8 +242,8 @@ public FrameBlock(Array[] data, String[] colnames) { if(debug) { for(int i = 0; i < data.length; i++) { if(data[i].size() != getNumRows()) - throw new DMLRuntimeException( - "Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows()); + throw new DMLRuntimeException("Invalid Frame allocation with different size arrays " + + data[i].size() + " vs " + getNumRows()); } } } @@ -400,8 +403,8 @@ public Map getColumnNameIDMap() { } /** - * Allocate column data structures if necessary, i.e., if schema specified but not all column data structures created - * yet. + * Allocate column data structures if necessary, i.e., if schema specified but not all column data structures + * created yet. * * @param numRows number of rows */ @@ -640,8 +643,8 @@ public void appendColumn(int[] col) { } /** - * Append a column of value type LONG as the last column of the data frame. The given array is wrapped but not copied - * and hence might be updated in the future. + * Append a column of value type LONG as the last column of the data frame. The given array is wrapped but not + * copied and hence might be updated in the future. * * @param col array of longs */ @@ -701,7 +704,9 @@ public void appendColumns(double[][] cols) { Array[] tmpData = new Array[ncol]; for(int j = 0; j < ncol; j++) tmpData[j] = ArrayFactory.create(cols[j]); - _colnames = empty ? null : ArrayUtils.addAll(getColumnNames(), createColNames(getNumColumns(), ncol)); // before schema modification + _colnames = empty ? null : ArrayUtils.addAll(getColumnNames(), createColNames(getNumColumns(), ncol)); // before + // schema + // modification _schema = empty ? tmpSchema : ArrayUtils.addAll(_schema, tmpSchema); _coldata = empty ? tmpData : ArrayUtils.addAll(_coldata, tmpData); _nRow = cols[0].length; @@ -859,17 +864,22 @@ private double arraysSizeInMemory() { if(rlen > 1000 && clen > 10 && ConfigurationManager.isParallelIOEnabled()) { final ExecutorService pool = CommonThreadPool.get(); try { - size += pool.submit(() -> { - return Arrays.stream(_coldata).parallel() // parallel columns - .map(x ->x.getInMemorySize()).reduce(0L, (a,x) -> a + x); - }).get(); + List> f = new ArrayList<>(clen); + for(int i = 0; i < clen; i++) { + final int j = i; + f.add(pool.submit(() -> _coldata[j].getInMemorySize())); + } + + for(Future e : f) { + size += e.get(); + } } catch(InterruptedException | ExecutionException e) { LOG.error(e); for(Array aa : _coldata) size += aa.getInMemorySize(); } - finally{ + finally { pool.shutdown(); } } @@ -1012,11 +1022,11 @@ public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, IndexRange ixrange public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, int rl, int ru, int cl, int cu, FrameBlock ret) { // check the validity of bounds - if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl || - cu >= getNumColumns()) { + if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || + cu < cl || cu >= getNumColumns()) { throw new DMLRuntimeException( - "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] " - + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]."); + "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + + "] " + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]."); } if((ru - rl + 1) < rhsFrame.getNumRows() || (cu - cl + 1) < rhsFrame.getNumColumns()) { @@ -1132,11 +1142,11 @@ public FrameBlock slice(int rl, int ru, int cl, int cu, boolean deep, FrameBlock } protected void validateSliceArgument(int rl, int ru, int cl, int cu) { - if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl || - cu >= getNumColumns()) { + if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || + cu < cl || cu >= getNumColumns()) { throw new DMLRuntimeException( - "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] " - + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]"); + "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + + "] " + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]"); } } @@ -1338,6 +1348,14 @@ public final FrameBlock detectSchema(double sampleFraction, int k) { return FrameLibDetectSchema.detectSchema(this, sampleFraction, k); } + public final FrameBlock applySchema(FrameBlock schema) { + return FrameLibApplySchema.applySchema(this, schema); + } + + public final FrameBlock applySchema(FrameBlock schema, int k) { + return FrameLibApplySchema.applySchema(this, schema, k); + } + /** * Drop the cell value which does not confirms to the data type of its column * @@ -1347,8 +1365,8 @@ public final FrameBlock detectSchema(double sampleFraction, int k) { public FrameBlock dropInvalidType(FrameBlock schema) { // sanity checks if(this.getNumColumns() != schema.getNumColumns()) - throw new DMLException("mismatch in number of columns in frame and its schema " + this.getNumColumns() + " != " - + schema.getNumColumns()); + throw new DMLException("mismatch in number of columns in frame and its schema " + this.getNumColumns() + + " != " + schema.getNumColumns()); // extract the schema in String array String[] schemaString = IteratorFactory.getStringRowIterator(schema).next(); @@ -1375,8 +1393,8 @@ else if(schemaCol.contains("STRING")) if(!dataType.toString().contains(type) && !(dataType == ValueType.BOOLEAN && type.equals("INT")) && !(dataType == ValueType.BOOLEAN && type.equals("FP"))) { - LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: " + (i + 1) - + ", row:" + (j + 1)); + LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: " + + (i + 1) + ", row:" + (j + 1)); this.set(j, i, null); } @@ -1554,12 +1572,9 @@ public FrameBlock map(FrameMapFunction lambdaExpr, long margin) { else if(margin == 2) { // Execute map function on columns for(int j = 0; j < getNumColumns(); j++) { - String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); // since more rows - // can be - // allocated, - // mutable array + // since more rows can be allocated, mutable array + String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); String[] outColumn = lambdaExpr.apply(actualColumn); - for(int i = 0; i < getNumRows(); i++) output[i][j] = outColumn[i]; } @@ -1615,7 +1630,8 @@ public static FrameMapFunction getCompiledFunction(String lambdaExpr, long margi sb.append(" return String.valueOf(" + expr + "); }}\n"); } else if(varname.length == 2) { - sb.append("public String apply(String " + varname[0].trim() + ", String " + varname[1].trim() + ") {\n"); + sb.append( + "public String apply(String " + varname[0].trim() + ", String " + varname[1].trim() + ") {\n"); sb.append(" return String.valueOf(" + expr + "); }}\n"); } } @@ -1651,11 +1667,12 @@ public FrameBlock replaceOperations(String pattern, String replacement) { boolean NaNp = "NaN".equals(pattern); boolean NaNr = "NaN".equals(replacement); - ValueType patternType = UtilFunctions.isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) | - NaNp ? (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); - ValueType replacementType = UtilFunctions - .isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(replacement) | - NaNr ? (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); + ValueType patternType = UtilFunctions + .isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) | + NaNp ? (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); + ValueType replacementType = UtilFunctions.isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils + .isCreatable(replacement) | + NaNr ? (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); if(patternType != replacementType || !ValueType.isSameTypeString(patternType, replacementType)) throw new DMLRuntimeException( diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java index 5b7c84605dd..206a0722d7b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java @@ -38,4 +38,9 @@ public ABooleanArray(int size) { @Override public abstract ABooleanArray select(boolean[] select, int nTrue); + + @Override + public boolean possiblyContainsNaN(){ + return false; + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index ff6c5d3d5f5..7f6698ef18c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -394,6 +394,8 @@ public boolean containsNull() { return false; } + public abstract boolean possiblyContainsNaN(); + public Array changeTypeWithNulls(ValueType t) { final ABooleanArray nulls = getNulls(); if(nulls == null) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java index d87cf39666e..47ec83c7884 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java @@ -335,6 +335,11 @@ public boolean equals(Array other){ return false; } + @Override + public boolean possiblyContainsNaN(){ + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 2 + 15); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index c362c6e8007..9b49839721a 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -315,6 +315,12 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return dict.possiblyContainsNaN(); + } + + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 575313a450d..7abee26fdc7 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -387,6 +387,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index 5f463bb0ce6..6eb7885a4d5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -339,6 +339,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index 2c6d3e80f4f..6ebd3d9a844 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -344,6 +344,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java index 02fa0386f6f..b46d86da1e6 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java @@ -346,6 +346,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index 772b07af8b0..99444015d43 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -456,6 +456,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java index a2745df32a9..450a9efa456 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java @@ -395,6 +395,11 @@ public boolean containsNull() { return (_a.size() < super._size) || _a.containsNull(); } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index e24815aebaf..9f0a3596440 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -37,6 +37,9 @@ import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode; import org.apache.sysds.utils.MemoryEstimates; +import ch.randelshofer.fastdoubleparser.JavaDoubleParser; +import ch.randelshofer.fastdoubleparser.JavaFloatParser; + public class StringArray extends Array { private String[] _data; @@ -466,7 +469,7 @@ protected Array changeTypeDouble() { for(int i = 0; i < size(); i++) { final String s = _data[i]; if(s != null) - ret[i] = Double.parseDouble(s); + ret[i] = JavaDoubleParser.parseDouble(s); } return new DoubleArray(ret); } @@ -482,7 +485,7 @@ protected Array changeTypeFloat() { for(int i = 0; i < size(); i++) { final String s = _data[i]; if(s != null) - ret[i] = Float.parseFloat(s); + ret[i] = JavaFloatParser.parseFloat(s); } return new FloatArray(ret); } @@ -678,6 +681,11 @@ public boolean equals(Array other) { return false; } + @Override + public boolean possiblyContainsNaN(){ + return true; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java index cddaa2313dd..92372ecab23 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java @@ -19,9 +19,11 @@ package org.apache.sysds.runtime.frame.data.lib; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; -import java.util.stream.IntStream; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -45,6 +47,10 @@ public class FrameLibApplySchema { private final int k; + public static FrameBlock applySchema(FrameBlock fb, FrameBlock schema) { + return applySchema(fb, schema, 1); + } + /** * Method to create a new FrameBlock where the given schema is applied, k is the parallelization degree. * @@ -136,18 +142,22 @@ private void apply(int i) { private void applyMultiThread() { final ExecutorService pool = CommonThreadPool.get(k); try { - - pool.submit(() -> { - IntStream.rangeClosed(0, nCol - 1).parallel() // parallel columns - .forEach(x -> apply(x)); - }).get(); - - pool.shutdown(); + List> f = new ArrayList<>(nCol ); + for(int i = 0; i < nCol ; i ++){ + final int j = i; + f.add(pool.submit(() -> apply(j))); + } + + for( Future e : f) + e.get(); } catch(InterruptedException | ExecutionException e) { - pool.shutdown(); throw new DMLRuntimeException("Failed to combine column groups", e); } + finally{ + pool.shutdown(); + + } } private void verifySize() { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java index 7f66ecfe6d3..d9b1f739ba1 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java @@ -98,19 +98,44 @@ else if(integerFloatPattern.matcher(val).matches()) { } public static ValueType isFloatType(final String val, final int len) { + if(len <= 25 && (simpleFloatMatch(val, len) || floatPattern.matcher(val).matches())) { + if(len <= 7 || (len == 8 && val.charAt(0) == '-')) + return ValueType.FP32; + else if(len >= 13) + return ValueType.FP64; - if(len <= 25 && floatPattern.matcher(val).matches()) { final double d = Double.parseDouble(val); - if(same(d, (float) d)) + if(d >= 10000 || d < 0.00001) + return ValueType.FP64; // just to be safe. + else if(same(d, (float) d)) return ValueType.FP32; else return ValueType.FP64; } else if(val.equals("infinity") || val.equals("-infinity") || val.equals("nan")) - return ValueType.FP64; + return ValueType.FP32; return null; } + private static boolean simpleFloatMatch(final String val, final int len) { + // a simple float matcher to avoid using the Regex. + boolean encounteredDot = false; + int start = val.charAt(0) == '-' && len > 1 ? 1 : 0; + for(int i = start; i < len; i++) { + final char c = val.charAt(i); + if(c >= '0' && c <= '9') + continue; + else if(c == '.' || c == ',') + if(encounteredDot == true) + return false; + else + encounteredDot = true; + else + return false; + } + return true; + } + private static boolean same(double d, float f) { // parse float and double, // and make back to string if equivalent use float. @@ -208,7 +233,8 @@ public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); if(rowTemp1.length != rowTemp2.length) - throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); + throw new DMLRuntimeException( + "Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); for(int i = 0; i < rowTemp1.length; i++) { // modify schema1 if necessary (different schema2) @@ -234,7 +260,6 @@ else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER")) return mergedFrame; } - public static boolean isDefault(String v, ValueType t) { if(v == null) return true; diff --git a/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java b/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java new file mode 100644 index 00000000000..f9fdf1b9547 --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.performance.simple; + +import java.util.Random; + +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; + +public class DetectTypeArray { + + public static void main(String[] args) { + Array a = ArrayFactory.create(generateRandomFloatString(1000, 134)); + + Timing t = new Timing(); + t.start(); + int N = 10000; + for(int i = 0; i < N; i++) + a.analyzeValueType(); + + System.out.println(t.stop() / N); + + } + + public static String[] generateRandomFloatString(int size, int seed) { + Random r = new Random(seed); + String[] ret = new String[size]; + for(int i = 0; i < size; i++) { + int e = r.nextInt(999); + int a = r.nextInt(999); + + ret[i] = String.format("%d.%03d", e, a); + } + + return ret; + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java index 562b164a6f7..340f385d88f 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java @@ -129,17 +129,17 @@ public void testIsIntLongString() { @Test public void testInfinite() { - assertEquals(ValueType.FP64, FrameUtil.isType("infinity")); + assertEquals(ValueType.FP32, FrameUtil.isType("infinity")); } @Test public void testMinusInfinite() { - assertEquals(ValueType.FP64, FrameUtil.isType("-infinity")); + assertEquals(ValueType.FP32, FrameUtil.isType("-infinity")); } @Test public void testNan() { - assertEquals(ValueType.FP64, FrameUtil.isType("nan")); + assertEquals(ValueType.FP32, FrameUtil.isType("nan")); } @Test diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index c712789ab71..91a17d2d2cb 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -189,7 +189,9 @@ public void testGet() { for(int i = 0; i < size; i++) { Object av = a.get(i); Object sv = s.get(i); - if(!(av == null && sv == null)) + if((av == null && sv != null) || (sv == null && av != null)) + fail("not both null"); + else if(av != null && sv != null) assertTrue(av.toString().equals(sv)); } } @@ -1720,7 +1722,9 @@ protected static void compare(Array a, Array b) { for(int i = 0; i < size; i++) { final Object av = a.get(i); final Object bv = b.get(i); - if(!(av == null && bv == null)) + if((av == null && bv != null) || (bv == null && av != null)) + fail("not both null"); + else if(av != null && bv != null) assertTrue(err, av.toString().equals(bv.toString())); } } @@ -1730,7 +1734,9 @@ protected static void compare(Array sub, Array b, int off) { for(int i = 0; i < size; i++) { final Object av = sub.get(i); final Object bv = b.get(i + off); - if(!(av == null && bv == null)) + if((av == null && bv != null) || (bv == null && av != null)) + fail("not both null"); + else if(av != null && bv != null) assertTrue(av.toString().equals(bv.toString())); } } @@ -1739,7 +1745,9 @@ protected static void compareSetSubRange(Array out, Array in, int rl, int for(int i = rl; i <= ru; i++, off++) { Object av = out.get(i); Object bv = in.get(off); - if(!(av == null && bv == null)) { + if((av == null && bv != null) || (bv == null && av != null)) + fail("not both null"); + else if(av != null && bv != null){ String v1 = av.toString(); String v2 = bv.toString(); assertEquals("i: " + i + " args: " + rl + " " + ru + " " + (off - i) + " " + out.size(), v1, v2);