diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 9c0f0f2e0f4..4e9a92ecb78 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -26,6 +26,7 @@ import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction; @@ -72,7 +73,9 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return TeeOOCInstruction.parseInstruction(str); case CentralMoment: return CentralMomentOOCInstruction.parseInstruction(str); - + case Ctable: + return CtableOOCInstruction.parseInstruction(str); + default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java new file mode 100644 index 00000000000..d6ec115afe1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.HashMap; + +import org.apache.sysds.common.Types; +import org.apache.sysds.lops.Ctable; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.CTableMap; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.LongLongDoubleHashMap; + +public class CtableOOCInstruction extends ComputationOOCInstruction { + private final CPOperand _outDim1; + private final CPOperand _outDim2; + private final boolean _ignoreZeros; + + protected CtableOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, CPOperand outDim1, CPOperand outDim2, boolean ignoreZeros, String opcode, String istr) { + super(type, op, in1, in2, in3, out, opcode, istr); + _ignoreZeros = ignoreZeros; + _outDim1 = outDim1; + _outDim2 = outDim2; + } + + public static CtableOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 8); + + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[6]); + + String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX); + String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX); + CPOperand outDim1 = new CPOperand(dim1Fields[0], Types.ValueType.FP64, Types.DataType.SCALAR, Boolean.parseBoolean(dim1Fields[1])); + CPOperand outDim2 = new CPOperand(dim2Fields[0], Types.ValueType.FP64, Types.DataType.SCALAR, Boolean.parseBoolean(dim2Fields[1])); + + boolean ignoreZeros = Boolean.parseBoolean(parts[7]); + + // does not require any op + return new CtableOOCInstruction(OOCType.Ctable, null, in1, in2, in3, out, outDim1, outDim2, ignoreZeros, opcode, str); + } + + @Override + public void processInstruction( ExecutionContext ec ) { + + MatrixObject in1 = ec.getMatrixObject(input1); // stream + LocalTaskQueue qIn1 = in1.getStreamHandle(); + IndexedMatrixValue tmp1 = null; + + long outputDim1 = ec.getScalarInput(_outDim1).getLongValue(); + long outputDim2 = ec.getScalarInput(_outDim2).getLongValue(); + + long cols = in1.getDataCharacteristics().getNumColBlocks(); + CTableMap map = new CTableMap(LongLongDoubleHashMap.EntryType.INT); + + Ctable.OperationTypes ctableOp = findCtableOperation(); + MatrixObject in2 = null, in3 = null; + LocalTaskQueue qIn2 = null, qIn3 = null; + double cst2 = 0, cst3 = 0; + + // init vars based on ctableOp + if (ctableOp.hasSecondInput()){ + in2 = ec.getMatrixObject(input2); // stream + qIn2 = in2.getStreamHandle(); + } else + cst2 = ec.getScalarInput(input2).getDoubleValue(); + + if (ctableOp.hasThirdInput()){ + in3 = ec.getMatrixObject(input3); // stream + qIn3 = in3.getStreamHandle(); + } else + cst3 = ec.getScalarInput(input3).getDoubleValue(); + + HashMap blocksIn2 = new HashMap<>(), blocksIn3 = new HashMap<>(); + MatrixBlock block2, block3; + + // only init result block if output dims known and dense + MatrixBlock result = null; + boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1); + if (outputDimsKnown){ + long totalRows = in1.getDataCharacteristics().getRows(); + long totalCols = in1.getDataCharacteristics().getCols(); + boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, totalRows*totalCols); + if(!sparse) + result = new MatrixBlock((int)outputDim1, (int)outputDim2, false); + } + + try { + while((tmp1 = qIn1.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + + MatrixBlock block1 = (MatrixBlock) tmp1.getValue(); + long r = tmp1.getIndexes().getRowIndex(); + long c = tmp1.getIndexes().getColumnIndex(); + long key = (r-1) * cols + (c-1); + + switch(ctableOp) { + case CTABLE_TRANSFORM: + // ctable(A,B,W) + block2 = getOrDequeueBlock(key, cols, blocksIn2, qIn2); + block3 = getOrDequeueBlock(key, cols, blocksIn3, qIn3); + block1.ctableOperations(_optr, block2, block3, map, result); + break; + case CTABLE_TRANSFORM_SCALAR_WEIGHT: + // ctable(A,B) or ctable(A,B,1) + block2 = getOrDequeueBlock(key, cols, blocksIn2, qIn2); + block1.ctableOperations(_optr, block2, cst3, _ignoreZeros, map, result); + break; + case CTABLE_TRANSFORM_HISTOGRAM: + // ctable(A,1) or ctable(A,1,1) + block1.ctableOperations(_optr, cst2, cst3, map, result); + break; + case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: + // ctable(A,1,W) + block3 = getOrDequeueBlock(key, cols, blocksIn3, qIn3); + block1.ctableOperations(_optr, cst2, block3, map, result); + break; + + default: + throw new DMLRuntimeException("Encountered an invalid OOC ctable operation ("+ctableOp+") while executing instruction: " + + this); + } + } + if (result == null){ + if(outputDimsKnown) + result = DataConverter.convertToMatrixBlock(map, (int)outputDim1, (int)outputDim2); + else + result = DataConverter.convertToMatrixBlock(map); + } + else + result.examSparsity(); + + ec.setMatrixOutput(output.getName(), result); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + } + + private MatrixBlock getOrDequeueBlock(long key, long cols, HashMap blocks, LocalTaskQueue queue) + throws InterruptedException { + MatrixBlock block = blocks.get(key); + if (block == null) { + IndexedMatrixValue tmp; + // corresponding block still in queue, dequeue until found + while ((tmp = queue.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + block = (MatrixBlock) tmp.getValue(); + long r = tmp.getIndexes().getRowIndex(); + long c = tmp.getIndexes().getColumnIndex(); + long tmpKey = (r-1) * cols + (c-1); + // found corresponding block + if (tmpKey == key) break; + // store all dequeued blocks in cache that we don't need yet + blocks.put(tmpKey, block); + } + } + else + blocks.remove(key); // needed only once + + return block; + } + + private Ctable.OperationTypes findCtableOperation() { + Types.DataType dt1 = input1.getDataType(); + Types.DataType dt2 = input2.getDataType(); + Types.DataType dt3 = input3.getDataType(); + return Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 5b1c7666612..d55d1ee5948 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM + Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CTableTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CTableTest.java new file mode 100644 index 00000000000..9b7c739c904 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CTableTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class CTableTest extends AutomatedTestBase{ + private static final String TEST_NAME1 = "CTableTest"; + private static final String TEST_NAME2 = "WeightedCTableTest"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + CTableTest.class.getSimpleName() + "/"; + + private static final String INPUT_NAME1 = "v"; + private static final String INPUT_NAME2 = "w"; + private static final String INPUT_NAME3 = "weights"; + private static final String OUTPUT_NAME = "res"; + + private final static double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1)); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2)); + } + + @Test + public void testCTableSimple(){ testCTable(1372, 1012, 5, 5, false);} + + @Test + public void testCTableValueSetDifferencesNonEmpty(){ testCTable(2000, 37, 4995, 5, false);} + + @Test + public void testWeightedCTableSimple(){ testCTable(1372, 1012, 5, 5, true);} + + @Test + public void testWeightedCTableValueSetDifferencesNonEmpty(){ testCTable(2000, 37, 4995, 5, true);} + + + public void testCTable(int rows, int cols, int maxValV, int maxValW, boolean isWeighted) + { + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + + try { + String TEST_NAME = isWeighted? TEST_NAME2:TEST_NAME1; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + if (isWeighted) + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME1), input(INPUT_NAME2), input(INPUT_NAME3), output(OUTPUT_NAME)}; + else + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME1), input(INPUT_NAME2), output(OUTPUT_NAME)}; + + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); + + // values <=0 invalid + double[][] v = TestUtils.floor(getRandomMatrix(rows, cols, 1, maxValV, 1.0, 7)); + double[][] w = TestUtils.floor(getRandomMatrix(rows, cols, 1, maxValW, 1.0, 13)); + double[][] weights = null; + + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(v), input(INPUT_NAME1), rows, cols, 1000, rows*cols); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(w), input(INPUT_NAME2), rows, cols, 1000, rows*cols); + + HDFSTool.writeMetaDataFile(input(INPUT_NAME1+".mtd"), Types.ValueType.FP64, new MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME2+".mtd"), Types.ValueType.FP64, new MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY); + + // for RScript + writeInputMatrixWithMTD("vR", v, true); + writeInputMatrixWithMTD("wR", w, true); + + if (isWeighted) { + weights = TestUtils.floor(getRandomMatrix(rows, cols, 1, maxValW, 1.0, 17)); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(weights), input(INPUT_NAME3), rows, cols, + 1000, rows * cols); + HDFSTool.writeMetaDataFile(input(INPUT_NAME3 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, rows * cols), Types.FileFormat.BINARY); + writeInputMatrixWithMTD("weightsR", weights, true); + } + + runTest(true, false, null, -1); + runRScript(true); + + // compare matrices + HashMap rfile = readRMatrixFromExpectedDir("resR"); + double[][] rRes = TestUtils.convertHashMapToDoubleArray(rfile); + double[][] dmlRes = DataConverter.convertToDoubleMatrix(DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), Types.FileFormat.BINARY, rRes.length, rRes[0].length, 1000, 1000)); + + TestUtils.compareMatrices(rRes, dmlRes, eps); + + String prefix = Instruction.OOC_INST_PREFIX; + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/CTableTest.R b/src/test/scripts/functions/ooc/CTableTest.R new file mode 100644 index 00000000000..a2b76ea0d2c --- /dev/null +++ b/src/test/scripts/functions/ooc/CTableTest.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +v <- as.matrix(readMM(paste(args[1], "vR.mtx", sep=""))) +w <- as.matrix(readMM(paste(args[1], "wR.mtx", sep=""))) + +res = table (v, w); +res = as.matrix(as.data.frame.matrix(res)); + +writeMM(as(res, "CsparseMatrix"), paste(args[2], "resR", sep="")); diff --git a/src/test/scripts/functions/ooc/CTableTest.dml b/src/test/scripts/functions/ooc/CTableTest.dml new file mode 100644 index 00000000000..1d8ff90be21 --- /dev/null +++ b/src/test/scripts/functions/ooc/CTableTest.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +v = read($1); +w = read($2); +res = table (v, w); + +write(res, $3, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedCTableTest.R b/src/test/scripts/functions/ooc/WeightedCTableTest.R new file mode 100644 index 00000000000..5e3a3f698d8 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedCTableTest.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +v <- as.vector(readMM(paste(args[1], "vR.mtx", sep=""))) +w <- as.vector(readMM(paste(args[1], "wR.mtx", sep=""))) +weights <- as.vector(readMM(paste(args[1], "weightsR.mtx", sep=""))) + +res = xtabs(weights ~ v + w) +res = as.matrix(as.data.frame.matrix(res)); + +writeMM(as(res, "CsparseMatrix"), paste(args[2], "resR", sep="")); diff --git a/src/test/scripts/functions/ooc/WeightedCTableTest.dml b/src/test/scripts/functions/ooc/WeightedCTableTest.dml new file mode 100644 index 00000000000..5af9d34d4ee --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedCTableTest.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +v = read($1); +w = read($2); +weights = read($3); +res = table (v, w, weights); + +write(res, $4, format="binary");