Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;

public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
Expand All @@ -51,6 +52,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
return ReblockOOCInstruction.parseInstruction(str);
case AggregateUnary:
return AggregateUnaryOOCInstruction.parseInstruction(str);
case Unary:
return UnaryOOCInstruction.parseInstruction(str);
case Binary:
return BinaryOOCInstruction.parseInstruction(str);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction {
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());

public enum OOCType {
Reblock, AggregateUnary, Binary
Reblock, AggregateUnary, Binary, Unary
}

protected final OOCInstruction.OOCType _ooctype;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;
Comment thread
j143 marked this conversation as resolved.

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

public class UnaryOOCInstruction extends ComputationOOCInstruction {
private UnaryOperator _uop = null;

protected UnaryOOCInstruction(OOCType type, UnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) {
super(type, op, in1, out, opcode, istr);

_uop = op;
}

public static UnaryOOCInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 2);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);

UnaryOperator uopcode = InstructionUtils.parseUnaryOperator(opcode);
return new UnaryOOCInstruction(OOCType.Unary, uopcode, in1, out, opcode, str);
}

public void processInstruction( ExecutionContext ec ) {
UnaryOperator uop = (UnaryOperator) _uop;
// Create thread and process the unary operation
MatrixObject min = ec.getMatrixObject(input1);
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);


ExecutorService pool = CommonThreadPool.get();
try {
Future<?> task =pool.submit(() -> {
IndexedMatrixValue tmp = null;
try {
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp.getIndexes(),
tmp.getValue().unaryOperations(uop, new MatrixBlock()));
qOut.enqueueTask(tmpOut);
}
qOut.closeInput();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
});
task.get();
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
} finally {
pool.shutdown();
}
}
}
114 changes: 114 additions & 0 deletions src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.ooc;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;

public class UnaryTest extends AutomatedTestBase {

private static final String TEST_NAME = "Unary";
private static final String TEST_DIR = "functions/ooc/";
private static final String TEST_CLASS_DIR = TEST_DIR + UnaryTest.class.getSimpleName() + "/";
private static final String INPUT_NAME = "X";
private static final String OUTPUT_NAME = "res";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
addTestConfiguration(TEST_NAME, config);
}

/**
* Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
*/
@Test
public void testUnary() {
testUnaryOperation(false);
}


public void testUnaryOperation(boolean rewrite)
{
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;

try {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-explain", "-stats", "-ooc",
"-args", input(INPUT_NAME), output(OUTPUT_NAME)};

int rows = 1000, cols = 1000;
MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7);
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY);
writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols);
HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64,
new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY);

runTest(true, false, null, -1);

HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME);
Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1));
double expected = 0.0;
for(int i = 0; i < rows; i++) {
for(int j = 0; j < cols; j++) {
expected += Math.ceil(mb.get(i, j));
}
}

Assert.assertEquals(expected, result, 1e-10);

String prefix = Instruction.OOC_INST_PREFIX;
Assert.assertTrue("OOC wasn't used for RBLK",
heavyHittersContainsString(prefix + Opcodes.RBLK));
Assert.assertTrue("OOC wasn't used for CEIL",
heavyHittersContainsString(prefix + Opcodes.CEIL));
}
catch(Exception ex) {
Assert.fail(ex.getMessage());
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite;
resetExecMode(platformOld);
}
}
}
29 changes: 29 additions & 0 deletions src/test/scripts/functions/ooc/Unary.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# Read input matrix and operator from command line args
X = read($1);
#print(toString(X))
Y = ceil(X);
#print(toString(Y))
res = as.matrix(sum(Y));
# Write the final matrix result
write(res, $2);
Loading