Skip to content

Commit 0cee60d

Browse files
fused_moe configs for MI325X (#300)
* corrected types for strides in triton FA (#274) (#276) Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> (cherry picked from commit 9a46e97) * fused_moe configs for MI325X New fused_moe configs for Mixtral-8x7B and Mixtral-8x22B with TP=1,2,4,8 for both FP8 and FP16 on the recently announced MI325X. --------- Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
1 parent bb14866 commit 0cee60d

16 files changed

+2912
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 64,
5+
"BLOCK_SIZE_K": 256,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 4,
8+
"num_stages": 0,
9+
"waves_per_eu": 0
10+
},
11+
"2": {
12+
"BLOCK_SIZE_M": 16,
13+
"BLOCK_SIZE_N": 64,
14+
"BLOCK_SIZE_K": 256,
15+
"GROUP_SIZE_M": 1,
16+
"num_warps": 4,
17+
"num_stages": 0,
18+
"waves_per_eu": 0
19+
},
20+
"4": {
21+
"BLOCK_SIZE_M": 16,
22+
"BLOCK_SIZE_N": 32,
23+
"BLOCK_SIZE_K": 128,
24+
"GROUP_SIZE_M": 1,
25+
"num_warps": 2,
26+
"num_stages": 0,
27+
"waves_per_eu": 0
28+
},
29+
"8": {
30+
"BLOCK_SIZE_M": 16,
31+
"BLOCK_SIZE_N": 64,
32+
"BLOCK_SIZE_K": 256,
33+
"GROUP_SIZE_M": 1,
34+
"num_warps": 4,
35+
"num_stages": 0,
36+
"waves_per_eu": 0
37+
},
38+
"16": {
39+
"BLOCK_SIZE_M": 16,
40+
"BLOCK_SIZE_N": 64,
41+
"BLOCK_SIZE_K": 256,
42+
"GROUP_SIZE_M": 1,
43+
"num_warps": 4,
44+
"num_stages": 0,
45+
"waves_per_eu": 0
46+
},
47+
"24": {
48+
"BLOCK_SIZE_M": 16,
49+
"BLOCK_SIZE_N": 64,
50+
"BLOCK_SIZE_K": 256,
51+
"GROUP_SIZE_M": 1,
52+
"num_warps": 4,
53+
"num_stages": 0,
54+
"waves_per_eu": 0
55+
},
56+
"32": {
57+
"BLOCK_SIZE_M": 16,
58+
"BLOCK_SIZE_N": 32,
59+
"BLOCK_SIZE_K": 256,
60+
"GROUP_SIZE_M": 4,
61+
"num_warps": 1,
62+
"num_stages": 0,
63+
"waves_per_eu": 0
64+
},
65+
"48": {
66+
"BLOCK_SIZE_M": 32,
67+
"BLOCK_SIZE_N": 64,
68+
"BLOCK_SIZE_K": 256,
69+
"GROUP_SIZE_M": 1,
70+
"num_warps": 2,
71+
"num_stages": 0,
72+
"waves_per_eu": 0
73+
},
74+
"64": {
75+
"BLOCK_SIZE_M": 32,
76+
"BLOCK_SIZE_N": 64,
77+
"BLOCK_SIZE_K": 256,
78+
"GROUP_SIZE_M": 4,
79+
"num_warps": 2,
80+
"num_stages": 0,
81+
"waves_per_eu": 0
82+
},
83+
"96": {
84+
"BLOCK_SIZE_M": 32,
85+
"BLOCK_SIZE_N": 64,
86+
"BLOCK_SIZE_K": 256,
87+
"GROUP_SIZE_M": 1,
88+
"num_warps": 2,
89+
"num_stages": 0,
90+
"waves_per_eu": 0
91+
},
92+
"128": {
93+
"BLOCK_SIZE_M": 64,
94+
"BLOCK_SIZE_N": 64,
95+
"BLOCK_SIZE_K": 256,
96+
"GROUP_SIZE_M": 4,
97+
"num_warps": 4,
98+
"num_stages": 0,
99+
"waves_per_eu": 0
100+
},
101+
"256": {
102+
"BLOCK_SIZE_M": 128,
103+
"BLOCK_SIZE_N": 128,
104+
"BLOCK_SIZE_K": 256,
105+
"GROUP_SIZE_M": 4,
106+
"num_warps": 8,
107+
"num_stages": 0,
108+
"waves_per_eu": 0
109+
},
110+
"512": {
111+
"BLOCK_SIZE_M": 256,
112+
"BLOCK_SIZE_N": 128,
113+
"BLOCK_SIZE_K": 128,
114+
"GROUP_SIZE_M": 4,
115+
"num_warps": 8,
116+
"num_stages": 0,
117+
"waves_per_eu": 0
118+
},
119+
"1024": {
120+
"BLOCK_SIZE_M": 128,
121+
"BLOCK_SIZE_N": 128,
122+
"BLOCK_SIZE_K": 256,
123+
"GROUP_SIZE_M": 1,
124+
"num_warps": 8,
125+
"num_stages": 0,
126+
"waves_per_eu": 0
127+
},
128+
"1536": {
129+
"BLOCK_SIZE_M": 256,
130+
"BLOCK_SIZE_N": 128,
131+
"BLOCK_SIZE_K": 128,
132+
"GROUP_SIZE_M": 1,
133+
"num_warps": 8,
134+
"num_stages": 0,
135+
"waves_per_eu": 0
136+
},
137+
"2048": {
138+
"BLOCK_SIZE_M": 128,
139+
"BLOCK_SIZE_N": 256,
140+
"BLOCK_SIZE_K": 128,
141+
"GROUP_SIZE_M": 1,
142+
"num_warps": 8,
143+
"num_stages": 0,
144+
"waves_per_eu": 0
145+
},
146+
"3072": {
147+
"BLOCK_SIZE_M": 256,
148+
"BLOCK_SIZE_N": 256,
149+
"BLOCK_SIZE_K": 64,
150+
"GROUP_SIZE_M": 1,
151+
"num_warps": 8,
152+
"num_stages": 0,
153+
"waves_per_eu": 0
154+
},
155+
"4096": {
156+
"BLOCK_SIZE_M": 256,
157+
"BLOCK_SIZE_N": 256,
158+
"BLOCK_SIZE_K": 64,
159+
"GROUP_SIZE_M": 1,
160+
"num_warps": 8,
161+
"num_stages": 0,
162+
"waves_per_eu": 0
163+
}
164+
}
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 16,
5+
"BLOCK_SIZE_K": 256,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 2,
8+
"num_stages": 0,
9+
"waves_per_eu": 0,
10+
"matrix_instr_nonkdim": 16,
11+
"kpack": 1
12+
},
13+
"2": {
14+
"BLOCK_SIZE_M": 16,
15+
"BLOCK_SIZE_N": 16,
16+
"BLOCK_SIZE_K": 256,
17+
"GROUP_SIZE_M": 1,
18+
"num_warps": 4,
19+
"num_stages": 0,
20+
"waves_per_eu": 0,
21+
"matrix_instr_nonkdim": 16,
22+
"kpack": 2
23+
},
24+
"4": {
25+
"BLOCK_SIZE_M": 16,
26+
"BLOCK_SIZE_N": 16,
27+
"BLOCK_SIZE_K": 128,
28+
"GROUP_SIZE_M": 1,
29+
"num_warps": 1,
30+
"num_stages": 0,
31+
"waves_per_eu": 0,
32+
"matrix_instr_nonkdim": 16,
33+
"kpack": 2
34+
},
35+
"8": {
36+
"BLOCK_SIZE_M": 16,
37+
"BLOCK_SIZE_N": 16,
38+
"BLOCK_SIZE_K": 256,
39+
"GROUP_SIZE_M": 1,
40+
"num_warps": 1,
41+
"num_stages": 0,
42+
"waves_per_eu": 0,
43+
"matrix_instr_nonkdim": 16,
44+
"kpack": 2
45+
},
46+
"16": {
47+
"BLOCK_SIZE_M": 16,
48+
"BLOCK_SIZE_N": 64,
49+
"BLOCK_SIZE_K": 64,
50+
"GROUP_SIZE_M": 1,
51+
"num_warps": 2,
52+
"num_stages": 0,
53+
"waves_per_eu": 0,
54+
"matrix_instr_nonkdim": 16,
55+
"kpack": 2
56+
},
57+
"24": {
58+
"BLOCK_SIZE_M": 16,
59+
"BLOCK_SIZE_N": 16,
60+
"BLOCK_SIZE_K": 256,
61+
"GROUP_SIZE_M": 1,
62+
"num_warps": 1,
63+
"num_stages": 0,
64+
"waves_per_eu": 0,
65+
"matrix_instr_nonkdim": 16,
66+
"kpack": 1
67+
},
68+
"32": {
69+
"BLOCK_SIZE_M": 16,
70+
"BLOCK_SIZE_N": 64,
71+
"BLOCK_SIZE_K": 256,
72+
"GROUP_SIZE_M": 4,
73+
"num_warps": 4,
74+
"num_stages": 0,
75+
"waves_per_eu": 0,
76+
"matrix_instr_nonkdim": 16,
77+
"kpack": 1
78+
},
79+
"48": {
80+
"BLOCK_SIZE_M": 16,
81+
"BLOCK_SIZE_N": 64,
82+
"BLOCK_SIZE_K": 128,
83+
"GROUP_SIZE_M": 4,
84+
"num_warps": 1,
85+
"num_stages": 0,
86+
"waves_per_eu": 0,
87+
"matrix_instr_nonkdim": 16,
88+
"kpack": 1
89+
},
90+
"64": {
91+
"BLOCK_SIZE_M": 32,
92+
"BLOCK_SIZE_N": 64,
93+
"BLOCK_SIZE_K": 128,
94+
"GROUP_SIZE_M": 4,
95+
"num_warps": 8,
96+
"num_stages": 0,
97+
"waves_per_eu": 0,
98+
"matrix_instr_nonkdim": 16,
99+
"kpack": 1
100+
},
101+
"96": {
102+
"BLOCK_SIZE_M": 32,
103+
"BLOCK_SIZE_N": 64,
104+
"BLOCK_SIZE_K": 256,
105+
"GROUP_SIZE_M": 4,
106+
"num_warps": 8,
107+
"num_stages": 0,
108+
"waves_per_eu": 0,
109+
"matrix_instr_nonkdim": 16,
110+
"kpack": 1
111+
},
112+
"128": {
113+
"BLOCK_SIZE_M": 64,
114+
"BLOCK_SIZE_N": 64,
115+
"BLOCK_SIZE_K": 128,
116+
"GROUP_SIZE_M": 4,
117+
"num_warps": 4,
118+
"num_stages": 0,
119+
"waves_per_eu": 0,
120+
"matrix_instr_nonkdim": 16,
121+
"kpack": 2
122+
},
123+
"256": {
124+
"BLOCK_SIZE_M": 128,
125+
"BLOCK_SIZE_N": 128,
126+
"BLOCK_SIZE_K": 128,
127+
"GROUP_SIZE_M": 4,
128+
"num_warps": 8,
129+
"num_stages": 0,
130+
"waves_per_eu": 0,
131+
"matrix_instr_nonkdim": 16,
132+
"kpack": 2
133+
},
134+
"512": {
135+
"BLOCK_SIZE_M": 256,
136+
"BLOCK_SIZE_N": 128,
137+
"BLOCK_SIZE_K": 64,
138+
"GROUP_SIZE_M": 4,
139+
"num_warps": 8,
140+
"num_stages": 0,
141+
"waves_per_eu": 0,
142+
"matrix_instr_nonkdim": 16,
143+
"kpack": 2
144+
},
145+
"1024": {
146+
"BLOCK_SIZE_M": 128,
147+
"BLOCK_SIZE_N": 256,
148+
"BLOCK_SIZE_K": 64,
149+
"GROUP_SIZE_M": 1,
150+
"num_warps": 8,
151+
"num_stages": 0,
152+
"waves_per_eu": 0,
153+
"matrix_instr_nonkdim": 16,
154+
"kpack": 2
155+
},
156+
"1536": {
157+
"BLOCK_SIZE_M": 128,
158+
"BLOCK_SIZE_N": 128,
159+
"BLOCK_SIZE_K": 64,
160+
"GROUP_SIZE_M": 1,
161+
"num_warps": 8,
162+
"num_stages": 0,
163+
"waves_per_eu": 0,
164+
"matrix_instr_nonkdim": 16,
165+
"kpack": 2
166+
},
167+
"2048": {
168+
"BLOCK_SIZE_M": 128,
169+
"BLOCK_SIZE_N": 128,
170+
"BLOCK_SIZE_K": 64,
171+
"GROUP_SIZE_M": 1,
172+
"num_warps": 8,
173+
"num_stages": 0,
174+
"waves_per_eu": 0,
175+
"matrix_instr_nonkdim": 16,
176+
"kpack": 2
177+
},
178+
"3072": {
179+
"BLOCK_SIZE_M": 128,
180+
"BLOCK_SIZE_N": 128,
181+
"BLOCK_SIZE_K": 64,
182+
"GROUP_SIZE_M": 1,
183+
"num_warps": 8,
184+
"num_stages": 0,
185+
"waves_per_eu": 0,
186+
"matrix_instr_nonkdim": 16,
187+
"kpack": 2
188+
},
189+
"4096": {
190+
"BLOCK_SIZE_M": 128,
191+
"BLOCK_SIZE_N": 256,
192+
"BLOCK_SIZE_K": 64,
193+
"GROUP_SIZE_M": 1,
194+
"num_warps": 8,
195+
"num_stages": 0,
196+
"waves_per_eu": 0,
197+
"matrix_instr_nonkdim": 16,
198+
"kpack": 1
199+
}
200+
}

0 commit comments

Comments
 (0)