Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.CtableOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ParameterizedBuiltinOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
Expand Down Expand Up @@ -75,6 +76,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
return CentralMomentOOCInstruction.parseInstruction(str);
case Ctable:
return CtableOOCInstruction.parseInstruction(str);
case ParameterizedBuiltin:
return ParameterizedBuiltinOOCInstruction.parseInstruction(str);

default:
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.sysds.runtime.util.OOCJoin;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -53,7 +54,7 @@ public abstract class OOCInstruction extends Instruction {
private static final AtomicInteger nextStreamId = new AtomicInteger(0);

public enum OOCType {
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin
}

protected final OOCInstruction.OOCType _ooctype;
Expand Down Expand Up @@ -208,6 +209,8 @@ protected <T> CompletableFuture<Void> submitOOCTasks(final List<OOCStream<T>> qu

final AtomicInteger globalTaskCtr = new AtomicInteger(0);
final CompletableFuture<Void> globalFuture = CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new));
if (_outQueues == null)
_outQueues = Collections.emptySet();
final Runnable oocFinalizer = oocTask(finalizer, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new));
final Object globalLock = new Object();

Expand Down Expand Up @@ -278,7 +281,14 @@ protected <T> CompletableFuture<Void> submitOOCTasks(final List<OOCStream<T>> qu

globalFuture.whenComplete((res, e) -> {
if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally())
futures.forEach(f -> f.cancel(true));
futures.forEach(f -> {
if (!f.isDone()) {
if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally())
f.cancel(true);
else
f.complete(null);
}
});

boolean runFinalizer;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;

import java.util.LinkedHashMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;

public class ParameterizedBuiltinOOCInstruction extends ComputationOOCInstruction {

protected final LinkedHashMap<String, String> params;

protected ParameterizedBuiltinOOCInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out,
String opcode, String istr) {
super(OOCInstruction.OOCType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
params = paramsMap;
}

public static ParameterizedBuiltinOOCInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// first part is always the opcode
String opcode = parts[0];
// last part is always the output
CPOperand out = new CPOperand(parts[parts.length - 1]);

// process remaining parts and build a hash map
LinkedHashMap<String, String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);

// determine the appropriate value function
ValueFunction func = null;

if(opcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinOOCInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
else if(opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) {
return new ParameterizedBuiltinOOCInstruction(null, paramsMap, out, opcode, str);
}
else
throw new NotImplementedException(); // TODO
}

@Override
public void processInstruction(ExecutionContext ec) {
if(instOpcode.equalsIgnoreCase(Opcodes.REPLACE.toString())) {
if(ec.isFrameObject(params.get("target"))){
throw new NotImplementedException();
} else{
MatrixObject targetObj = ec.getMatrixObject(params.get("target"));
OOCStream<IndexedMatrixValue> qIn = targetObj.getStreamHandle();
OOCStream<IndexedMatrixValue> qOut = createWritableStream();

double pattern = Double.parseDouble(params.get("pattern"));
double replacement = Double.parseDouble(params.get("replacement"));

mapOOC(qIn, qOut, tmp -> new IndexedMatrixValue(tmp.getIndexes(), tmp.getValue().replaceOperations(new MatrixBlock(), pattern, replacement)));

ec.getMatrixObject(output).setStreamHandle(qOut);
}
}
else if(instOpcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) {
MatrixObject targetObj = ec.getMatrixObject(params.get("target"));
OOCStream<IndexedMatrixValue> qIn = targetObj.getStreamHandle();
Data pattern = ec.getVariable(params.get("pattern"));

if( pattern == null ) //literal
pattern = ScalarObjectFactory.createScalarObject(Types.ValueType.FP64, params.get("pattern"));

if (!pattern.getDataType().isScalar())
throw new NotImplementedException();

Data finalPattern = pattern;

AtomicBoolean found = new AtomicBoolean(false);

MutableObject<CompletableFuture<Void>> futureRef = new MutableObject<>();
CompletableFuture<Void> future = submitOOCTasks(qIn, tmp -> {
boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue());

if (contains) {
found.set(true);

// Now we may complete the future
if (futureRef.getValue() != null)
futureRef.getValue().complete(null);
}
}, () -> {});
futureRef.setValue(future);

try {
futureRef.getValue().get();
} catch (InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}

boolean ret = found.get();
ec.setScalarOutput(output.getName(), new BooleanObject(ret));
}
}
}
121 changes: 121 additions & 0 deletions src/test/java/org/apache/sysds/test/functions/ooc/ContainsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.ooc;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;

public class ContainsTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "Contains";
private final static String TEST_DIR = "functions/ooc/";
private final static String TEST_CLASS_DIR = TEST_DIR + ContainsTest.class.getSimpleName() + "/";
private final static double eps = 1e-8;
private static final String INPUT_NAME_1 = "X";
private static final String OUTPUT_NAME = "res";

private final static int rows = 1500;
private final static int cols = 1200;
private final static int maxVal = 2;
private final static double sparsity1 = 1;
private final static double sparsity2 = 0.05;

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
addTestConfiguration(TEST_NAME1, config);
}

@Test
public void testContainsDense() {
runContainsTest(false);
}

@Test
public void testContainsSparse() {
runContainsTest(true);
}

private void runContainsTest(boolean sparse) {
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);

try {
getAndLoadTestConfiguration(TEST_NAME1);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)};

// 1. Generate the data in-memory as MatrixBlock objects
double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, sparse ? sparsity2 : sparsity1, 7);
X_data[rows-1][cols-1] = -1;

// 2. Convert the double arrays to MatrixBlock objects
MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data);

// 3. Create a binary matrix writer
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);

// 4. Write matrix A to a binary SequenceFile
writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros());
HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64,
new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY);

runTest(true, false, null, -1);

//check replace OOC op
Assert.assertTrue("OOC wasn't used for contains",
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CONTAINS));

//compare results

// rerun without ooc flag
programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")};
runTest(true, false, null, -1);

// compare matrices
MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
Types.FileFormat.BINARY, 1, 1, 1000);
MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
Types.FileFormat.BINARY, 1, 1, 1000);
TestUtils.compareMatrices(ret1, ret2, eps);
}
catch(IOException e) {
throw new RuntimeException(e);
}
finally {
resetExecMode(platformOld);
}
}
}
Loading
Loading