-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Expand file tree
/
Copy pathtest_dataflow_pattern.py
More file actions
2037 lines (1674 loc) · 67.9 KB
/
test_dataflow_pattern.py
File metadata and controls
2037 lines (1674 loc) · 67.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: F403, F405, F841
import functools
import math
import pytest
import tvm.testing
from tvm import relax as rx
from tvm import tirx
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *
from tvm.script import relax as R
from tvm.script import tirx as T
@tvm.script.ir_module
class Module:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
k = T.int32()
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
C = T.match_buffer(z, (32, 32))
for i0, j0, k0 in T.grid(32, 32, 32):
with T.sblock():
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
with T.init():
C[i, j] = 0.0
C[i, j] += A[i, k] * B[j, k]
@T.prim_func
def tir_relu(x: T.handle, y: T.handle):
T.func_attr({"global_symbol": "tir_relu"})
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
for i, j in T.grid(32, 32):
with T.sblock():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.max(A[vi, vj], 0.0)
@T.prim_func
def tir_zeros(x: T.handle, n: T.int64):
T.func_attr({"global_symbol": "tir_zeros"})
A = T.match_buffer(x, [n])
for i in range(n):
with T.sblock():
vi = T.axis.remap("S", [i])
A[vi] = 1.0
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple:
cls = Module
with R.dataflow():
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_tir(
cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32])
)
gv = (lv1, lv2)
R.output(gv)
return gv
main_fn = Module["main"]
bindings = main_fn.body.blocks[0].bindings
## Node-wise Matching
def test_expr_pattern():
ep = is_expr(rx.Var("x"))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, rx.Var)
def test_var_pattern():
v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"
assert v.match(rx.Var("x"))
assert is_var().match(rx.Var("x"))
assert is_var().match(rx.DataflowVar("x")) # DataflowVar is also a Var
assert not v.match(rx.GlobalVar("x"))
def test_dataflow_var_pattern():
v = is_dfv("x")
assert isinstance(v, DataflowVarPattern)
assert v.name == "x"
assert v.match(rx.DataflowVar("x"))
assert not v.match(rx.GlobalVar("x"))
assert is_dfv().match(bindings[0].var)
def test_global_var_pattern():
assert is_gv("x").match(rx.GlobalVar("x"))
# TODO: disabled as regex is not supported due to
# symbol conflict with PyTorch
# assert is_gv("x.*").match(rx.GlobalVar("x_2"))
assert is_gv().match(rx.GlobalVar("x"))
assert not is_gv("x").match(rx.GlobalVar("y"))
assert not is_gv("x").match(rx.Var("x"))
def test_constant_pattern():
c = is_const()
assert isinstance(c, ConstantPattern)
assert c.match(rx.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]))
def test_wildcard_pattern():
wc = wildcard()
assert isinstance(wc, WildcardPattern)
assert wc.match(rx.Var("x"))
def test_call_pattern():
wc1 = wildcard()
wc2 = wildcard()
c = is_op("relax.add")(wc1, wc2)
assert isinstance(c, CallPattern)
assert isinstance(c.args[0], WildcardPattern)
assert isinstance(c.args[1], WildcardPattern)
assert c.match(rx.op.add(rx.Var("x"), rx.Var("y")))
def test_function_pattern():
wc1 = wildcard()
wc2 = wildcard()
f = FunctionPattern([wc1, wc2], is_op("relax.add")(wc1, wc2))
assert isinstance(f, FunctionPattern)
assert isinstance(f.params[0], WildcardPattern)
assert isinstance(f.params[1], WildcardPattern)
assert isinstance(f.body, CallPattern)
assert isinstance(f.body.args[0], WildcardPattern)
assert isinstance(f.body.args[1], WildcardPattern)
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32")))
assert not f.match(
rx.Function([x, y], rx.op.multiply(x, y), ret_struct_info=R.Tensor("float32"))
)
def test_tuple_pattern():
wc1 = wildcard()
wc2 = is_dfv()
t = is_tuple([wc1, wc2])
assert isinstance(t, TuplePattern)
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], DataflowVarPattern)
assert t.match(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]))
assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.GlobalVar("y")]))
assert not t.match(rx.Tuple([]))
assert t[0].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
assert t[1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
# Negative index is also allowed
assert t[-1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
# None means any index.
assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
with pytest.raises(IndexError):
t[2] # index cannot be greater than or equal to the tuple size.
def test_unordered_tuple_pattern():
t = is_tuple([is_const(), is_dfv()], unordered=True)
assert isinstance(t, UnorderedTuplePattern)
assert isinstance(t.fields[0], ConstantPattern)
assert isinstance(t.fields[1], DataflowVarPattern)
assert t.match(rx.Tuple([rx.const([]), rx.DataflowVar("x")]))
assert t.match(rx.Tuple([rx.DataflowVar("x"), rx.const([])]))
assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.DataflowVar("y")]))
assert not t.match(rx.Tuple([]))
def test_tuple_get_item_pattern():
assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
)
assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
)
def test_or_pattern():
dfv_or_gv = is_dfv("x") | is_gv("x")
assert isinstance(dfv_or_gv, OrPattern)
assert dfv_or_gv.match(rx.DataflowVar("x"))
assert dfv_or_gv.match(rx.GlobalVar("x"))
assert not dfv_or_gv.match(rx.Var("x"))
assert not dfv_or_gv.match(rx.DataflowVar("y"))
assert not dfv_or_gv.match(rx.GlobalVar("y"))
def test_and_pattern():
# float[2, 3, 3]
f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32")
assert isinstance(f32_233, AndPattern)
assert f32_233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))
assert not f32_233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
assert not f32_233.match(rx.Var("x", R.Tensor("float32", ndim=3)))
def test_not_pattern():
no_shape233 = ~wildcard().has_shape((2, 3, 3))
assert isinstance(no_shape233, NotPattern)
assert no_shape233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
assert not no_shape233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))
def test_dtype_pattern():
dtype = "float16"
pattern = has_dtype(dtype)
assert isinstance(pattern, DataTypePattern)
assert pattern.dtype == dtype
assert has_dtype("float32").match(bindings[0].var)
def test_shape_pattern():
shape = [32, 32]
pattern = wildcard().has_shape(shape)
assert isinstance(pattern, ShapePattern)
tvm.ir.structural_equal(pattern.shape, shape)
assert pattern.match(bindings[0].var)
assert wildcard().has_shape([32, 32]).match(bindings[0].var)
n, m = tirx.Var("n", dtype="int64"), tirx.Var("m", dtype="int64")
symsh_var = rx.Var("x", R.Tensor([n, m, n + m], "float32"))
assert wildcard().has_shape([n, m, n + m]).match(symsh_var)
assert wildcard().has_shape([n, m, m + n]).match(symsh_var) # + is commutative.
assert not wildcard().has_shape([1, 2, 3]).match(symsh_var)
assert not wildcard().has_shape([m, n, n + m]).match(symsh_var)
def test_prim_arr_pattern():
"""
The difference between is_shape and has_shape is that:
1) is_shape directly matches a shape (e.g., as an argument);
2) has_shape matches a tensor and puts assumptions on the tensor's shape.
"""
pattern = is_shape([32, 32])
assert pattern[0] == 32
assert pattern[1] == 32
assert isinstance(pattern, PrimArrPattern)
assert pattern.match(rx.get_shape_of(bindings[0].var))
n, m = tirx.Var("n", dtype="int64"), tirx.Var("m", dtype="int64")
symbolic_shape = rx.ShapeExpr([n, m, n + m])
assert is_shape([n, m, n + m]).match(symbolic_shape)
assert not is_shape([n, m, n * m]).match(symbolic_shape)
def test_extern_fn_pattern():
pattern = ExternFuncPattern("test.blockbuilder.nop")
assert pattern.match(rx.ExternFunc("test.blockbuilder.nop"))
def test_op_attr():
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
conv2d = rx.op.nn.conv2d(x, y, strides=(3, 3))
xp = is_var("x")
yp = is_var("y")
assert is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d)
assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [4, 3]}).match(conv2d)
def test_match_call_attr():
x = rx.Var("x", R.Tensor("float32"))
y = rx.Var("y", R.Tensor("float32"))
fn = rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32"))
annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"})
xp = is_var("x")
yp = is_var("y")
root_pattern = FunctionPattern([xp, yp], is_op("relax.add")(xp, yp))
assert root_pattern.has_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}).match(
annotated_fn
)
assert root_pattern.has_attr({"Codegen": "test-codegen"}).match(annotated_fn)
assert not root_pattern.has_attr({"ping": "pong"}).match(annotated_fn)
assert root_pattern.has_attr({}).match(annotated_fn)
def test_is_call_tir():
lv1_val = bindings[1].value
lv2_val = bindings[2].value
var2val = get_var2val(Module["main"])
assert is_call_tir("tir_relu").match(lv1_val)
assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val)
assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val)
assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val)
@R.function(pure=False)
def simple_call_packed(
x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
) -> R.Tensor:
gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
return gv0
def test_varg_default_wildcard():
expr = simple_call_packed.body.blocks[0].bindings[0].value
yes_pattern_explicit = ExternFuncPattern("test.vm.mul")(wildcard(), wildcard())
yes_pattern_implicit = ExternFuncPattern("test.vm.mul")(varg_default_wildcard=True)
no_pattern = ExternFuncPattern("test.vm.mul")(wildcard())
assert yes_pattern_explicit.match(expr)
assert yes_pattern_implicit.match(expr)
assert not no_pattern.match(expr)
def test_simple_call_packed():
expr = simple_call_packed.body.blocks[0].bindings[0].value
assert is_call_packed("test.vm.mul").match(expr)
assert is_call_packed("test.vm.mul", [is_var("x"), is_var("w")]).match(expr)
## Graph-wise Matching
def test_simple_used_by():
with PatternContext() as ctx:
n0 = is_var("x") # x is a free var (fn arg)
n1 = wildcard()
n0 ^ n1
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == main_fn.params[0]
assert matched[n1] == dfb.bindings[0].var
def test_simple_call_tir_edge():
with PatternContext() as ctx:
n0 = is_call_tir("tir_matmul")
n1 = is_call_tir("tir_relu")
n0.used_by(n1)
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == dfb.bindings[0].var
assert matched[n1] == dfb.bindings[1].var
def test_simple_oub():
with PatternContext() as ctx:
n0 = is_call_tir("tir_matmul")
n1 = is_call_tir("tir_relu")
n0 >> n1
dfb = main_fn.body.blocks[0]
matched = ctx.match_dfb(dfb)
assert matched
assert matched[n0] == dfb.bindings[0].var
assert matched[n1] == dfb.bindings[1].var
def test_counter_syntax_match():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_impossible")
n0 >> n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_impossible")
n0 ^ n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
@tvm.script.ir_module
class Diamond:
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# matmul
# / \
# relu sigmoid
# \ /
# add
lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32"))
R.output(lv3)
return lv3
def test_diamond():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
n0 ^ n1
n0 ^ n2
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
# simplify it with fork_to
with PatternContext() as ctx:
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(n1, n2)
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_diamond_counter_oub():
with PatternContext() as ctx:
n0 = is_call_dps_packed("extern_matmul")
n1 = is_call_dps_packed("extern_relu")
n2 = is_call_dps_packed("extern_sigmoid")
n3 = is_call_dps_packed("extern_add")
n0 >> n1
n0 >> n2
n1 >> n3
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
assert not ctx.match_dfb(dfb)
@tvm.script.ir_module
class SmallDiamond:
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu
# / \
# \ /
# add
lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
@tvm.script.ir_module
class SmallParallel:
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu relu
# \ /
# add
lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32"))
R.output(lv2)
return lv2
def test_distinguish_diamond_and_parallel():
# pattern lang cannot distinguish the two cases above.
diamond = SmallDiamond["main"].body.blocks[0]
parallel = SmallParallel["main"].body.blocks[0]
with PatternContext() as ctx:
# describe a diamond pattern
fork = is_call_dps_packed("my_relu")
join = is_call_dps_packed("my_add")
fork.only_used_by(join, index=0)
fork.only_used_by(join, index=1)
assert ctx.match_dfb(diamond)
assert not ctx.match_dfb(parallel)
with PatternContext() as ctx:
# describe a parallel pattern
join = is_call_dps_packed("my_add")
# Due to one-one matching:
# is_call_dps_packed("my_relu") creates the 1st relu
is_call_dps_packed("my_relu") >> join
# is_call_dps_packed("my_relu")
# creates the another different relu (obj address is different)
is_call_dps_packed("my_relu") >> join
assert ctx.match_dfb(parallel)
assert not ctx.match_dfb(diamond)
@tvm.script.ir_module
class CBRx2:
@R.function
def main(
x: R.Tensor((32, 32), "float32"),
w0: R.Tensor((1, 1), "float32"),
bias0: R.Tensor((32, 32), "float32"),
w1: R.Tensor((1, 1), "float32"),
bias1: R.Tensor((32, 32), "float32"),
) -> R.Tensor:
# R.TensorRT's CBR Optimization Pattern
# input
# / \
# cbr0 cbr1
# \ /
# concat
with R.dataflow():
lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), dtype="float32"))
lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32"))
lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32"))
lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), dtype="float32"))
lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32"))
R.output(lv6)
return lv6
def test_nested_context():
dfb = CBRx2["main"].body.blocks[0]
with PatternContext() as ctx0:
(
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
with PatternContext() as ctx1:
is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu") # pattern to miss
with PatternContext() as ctx2:
is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu")
assert ctx2.match_dfb(dfb)
assert PatternContext.current() == ctx2
assert not ctx1.match_dfb(dfb)
assert PatternContext.current() == ctx1
assert ctx0.match_dfb(dfb)
assert PatternContext.current() == ctx0
def test_two_cbr():
with PatternContext() as ctx:
cbr0 = (
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
assert cbr0.patterns[0] != cbr1.patterns[0]
assert cbr0.patterns[1] != cbr1.patterns[1]
assert cbr0.patterns[2] != cbr1.patterns[2]
is_var("x").fork_to(cbr0, cbr1)
dfb = CBRx2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
# Deny the pattern
cbr0 = (
is_call_dps_packed("conv1x1")
>> is_call_dps_packed("bias_add")
>> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
# input has no fork at y.
is_var("y").fork_to(cbr0, cbr1)
dfb = CBRx2["main"].body.blocks[0]
assert not ctx.match_dfb(dfb)
def test_two_matmul():
# Same as Figure 2(a) in TASO paper.
@tvm.script.ir_module
class MatMul2:
@R.function
def main(
a: R.Tensor((32, 16), "float32"),
b: R.Tensor((16, 48), "float32"),
c: R.Tensor((48, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), dtype="float32"))
lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
with PatternContext() as ctx:
is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
is_call_dps_packed("matmul").has_shape([32, 48]) >> is_call_dps_packed("matmul").has_shape(
[32, 32]
)
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
# Three MatMul cannot match
assert not ctx.match_dfb(dfb)
def test_concat_mm_split():
# Same as Figure 2(b) in TASO paper.
@tvm.script.ir_module
class CMS:
@R.function
def main(
a: R.Tensor((32, 32), "float32"),
b: R.Tensor((16, 32), "float32"),
c: R.Tensor((16, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed(
"my_split",
(lv1,),
[R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")],
)
lv3 = R.TupleGetItem(lv2, 0)
lv4 = R.TupleGetItem(lv2, 1)
lv5 = R.add(lv3, lv4)
R.output(lv5)
return lv5
with PatternContext() as ctx:
(
is_call_dps_packed("my_concat")
>> is_call_dps_packed("my_matmul")
>> is_call_dps_packed("my_split")
)
dfb = CMS["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
split = is_call_dps_packed("my_split")
lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32])
lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32])
split.fork_to(lv3, lv4)
add = is_op("relax.add")(lv3, lv4)
# TODO(@ganler): simplify this through implicit graph pattern.
lv3 >> add
lv4 >> add
dfb = CMS["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_self_attention():
# The example comes from.
# https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/
@tvm.script.ir_module
class SelfAttention:
@R.function
def main(
x: R.Tensor(("b", "s", "n", "h"), "float32"),
wq: R.Tensor(("h", "h"), "float32"),
wk: R.Tensor(("h", "h"), "float32"),
wv: R.Tensor(("h", "h"), "float32"),
) -> R.Tensor:
b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64()
with R.dataflow():
fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32"))
tpq = R.call_dps_packed(
"my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")
)
fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32"))
tpk = R.call_dps_packed(
"my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")
)
mul = R.multiply(tpq, tpk)
scale = R.multiply(mul, R.const(1.1, "float32"))
softmax = R.call_dps_packed(
"softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")
)
fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32"))
tpv = R.call_dps_packed(
"my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")
)
out = R.multiply(softmax, tpv)
R.output(out)
return out
with PatternContext() as ctx:
fc_trans_q = is_call_dps_packed("my_fc") >> is_call_dps_packed("my_transpose")
fc_trans_k = fc_trans_q.dup()
fc_trans_v = fc_trans_q.dup()
is_var("x").fork_to(fc_trans_q, fc_trans_k, fc_trans_v)
dfb = SelfAttention["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
def test_nested_diamond():
@tvm.script.ir_module
class DiamondInDiamond:
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# matmul0 matmul1
# / \ / \
# sigmoid2 add4 sigmoid3
# \ / \ /
# add5 add6
# \ /
# add7
lv0 = R.call_dps_packed(
"extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
)
lv1 = R.call_dps_packed(
"extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
)
lv2 = R.call_dps_packed(
"extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")
)
lv3 = R.call_dps_packed(
"extern_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")
)
lv4 = R.call_dps_packed(
"extern_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")
)
lv5 = R.call_dps_packed(
"extern_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")
)
lv6 = R.call_dps_packed(
"extern_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")
)
lv7 = R.call_dps_packed(
"extern_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")
)
R.output(lv7)
return lv7
# match matmul0 diamond
with PatternContext() as ctx:
sigmoid2 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 ^ add5
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# counter case: mis-match matmul0 diamond
with PatternContext() as ctx:
sigmoid2 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 >> add5 # not only-used-by relation
assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# match matmul1 diamond
with PatternContext() as ctx:
sigmoid3 = is_call_dps_packed("extern_sigmoid")
add4 = is_call_dps_packed("extern_add")
is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4)
add6 = is_call_dps_packed("extern_add")
sigmoid3 >> add6
add4 ^ add6
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# match add-4-5-6-7
with PatternContext() as ctx:
add5, add6, add7 = (
is_call_dps_packed("extern_add"),
is_call_dps_packed("extern_add"),
is_call_dps_packed("extern_add"),
)
is_call_dps_packed("extern_add").fork_to(add5, add6) # add4
add5 >> add7
add6 >> add7
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
def test_incremental_solving():
@R.function
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu -> sigmoid -> neg
lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), dtype="float32"))
R.output(lv2)
return lv2
relu = is_call_dps_packed("extern_relu")
sigmoid = is_call_dps_packed("extern_sigmoid")
neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid
with PatternContext(incremental=True) as ctx1:
# because we are doing incremental solving
# relu >> sigmoid is still a constraint in this context.
# that said the total constraint is:
# relu >> sigmoid >> neg
sigmoid >> neg
assert ctx1.match_dfb(simple_chain.body.blocks[0])
# match relue -> sigmoid
assert ctx0.match_dfb(simple_chain.body.blocks[0])
def test_incremental_solving_counter():
@R.function
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# sigmoid -> neg
lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
relu = is_call_dps_packed("extern_relu")
sigmoid = is_call_dps_packed("extern_sigmoid")
neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid # cannot match
with PatternContext(incremental=False) as ctx1:
# total constraint: sigmoid >> neg
sigmoid >> neg
assert ctx1.match_dfb(simple_chain.body.blocks[0])
with PatternContext(incremental=True) as ctx1:
# total constraint: relu >> sigmoid >> neg
sigmoid >> neg
assert not ctx1.match_dfb(simple_chain.body.blocks[0])
def test_rewrite_simple():
@R.function
def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"):
with R.dataflow():
x2 = R.add(x, x)
x4 = R.add(x2, x2)
R.output(x4)
return x4
@R.function
def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
with R.dataflow():
lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2, "float32"))
x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv, R.const(2, "float32"))
R.output(x4)
return x4
@R.function
def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
with R.dataflow():
x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4, "float32"))
R.output(x4)
return x4
x = wildcard()
pattern = is_op("relax.add")(x, x)
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(2, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected1.with_attr("global_symbol", "main"))
add1 = is_op("relax.add")(x, x)
pattern = is_op("relax.add")(add1, add1)
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(4, "float32"))
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected2.with_attr("global_symbol", "main"))
# No rewriting, return the original call node as is
def rewriter(orig, _):
return orig
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, main)
def test_rewrite_attention():
@R.function
def main(
Q: R.Tensor((2, 4096, 8, 40), "float32"),
K: R.Tensor((2, 4096, 8, 40), "float32"),
V: R.Tensor((2, 4096, 8, 40), "float32"),
) -> R.Tensor((2, 4096, 8, 40), "float32"):
with R.dataflow():
lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))
lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))
lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))
lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
lv3_1 = R.matmul(lv59, lv62_transposed)
lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
lv69 = R.nn.softmax(lv68, axis=-1)
lv_3 = R.matmul(lv69, lv65)
lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
R.output(lv72)
return lv72
@R.function
def expected(
Q: R.Tensor((2, 4096, 8, 40), dtype="float32"),
K: R.Tensor((2, 4096, 8, 40), dtype="float32"),
V: R.Tensor((2, 4096, 8, 40), dtype="float32"),
) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
with R.dataflow():
lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.nn.attention(Q, V, K)
R.output(lv72)
return lv72
def BSNH_to_BSH(tensor):
return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), wildcard())
def BSH_to_BSNH(tensor):
return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, wildcard()))
Q = wildcard()
K = wildcard()
V = wildcard()
Q_3D = BSNH_to_BSH(Q)
V_3D = BSNH_to_BSH(V)
K_3D = BSNH_to_BSH(K)
matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
multiply = is_op("relax.multiply")(matmul1, is_const())
softmax = is_op("relax.nn.softmax")(multiply)
matmul2 = is_op("relax.matmul")(softmax, K_3D)
pattern = BSH_to_BSNH(matmul2)
def rewriter(_, matchings):
return R.nn.attention(matchings[Q], matchings[K], matchings[V])
rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main"))
def test_attention_qkv():
@tvm.script.ir_module
class QKV_proj:
@R.function
def main(
x: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
lv2 = R.matmul(x, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out
with PatternContext() as ctx: