@@ -2096,87 +2096,43 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
20962096 if (offset < 0 ) {
20972097 return NullOpt;
20982098 }
2099- // We align the block and desc block's bindings from the right side
2100- // block (v0=..., v1=..., v2=...)
2101- // ^ i_block
2102- // desc_block( v1=..., v2=...)
2103- // ^ i_desc
2104-
2105- std::vector<IterVarType> iter_types = GetBlockVarTypes (block_sref);
2106- ICHECK (block_loops.size () == iter_types.size ());
2107-
2108- for (int i_desc = 0 , i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) {
2109- // For each block var binding, we find
2110- const PrimExpr& block_bind = block->iter_values [i_block];
2111- const PrimExpr& desc_bind = desc_block->iter_values [i_desc];
2112- LOG (INFO) << " block bind: " << block_bind;
2113- LOG (INFO) << " desc bind: " << desc_bind;
2114- // Step 4.1. Find the corresponding loop of the i-th block var of block
2115- const tir::ForNode* block_loop = nullptr ;
2116- for (int i = block_loops.size () - 1 ; i >= 0 ; --i) {
2117- // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
2118- PrimExpr r = analyzer.Simplify (block_bind - block_loops[i]->loop_var );
2119- const auto * int_block_extent = block_loops[i]->extent .as <IntImmNode>();
2120- const auto * int_desc_extent = desc_loops[i_desc]->extent .as <IntImmNode>();
2121-
2122- if (i_desc != n_desc_vars - 1 && iter_types[i] == IterVarType::kCommReduce ) continue ;
2123-
2124- // if (int_block_extent->value == int_desc_extent->value) {
2125- if (!tir::UsesVar (r, [&block_loop_vars](const tir::VarNode* var) {
2126- return block_loop_vars.count (var);
2127- })) {
2128- block_loop = block_loops[i];
2129- LOG (INFO) << " Selected " << i << " th block loop " << block_loops[i]->loop_var << " , "
2130- << block_loop->extent ;
2131- break ;
2132- } else {
2133- LOG (INFO) << i << " th block loop not ok "
2134- << " , " << block_loops[i]->loop_var << " , " << block_loops[i]->extent ;
2135- }
2136- }
2137- if (block_loop == nullptr ) {
2138- return NullOpt;
2139- }
2140- // Step 4.2. Find the corresponding loop of the i-th block var of desc
2141- const tir::ForNode* desc_loop = nullptr ;
2142- for (int i = 0 , n = desc_loops.size (); i < n; ++i) {
2143- // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
2144- PrimExpr r = analyzer.Simplify (desc_bind - desc_loops[i]->loop_var );
2145- if (!tir::UsesVar (r, [&desc_loop_vars](const tir::VarNode* var) {
2146- return desc_loop_vars.count (var);
2147- })) {
2148- desc_loop = desc_loops[i];
2149- LOG (INFO) << " Selected " << i << " th desc loop " << desc_loop->extent ;
2150- ;
2099+
2100+ std::vector<IterVarType> iter_types_desc;
2101+ for (const IterVar& iter_var : desc_block->block ->iter_vars ) {
2102+ iter_types_desc.push_back (iter_var->iter_type );
2103+ }
2104+
2105+ std::vector<IterVarType> iter_types_block = GetBlockVarTypes (block_sref);
2106+ ICHECK (block_loops.size () == iter_types_block.size ());
2107+
2108+ int next_block_ind = block_loops.size () - 1 ;
2109+ for (int i_desc = n_desc_vars - 1 ; i_desc >= 0 ; --i_desc) {
2110+ const tir::ForNode* desc_loop = desc_loops[i_desc];
2111+ const auto * int_desc_extent = desc_loop->extent .as <IntImmNode>();
2112+
2113+ for (int i_block = next_block_ind; i_block >= 0 ; --i_block) {
2114+ const tir::ForNode* block_loop = block_loops[i_block];
2115+ const auto * int_block_extent = block_loop->extent .as <IntImmNode>();
2116+
2117+ LOG (INFO) << i_desc << " , " << i_block << " , " << iter_types_block[i_block] << " , "
2118+ << iter_types_desc[i_desc] << " , " << int_block_extent->value << " , "
2119+ << int_desc_extent->value ;
2120+
2121+ if (int_block_extent->value % int_desc_extent->value != 0 ) continue ;
2122+ if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue ;
2123+
2124+ const tir::StmtSRef& block_loop_sref = self->stmt2ref [block_loop];
2125+ auto it = ret->loop_map .find (block_loop_sref);
2126+ if (it == ret->loop_map .end ()) {
2127+ ret->loop_map .Set (block_loop_sref, GetRef<tir::For>(desc_loop));
2128+ next_block_ind = i_block - 1 ;
2129+ LOG (INFO) << " Selected " << i_block << " th block loop " << block_loops[i_block]->loop_var
2130+ << " , " << block_loop->extent << " for i_desc = " << i_desc;
21512131 break ;
21522132 }
21532133 }
2154- if (desc_loop == nullptr ) {
2155- return NullOpt;
2156- }
2157- // Step 4.3. Check divisibility of loop extents
2158- PrimExpr block_extent = analyzer.Simplify (block_loop->extent );
2159- PrimExpr desc_extent = analyzer.Simplify (desc_loop->extent );
2160- if (const auto * int_block_extent = block_extent.as <IntImmNode>()) {
2161- if (const auto * int_desc_extent = desc_extent.as <IntImmNode>()) {
2162- if (int_block_extent->value % int_desc_extent->value != 0 ) {
2163- return NullOpt;
2164- }
2165- } else {
2166- return NullOpt;
2167- }
2168- } else {
2169- return NullOpt;
2170- }
2171- // Step 4.4. Maps the result of Step 4.1 to Step 4.2
2172- const tir::StmtSRef& block_loop_sref = self->stmt2ref [block_loop];
2173- auto it = ret->loop_map .find (block_loop_sref);
2174- if (it == ret->loop_map .end ()) {
2175- ret->loop_map .Set (block_loop_sref, GetRef<tir::For>(desc_loop));
2176- } else if ((*it).second .get () != desc_loop) {
2177- return NullOpt;
2178- }
21792134 }
2135+
21802136 for (int i = 0 , n = desc_loops.size (); i < n; ++i) {
21812137 ret->desc_loop_indexer .Set (GetRef<tir::For>(desc_loops[i]), Integer (i));
21822138 }
0 commit comments