diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 0e5b3f1f512..a744b5d8136 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -27,6 +27,7 @@ import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -51,6 +52,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return ReblockOOCInstruction.parseInstruction(str); case AggregateUnary: return AggregateUnaryOOCInstruction.parseInstruction(str); + case Unary: + return UnaryOOCInstruction.parseInstruction(str); case Binary: return BinaryOOCInstruction.parseInstruction(str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index fe73e57fd24..db3d2da8b11 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - Reblock, AggregateUnary, Binary + Reblock, AggregateUnary, Binary, Unary } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java new file mode 100644 index 00000000000..d2fccd5fd6a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import org.apache.sysds.runtime.util.CommonThreadPool; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +public class UnaryOOCInstruction extends ComputationOOCInstruction { + private UnaryOperator _uop = null; + + protected UnaryOOCInstruction(OOCType type, UnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { + super(type, op, in1, out, opcode, istr); + + _uop = op; + } + + public static UnaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 2); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + + UnaryOperator uopcode = InstructionUtils.parseUnaryOperator(opcode); + return new UnaryOOCInstruction(OOCType.Unary, uopcode, in1, out, opcode, str); + } + + public void processInstruction( ExecutionContext ec ) { + UnaryOperator uop = (UnaryOperator) _uop; + // Create thread and process the unary operation + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue qIn = min.getStreamHandle(); + LocalTaskQueue qOut = new LocalTaskQueue<>(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + + ExecutorService pool = CommonThreadPool.get(); + try { + Future task =pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().unaryOperations(uop, new MatrixBlock())); + qOut.enqueueTask(tmpOut); + } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }); + task.get(); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } finally { + pool.shutdown(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java new file mode 100644 index 00000000000..90f5f8ff0cb --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/UnaryTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class UnaryTest extends AutomatedTestBase { + + private static final String TEST_NAME = "Unary"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + UnaryTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME); + addTestConfiguration(TEST_NAME, config); + } + + /** + * Test the sum of scalar multiplication, "sum(X*7)", with OOC backend. + */ + @Test + public void testUnary() { + testUnaryOperation(false); + } + + + public void testUnaryOperation(boolean rewrite) + { + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + int rows = 1000, cols = 1000; + MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); + HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64, + new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY); + + runTest(true, false, null, -1); + + HashMap dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME); + Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1)); + double expected = 0.0; + for(int i = 0; i < rows; i++) { + for(int j = 0; j < cols; j++) { + expected += Math.ceil(mb.get(i, j)); + } + } + + Assert.assertEquals(expected, result, 1e-10); + + String prefix = Instruction.OOC_INST_PREFIX; + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + Assert.assertTrue("OOC wasn't used for CEIL", + heavyHittersContainsString(prefix + Opcodes.CEIL)); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite; + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/Unary.dml b/src/test/scripts/functions/ooc/Unary.dml new file mode 100644 index 00000000000..6d34e8fd763 --- /dev/null +++ b/src/test/scripts/functions/ooc/Unary.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read input matrix and operator from command line args +X = read($1); +#print(toString(X)) +Y = ceil(X); +#print(toString(Y)) +res = as.matrix(sum(Y)); +# Write the final matrix result +write(res, $2);