From a84527911bb32c4adbae9dbc1173af7da2ae5ba3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 12 Dec 2022 22:32:56 +0900 Subject: [PATCH 1/2] Fix slice axis in combine dense --- .../transforms/combine_parallel_dense.cc | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 7cf102b5bcab..e5f7e0b975f4 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -195,23 +195,40 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner { void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int index = 0; + const auto dense_op = Op::Get("nn.dense"); for (const auto& branch : branches) { const CallNode* call = branch[depth]; auto& out_shape = call->type_as()->shape; - auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]); - ICHECK(out_dims != nullptr); - Array begin; - Array end; - Array strides; - for (size_t k = 0; k < out_shape.size() - 1; ++k) { - begin.push_back(0); - end.push_back(-1); - strides.push_back(1); + + const CallNode* dense = branch[0]; + ICHECK(dense->op.same_as(dense_op)); + auto& dense_shape = dense->type_as()->shape; + auto dense_out_dims = tir::as_const_int(dense_shape[1]); + ICHECK(dense_out_dims != nullptr); + + // dense can be followed by shape-changing operations, so the slicing axis is + // not necessarily the last one. + // TODO(masahi): The following logic is incorrect if (1) there is no axis in + // out_shape[i] that directly corresponds to the output channel of dense or (2) there + // is another axis that happens to have the same size as the output channel of dense. + // Such cases might arise due to reshape / transpose / split etc. Revisit this logic + // when we encounter them in practice. + auto slice_axis = -1; + for (size_t i = out_shape.size() - 1; i >= 0; --i) { + ICHECK(tir::as_const_int(out_shape[i])); + if (*tir::as_const_int(out_shape[i]) == *dense_out_dims) { + slice_axis = i; + break; + } } - begin.push_back(index); - end.push_back(*out_dims); - strides.push_back(1); - index += *out_dims; + ICHECK(slice_axis != -1); + + Array begin(out_shape.size(), 0); + Array end(out_shape.size(), -1); + Array strides(out_shape.size(), 1); + begin.Set(slice_axis, index); + end.Set(slice_axis, *dense_out_dims); + index += *dense_out_dims; auto slice = MakeStridedSlice(data, begin, end, strides, "size"); subst_map->insert({GetRef(branch[depth]), slice}); } From fed49d10c712662618dc7d5ce391ab2022e3febc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 12 Dec 2022 23:22:13 +0900 Subject: [PATCH 2/2] add test --- .../relay/test_pass_combine_parallel_dense.py | 51 ++++++++++++++++--- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index cd946ab593bf..2494c1a550cd 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te +import tvm.testing from tvm import relay from tvm.relay import transform @@ -359,10 +359,47 @@ def check(i, j, k, scale1, scale2, newshape1, newshape2): check(100, 200, 300, 0.5, 0.25, (1, 1, 20000), (1, 1, 40000)) +def test_combine_parallel_dense_expand_dims(): + """Verify that the correct slice axis is selected after the combined dense.""" + + def before(x, w1, w2): + args = [x, w1, w2] + y1 = relay.nn.dense(x, w1) + y1 = relay.expand_dims(y1, axis=2) + + y2 = relay.nn.dense(x, w2) + y2 = relay.expand_dims(y2, axis=2) + + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2): + args = [x, w1, w2] + w_stacked = relay.concatenate((w1, w2), axis=0) + y = relay.nn.dense(x, w_stacked, units=24) + y = relay.expand_dims(y, axis=2) + + strides = [1, 1, 1] + y1 = relay.strided_slice( + y, begin=[0, 0, 0], end=[-1, 16, -1], strides=strides, slice_mode="size" + ) + y2 = relay.strided_slice( + y, begin=[0, 16, 0], end=[-1, 8, -1], strides=strides, slice_mode="size" + ) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + x = relay.var("x", shape=(2, 32)) + w1 = relay.var("w1", shape=(16, 32)) + w2 = relay.var("w2", shape=(8, 32)) + + y_before = before(x, w1, w2) + combine_pass = transform.CombineParallelDense(min_num_branches=2, to_batch=False) + y = run_opt_pass(y_before, combine_pass) + y_expected = expected(x, w1, w2) + y_expected = run_opt_pass(y_expected, transform.InferType()) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) + + if __name__ == "__main__": - test_combine_parallel_dense() - test_combine_parallel_dense_biasadd() - test_combine_parallel_dense_biasadd_scale_reshape() - test_combine_parallel_dense_flat() - test_combine_parallel_dense_flat_biasadd() - test_combine_parallel_dense_flat_biasadd_scale_reshape() + tvm.testing.main()