forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathiter_affine_map.cc
More file actions
2249 lines (2053 loc) · 88.3 KB
/
iter_affine_map.cc
File metadata and controls
2249 lines (2053 loc) · 88.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/arith/iter_affine_map.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
#include "../support/utils.h"
#include "const_fold.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace tir;
IterMark::IterMark(PrimExpr source, PrimExpr extent) {
auto n = make_object<IterMarkNode>();
n->source = std::move(source);
n->extent = std::move(extent);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
return IterMark(source, extent);
});
TVM_REGISTER_NODE_TYPE(IterMarkNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterMarkNode*>(node.get());
p->stream << "IterMark(" << op->source << ", extent=" << op->extent << ")";
});
IterSplitExpr::IterSplitExpr(IterMark source) {
auto n = make_object<IterSplitExprNode>();
auto one = make_const(source->source->dtype, 1);
n->dtype = source->source->dtype;
n->source = std::move(source);
n->extent = n->source->extent;
n->lower_factor = one;
n->scale = one;
data_ = std::move(n);
}
IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) {
auto n = make_object<IterSplitExprNode>();
auto one = make_const(source->source->dtype, 1);
n->dtype = source->source->dtype;
n->source = std::move(source);
n->extent = n->source->extent;
n->lower_factor = one;
n->scale = std::move(scale);
data_ = std::move(n);
}
IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
PrimExpr scale) {
auto n = make_object<IterSplitExprNode>();
n->dtype = source->source->dtype;
n->source = std::move(source);
n->lower_factor = std::move(lower_factor);
n->extent = std::move(extent);
n->scale = std::move(scale);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
.set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
return IterSplitExpr(source, lower_factor, extent, scale);
});
TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterSplitExprNode*>(node.get());
p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
<< ", extent=" << op->extent << ", scale=" << op->scale << ")";
});
IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
auto n = make_object<IterSumExprNode>();
n->dtype = base->dtype;
n->args = std::move(args);
n->base = std::move(base);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("arith.IterSumExpr")
.set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
return IterSumExpr(args, base);
});
TVM_REGISTER_NODE_TYPE(IterSumExprNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterSumExprNode*>(node.get());
p->stream << "IterSum(" << op->args << ", " << op->base << ")";
});
/*!
* \brief Collector that collects the outgoing split reference of each IterMark.
*
* These out-going splits can then be used to check if the iterators are independent.
*/
class IterMarkSplitCollector {
public:
// mark all IterMarks that are visited.
std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
// each iter mark to its outgoing splits that are referenced.
std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
mark2splits_;
/*!
* \brief Collect all mark2splits recursively from indices.
* \param indices The iterator of interest.
*/
void Collect(const Array<IterSumExpr>& indices) {
for (IterSumExpr sum_expr : indices) {
for (IterSplitExpr split : sum_expr->args) {
this->CollectInternal(split->source);
mark2splits_[split->source].push_back(split);
}
}
}
void CollectInternal(const IterMark& mark) {
if (visited_.count(mark)) return;
visited_.insert(mark);
if (auto* op = mark->source.as<IterSumExprNode>()) {
for (IterSplitExpr split : op->args) {
this->CollectInternal(split->source);
mark2splits_[split->source].push_back(split);
}
}
}
};
/*! \brief Record form of IterMark(x, extent) + offset */
struct IterMarkWithOffset {
IterMark mark;
PrimExpr offset{0};
IterMarkWithOffset() {}
IterMarkWithOffset(IterMark mark, PrimExpr offset) : mark(mark), offset(offset) {}
};
/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
class IterMapRewriter : public ExprMutator {
public:
using Parent = ExprMutator;
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
IterMapLevel check_level, bool simplify_trivial_iterators,
Array<String>* errors)
: analyzer_(analyzer),
check_level_(check_level),
errors_(*errors),
padding_predicate_(const_false()) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
if (simplify_trivial_iterators && is_one(vrng->extent)) {
var_map_[var] = IterSumExpr({}, vrng->min);
} else if (is_zero(vrng->min)) {
IterMark mark(var, vrng->extent);
var_map_[var] = IterSplitExpr(mark);
input_marks_.push_back(mark);
} else {
IterMark mark(var - vrng->min, vrng->extent);
IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark));
sum_expr.CopyOnWrite()->base = vrng->min;
var_map_[var] = sum_expr;
input_marks_.push_back(mark);
}
}
}
PrimExpr padding_predicate() const { return padding_predicate_; }
bool requires_padding() const { return requires_padding_; }
IterSumExpr Rewrite(const PrimExpr& expr) {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}
IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) {
update_iterator_padding_ = true;
auto res = Rewrite(expr);
update_iterator_padding_ = false;
return res;
}
IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
const Optional<PrimExpr>& predicate_induced_min,
const Optional<PrimExpr>& predicate_induced_max) {
return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
predicate_induced_max);
}
/*!
* \brief If require bijective mapping, this function checks two conditions:
* - C0: Each iter mark should be fully covered by non-overlapping splits.
* - C1: All of the input iterators are used.
* Example: given x in [0, 8) y in [0, 6)
* - bindings = [x, x + 1, y] won't pass because x and x+1 contribute
* two splits that overlaps with each other.
* - bindings = [x / 4, x % 4, y] will pass because x / 4 and x % 4
* contribute two non-overlapping splits that covers x.
* - bindings = [x / 4, x % 4] won't pass because y is not used.
*
* If only require surjective mapping, this function checks one condition:
* - C0: Each iter mark has a chance to be fully covered by non-overlapping splits.
* Example: given x in [0, 8) y in [0, 6)
* - bindings = [x / 4] will pass because x / 4 can be one split of x
* - bindings = [x / 4, x % 4] will pass because x / 4 and x % 4
* contribute two non-overlapping splits that covers x.
* - bindings = [x / 3] will not pass because x / 3 can not be one split of x
* \return whether the bindings are valid
*/
bool CheckMapping(const Array<IterSumExpr>& bindings, IterMapLevel check_level) {
IterMarkSplitCollector collector;
// We can check that for each iter mark:
// All the splits that refers to the iter_mark covers its extent.
// The splits do not overlap with each other.
collector.Collect(bindings);
for (const IterMark& mark : collector.visited_) {
if (TryNormalizeSplits(mark, collector.mark2splits_[mark], check_level).empty()) {
return false;
}
}
if (check_level == IterMapLevel::Bijective) {
// all input marks must be visited
for (const IterMark& mark : input_marks_) {
if (collector.visited_.count(mark) == 0 && !is_one(mark->extent)) {
return false;
}
}
}
return true;
}
/*!
* \brief Check the validity of iterator constraints
* The flattened forms of two different iterator constraints
* either 1) follow inclusion relation or 2) have no intersection
*
* For Example, x = i0*30 + i1*15 + i2*3 + i3,
* 1) [i0*2 + i1 < 3, i2*3 + i3 < 5] is valid, since {i0, i1} \\intersect {i2, i3} = empty set.
* 2) [i0*2 + i1 < 3, i1*5 + i2 < 5] is not valid,
* since {i0, i1} \\intersect {i1, i2} = {i1}, i0 \\in {i0, i1}, i0 \\notin {i1, i2}
* \return whether the predicates are valid;
*/
bool CheckConstraints() const {
// the constrained_iters_flattened_ are in the order of shorter to longer
// since we visit the predicates in the order of size
for (size_t i = 0; i < constrained_iters_flattened_.size(); ++i) {
for (size_t j = i + 1; j < constrained_iters_flattened_.size(); ++j) {
// state: 0(start), -1(no intersection), 1(inclusion)
int state = 0;
for (const IterSplitExpr& arg1 : constrained_iters_flattened_[i]->args) {
bool found = false;
for (const IterSplitExpr& arg2 : constrained_iters_flattened_[j]->args) {
if (IterSplitEqual(arg1, arg2)) {
found = true;
break;
}
}
// Check either it is inclusion or intersection, but not both
if (state == 0) {
state = found ? 1 : -1;
} else if ((state == -1 && found) || (state == 1 && !found)) {
return false;
}
}
}
}
return true;
}
// override the original mutate function.
PrimExpr VisitExpr(const PrimExpr& input_expr) final {
auto expr = ExprMutator::VisitExpr(input_expr);
if (expr->IsInstance<IterMapExprNode>()) {
ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in "
<< "IterMapRewriter using DirectMutate. "
<< "Indirect return occurred in " << input_expr;
}
return expr;
}
// Normal mutation without normalization.
PrimExpr DirectMutate(const PrimExpr& expr) { return ExprMutator::VisitExpr(expr); }
PrimExpr VisitExpr_(const VarNode* op) final;
PrimExpr VisitExpr_(const AddNode* op) final;
PrimExpr VisitExpr_(const SubNode* op) final;
PrimExpr VisitExpr_(const MulNode* op) final;
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
private:
/* \brief Preprocessing common to both FloorDiv and FloorMod
*
* \param dividend The dividend to be manipulated.
*/
IterSumExpr PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend);
// Create an iterator that represents the expression (split+base), with
// padding such that the iterator's extents are evenly divisible by
// `divisor`.
//
// If iterators can have padding added through UpdatePadding, pad a
// dividend out to be evenly divisible. Otherwise, validate that the
// padding previously defined for the split using UpdatePadding can be
// used. If no such previous padding exists, return an empty
// IterMark.
//
// Returns a pair of IterSplit that represents (split+base) in a
// form that can be dividied by divisors, and PrimExpr that
// represents the left padding applied to split.
std::pair<IterSplitExpr, PrimExpr> PadDividendToDivisor(IterSplitExpr split, PrimExpr base,
PrimExpr divisor);
friend struct ErrorLogger;
/* \brief Utility class for logging errors.
*
* It is not an error for IterMapRewriter to receive an expression that
* cannot be represented as an IterSumExpr. In these cases,
* IterMapRewriter returns the unrepresentable portions of the TIR graph
* without modification. As a result, the usual ICHECK or LOG(FATAL)
* macros cannot be used. Instead, ErrorLogger(this) can be used to
* report an unrepresentable TIR graph, which may be used in error
* messages at the calling scope.
*/
class ErrorLogger {
public:
explicit ErrorLogger(IterMapRewriter* rewriter) : rewriter(rewriter) {}
~ErrorLogger() { rewriter->errors_.push_back(os.str()); }
template <typename T>
ErrorLogger& operator<<(T&& t) {
os << std::forward<T>(t);
return *this;
}
private:
IterMapRewriter* rewriter;
std::ostringstream os;
};
struct IterPaddingInfo {
// GCD of padding factor collected during first pass
PrimExpr padding_factor{1};
PrimExpr left_pad{0};
PrimExpr right_pad{0};
// Padded form of original iter mark
IterMark padded;
};
// temp hash for de-duplication purposes.
struct IterSumHash {
size_t operator()(const IterSumExpr& value) const {
// for now only hash on source index.
size_t hash = value->args.size();
for (const IterSplitExpr& arg : value->args) {
hash = support::HashCombine(hash, std::hash<const Object*>()(arg->source.get()));
}
return hash;
}
};
static bool IterSplitEqual(const IterSplitExpr& lhs, const IterSplitExpr& rhs,
bool check_scale = true) {
tir::ExprDeepEqual equal;
if (!lhs->source.same_as(rhs->source)) return false;
if (!equal(lhs->lower_factor, rhs->lower_factor)) return false;
if (check_scale && !equal(lhs->scale, rhs->scale)) return false;
if (!equal(lhs->extent, rhs->extent)) return false;
return true;
}
struct IterSumEqual {
bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
tir::ExprDeepEqual equal;
if (lhs->args.size() != rhs->args.size()) return false;
if (!equal(lhs->base, rhs->base)) return false;
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!IterSplitEqual(lhs->args[i], rhs->args[i])) return false;
}
return true;
}
};
// Internal analyzer
Analyzer* analyzer_;
// Iter map check level
IterMapLevel check_level_;
// Error messages for each unresolved expression.
Array<String>& errors_;
// The var map
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
// input iter marks
std::vector<IterMark> input_marks_;
// Map from an iter mark to the padded iterator information for
// it. This is necessary for introducing the same padding in all
// usage of an input iterator. (e.g. (i-1) occurring in the
// expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be
// left-padded by 31 for each occurrence.)
std::unordered_map<IterMark, IterPaddingInfo, StructuralHash, StructuralEqual> padded_iter_map_;
// Map from padded iter mark to it's origin mark
std::unordered_map<IterMark, IterMark, StructuralHash, StructuralEqual> padded_origin_map_;
/* If update_iterator_padding_ is true, allow the extents of the IterMap to be
* padded beyond the original iterators.
*
* For example, if update_iterator_padding_ is true, the expressions i//4 and
* i%4, where i is on the range [0,18), would be represented as
* IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4).
* This representation would be forbidden if update_iterator_padding_ is false,
* because lower_factor=4 does not evenly divide the original extent of
* 18.
*/
bool update_iterator_padding_{false};
/* A boolean expression that is true for any padding that has been
* introduced, and false otherwise. If update_iterator_padding_ is false,
* padding_predicate_ will always be false.
*
* Example: [i//4, i%4], i in range [0,16)
* padding_predicate_ will be false
*
* Example: [i//4, i%4], i in range [0,18)
* padding_predicate_ will be `(i//4 == 3) && (i%4 >= 2)`
*
* Example: [i//4, i%4], i in range [0,N)
* padding_predicate_ will be `(N%4!=0) && (i//4 == (N+3)//4-1) && (i%4 >= N%4)`
*/
PrimExpr padding_predicate_;
/* A boolean flag denotes there are padding iterations detected
* in the first round of indices rewriting.
*/
bool requires_padding_{false};
// The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
// an extra offset)
// Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: j*2 + k < 9
// Then, flattened form = IterSum(IterSplit(i, scale=9),
// IterSplit(j, scale=2),
// IterSplit(k, scale=1))
// normal form = IterSum(IterSplit(i, scale=9),
// IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
// IterSplit(k, scale=1)),
// extent=9)
// scale=1))
// Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: 1 <= j*2 + k < 9
// Then, flattened form = IterSum(IterSplit(i, scale=8),
// IterSplit(j, scale=2),
// IterSplit(k, scale=1))
// normal form = IterSum(IterSplit(i, scale=8),
// IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
// IterSplit(k, scale=1), base=-1),
// extent=9-1)
// scale=1),
// base=1)
std::unordered_map<IterSumExpr, IterMarkWithOffset, IterSumHash, IterSumEqual> sum_fuse_map_;
// The map for sum that maps normal form to flattened form
std::unordered_map<IterSumExpr, IterSumExpr, IterSumHash, IterSumEqual> flattened_map_;
// The flattened forms of constrained iters
std::vector<IterSumExpr> constrained_iters_flattened_;
/*!
* \brief Look for a split in splits that is not used such that its lower_factor is smallest.
* Note that here we use division to compare lower_factor.
* \param splits the split array to search in.
* \param used the input used array.
* \param expected_lower_factor the skipped lower factor.
* \return the index of the expected split, split.size() if not found.
*/
size_t SearchSkipLowerFactor(const std::vector<IterSplitExpr>& splits,
const std::vector<bool>& used,
const PrimExpr& expected_lower_factor) {
size_t res = splits.size();
for (size_t i = 0; i < splits.size(); ++i) {
if (used[i]) continue;
if (!used[i] && !CanProveDivisible(splits[i]->lower_factor, expected_lower_factor)) {
// all the remaining unused splits should have their lower factor divisible
return splits.size();
}
if (res == splits.size() ||
CanProveDivisible(splits[res]->lower_factor, splits[i]->lower_factor)) {
// note down the split with smaller lower factor
res = i;
}
}
return res;
}
/*!
* \brief If bijective is required, verify that splits fully covers mark in a non-overlapping
* fashion, If not, verify that splits are valid and compatible for the mark.
* If verification passes, return splits from outermost to innermost order.
* If not, return an empty array.
* \param mark The iterator of interest.
* \param splits The splits to be verified.
* \param check_level Iteration mapping's check level.
* \return The normalized splits.
*/
Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
const std::vector<IterSplitExpr>& splits,
IterMapLevel check_level) {
std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> iters;
PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
for (size_t i = 0; i < splits.size(); ++i) {
size_t j = 0;
for (; j < splits.size(); ++j) {
if (used[j]) continue;
if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) {
break;
}
}
if (j == splits.size()) {
// we do not allow incomplete split if the bindings should be bijective
if (check_level == IterMapLevel::Bijective) {
return Array<IterSplitExpr>();
}
// look for the next split skipping this lower factor
// For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
// It is valid to only have [y / 6, y % 2] if bijective is not required
// We can skip (y / 2) % 6
j = SearchSkipLowerFactor(splits, used, expected_lower_factor);
// split not found
if (j == splits.size()) {
return Array<IterSplitExpr>();
}
}
used[j] = true;
iters.push_back(splits[j]);
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}
// Extract iteration mark info before padding
auto pad_mark_it = padded_origin_map_.find(mark);
bool has_padding = pad_mark_it != padded_origin_map_.end();
bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent);
bool match_iter_divisor =
match_full_iter || CanProveDivisible(mark->extent, expected_lower_factor);
// Case 1. bijective is required.
// We check the extent we calculate is consistent with the extent of the mark and
// iteration mark's padding is not allowed.
//
// Case 2. bijective is not required and there is no padding.
// We check the extent we calculate is a factor of the extent of the mark
// For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not.
//
// Case 3. bijective is not required and there exists padding. We check either
// (3.1) The extent we calculate is consistent with the extent of the padded mark and it is
// the single split for the iter mark.
// For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective
// according to how we pad the original iteration mark.
// (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent
// before padding is greater or equal than the extent we calculate.
// For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24.
//
if (check_level == IterMapLevel::Bijective) {
if (has_padding) {
ErrorLogger(this) << "Bijectvie mapping should not take iter paddings";
return {};
} else if (!match_full_iter) {
ErrorLogger(this) << "The iterations do not traverse full iter space";
return {};
}
} else if (!has_padding) {
if (!match_iter_divisor) {
ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
return {};
}
} else if (check_level == IterMapLevel::Surjective) {
PrimExpr extent_before_padding = pad_mark_it->second->extent;
if (match_full_iter) {
if (splits.size() != 1) {
ErrorLogger(this) << "Dependent iterations on padding iter space";
return Array<IterSplitExpr>();
} else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) &&
!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
ErrorLogger(this) << "Split on padding iteration is not surjective "
<< "if the split extent equals to the full iter space extent";
return Array<IterSplitExpr>();
}
} else if (match_iter_divisor) {
if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
ErrorLogger(this) << "The extent before padding is less than lower factor";
return Array<IterSplitExpr>();
}
} else {
ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
return {};
}
}
return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
}
/*!
* \brief Normalize the iter expression with constraint (min <= expr < max)
* \param expr The iter expression.
* \param predicate_induced_min Closed lower bound from iter constraint, maybe undefined.
* \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
* \return The Normalized expression.
*/
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min,
Optional<PrimExpr> predicate_induced_max) {
// normalize to zero base
PrimExpr base = expr->base;
if (!is_zero(base)) {
expr.CopyOnWrite()->base = 0;
if (predicate_induced_min.defined())
predicate_induced_min = predicate_induced_min.value() - base;
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
const IterSplitExpr split = opt.value()->args[0];
IterSumExpr structured_form = Downcast<IterSumExpr>(split->source->source);
// get the flattened form
auto it = flattened_map_.find(structured_form);
ICHECK(it != flattened_map_.end());
IterSumExpr flattened_form = it->second;
// get the mark and offset of the structured_form
auto it_mark = sum_fuse_map_.find(flattened_form);
ICHECK(it_mark != sum_fuse_map_.end());
IterMark mark = it_mark->second.mark;
PrimExpr mark_offset = it_mark->second.offset;
PrimExpr iter_min = mark_offset;
PrimExpr iter_max = iter_min + mark->extent;
if (predicate_induced_min.defined()) {
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
iter_max = min(predicate_induced_max.value(), iter_max);
}
if (!is_zero(iter_min)) {
// structured form's offset should be updated
flattened_map_.erase(structured_form);
structured_form.CopyOnWrite()->base = -iter_min;
mark.CopyOnWrite()->source = structured_form;
flattened_map_[structured_form] = flattened_form;
}
mark.CopyOnWrite()->extent = iter_max - iter_min;
sum_fuse_map_[flattened_form] = {mark, iter_min};
// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
expr.CopyOnWrite()->base = base + iter_min;
return expr;
}
ErrorLogger(this) << "Could not normalize iterators using the constraints given.";
return expr;
}
/*!
* \brief Normalize expr to an iterator + offset.
* \param expr The input expression.
* \return The Normalized expression.
*/
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
if (opt.defined()) {
return opt.value();
} else {
ErrorLogger(this) << "Could not normalize iterators";
return expr;
}
}
/*!
* \brief Create a IterSumExpr from expr.
* \param expr The input expr.
* \return The transformed IterSumExpr.
*/
static IterSumExpr ToIterSumExpr(const PrimExpr& expr) {
if (const auto* op = expr.as<IterSumExprNode>()) {
return GetRef<IterSumExpr>(op);
} else if (const auto* op = expr.as<IterSplitExprNode>()) {
return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
} else {
ICHECK(!expr->IsInstance<IterMapExprNode>());
return IterSumExpr({}, expr);
}
}
/*!
* \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base
* = (x1*s1 + x2*s2 + ... + xn)*cn + base
* = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + base
* = [IterSplit(IterMark(y), scale=cn)] + base
* return a corresponding IterSumExpr with extra offset if needed.
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
* \param check_level The check level if iter mapping.
* \return The sum with the fused IterMark and extra offset if succeed.
*/
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
// canonicalize the expression into two different forms: flattened form and structured form
// step0. check if find the base scale first
Optional<IntImm> base_scale = NullOpt;
size_t base_index = 0;
for (size_t i = 0; i < expr->args.size(); ++i) {
if (const auto* op = expr->args[i]->scale.as<IntImmNode>()) {
if (!base_scale || op->value < base_scale.value()->value) {
base_scale = GetRef<IntImm>(op);
base_index = i;
}
}
}
if (!base_scale) {
return NullOpt;
}
// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
PrimExpr tail_extent = 0;
PrimExpr expected_scale = base_scale.value();
for (size_t i = 0; i < expr->args.size();) {
// find position such that expr->args[j] match expected scale
int j = i == 0 ? base_index : expr->args.size() - 1;
size_t matched_pos = expr->args.size();
PrimExpr matched_scale{nullptr};
bool is_exact_match{false};
for (; j >= 0; --j) {
if (visited[j]) {
continue;
}
const PrimExpr& cur_scale = expr->args[j]->scale;
// for bijective mapping, the matched scale must equal to expected scale
if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
matched_pos = j;
matched_scale = cur_scale;
is_exact_match = true;
break;
}
if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
// find the closest scale which is less or equal to expected scale
if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
if (matched_pos == expr->args.size() ||
analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
matched_pos = j;
matched_scale = cur_scale;
}
}
}
}
if (matched_pos == expr->args.size()) {
return NullOpt;
}
// look for the longest constrained iter started from expr->args[j]
// Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: j*2 + k < 9
// We need to match the predicate in expr and adjust the expected scale,
// otherwise we expect the scale of i to be 2*5=10
Optional<IterSumExpr> constraint_to_match;
for (const IterSumExpr& iter : constrained_iters_flattened_) {
if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
// find a predicate started from match position
if (!constraint_to_match ||
constraint_to_match.value()->args.size() < iter->args.size()) {
constraint_to_match = iter;
}
}
}
if (constraint_to_match) {
// match the predicate and mark the iterators in the constraint_to_match as visited
// Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate = j*2 + k < 9
// then j*2 + k matches the lower two splits of expr
for (auto it = constraint_to_match.value()->args.rbegin();
it != constraint_to_match.value()->args.rend(); ++it) {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
break;
}
}
if (k == expr->args.size()) {
return NullOpt;
}
visited[k] = true;
flattened_iters.push_back(expr->args[k]);
}
auto iter = sum_fuse_map_.find(constraint_to_match.value());
ICHECK(iter != sum_fuse_map_.end());
const IterMarkWithOffset& iter_matched = iter->second;
grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
expected_extra_base += iter_matched.offset * matched_scale;
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * iter_matched.mark->extent;
// move forward
i += constraint_to_match.value()->args.size();
} else {
// constraint_to_match not found, skip this iterator
visited[matched_pos] = true;
IterSplitExpr arg = expr->args[matched_pos];
arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
flattened_iters.push_back(arg);
grouped_iters.push_back(arg);
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * expr->args[matched_pos]->extent;
++i;
}
}
// Get the flattened form and structured form
// both forms have splits from outermost to innermost
IterSumExpr structured_form = expr, flattened_form = expr;
flattened_form.CopyOnWrite()->args =
Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend());
flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0);
structured_form.CopyOnWrite()->args =
Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend());
structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0);
auto it = sum_fuse_map_.find(flattened_form);
if (it != sum_fuse_map_.end()) {
// old iter
if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) {
// the extra offset is not consistent with old
return NullOpt;
}
return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())},
expr->base + expected_extra_base);
} else {
// new iter, form a new mark
IterMark mark =
IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
expr->base + expected_extra_base);
}
}
bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs);
PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
tir::ExprDeepEqual equal;
for (size_t i = 0; i < lhs->args.size(); ++i) {
IterSplitExpr lvalue = lhs->args[i];
if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
equal(lvalue->extent, rhs->extent)) {
if (sign > 0) {
rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
} else {
rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
}
lhs->args.Set(i, rhs);
return;
}
}
if (sign > 0) {
lhs->args.push_back(rhs);
} else {
rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
lhs->args.push_back(rhs);
}
}
static void AddToLhs(IterSumExprNode* lhs, const IterSumExpr& rhs, int sign) {
for (const auto& arg : rhs->args) {
AddToLhs(lhs, arg, sign);
}
if (sign > 0) {
lhs->base += rhs->base;
} else {
lhs->base -= rhs->base;
}
}
static void MulToLhs(IterSumExprNode* lhs, const PrimExpr& rhs) {
for (size_t i = 0; i < lhs->args.size(); ++i) {
IterSplitExpr lvalue = lhs->args[i];
lvalue.CopyOnWrite()->scale *= rhs;
lhs->args.Set(i, lvalue);
}
lhs->base *= rhs;
}
};
/*! \brief An internal struct to represent range extent on iterators(iter < upper_bound). */
struct IterConstraint {
// The expr of the iter
PrimExpr iter;
// The expr of the lower_bound, maybe undefined
Optional<PrimExpr> lower_bound;
// The expr of the upper_bound, maybe undefined
Optional<PrimExpr> upper_bound;
// The size of the iter, which is the number of nodes
size_t expr_size = 0;
IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
size_t size)
: iter(std::move(iter)),
lower_bound(std::move(lower_bound)),
upper_bound(std::move(upper_bound)),
expr_size(size) {}
};
/*!
* \brief Split the predicate into `(a < b) && (c < d) && ...`
* \param pred The predicate to be split.
* \param input_iters The input iterators.
* \param result The result of predicate split.
* \return A list of IterConstraint, empty if the split failed.
*/
bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
std::vector<IterConstraint>* result) {
arith::PVar<PrimExpr> lhs, rhs, rest;
for (;;) {
// try extract comparisions
bool is_finish = false;
bool is_greater = false;
bool is_equal = false;
if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) {
// pass
} else if ((lhs < rhs).Match(pred)) {
is_finish = true;
} else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) {
is_equal = true;
} else if ((lhs <= rhs).Match(pred)) {
is_equal = true;
is_finish = true;
} else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) {
is_greater = true;
} else if ((lhs > rhs).Match(pred)) {
is_greater = true;
is_finish = true;
} else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) {
is_greater = true;
is_equal = true;
} else if ((lhs >= rhs).Match(pred)) {
is_greater = true;
is_equal = true;
is_finish = true;
} else {
return false;