Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions src/main/java/org/apache/sysds/runtime/util/DataConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,23 @@ private static String dfFormat(DecimalFormat df, double value) {
}
}

/**
* Creates a non-grouping {@link DecimalFormat} for printing values. When {@code decimal >= 0}
* both the minimum and maximum fraction digits are pinned to {@code decimal}, so values are
* printed with exactly that many decimals; otherwise the {@link DecimalFormat} defaults apply.
* @param decimal number of decimal places to print, -1 for default
* @return a configured {@link DecimalFormat}
*/
private static DecimalFormat createDecimalFormat(int decimal) {
DecimalFormat df = new DecimalFormat();
df.setGroupingUsed(false);
if (decimal >= 0) {
df.setMinimumFractionDigits(decimal);
df.setMaximumFractionDigits(decimal);
}
return df;
}

public static String toString(MatrixBlock mb) {
return toString(mb, false, " ", "\n", mb.getNumRows(), mb.getNumColumns(), 3);
}
Expand Down Expand Up @@ -913,11 +930,7 @@ public static String toString(MatrixBlock mb, boolean sparse, String separator,
if (colsToPrint >= 0)
colLength = colsToPrint < clen ? colsToPrint : clen;

DecimalFormat df = new DecimalFormat();
df.setGroupingUsed(false);
if (decimal >= 0){
df.setMinimumFractionDigits(decimal);
}
DecimalFormat df = createDecimalFormat(decimal);

if (sparse){ // Sparse Print Format
if (mb.isInSparseFormat()){ // Block is in sparse format
Expand Down Expand Up @@ -997,11 +1010,7 @@ public static String toString(TensorBlock tb, boolean sparse, String separator,
if (colsToPrint >= 0)
colLength = Math.min(colsToPrint, clen);

DecimalFormat df = new DecimalFormat();
df.setGroupingUsed(false);
if (decimal >= 0){
df.setMinimumFractionDigits(decimal);
}
DecimalFormat df = createDecimalFormat(decimal);

if (sparse){ // Sparse Print Format
// TODO use sparse iterator for sparse block
Expand Down Expand Up @@ -1147,10 +1156,7 @@ public static String toString(FrameBlock fb, boolean sparse, String separator, S
sb.append(lineseparator);

//print data
DecimalFormat df = new DecimalFormat();
df.setGroupingUsed(false);
if (decimal >= 0)
df.setMinimumFractionDigits(decimal);
DecimalFormat df = createDecimalFormat(decimal);

Iterator<Object[]> iter = IteratorFactory.getObjectRowIterator(fb, 0, rowLength);
while( iter.hasNext() ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.sysds.test.component.frame;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import org.apache.sysds.common.Types.ValueType;
Expand All @@ -38,6 +39,42 @@ public void test100x100() {
FrameBlock f = createFrameBlock();
assertTrue(DataConverter.toString(f, false, " ", "\n", 100, 100, 3).length() < 75);
}

@Test
public void testDecimalClampsFractionDigits() {
FrameBlock f = new FrameBlock(new ValueType[]{ValueType.FP64}, new String[]{"C1"});
f.ensureAllocatedColumns(1);
f.set(0, 0, 5.244058388023880);
// decimal=2 must print exactly two fraction digits, not DecimalFormat's default max of 3
String out = DataConverter.toString(f, false, " ", "\n", 1, 1, 2);
assertTrue("expected value clamped to 5.24, got: " + out, out.contains("5.24\n"));
assertFalse("decimal=2 must not print three digits: " + out, out.contains("5.244"));
}

@Test
public void testDecimalPadsAndRounds() {
FrameBlock f = new FrameBlock(new ValueType[]{ValueType.FP64}, new String[]{"C1"});
f.ensureAllocatedColumns(2);
f.set(0, 0, 22.0); // integer-valued: padded up to the requested digits
f.set(1, 0, 5.244058388023880); // rounded at the last requested digit
String out = DataConverter.toString(f, false, " ", "\n", 2, 1, 4);
assertTrue("expected 22.0000 padded: " + out, out.contains("22.0000\n"));
assertTrue("expected 5.2441 rounded: " + out, out.contains("5.2441\n"));
}

@Test
public void testNegativeDecimalUsesDefaultFormatting() {
FrameBlock f = new FrameBlock(new ValueType[]{ValueType.FP64}, new String[]{"C1"});
f.ensureAllocatedColumns(2);
f.set(0, 0, 22.0); // integer-valued: no fraction digits when unconstrained
f.set(1, 0, 5.244058388023880); // default cap of three fraction digits
// decimal < 0 leaves DecimalFormat unconstrained (no min/max fraction digits set)
String out = DataConverter.toString(f, false, " ", "\n", 2, 1, -1);
assertTrue("expected unpadded 22: " + out, out.contains("22\n"));
assertFalse("integer value must not be padded: " + out, out.contains("22.0"));
assertTrue("expected default 5.244: " + out, out.contains("5.244\n"));
assertFalse("must not print a fourth digit: " + out, out.contains("5.2441"));
}

private FrameBlock createFrameBlock() {
FrameBlock f = new FrameBlock(new ValueType[]{ValueType.STRING, ValueType.STRING});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.component.tensor;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.junit.Test;

public class TensorToStringTest {
@Test
public void testDecimalClampsFractionDigits() {
TensorBlock tb = new TensorBlock(ValueType.FP64, new int[]{1, 1});
tb.allocateBlock();
tb.set(0, 0, 5.244058388023880);
// decimal=2 must print exactly two fraction digits, not DecimalFormat's default max of 3
String out = DataConverter.toString(tb, false, " ", "\n", "[", "]", 1, 1, 2);
assertTrue("expected value clamped to 5.24, got: " + out, out.contains("5.24"));
assertFalse("decimal=2 must not print three digits: " + out, out.contains("5.244"));
}

@Test
public void testDecimalPadsAndRounds() {
TensorBlock tb = new TensorBlock(ValueType.FP64, new int[]{1, 2});
tb.allocateBlock();
tb.set(0, 0, 22.0); // integer-valued: padded up to the requested digits
tb.set(0, 1, 5.244058388023880); // rounded at the last requested digit
String out = DataConverter.toString(tb, false, " ", "\n", "[", "]", 1, 2, 4);
assertTrue("expected 22.0000 padded: " + out, out.contains("22.0000"));
assertTrue("expected 5.2441 rounded: " + out, out.contains("5.2441"));
}

@Test
public void testNegativeDecimalUsesDefaultFormatting() {
TensorBlock tb = new TensorBlock(ValueType.FP64, new int[]{1, 2});
tb.allocateBlock();
tb.set(0, 0, 22.0); // integer-valued: no fraction digits when unconstrained
tb.set(0, 1, 5.244058388023880); // default cap of three fraction digits
// decimal < 0 leaves DecimalFormat unconstrained (no min/max fraction digits set)
String out = DataConverter.toString(tb, false, " ", "\n", "[", "]", 1, 2, -1);
assertTrue("expected unpadded 22: " + out, out.contains("22"));
assertFalse("integer value must not be padded: " + out, out.contains("22.0"));
assertTrue("expected default 5.244: " + out, out.contains("5.244"));
assertFalse("must not print a fourth digit: " + out, out.contains("5.2441"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,96 @@ protected void toStringTestHelper(ExecMode platform, String testName, String exp
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}

@Test
public void testPrintWithDecimal(){
String testName = "ToString12";

String decimalPoints = "2";
String value = "22";
String expectedOutput = "22.00\n";

addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName));
toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value);
}


@Test
public void testPrintWithDecimal2(){
String testName = "ToString12";

String decimalPoints = "2";
String value = "5.244058388023880";
String expectedOutput = "5.24\n";

addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName));
toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value);
}


@Test
public void testPrintWithDecimal3(){
String testName = "ToString12";

String decimalPoints = "10";
String value = "5.244058388023880";
String expectedOutput = "5.2440583880\n";

addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName));
toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value);
}


@Test
public void testPrintWithDecimal4(){
String testName = "ToString12";

String decimalPoints = "4";
String value = "5.244058388023880";
String expectedOutput = "5.2441\n";

addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName));
toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value);
}


@Test
public void testPrintWithDecimal5(){
String testName = "ToString12";

String decimalPoints = "10";
String value = "0.000000008023880";
String expectedOutput = "0.0000000080\n";

addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName));
toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value);
}

protected void toStringTestHelper2(ExecMode platform, String testName, String expectedOutput, String decimalPoints, String value) {
ExecMode platformOld = rtplatform;

rtplatform = platform;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if (rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
try {
// Create and load test configuration
getAndLoadTestConfiguration(testName);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testName + ".dml";
programArgs = new String[]{"-args", output(OUTPUT_NAME), value, decimalPoints};

// Run DML and R scripts
runTest(true, false, null, -1);

// Compare output strings
String output = TestUtils.readDMLString(output(OUTPUT_NAME));
TestUtils.compareScalars(expectedOutput, output);
}
finally {
// Reset settings
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
}
24 changes: 24 additions & 0 deletions src/test/scripts/functions/misc/ToString12.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

X = matrix($2, rows=1, cols=1)
str = toString(X, rows=3, cols=3, decimal=$3)
write(str, $1)
Loading