Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public enum Builtins {
COLVAR("colVars", false),
COMPONENTS("components", true),
COMPRESS("compress", false, ReturnType.MULTI_RETURN),
QUANTIZE_COMPRESS("quantize_compress", false, ReturnType.MULTI_RETURN),
CONFUSIONMATRIX("confusionMatrix", true),
CONV2D("conv2d", false),
CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false),
Expand Down
139 changes: 70 additions & 69 deletions src/main/java/org/apache/sysds/common/InstructionType.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,75 +19,76 @@
package org.apache.sysds.common;

public enum InstructionType {
AggregateBinary,
AggregateTernary,
AggregateUnary,
UaggOuterChain,
Binary,
Unary,
Builtin,
Ternary,
BuiltinNary,
ParameterizedBuiltin,
MultiReturnParameterizedBuiltin,
Variable,
Reorg,
Reshape,
Dnn,
Quaternary,
FCall,
Append,
Rand,
StringInit,
Ctable,
CentralMoment,
Covariance,
QSort,
QPick,
MatrixIndexing,
MultiReturnBuiltin,
MultiReturnComplexMatrixBuiltin,
Partition,
Compression,
DeCompression,
SpoofFused,
Prefetch,
EvictLineageCache,
Broadcast,
TrigRemote,
Local,
Sql,
MMTSJ,
PMMJ,
MMChain,
AggregateBinary,
AggregateTernary,
AggregateUnary,
UaggOuterChain,
Binary,
Unary,
Builtin,
Ternary,
BuiltinNary,
ParameterizedBuiltin,
MultiReturnParameterizedBuiltin,
Variable,
Reorg,
Reshape,
Dnn,
Quaternary,
FCall,
Append,
Rand,
StringInit,
Ctable,
CentralMoment,
Covariance,
QSort,
QPick,
MatrixIndexing,
MultiReturnBuiltin,
MultiReturnComplexMatrixBuiltin,
Partition,
Compression,
DeCompression,
QuantizeCompression,
SpoofFused,
Prefetch,
EvictLineageCache,
Broadcast,
TrigRemote,
Local,
Sql,
MMTSJ,
PMMJ,
MMChain,

//SP Types
MAPMM,
MAPMMCHAIN,
TSMM2,
CPMM,
RMM,
ZIPMM,
PMAPMM,
Reblock,
CSVReblock,
LIBSVMReblock,
Checkpoint,
MAppend,
RAppend,
GAppend,
GAlignedAppend,
CumsumAggregate,
CumsumOffset,
BinUaggChain,
Cast,
TSMM,
AggregateUnarySketch,
PMM,
MatrixReshape,
Write,
Init,
//SP Types
MAPMM,
MAPMMCHAIN,
TSMM2,
CPMM,
RMM,
ZIPMM,
PMAPMM,
Reblock,
CSVReblock,
LIBSVMReblock,
Checkpoint,
MAppend,
RAppend,
GAppend,
GAlignedAppend,
CumsumAggregate,
CumsumOffset,
BinUaggChain,
Cast,
TSMM,
AggregateUnarySketch,
PMM,
MatrixReshape,
Write,
Init,

//FED
Tsmm;
//FED
Tsmm;
}
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Opcodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ public enum Opcodes {
PARTITION("partition", InstructionType.Partition),
COMPRESS(Compression.OPCODE, InstructionType.Compression, InstructionType.Compression),
DECOMPRESS(DeCompression.OPCODE, InstructionType.DeCompression, InstructionType.DeCompression),
QUANTIZE_COMPRESS("quantize_compress", InstructionType.QuantizeCompression),
SPOOF("spoof", InstructionType.SpoofFused),
PREFETCH("prefetch", InstructionType.Prefetch),
EVICT("_evict", InstructionType.EvictLineageCache),
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,9 @@ public enum OpOp2 {
//fused ML-specific operators for performance
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
MINUS1_MULT(false); //1-X*Y

MINUS1_MULT(false), //1-X*Y
QUANTIZE_COMPRESS(false); //quantization-fused compression

private final boolean _validOuter;

private OpOp2(boolean outer) {
Expand Down
9 changes: 9 additions & 0 deletions src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,15 @@ public enum MemoryManager {
*/
public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;

/**
* This variable allows for insertion of Quantize and compress in the dml script from the user.
*/
public static boolean ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND = true;

/**
* Boolean specifying if quantization-fused compression rewrite is allowed.
*/
public static boolean ALLOW_QUANTIZE_COMPRESS_REWRITE = true;

/**
* Boolean specifying if compression rewrites is allowed. This is disabled at run time if the IPA for Workload aware compression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
_dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications
_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock

if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
_dagRuleSet.add( new RewriteQuantizationFusedCompression() );

//add statement block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;

import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.hops.BinaryOp;

import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;

import org.apache.sysds.hops.Hop;

/**
* Rule: RewriteFloorCompress. Detects the sequence `M2 = floor(M * S)` followed by `C = compress(M2)` and prepares for
* fusion into a single operation. This rewrite improves performance by avoiding intermediate results. Currently, it
* identifies the pattern without applying fusion.
*/
public class RewriteQuantizationFusedCompression extends HopRewriteRule {
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if(roots == null)
return null;

// traverse the HOP DAG
HashMap<String, Hop> floors = new HashMap<>();
HashMap<String, Hop> compresses = new HashMap<>();
for(Hop h : roots)
collectFloorCompressSequences(h, floors, compresses);

Hop.resetVisitStatus(roots);

// check compresses for compress-after-floor pattern
for(Entry<String, Hop> e : compresses.entrySet()) {
String inputname = e.getKey();
Hop compresshop = e.getValue();

if(floors.containsKey(inputname) // floors same name
&& ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) ||
(floors.get(inputname).getEndLine() < compresshop.getEndLine()) ||
(floors.get(inputname).getBeginLine() == compresshop.getBeginLine() &&
floors.get(inputname).getEndLine() == compresshop.getBeginLine() &&
floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) {

// retrieve the floor hop and inputs
Hop floorhop = floors.get(inputname);
Hop floorInput = floorhop.getInput().get(0);

// check if the input of the floor operation is a matrix
if(floorInput.getDataType() == DataType.MATRIX) {

// Check if the input of the floor operation involves a multiplication operation
if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) {
Hop initialMatrix = floorInput.getInput().get(0);
Hop sf = floorInput.getInput().get(1);

// create fused hop
BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64,
OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf);

// rewire compress consumers to fusedHop
List<Hop> parents = new ArrayList<>(compresshop.getParent());
for(Hop p : parents) {
HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop);
}
}
}
}
}
return roots;
}

@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
// do nothing, floor/compress do not occur in predicates
return root;
}

private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) {
if(hop.isVisited())
return;

// process childs
if(!hop.getInput().isEmpty())
for(Hop c : hop.getInput())
collectFloorCompressSequences(c, floors, compresses);

// process current hop
if(hop instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) hop;
if(uop.getOp() == OpOp1.FLOOR) {
floors.put(uop.getName(), uop);
}
else if(uop.getOp() == OpOp1.COMPRESS) {
compresses.put(uop.getInput(0).getName(), uop);
}
}
hop.setVisited();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ else if(((ConstIdentifier) getThirdExpr().getOutput())
else
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
break;

default: //always unconditional
raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false);
}
Expand Down Expand Up @@ -2011,8 +2011,38 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
output.setValueType(id.getValueType());
}
else
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
raiseValidateError("The compress or decompress instruction is not allowed in dml scripts");
break;
case QUANTIZE_COMPRESS:
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) {
checkNumParameters(2);
Expression firstExpr = getFirstExpr();
Expression secondExpr = getSecondExpr();

checkMatrixParam(getFirstExpr());

if(secondExpr != null) {
// check if scale factor is a scalar, vector or matrix
checkMatrixScalarParam(secondExpr);
// if scale factor is a vector or matrix, make sure it has an appropriate shape
if(secondExpr.getOutput().getDataType() != DataType.SCALAR) {
if(is1DMatrix(secondExpr)) {
long vectorLength = secondExpr.getOutput().getDim1();
if(vectorLength != firstExpr.getOutput().getDim1()) {
raiseValidateError(
"The length of the row-wise scale factor vector must match the number of rows in the matrix.");
}
}
else {
checkMatchingDimensions(firstExpr, secondExpr);
}
}
}
}
else
raiseValidateError("The quantize_compress instruction not allowed in dml scripts");
break;

case ROW_COUNT_DISTINCT:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2585,6 +2585,9 @@ else if ( sop.equalsIgnoreCase(Opcodes.NOTEQUAL.toString()) )
case DECOMPRESS:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.DECOMPRESS, expr);
break;
case QUANTIZE_COMPRESS:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
break;

// Boolean binary
case XOR:
Expand Down
Loading