Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,23 +195,40 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner {
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) {
int index = 0;
const auto dense_op = Op::Get("nn.dense");
for (const auto& branch : branches) {
const CallNode* call = branch[depth];
auto& out_shape = call->type_as<TensorTypeNode>()->shape;
auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]);
ICHECK(out_dims != nullptr);
Array<Integer> begin;
Array<Integer> end;
Array<Integer> strides;
for (size_t k = 0; k < out_shape.size() - 1; ++k) {
begin.push_back(0);
end.push_back(-1);
strides.push_back(1);

const CallNode* dense = branch[0];
ICHECK(dense->op.same_as(dense_op));
auto& dense_shape = dense->type_as<TensorTypeNode>()->shape;
auto dense_out_dims = tir::as_const_int(dense_shape[1]);
ICHECK(dense_out_dims != nullptr);

// dense can be followed by shape-changing operations, so the slicing axis is
// not necessarily the last one.
// TODO(masahi): The following logic is incorrect if (1) there is no axis in
// out_shape[i] that directly corresponds to the output channel of dense or (2) there
// is another axis that happens to have the same size as the output channel of dense.
// Such cases might arise due to reshape / transpose / split etc. Revisit this logic
// when we encounter them in practice.
auto slice_axis = -1;
for (size_t i = out_shape.size() - 1; i >= 0; --i) {
ICHECK(tir::as_const_int(out_shape[i]));
if (*tir::as_const_int(out_shape[i]) == *dense_out_dims) {
slice_axis = i;
break;
}
}
begin.push_back(index);
end.push_back(*out_dims);
strides.push_back(1);
index += *out_dims;
ICHECK(slice_axis != -1);

Array<Integer> begin(out_shape.size(), 0);
Array<Integer> end(out_shape.size(), -1);
Array<Integer> strides(out_shape.size(), 1);
begin.Set(slice_axis, index);
end.Set(slice_axis, *dense_out_dims);
index += *dense_out_dims;
auto slice = MakeStridedSlice(data, begin, end, strides, "size");
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
Expand Down
51 changes: 44 additions & 7 deletions tests/python/relay/test_pass_combine_parallel_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import tvm.testing
from tvm import relay
from tvm.relay import transform

Expand Down Expand Up @@ -359,10 +359,47 @@ def check(i, j, k, scale1, scale2, newshape1, newshape2):
check(100, 200, 300, 0.5, 0.25, (1, 1, 20000), (1, 1, 40000))


def test_combine_parallel_dense_expand_dims():
"""Verify that the correct slice axis is selected after the combined dense."""

def before(x, w1, w2):
args = [x, w1, w2]
y1 = relay.nn.dense(x, w1)
y1 = relay.expand_dims(y1, axis=2)

y2 = relay.nn.dense(x, w2)
y2 = relay.expand_dims(y2, axis=2)

y = relay.Tuple((y1, y2))
return relay.Function(args, y)

def expected(x, w1, w2):
args = [x, w1, w2]
w_stacked = relay.concatenate((w1, w2), axis=0)
y = relay.nn.dense(x, w_stacked, units=24)
y = relay.expand_dims(y, axis=2)

strides = [1, 1, 1]
y1 = relay.strided_slice(
y, begin=[0, 0, 0], end=[-1, 16, -1], strides=strides, slice_mode="size"
)
y2 = relay.strided_slice(
y, begin=[0, 16, 0], end=[-1, 8, -1], strides=strides, slice_mode="size"
)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)

x = relay.var("x", shape=(2, 32))
w1 = relay.var("w1", shape=(16, 32))
w2 = relay.var("w2", shape=(8, 32))

y_before = before(x, w1, w2)
combine_pass = transform.CombineParallelDense(min_num_branches=2, to_batch=False)
y = run_opt_pass(y_before, combine_pass)
y_expected = expected(x, w1, w2)
y_expected = run_opt_pass(y_expected, transform.InferType())
tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)


if __name__ == "__main__":
test_combine_parallel_dense()
test_combine_parallel_dense_biasadd()
test_combine_parallel_dense_biasadd_scale_reshape()
test_combine_parallel_dense_flat()
test_combine_parallel_dense_flat_biasadd()
test_combine_parallel_dense_flat_biasadd_scale_reshape()
tvm.testing.main()