diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java index 3373205fc35..b4296e227d3 100644 --- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java @@ -884,6 +884,23 @@ private static String dfFormat(DecimalFormat df, double value) { } } + /** + * Creates a non-grouping {@link DecimalFormat} for printing values. When {@code decimal >= 0} + * both the minimum and maximum fraction digits are pinned to {@code decimal}, so values are + * printed with exactly that many decimals; otherwise the {@link DecimalFormat} defaults apply. + * @param decimal number of decimal places to print, -1 for default + * @return a configured {@link DecimalFormat} + */ + private static DecimalFormat createDecimalFormat(int decimal) { + DecimalFormat df = new DecimalFormat(); + df.setGroupingUsed(false); + if (decimal >= 0) { + df.setMinimumFractionDigits(decimal); + df.setMaximumFractionDigits(decimal); + } + return df; + } + public static String toString(MatrixBlock mb) { return toString(mb, false, " ", "\n", mb.getNumRows(), mb.getNumColumns(), 3); } @@ -913,11 +930,7 @@ public static String toString(MatrixBlock mb, boolean sparse, String separator, if (colsToPrint >= 0) colLength = colsToPrint < clen ? colsToPrint : clen; - DecimalFormat df = new DecimalFormat(); - df.setGroupingUsed(false); - if (decimal >= 0){ - df.setMinimumFractionDigits(decimal); - } + DecimalFormat df = createDecimalFormat(decimal); if (sparse){ // Sparse Print Format if (mb.isInSparseFormat()){ // Block is in sparse format @@ -997,11 +1010,7 @@ public static String toString(TensorBlock tb, boolean sparse, String separator, if (colsToPrint >= 0) colLength = Math.min(colsToPrint, clen); - DecimalFormat df = new DecimalFormat(); - df.setGroupingUsed(false); - if (decimal >= 0){ - df.setMinimumFractionDigits(decimal); - } + DecimalFormat df = createDecimalFormat(decimal); if (sparse){ // Sparse Print Format // TODO use sparse iterator for sparse block @@ -1147,10 +1156,7 @@ public static String toString(FrameBlock fb, boolean sparse, String separator, S sb.append(lineseparator); //print data - DecimalFormat df = new DecimalFormat(); - df.setGroupingUsed(false); - if (decimal >= 0) - df.setMinimumFractionDigits(decimal); + DecimalFormat df = createDecimalFormat(decimal); Iterator iter = IteratorFactory.getObjectRowIterator(fb, 0, rowLength); while( iter.hasNext() ) { diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameToStringTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameToStringTest.java index 2b29214b591..60587bf51a2 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameToStringTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameToStringTest.java @@ -19,6 +19,7 @@ package org.apache.sysds.test.component.frame; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import org.apache.sysds.common.Types.ValueType; @@ -38,6 +39,42 @@ public void test100x100() { FrameBlock f = createFrameBlock(); assertTrue(DataConverter.toString(f, false, " ", "\n", 100, 100, 3).length() < 75); } + + @Test + public void testDecimalClampsFractionDigits() { + FrameBlock f = new FrameBlock(new ValueType[]{ValueType.FP64}, new String[]{"C1"}); + f.ensureAllocatedColumns(1); + f.set(0, 0, 5.244058388023880); + // decimal=2 must print exactly two fraction digits, not DecimalFormat's default max of 3 + String out = DataConverter.toString(f, false, " ", "\n", 1, 1, 2); + assertTrue("expected value clamped to 5.24, got: " + out, out.contains("5.24\n")); + assertFalse("decimal=2 must not print three digits: " + out, out.contains("5.244")); + } + + @Test + public void testDecimalPadsAndRounds() { + FrameBlock f = new FrameBlock(new ValueType[]{ValueType.FP64}, new String[]{"C1"}); + f.ensureAllocatedColumns(2); + f.set(0, 0, 22.0); // integer-valued: padded up to the requested digits + f.set(1, 0, 5.244058388023880); // rounded at the last requested digit + String out = DataConverter.toString(f, false, " ", "\n", 2, 1, 4); + assertTrue("expected 22.0000 padded: " + out, out.contains("22.0000\n")); + assertTrue("expected 5.2441 rounded: " + out, out.contains("5.2441\n")); + } + + @Test + public void testNegativeDecimalUsesDefaultFormatting() { + FrameBlock f = new FrameBlock(new ValueType[]{ValueType.FP64}, new String[]{"C1"}); + f.ensureAllocatedColumns(2); + f.set(0, 0, 22.0); // integer-valued: no fraction digits when unconstrained + f.set(1, 0, 5.244058388023880); // default cap of three fraction digits + // decimal < 0 leaves DecimalFormat unconstrained (no min/max fraction digits set) + String out = DataConverter.toString(f, false, " ", "\n", 2, 1, -1); + assertTrue("expected unpadded 22: " + out, out.contains("22\n")); + assertFalse("integer value must not be padded: " + out, out.contains("22.0")); + assertTrue("expected default 5.244: " + out, out.contains("5.244\n")); + assertFalse("must not print a fourth digit: " + out, out.contains("5.2441")); + } private FrameBlock createFrameBlock() { FrameBlock f = new FrameBlock(new ValueType[]{ValueType.STRING, ValueType.STRING}); diff --git a/src/test/java/org/apache/sysds/test/component/tensor/TensorToStringTest.java b/src/test/java/org/apache/sysds/test/component/tensor/TensorToStringTest.java new file mode 100644 index 00000000000..5c9ed821e78 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/tensor/TensorToStringTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.tensor; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.TensorBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.junit.Test; + +public class TensorToStringTest { + @Test + public void testDecimalClampsFractionDigits() { + TensorBlock tb = new TensorBlock(ValueType.FP64, new int[]{1, 1}); + tb.allocateBlock(); + tb.set(0, 0, 5.244058388023880); + // decimal=2 must print exactly two fraction digits, not DecimalFormat's default max of 3 + String out = DataConverter.toString(tb, false, " ", "\n", "[", "]", 1, 1, 2); + assertTrue("expected value clamped to 5.24, got: " + out, out.contains("5.24")); + assertFalse("decimal=2 must not print three digits: " + out, out.contains("5.244")); + } + + @Test + public void testDecimalPadsAndRounds() { + TensorBlock tb = new TensorBlock(ValueType.FP64, new int[]{1, 2}); + tb.allocateBlock(); + tb.set(0, 0, 22.0); // integer-valued: padded up to the requested digits + tb.set(0, 1, 5.244058388023880); // rounded at the last requested digit + String out = DataConverter.toString(tb, false, " ", "\n", "[", "]", 1, 2, 4); + assertTrue("expected 22.0000 padded: " + out, out.contains("22.0000")); + assertTrue("expected 5.2441 rounded: " + out, out.contains("5.2441")); + } + + @Test + public void testNegativeDecimalUsesDefaultFormatting() { + TensorBlock tb = new TensorBlock(ValueType.FP64, new int[]{1, 2}); + tb.allocateBlock(); + tb.set(0, 0, 22.0); // integer-valued: no fraction digits when unconstrained + tb.set(0, 1, 5.244058388023880); // default cap of three fraction digits + // decimal < 0 leaves DecimalFormat unconstrained (no min/max fraction digits set) + String out = DataConverter.toString(tb, false, " ", "\n", "[", "]", 1, 2, -1); + assertTrue("expected unpadded 22: " + out, out.contains("22")); + assertFalse("integer value must not be padded: " + out, out.contains("22.0")); + assertTrue("expected default 5.244: " + out, out.contains("5.244")); + assertFalse("must not print a fourth digit: " + out, out.contains("5.2441")); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java index ee6a2953980..18ca2fbc454 100644 --- a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java +++ b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java @@ -270,4 +270,96 @@ protected void toStringTestHelper(ExecMode platform, String testName, String exp DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } + + @Test + public void testPrintWithDecimal(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "22"; + String expectedOutput = "22.00\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal2(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "5.244058388023880"; + String expectedOutput = "5.24\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal3(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "5.244058388023880"; + String expectedOutput = "5.2440583880\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal4(){ + String testName = "ToString12"; + + String decimalPoints = "4"; + String value = "5.244058388023880"; + String expectedOutput = "5.2441\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal5(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "0.000000008023880"; + String expectedOutput = "0.0000000080\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + protected void toStringTestHelper2(ExecMode platform, String testName, String expectedOutput, String decimalPoints, String value) { + ExecMode platformOld = rtplatform; + + rtplatform = platform; + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if (rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + try { + // Create and load test configuration + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[]{"-args", output(OUTPUT_NAME), value, decimalPoints}; + + // Run DML and R scripts + runTest(true, false, null, -1); + + // Compare output strings + String output = TestUtils.readDMLString(output(OUTPUT_NAME)); + TestUtils.compareScalars(expectedOutput, output); + } + finally { + // Reset settings + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } } diff --git a/src/test/scripts/functions/misc/ToString12.dml b/src/test/scripts/functions/misc/ToString12.dml new file mode 100644 index 00000000000..4f120630b75 --- /dev/null +++ b/src/test/scripts/functions/misc/ToString12.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix($2, rows=1, cols=1) +str = toString(X, rows=3, cols=3, decimal=$3) +write(str, $1)