From 65ae3c9afb289536f9f74d56a8a3c3f7b483c2e4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Dec 2023 20:23:48 +0000 Subject: [PATCH] [Unity][Transform] Implement relax.transform.ReorderTakeAfterMatmul If `R.matmul(x, R.take(weights, indices))` occurs, with `R.take` selecting along the output feature dimension, it can be rearranged to `R.take(R.matmul(x, weights), indices)`. --- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 15 ++ .../transform/reorder_take_after_matmul.cc | 164 +++++++++++++++ ...est_transform_reorder_take_after_matmul.py | 186 ++++++++++++++++++ 4 files changed, 366 insertions(+) create mode 100644 src/relax/transform/reorder_take_after_matmul.cc create mode 100644 tests/python/relax/test_transform_reorder_take_after_matmul.py diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index eeac5f82c8c0..7efe144c5062 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -63,6 +63,7 @@ RemovePurityChecking, RemoveUnusedParameters, RemoveUnusedOutputs, + ReorderTakeAfterMatmul, RewriteCUDAGraph, RewriteDataflowReshape, RunCodegen, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 03d0878810c9..1f390adb2e16 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1302,6 +1302,21 @@ def ExpandMatmulOfSum(): return _ffi_api.ExpandMatmulOfSum() # type: ignore +def ReorderTakeAfterMatmul(): + """Reorder `matmul(x, take(weights, indices))` to `take(matmul(x,weights),indices)` + + Useful for optimizing LoRA computations, where several LoRAs may + be batched together. + + Returns + ------- + ret : tvm.transform.Pass + The corresponding pass. + """ + + return _ffi_api.ReorderTakeAfterMatmul() # type: ignore + + def CombineParallelMatmul(check=None): """Combine multiple matmul operators sharing the same LHS matrix into one, followed by slicing. When all matmul branches in a tree have the same set of fused ops, diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc new file mode 100644 index 000000000000..9e037f05f0dd --- /dev/null +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/expand_matmul_of_sum.cc + * \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)` + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/tensor/index.h" +#include "../op/tensor/linear_algebra.h" +#include "../op/tensor/manipulate.h" + +namespace tvm { +namespace relax { + +namespace { +std::tuple)>> CreatePatterns() { + auto pat_lhs = WildcardPattern(); + + auto pat_weights = WildcardPattern(); + auto pat_indices = WildcardPattern(); + auto pat_rhs = IsOp("relax.take")(pat_weights, pat_indices); + + auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs); + + auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto lhs = matches[pat_lhs]; + auto weights = matches[pat_weights]; + auto indices = matches[pat_indices]; + + const auto* take_call = matches[pat_rhs].as(); + ICHECK(take_call) << "InternalError: " + << "Match of relax.take operator should produce Call, " + << "but instead produces " << matches[pat_rhs] << " with type " + << matches[pat_rhs]->GetTypeKey(); + const auto* attrs = take_call->attrs.as(); + ICHECK(attrs) << "InternalError: " + << "Attributes for relax.take operator should be TakeAttrs, " + << "but were instead " << take_call->attrs << " with type " + << take_call->GetTypeKey(); + + const auto* lhs_sinfo = lhs->struct_info_.as(); + if (!lhs_sinfo) return expr; + + const auto* weights_sinfo = weights->struct_info_.as(); + if (!weights_sinfo) return expr; + + const auto* indices_sinfo = indices->struct_info_.as(); + if (!indices_sinfo) return expr; + + const auto* matmul_sinfo = expr->struct_info_.as(); + if (!matmul_sinfo) return expr; + + if (!attrs->axis.defined()) return expr; + auto axis = attrs->axis.value()->value; + + if (lhs_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim() || + matmul_sinfo->IsUnknownNdim() || weights_sinfo->IsUnknownNdim()) + return expr; + + if (indices_sinfo->ndim == 1 && axis + 1 == weights_sinfo->ndim) { + // Simpler case. The activations may have batch dimensions, but + // the weights do not. + + // lhs.shape = [*batch, infeatures] + // weights.shape = [infeatures, table_size] + // indices.shape = [outfeatures] + + // out_table.shape = [*batch, table_size] + auto out_table = matmul(lhs, weights, DataType::Void()); + // new_output.shape = [*batch, outfeatures] + auto new_output = take(out_table, indices, Integer(matmul_sinfo->ndim - 1)); + + return new_output; + } else if (lhs_sinfo->ndim == 3 && weights_sinfo->ndim == 3 && indices_sinfo->ndim == 1 && + axis == 0 && weights_sinfo->GetShape().defined() && + lhs_sinfo->GetShape().defined()) { + // More complicated case, used for batched LoRA. The conditions + // on the argument dimensions can probably be relaxed, but would + // probably need to remove the use of the einsum operator. + + auto lhs_shape = lhs_sinfo->GetShape().value(); + auto weight_shape = weights_sinfo->GetShape().value(); + + // lhs.shape = [batch1, batch2, infeatures] + // weights.shape = [table_size, infeatures, outfeatures] + // indices.shape = [batch1] + + // reordered_weight.shape = [infeatures, table_size, outfeatures] + auto reordered_weight = permute_dims(weights, Array{Integer(1), Integer(0), Integer(2)}); + // fused_weight.shape = [infeatures, table_size * outfeatures] + auto fused_weight = reshape(reordered_weight, + ShapeExpr({weight_shape[1], weight_shape[0] * weight_shape[2]})); + // fused_output.shape = [batch1, batch2, table_size * outfeatures] + auto fused_output = matmul(lhs, fused_weight, DataType::Void()); + // indexed_output.shape = [batch1, batch2, table_size, outfeatures] + auto indexed_output = reshape( + fused_output, ShapeExpr({lhs_shape[0], lhs_shape[1], weight_shape[0], weight_shape[2]})); + + // TODO(Lunderberg): Find a better way to express these last two + // steps. For an output at [i,j,k], the value is + // `indexed_output[i, j, indices[i], k]`, but there doesn't seem + // to be a good way to express that in relax. It could be + // written using `call_te`, but that would prevent later + // optimizations from recognizing the high-level relax + // operations. + + // duplicated_output.shape = [batch1, batch2, batch1, outfeatures] + auto duplicated_output = take(indexed_output, indices, Integer(2)); + // new_output.shape = [batch1, batch2, outfeatures] + auto new_output = einsum(Tuple({duplicated_output}), "ijik->ijk"); + + return new_output; + } else { + return expr; + } + }; + + return {pat_matmul, rewriter}; +} + +} // namespace + +namespace transform { +Pass ReorderTakeAfterMatmul() { + auto pass_func = [=](Function func, IRModule mod, PassContext pc) { + auto [pattern, rewriter] = CreatePatterns(); + return RewriteCall(pattern, rewriter, func); + }; + return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul") + .set_body_typed(ReorderTakeAfterMatmul); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_reorder_take_after_matmul.py b/tests/python/relax/test_transform_reorder_take_after_matmul.py new file mode 100644 index 000000000000..bf969fb3fedb --- /dev/null +++ b/tests/python/relax/test_transform_reorder_take_after_matmul.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import inspect + +import pytest + +import tvm.testing +from tvm import relax +from tvm.script import ir as I, relax as R, tir as T + + +class Base: + def test_compare(self): + transform = relax.transform.ReorderTakeAfterMatmul() + + if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception): + with pytest.raises(self.Expected): + transform(self.Before) + else: + after = transform(self.Before) + tvm.ir.assert_structural_equal(self.Expected, after) + + +class TestSimple(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([1, 16], "float32"), + weight_table: R.Tensor([16, "weight_table_size"], "float32"), + routing_table: R.Tensor([32], "int64"), + ) -> R.Tensor([1, 32], "float32"): + weight_table_size = T.int64() + with R.dataflow(): + weight: R.Tensor([16, 32], "float32") = R.take(weight_table, routing_table, axis=1) + out: R.Tensor([1, 32], "float32") = R.matmul(x, weight) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([1, 16], "float32"), + weight_table: R.Tensor([16, "weight_table_size"], "float32"), + routing_table: R.Tensor([32], "int64"), + ) -> R.Tensor([1, 32], "float32"): + weight_table_size = T.int64() + with R.dataflow(): + out_table: R.Tensor([1, weight_table_size], "float32") = R.matmul(x, weight_table) + out: R.Tensor([1, 32], "float32") = R.take(out_table, routing_table, axis=1) + R.output(out) + return out + + +class TestBatchedActivations(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16], "float32"), + weight_table: R.Tensor([16, "weight_table_size"], "float32"), + routing_table: R.Tensor([32], "int64"), + ) -> R.Tensor(["batch_size", 1, 32], "float32"): + batch_size = T.int64() + weight_table_size = T.int64() + with R.dataflow(): + weight: R.Tensor([16, 32], "float32") = R.take(weight_table, routing_table, axis=1) + out: R.Tensor([batch_size, 1, 32], "float32") = R.matmul(x, weight) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16], "float32"), + weight_table: R.Tensor([16, "weight_table_size"], "float32"), + routing_table: R.Tensor([32], "int64"), + ) -> R.Tensor(["batch_size", 1, 32], "float32"): + batch_size = T.int64() + weight_table_size = T.int64() + with R.dataflow(): + out_table: R.Tensor([batch_size, 1, weight_table_size], "float32") = R.matmul( + x, weight_table + ) + out: R.Tensor([batch_size, 1, 32], "float32") = R.take( + out_table, routing_table, axis=2 + ) + R.output(out) + return out + + +class TestStaticBatchedActivationsAndWeights(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([128, 1, 16], "float32"), + weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"), + routing_table: R.Tensor([128], "int64"), + ) -> R.Tensor([128, 1, 32], "float32"): + batch_size = T.int64() + routing_table_size = T.int64() + with R.dataflow(): + weight = R.take(weight_table, routing_table, axis=0) + out = R.matmul(x, weight) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([128, 1, 16], "float32"), + weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"), + routing_table: R.Tensor([128], "int64"), + ) -> R.Tensor([128, 1, 32], "float32"): + batch_size = T.int64() + routing_table_size = T.int64() + with R.dataflow(): + reordered_weight = R.permute_dims(weight_table, [1, 0, 2]) + fused_weight = R.reshape(reordered_weight, [16, routing_table_size * 32]) + fused_output = R.matmul(x, fused_weight) + reordered_output = R.reshape(fused_output, [128, 1, routing_table_size, 32]) + tabular_output = R.take(reordered_output, routing_table, axis=2) + out = R.einsum([tabular_output], "ijik->ijk") + R.output(out) + return out + + +class TestDynamicBatchedActivationsAndWeights(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16], "float32"), + weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"), + routing_table: R.Tensor(["batch_size"], "int64"), + ) -> R.Tensor(["batch_size", 1, 32], "float32"): + batch_size = T.int64() + routing_table_size = T.int64() + with R.dataflow(): + weight = R.take(weight_table, routing_table, axis=0) + out = R.matmul(x, weight) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16], "float32"), + weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"), + routing_table: R.Tensor(["batch_size"], "int64"), + ) -> R.Tensor(["batch_size", 1, 32], "float32"): + batch_size = T.int64() + routing_table_size = T.int64() + with R.dataflow(): + reordered_weight = R.permute_dims(weight_table, [1, 0, 2]) + fused_weight = R.reshape(reordered_weight, [16, routing_table_size * 32]) + fused_output = R.matmul(x, fused_weight) + reordered_output = R.reshape(fused_output, [batch_size, 1, routing_table_size, 32]) + tabular_output = R.take(reordered_output, routing_table, axis=2) + out = R.einsum([tabular_output], "ijik->ijk") + R.output(out) + return out + + +if __name__ == "__main__": + tvm.testing.main()