diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 4e9a92ecb78..15c23e23c9d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.ParameterizedBuiltinOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; @@ -75,6 +76,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return CentralMomentOOCInstruction.parseInstruction(str); case Ctable: return CtableOOCInstruction.parseInstruction(str); + case ParameterizedBuiltin: + return ParameterizedBuiltinOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 1fdd5cd9657..7abf593aba0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.util.OOCJoin; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -53,7 +54,7 @@ public abstract class OOCInstruction extends Instruction { private static final AtomicInteger nextStreamId = new AtomicInteger(0); public enum OOCType { - Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing + Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin } protected final OOCInstruction.OOCType _ooctype; @@ -208,6 +209,8 @@ protected CompletableFuture submitOOCTasks(final List> qu final AtomicInteger globalTaskCtr = new AtomicInteger(0); final CompletableFuture globalFuture = CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)); + if (_outQueues == null) + _outQueues = Collections.emptySet(); final Runnable oocFinalizer = oocTask(finalizer, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new)); final Object globalLock = new Object(); @@ -278,7 +281,14 @@ protected CompletableFuture submitOOCTasks(final List> qu globalFuture.whenComplete((res, e) -> { if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) - futures.forEach(f -> f.cancel(true)); + futures.forEach(f -> { + if (!f.isDone()) { + if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) + f.cancel(true); + else + f.complete(null); + } + }); boolean runFinalizer; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java new file mode 100644 index 00000000000..e56d32e4401 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin; +import org.apache.sysds.runtime.functionobjects.ValueFunction; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.BooleanObject; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.SimpleOperator; + +import java.util.LinkedHashMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; + +public class ParameterizedBuiltinOOCInstruction extends ComputationOOCInstruction { + + protected final LinkedHashMap params; + + protected ParameterizedBuiltinOOCInstruction(Operator op, LinkedHashMap paramsMap, CPOperand out, + String opcode, String istr) { + super(OOCInstruction.OOCType.ParameterizedBuiltin, op, null, null, out, opcode, istr); + params = paramsMap; + } + + public static ParameterizedBuiltinOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + // first part is always the opcode + String opcode = parts[0]; + // last part is always the output + CPOperand out = new CPOperand(parts[parts.length - 1]); + + // process remaining parts and build a hash map + LinkedHashMap paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts); + + // determine the appropriate value function + ValueFunction func = null; + + if(opcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) { + func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); + return new ParameterizedBuiltinOOCInstruction(new SimpleOperator(func), paramsMap, out, opcode, str); + } + else if(opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) { + return new ParameterizedBuiltinOOCInstruction(null, paramsMap, out, opcode, str); + } + else + throw new NotImplementedException(); // TODO + } + + @Override + public void processInstruction(ExecutionContext ec) { + if(instOpcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) { + if(ec.isFrameObject(params.get("target"))){ + throw new NotImplementedException(); + } else{ + MatrixObject targetObj = ec.getMatrixObject(params.get("target")); + OOCStream qIn = targetObj.getStreamHandle(); + OOCStream qOut = createWritableStream(); + + double pattern = Double.parseDouble(params.get("pattern")); + double replacement = Double.parseDouble(params.get("replacement")); + + mapOOC(qIn, qOut, tmp -> new IndexedMatrixValue(tmp.getIndexes(), tmp.getValue().replaceOperations(new MatrixBlock(), pattern, replacement))); + + ec.getMatrixObject(output).setStreamHandle(qOut); + } + } + else if(instOpcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) { + MatrixObject targetObj = ec.getMatrixObject(params.get("target")); + OOCStream qIn = targetObj.getStreamHandle(); + Data pattern = ec.getVariable(params.get("pattern")); + + if( pattern == null ) //literal + pattern = ScalarObjectFactory.createScalarObject(Types.ValueType.FP64, params.get("pattern")); + + if (!pattern.getDataType().isScalar()) + throw new NotImplementedException(); + + Data finalPattern = pattern; + + AtomicBoolean found = new AtomicBoolean(false); + + MutableObject> futureRef = new MutableObject<>(); + CompletableFuture future = submitOOCTasks(qIn, tmp -> { + boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue()); + + if (contains) { + found.set(true); + + // Now we may complete the future + if (futureRef.getValue() != null) + futureRef.getValue().complete(null); + } + }, () -> {}); + futureRef.setValue(future); + + try { + futureRef.getValue().get(); + } catch (InterruptedException | ExecutionException e) { + throw new DMLRuntimeException(e); + } + + boolean ret = found.get(); + ec.setScalarOutput(output.getName(), new BooleanObject(ret)); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/ContainsTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/ContainsTest.java new file mode 100644 index 00000000000..059763f816a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/ContainsTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class ContainsTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "Contains"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + ContainsTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1500; + private final static int cols = 1200; + private final static int maxVal = 2; + private final static double sparsity1 = 1; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testContainsDense() { + runContainsTest(false); + } + + @Test + public void testContainsSparse() { + runContainsTest(true); + } + + private void runContainsTest(boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, sparse ? sparsity2 : sparsity1, 7); + X_data[rows-1][cols-1] = -1; + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check replace OOC op + Assert.assertTrue("OOC wasn't used for contains", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CONTAINS)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, 1, 1, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, 1, 1, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/ReplaceTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/ReplaceTest.java new file mode 100644 index 00000000000..a3a2ba9f698 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/ReplaceTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class ReplaceTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "Replace"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + ReplaceTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1500; + private final static int cols = 1200; + private final static int maxVal = 2; + private final static double sparsity1 = 1; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testReplaceDense() { + runReplaceTest(false); + } + + @Test + public void testReplaceSparse() { + runReplaceTest(true); + } + + private void runReplaceTest(boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, sparse ? sparsity2 : sparsity1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + X_mb = X_mb.unaryOperations(new UnaryOperator(Builtin.getBuiltinFnObject(Builtin.BuiltinCode.ROUND))); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check replace OOC op + Assert.assertTrue("OOC wasn't used for replacement", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.REPLACE)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/Contains.dml b/src/test/scripts/functions/ooc/Contains.dml new file mode 100644 index 00000000000..27514f8c76b --- /dev/null +++ b/src/test/scripts/functions/ooc/Contains.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +res = as.matrix(contains(target=X, pattern=-1)); + +write(res, $2, format="binary"); diff --git a/src/test/scripts/functions/ooc/Replace.dml b/src/test/scripts/functions/ooc/Replace.dml new file mode 100644 index 00000000000..5108d272860 --- /dev/null +++ b/src/test/scripts/functions/ooc/Replace.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +res = replace(target=X, pattern=1, replacement=-1); + +write(res, $2, format="binary");