forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_linear.py
More file actions
1053 lines (931 loc) · 39.1 KB
/
test_linear.py
File metadata and controls
1053 lines (931 loc) · 39.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import unittest
from itertools import product
from typing import Callable, Dict, List, Optional, Tuple
import torch
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackFloatingPointPartitioner,
XnnpackPartitioner,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
QuantizationConfig,
)
from executorch.backends.xnnpack.test.tester import Quantize, Tester
from executorch.backends.xnnpack.test.tester.tester import (
Partition,
ToEdgeTransformAndLower,
)
from torch.export.graph_signature import ExportGraphSignature, InputKind
try:
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
torchao_installed = True
except:
torchao_installed = False
# Pytorch Modules Used for Testing
class BaseLinear(torch.nn.Module):
def __init__(
self,
in_size: int = 2,
input_channels: int = 4,
output_channels: int = 4,
dtype: torch.dtype = torch.float,
use_bias: bool = False,
):
super().__init__()
self.linear = torch.nn.Linear(
input_channels, output_channels, bias=use_bias
).to(dtype=dtype)
self.ic = input_channels
self.oc = output_channels
assert dtype in [torch.float, torch.half], "Unsupported op dtype"
self.op_dtype = dtype
self.in_size = in_size
def forward(self, x):
return self.linear(x)
def get_inputs(self, rank=3):
# rank = 3 as default to inflate the act rank by 1 in batch dim
# This is to make sure we don't specialize on 2D shapes.
inp = torch.randn(self.in_size, self.ic).to(self.op_dtype)
for _ in range(rank - 2):
inp = inp.unsqueeze(0)
assert inp.ndim == rank
return (inp,)
class AddMMModule(torch.nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.mat = torch.nn.Parameter(torch.randn(in_size, out_size))
self.bias = torch.nn.Parameter(torch.randn(1, out_size))
def forward(self, x):
return torch.addmm(self.bias, x, self.mat)
class LinearReluModule(torch.nn.Module):
def __init__(self, in_size, out_size, use_bias, dtype=torch.float):
super().__init__()
self.dtype = dtype
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias).to(dtype=dtype)
def forward(self, x):
return torch.nn.functional.relu(self.linear(x))
def get_inputs(self):
return (torch.randn(1, self.in_size, self.ic).to(self.op_dtype),)
class LinearParallelSequentialModule(torch.nn.Module):
def __init__(
self,
in_size=2,
input_size=4,
intermediate_size=5,
output_size=3,
dtype=torch.float,
):
super().__init__()
self.linear1_weight = torch.nn.Parameter(
torch.rand(intermediate_size, input_size)
)
self.linear1_bias = torch.nn.Parameter(torch.rand(intermediate_size))
self.linear2_weight = torch.nn.Parameter(
torch.rand(intermediate_size, input_size)
)
self.linear2_bias = torch.nn.Parameter(torch.rand(intermediate_size))
self.linear3_weight = torch.nn.Parameter(
torch.rand(output_size, intermediate_size)
)
self.linear3_bias = torch.nn.Parameter(torch.rand(output_size))
self.in_size = in_size
self.input_size = input_size
self.dtype = torch.float
def forward(self, x, y):
a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias)
b = torch.nn.functional.linear(y, self.linear2_weight, self.linear2_bias)
c = torch.nn.functional.linear(b, self.linear3_weight, self.linear3_bias)
return (a, c)
def get_inputs(self):
return (
torch.rand(self.in_size, self.input_size, dtype=self.dtype),
torch.rand(self.in_size, self.input_size, dtype=self.dtype),
)
class LinearSequential(torch.nn.Module):
def __init__(
self,
in_size=2,
input_size=4,
intermediate_size=5,
output_size=3,
dtype=torch.float,
):
super().__init__()
self.linear1_weight = torch.nn.Parameter(
torch.rand(intermediate_size, input_size)
)
self.linear1_bias = torch.nn.Parameter(torch.rand(intermediate_size))
self.linear2_weight = torch.nn.Parameter(
torch.rand(output_size, intermediate_size)
)
self.linear2_bias = torch.nn.Parameter(torch.rand(output_size))
self.in_size = in_size
self.input_size = input_size
self.dtype = torch.float
def forward(self, x):
a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias)
b = torch.nn.functional.linear(a, self.linear2_weight, self.linear2_bias)
return b
def get_inputs(self):
return (torch.rand(self.in_size, self.input_size, dtype=torch.float),)
class ParallelLinear(torch.nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear1_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
self.linear1_bias = torch.nn.Parameter(torch.rand(output_size))
self.linear2_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
self.linear2_bias = torch.nn.Parameter(torch.rand(output_size))
def forward(self, x, y):
a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias)
b = torch.nn.functional.linear(y, self.linear2_weight, self.linear2_bias)
return a + b
class SharedDQChain(torch.nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear1_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
self.linear1_bias = torch.nn.Parameter(torch.rand(output_size))
self.linear2_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
self.linear2_bias = torch.nn.Parameter(torch.rand(output_size))
def forward(self, x):
a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias)
b = torch.nn.functional.linear(x, self.linear2_weight, self.linear2_bias)
return a + b
class TestLinear(unittest.TestCase):
"""
Test Class for XNNPACK Linear Operators.
Notes:
- XNNPACK Does not support Per Tensor Quantized Weights with Dynamic Activations
- XNNPACK Only supports Per-Token Activation, so Dynamic per-tensor Quantization
As done by the default dynamic quantization flow does Per-Token Quantization
Activation under the hood, where the torch.nn.Module is doing Per-Tensor Quantization
on the Activation. This is sufficient because Per-Token Quantization on Activations
should produce strictly better results compared to Per-Tensor Quantization
"""
def setUp(self):
torch._dynamo.reset()
@staticmethod
def _get_4b_dqconfig() -> QuantizationConfig:
# Returns a QuantizationConfig for 4b dynamic quantization for XNNPACK.
qconfig: QuantizationConfig = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=True,
weight_qmin=-8,
weight_qmax=7,
)
return qconfig
def _test_linear(
self,
make_module,
uses_bias,
num_batch_dims=1,
quant_type=None,
dtype: torch.dtype = torch.float,
atol=1e-03, # TODO(T212995726): Investigate right atol for rand[n] inputs
):
"""
Helper function to test linear op with different configurations.
"""
edge_op = (
"executorch_exir_dialects_edge__ops_aten_addmm_default"
if uses_bias
else "executorch_exir_dialects_edge__ops_aten_mm_default"
)
in_sizes = [3, 4, 4]
input_sizes = [4, 37, 17]
output_sizes = [4, 17, 37]
quant_config = None
if quant_type is not None:
if quant_type == "per_channel":
quant_config = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=False,
)
elif quant_type == "per_tensor":
quant_config = get_symmetric_quantization_config(
is_per_channel=False,
is_dynamic=False,
)
else:
raise ValueError(f"Unsupported quant type {quant_type}")
"""
Note that torch.nn.Linear maps to aten.mm.default (no bias) or aten.addmm.default (bias),
which ares then transformed into aten.linear.default by the ConvertToLinear pass.
"""
for i, _ in enumerate(in_sizes):
torch._dynamo.reset()
in_size = int(in_sizes[i])
input_size = int(input_sizes[i])
output_size = int(output_sizes[i])
input_shape = [in_size] * num_batch_dims + [input_size]
module = make_module(input_size, output_size).eval().to(dtype)
inputs = (torch.randn(input_shape).to(dtype),)
dynamic_shape = {}
for i in range(num_batch_dims):
dynamic_shape[i] = torch.export.Dim(f"batch{i}", min=2, max=in_size)
dynamic_shape = (dynamic_shape,)
for legacy_mode in (True, False):
tester = Tester(module, inputs, dynamic_shapes=dynamic_shape)
if quant_config:
tester.quantize(Quantize(quantization_config=quant_config))
tester.export()
if quant_config:
tester.check(["torch.ops.quantized_decomposed"])
if legacy_mode:
tester.to_edge()
tester.partition()
else:
tester.to_edge_transform_and_lower()
tester.check_count(
{"torch.ops.higher_order.executorch_call_delegate": 1}
)
tester.check_not([edge_op])
if quant_config:
tester.check_not(
[
"executorch_exir_dialects_edge__ops_aten_mm_default",
"executorch_exir_dialects_edge__ops_aten_addmm_default",
]
)
tester.to_executorch()
tester.serialize()
tester.run_method_and_compare_outputs(
qtol=bool(quant_config), atol=atol
)
def _test_dqlinear(
self,
module,
inputs,
dynamic_shapes,
linear_count=1,
is_per_channel=False,
uses_bias=False,
qconfig: Optional[QuantizationConfig] = None,
atol=5e-02, # TODO(T212995726): Investigate right atol for rand[n] inputs
):
"""
Helper function to test dynamic quantized linear op with different configurations.
"""
quant_config = qconfig or get_symmetric_quantization_config(
is_per_channel=is_per_channel,
is_dynamic=True,
)
for legacy_partitioner in (True, False):
for per_op_mode in (True, False):
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=per_op_mode,
)
tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes)
tester.quantize(Quantize(quantization_config=quant_config))
tester.export()
if legacy_partitioner:
tester.to_edge()
tester.partition(Partition(DynamicallyQuantizedPartitioner))
else:
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
)
tester.check_count(
{
"torch.ops.higher_order.executorch_call_delegate": (
linear_count if per_op_mode else 1
)
}
)
tester.check_not(
[
"executorch_exir_dialects_edge__ops_aten_mm_default",
"executorch_exir_dialects_edge__ops_aten_addmm_default",
]
)
tester.to_executorch()
tester.serialize()
tester.run_method_and_compare_outputs(atol=atol)
def _test_groupwise_dq_linear(
self,
mod: torch.nn.Module,
inputs: Tuple[torch.Tensor],
use_bias: bool = False,
group_size: int = 8,
num_linears: int = 1,
atol: float = 5e-3, # TODO(T212995726): Investigate right atol for rand[n] inputs
rtol: float = 5e-3, # TODO(T212995726): Investigate right rtol for rand[n] inputs
):
"""
Helper function to test groupwise dynamic quantized linear op with different configurations.
"""
quantize_(mod, int8_dynamic_activation_int4_weight(group_size=group_size))
unwrap_tensor_subclass(mod)
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=True,
)
tester = (
Tester(mod, inputs)
.export()
.check_count(
{
"torch.ops.quant.choose_qparams_affine.default": 1 * num_linears,
"torch.ops.quant.quantize_affine.default": 1 * num_linears,
"torch.ops.quant.dequantize_affine.default": 2 * num_linears,
"torch.ops.aten.linear.default": 1 * num_linears,
}
)
)
(
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
)
)
(
tester.check_count(
{
"torch.ops.higher_order.executorch_call_delegate": 1,
}
)
.check_not(
[
"executorch_exir_dialects_edge__ops_quant_choose_qparams_affine_default",
"executorch_exir_dialects_edge__ops_quant_quantize_affine_default",
"executorch_exir_dialects_edge__ops_quant_dequantize_affine_default",
"executorch_exir_dialects_edge__ops_aten_mm_default",
"executorch_exir_dialects_edge__ops_aten_addmm_default",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
)
def _test_linear_overwrite_precision(
self,
make_module: Callable[[int, int], torch.nn.Module],
uses_bias: bool,
quant_type: str,
quant_node_checks: List[Dict[str, int]],
atol: float = 1e-03, # TODO(T212995726): Investigate right atol for rand[n] inputs
):
"""
This test is to test the overwrite precision of linear op.
We will test partitioning, lowering, and running the quantized linear model as fp32 linear op.
When using legacy_mode, we will test we don't partition [add]mm given,
(1) We can't assume that weights are always static (non param).
(2) Alternatively, when lowering [add]mm to xnn::bmm we can't support bias.
(2)(a) Only lowering non-bias [add]mm, which is only exposed on legacy_path deemed low ROI.
"""
in_sizes = [3, 4, 4]
input_sizes = [4, 37, 17]
output_sizes = [4, 17, 37]
assert quant_type in ["per_tensor", "per_channel", "per_channel_dynamic"]
per_channel = "per_channel" in quant_type
dynamic = "dynamic" in quant_type
quant_config = get_symmetric_quantization_config(
is_per_channel=per_channel,
is_dynamic=dynamic,
)
# Using FP32 partitioner for this quantized graph
partitioner = XnnpackFloatingPointPartitioner()
def get_qnode_checks(quant_node_checks, dialect):
d = {}
assert dialect in ["aten", "edge"]
if dialect == "aten":
d = {
f"torch.ops.quantized_decomposed.{op}": count
for op, count in quant_node_checks.items()
}
elif dialect == "edge":
d = {
f"executorch.exir.dialects.edge._ops.quantized_decomposed.{op}".replace(
".", "_"
): count
for op, count in quant_node_checks.items()
}
assert len(d) == len(quant_node_checks)
return d
for i, _ in enumerate(in_sizes):
torch._dynamo.reset()
in_size = int(in_sizes[i])
input_size = int(input_sizes[i])
output_size = int(output_sizes[i])
input_shape = [in_size] + [input_size]
module = make_module(input_size, output_size).eval()
inputs = (torch.randn(input_shape),)
addmm_op_str = (
"executorch_exir_dialects_edge__ops_aten_addmm_default"
if uses_bias
else "executorch_exir_dialects_edge__ops_aten_mm_default"
)
linear_op_str = "executorch_exir_dialects_edge__ops_aten_linear_default"
for legacy_mode in (True, False):
tester = (
Tester(module, inputs)
.quantize(Quantize(quantization_config=quant_config))
.export()
.dump_artifact()
.check_count(get_qnode_checks(quant_node_checks, "aten"))
)
if legacy_mode:
tester.to_edge()
tester.partition(Partition(partitioner=partitioner))
# We don't expect [add]mm to be partitioned
tester.check([addmm_op_str])
else:
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower(partitioners=[partitioner])
)
# We do expect linear to be partitioned
tester.check_not([linear_op_str])
# For legacy mode, fp32 permute_copy gets partitioned. (just a side effect)
# For new mode, fp32 linear gets partitioned.
tester.check_count(
{"torch.ops.higher_order.executorch_call_delegate": 1}
)
# Typically, we would not see any quantized ops in the graph.
# But here we shouldn't partition these.
tester.check_count(get_qnode_checks(quant_node_checks, "edge"))
# TODO: Need to figure out how to load quantized ops in pybindings.
# tester.to_executorch()
# tester.serialize()
# tester.run_method_and_compare_outputs(
# qtol=bool(quant_config), atol=atol
# )
def test_qd8_f32_per_channel_shared_dq_chain(self):
for use_bias in (False, True):
module = SharedDQChain(
input_size=13,
output_size=17,
)
inputs = (torch.randn(1, 2, 13),)
self._test_dqlinear(
module,
inputs,
dynamic_shapes=None,
is_per_channel=True,
linear_count=2,
uses_bias=use_bias,
)
def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
for uses_bias in (False, True):
module = BaseLinear(
in_size=8,
input_channels=13,
output_channels=17,
dtype=dtype,
use_bias=uses_bias,
)
inputs = module.get_inputs()
self._test_dqlinear(
module,
inputs,
dynamic_shapes=({1: torch.export.Dim("batch", max=100)},),
is_per_channel=True,
uses_bias=uses_bias,
)
def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.float):
for uses_bias in (False, True):
module = BaseLinear(
in_size=8,
input_channels=13,
output_channels=17,
dtype=dtype,
use_bias=uses_bias,
)
inputs = module.get_inputs()
dynamic_shapes = ({1: torch.export.Dim("batch", max=100)},)
quant_config = get_symmetric_quantization_config(
is_per_channel=False,
is_dynamic=True,
)
for legacy_partitioner in (True, False):
for per_op_mode in (True, False):
# Every combination should fail to partition Linear or [add]mm.
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=per_op_mode,
)
tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes)
tester.quantize(Quantize(quantization_config=quant_config))
tester.export()
if legacy_partitioner:
tester.to_edge()
tester.partition(
Partition(DynamicallyQuantizedPartitioner)
).dump_artifact()
# should have [add]mm node
if uses_bias:
tester.check(
[
"executorch_exir_dialects_edge__ops_aten_addmm_default",
]
)
else:
tester.check(
[
"executorch_exir_dialects_edge__ops_aten_mm_default",
]
)
else:
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
).dump_artifact()
# should not have a delegate node
tester.check_not(
[
"torch.ops.higher_order.executorch_call_delegate",
]
)
# No need to run the model, since it should fail to partition.
return
def _test_qd8_per_channel_4w_linear(self, dtype: torch.dtype = torch.float):
qconfig = self._get_4b_dqconfig()
input_channels = [2, 63]
output_channels = [1, 127]
batches = [
2,
]
use_bias = [False, True]
dtypes = [
dtype,
]
for bs, bias, ipc, opc, dtype in product(
batches,
use_bias,
input_channels,
output_channels,
dtypes,
):
module = BaseLinear(
in_size=bs,
input_channels=ipc,
output_channels=opc,
dtype=dtype,
use_bias=bias,
)
inputs = module.get_inputs()
self._test_dqlinear(
module,
inputs,
dynamic_shapes=({1: torch.export.Dim("batch", max=100)},),
is_per_channel=True,
uses_bias=bias,
qconfig=qconfig,
atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs
)
def _test_qd8_per_token_weight_per_channel_group_int4(
self, dtype: torch.dtype = torch.float
):
M_sizes = [1, 2, 17, 31]
K_sizes = [32, 32, 64, 128]
bl_sizes = [32, 32, 32, 64]
N_sizes = [2, 17, 92, 128]
for input_rank in range(2, 4):
for use_bias in [True, False]:
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
lin_mod = BaseLinear(
in_size=M,
input_channels=K,
output_channels=N,
dtype=dtype,
use_bias=use_bias,
)
inputs = lin_mod.get_inputs(rank=input_rank)
# Half requires slightly higher atol, but if you look at error it is not that bad:
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
# -- Model vs. Reference --
# Numel: 4, 4
# Median: -0.05023193359375, -0.0516357421875
# Mean: 0.2373046875, 0.237060546875
# Max: 1.0078125, 1.0078125
# Min: -0.08465576171875, -0.08441162109375
atol = (
1e-2 if dtype == torch.half else 5e-3
) # TODO(T212995726): Investigate right atol for rand[n] inputs
self._test_groupwise_dq_linear(
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol
)
def test_fp16_linear(self):
for use_bias in (True, False):
for num_batch_dims in range(1, 3):
self._test_linear(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
num_batch_dims=num_batch_dims,
uses_bias=use_bias,
dtype=torch.float16,
atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs
)
def test_fp32_linear(self):
for use_bias in (True, False):
for num_batch_dims in range(1, 3):
self._test_linear(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
uses_bias=use_bias,
num_batch_dims=num_batch_dims,
)
def test_qc8_linear(self):
for use_bias in (True, False):
for num_batch_dims in range(1, 3):
self._test_linear(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
uses_bias=use_bias,
quant_type="per_channel",
num_batch_dims=num_batch_dims,
)
def test_fp32_addmm(self):
# Note that the ConvertToLinear pass requires the weight matrix to be transposed.
self._test_linear(
lambda in_size, out_size: AddMMModule(in_size, out_size),
uses_bias=True,
)
def test_fp32_linear_fused_relu(self):
for use_bias in (True, False):
for num_batch_dims in range(1, 3):
self._test_linear(
lambda in_size, out_size: LinearReluModule(
in_size,
out_size,
use_bias, # noqa
),
uses_bias=use_bias,
num_batch_dims=num_batch_dims,
)
def test_qs8_linear_fused_relu(self):
for use_bias in (True, False):
for num_batch_dims in range(1, 3):
self._test_linear(
lambda in_size, out_size: LinearReluModule(
in_size,
out_size,
use_bias, # noqa
),
num_batch_dims=num_batch_dims,
uses_bias=use_bias,
quant_type="per_tensor",
)
def test_qs8_linear(self):
for use_bias in (True, False):
for num_batch_dims in range(1, 3):
self._test_linear(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
uses_bias=use_bias,
num_batch_dims=num_batch_dims,
quant_type="per_tensor",
)
# Tests for q[dp]8-f16-qc8w
def test_qd8_f16_per_channel_linear(self):
self._test_qd8_per_channel_linear(dtype=torch.half)
def test_qd8_f16_per_tensor_linear(self):
"""
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
"""
self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half)
# Tests for q[dp]8-f32-qc8w
def test_qd8_f32_per_channel_linear(self):
self._test_qd8_per_channel_linear(dtype=torch.float)
def test_qd8_f32_per_tensor_linear(self):
"""
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
"""
self._test_qd8_linear_per_tensor_unsupported(dtype=torch.half)
# Tests for q[dp]8-f16-qc4w
def test_linear_qd8_f16_per_channel_int4(self):
self._test_qd8_per_channel_4w_linear(dtype=torch.half)
# Tests for q[dp]8-f32-qc4w
def test_linear_qd8_f32_per_channel_int4(self):
self._test_qd8_per_channel_4w_linear(dtype=torch.float)
# Tests for q[dp]8-f16-qb4w
@unittest.skipIf(
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
)
def test_linear_qd8_f16_per_token_weight_per_channel_group_int4(self):
self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.half)
# Tests for q[dp]8-f32-qb4w
@unittest.skipIf(
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
)
def test_linear_qd8_f32_per_token_weight_per_channel_group_int4(self):
self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.float)
@unittest.skipIf(
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
)
def test_linear_qd8_per_token_groupwise_unsupported_groupsize(self):
# groupsize must be multiple of 32
for dtype in [torch.float, torch.half]:
lin_mod = BaseLinear(
in_size=1,
input_channels=60,
output_channels=60,
dtype=dtype,
use_bias=True,
)
inputs = lin_mod.get_inputs()
with self.assertRaisesRegex(
AssertionError,
"Delegation to XNNPACK requires group_size to be a multiple of 32, but got 30",
):
self._test_groupwise_dq_linear(
lin_mod, inputs, group_size=30, use_bias=False, atol=1e-2
)
def test_qd8_per_channel_linear_parallel(self):
in_size = 2
input_size = 4
output_size = 5
for dtype in torch.float, torch.half:
inputs = (
torch.rand(in_size, input_size, dtype=dtype),
torch.rand(in_size, input_size, dtype=dtype),
)
batch_dim = torch.export.Dim("batch", max=100)
dynamic_shapes = ({0: batch_dim}, {0: batch_dim})
self._test_dqlinear(
ParallelLinear(input_size=input_size, output_size=output_size).to(
dtype
),
inputs,
dynamic_shapes=dynamic_shapes,
linear_count=2,
is_per_channel=True,
uses_bias=True,
)
def test_qd8_per_channel_linear_with_two_batch(self):
in_size = 2
input_size = 14
output_size = 15
for dtype in torch.float, torch.half:
for use_bias in (False, True):
linear = BaseLinear(
in_size=in_size,
input_channels=input_size,
output_channels=output_size,
dtype=dtype,
use_bias=use_bias,
)
# Create inputs with two batch dimensions, i.e. 3D activation
inputs = (torch.randn(in_size, in_size, input_size).to(dtype),)
batch_dim = torch.export.Dim("batch", max=100)
dynamic_shapes = ({0: batch_dim, 1: batch_dim},)
self._test_dqlinear(
linear,
inputs,
dynamic_shapes=dynamic_shapes,
linear_count=1,
is_per_channel=True,
uses_bias=True,
)
def test_qd8_per_channel_linear_sequential(self):
lin_mod = LinearSequential()
inputs = lin_mod.get_inputs()
dynamic_shapes = ({0: torch.export.Dim("batch", max=100)},)
self._test_dqlinear(
lin_mod,
inputs,
dynamic_shapes=dynamic_shapes,
linear_count=2,
is_per_channel=True,
uses_bias=True,
atol=1e-1, # TODO(T212995726): Investigate right atol for rand[n] inputs
)
def test_qd8_per_channel_linear_parallel_and_sequential(self):
lin_mod = LinearParallelSequentialModule()
inputs = lin_mod.get_inputs()
dynamic_shapes = (
{0: torch.export.Dim("batch", max=100)},
{0: torch.export.Dim("batch2", max=100)},
)
self._test_dqlinear(
lin_mod,
inputs,
dynamic_shapes=dynamic_shapes,
linear_count=3,
is_per_channel=True,
uses_bias=True,
atol=1e-1, # TODO(T212995726): Investigate right atol for rand[n] inputs
)
def test_linear_qs8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_tensor",
quant_node_checks={
"quantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_tensor.default": 3, # 1: act, 1: weight, 1: output
},
)
def test_linear_qc8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_channel",
quant_node_checks={
"quantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_tensor.default": 2, # 1: act, 1: output
"dequantize_per_channel.default": 1, # 1: weight
},
)
def test_linear_qd8_as_fp32(self):
for use_bias in (True, False):
self._test_linear_overwrite_precision(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
use_bias,
"per_channel_dynamic",
quant_node_checks={
"quantize_per_tensor.tensor": 1, # 1: act
"dequantize_per_tensor.tensor": 1, # 1: act
"dequantize_per_channel.default": 1, # 1: weight
},
)
def test_linear_with_force_non_static_weights_for_f32_linear(self):
def check_signature(
signature: ExportGraphSignature,
force_flag: bool,
use_bias: bool,
legacy_mode: bool,
):
num_params = 0
if force_flag:
num_params = 1 # weight_param
if use_bias:
num_params += 1 # bias_param
sign_params: int = 0
input_specs = signature.input_specs