Skip to content

Commit 928668b

Browse files
committed
Reworking GetTensorizeloopmapping
1 parent a80e639 commit 928668b

1 file changed

Lines changed: 33 additions & 77 deletions

File tree

src/tir/schedule/analysis/analysis.cc

Lines changed: 33 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,87 +2096,43 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
20962096
if (offset < 0) {
20972097
return NullOpt;
20982098
}
2099-
// We align the block and desc block's bindings from the right side
2100-
// block (v0=..., v1=..., v2=...)
2101-
// ^ i_block
2102-
// desc_block( v1=..., v2=...)
2103-
// ^ i_desc
2104-
2105-
std::vector<IterVarType> iter_types = GetBlockVarTypes(block_sref);
2106-
ICHECK(block_loops.size() == iter_types.size());
2107-
2108-
for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) {
2109-
// For each block var binding, we find
2110-
const PrimExpr& block_bind = block->iter_values[i_block];
2111-
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
2112-
LOG(INFO) << "block bind: " << block_bind;
2113-
LOG(INFO) << "desc bind: " << desc_bind;
2114-
// Step 4.1. Find the corresponding loop of the i-th block var of block
2115-
const tir::ForNode* block_loop = nullptr;
2116-
for (int i = block_loops.size() - 1; i >= 0; --i) {
2117-
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
2118-
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
2119-
const auto* int_block_extent = block_loops[i]->extent.as<IntImmNode>();
2120-
const auto* int_desc_extent = desc_loops[i_desc]->extent.as<IntImmNode>();
2121-
2122-
if (i_desc != n_desc_vars - 1 && iter_types[i] == IterVarType::kCommReduce) continue;
2123-
2124-
// if (int_block_extent->value == int_desc_extent->value) {
2125-
if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
2126-
return block_loop_vars.count(var);
2127-
})) {
2128-
block_loop = block_loops[i];
2129-
LOG(INFO) << "Selected " << i << " th block loop " << block_loops[i]->loop_var << ", "
2130-
<< block_loop->extent;
2131-
break;
2132-
} else {
2133-
LOG(INFO) << i << " th block loop not ok "
2134-
<< ", " << block_loops[i]->loop_var << ", " << block_loops[i]->extent;
2135-
}
2136-
}
2137-
if (block_loop == nullptr) {
2138-
return NullOpt;
2139-
}
2140-
// Step 4.2. Find the corresponding loop of the i-th block var of desc
2141-
const tir::ForNode* desc_loop = nullptr;
2142-
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
2143-
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
2144-
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
2145-
if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) {
2146-
return desc_loop_vars.count(var);
2147-
})) {
2148-
desc_loop = desc_loops[i];
2149-
LOG(INFO) << "Selected " << i << " th desc loop " << desc_loop->extent;
2150-
;
2099+
2100+
std::vector<IterVarType> iter_types_desc;
2101+
for (const IterVar& iter_var : desc_block->block->iter_vars) {
2102+
iter_types_desc.push_back(iter_var->iter_type);
2103+
}
2104+
2105+
std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
2106+
ICHECK(block_loops.size() == iter_types_block.size());
2107+
2108+
int next_block_ind = block_loops.size() - 1;
2109+
for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
2110+
const tir::ForNode* desc_loop = desc_loops[i_desc];
2111+
const auto* int_desc_extent = desc_loop->extent.as<IntImmNode>();
2112+
2113+
for (int i_block = next_block_ind; i_block >= 0; --i_block) {
2114+
const tir::ForNode* block_loop = block_loops[i_block];
2115+
const auto* int_block_extent = block_loop->extent.as<IntImmNode>();
2116+
2117+
LOG(INFO) << i_desc << ", " << i_block << ", " << iter_types_block[i_block] << ", "
2118+
<< iter_types_desc[i_desc] << ", " << int_block_extent->value << ", "
2119+
<< int_desc_extent->value;
2120+
2121+
if (int_block_extent->value % int_desc_extent->value != 0) continue;
2122+
if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue;
2123+
2124+
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
2125+
auto it = ret->loop_map.find(block_loop_sref);
2126+
if (it == ret->loop_map.end()) {
2127+
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
2128+
next_block_ind = i_block - 1;
2129+
LOG(INFO) << "Selected " << i_block << " th block loop " << block_loops[i_block]->loop_var
2130+
<< ", " << block_loop->extent << " for i_desc = " << i_desc;
21512131
break;
21522132
}
21532133
}
2154-
if (desc_loop == nullptr) {
2155-
return NullOpt;
2156-
}
2157-
// Step 4.3. Check divisibility of loop extents
2158-
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
2159-
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
2160-
if (const auto* int_block_extent = block_extent.as<IntImmNode>()) {
2161-
if (const auto* int_desc_extent = desc_extent.as<IntImmNode>()) {
2162-
if (int_block_extent->value % int_desc_extent->value != 0) {
2163-
return NullOpt;
2164-
}
2165-
} else {
2166-
return NullOpt;
2167-
}
2168-
} else {
2169-
return NullOpt;
2170-
}
2171-
// Step 4.4. Maps the result of Step 4.1 to Step 4.2
2172-
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
2173-
auto it = ret->loop_map.find(block_loop_sref);
2174-
if (it == ret->loop_map.end()) {
2175-
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
2176-
} else if ((*it).second.get() != desc_loop) {
2177-
return NullOpt;
2178-
}
21792134
}
2135+
21802136
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
21812137
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
21822138
}

0 commit comments

Comments
 (0)