-
Notifications
You must be signed in to change notification settings - Fork 90
Expand file tree
/
Copy pathmla_preprocess_kernel.cpp
More file actions
207 lines (190 loc) · 10.1 KB
/
mla_preprocess_kernel.cpp
File metadata and controls
207 lines (190 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#include "kernel_operator.h"
#include "mla_preprocess_mix_fp16.hpp"
#include "mla_preprocess_mix_bf16.hpp"
#include "../op_host/tiling/mla_preprocess_tiling.h"
extern "C" __global__ __aicore__ void mla_preprocess(
GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
{
PRELOAD(2);
SetAtomicnone();
SetMasknorm();
#ifdef __DAV_C220_CUBE__
SetPadding<uint64_t>((uint64_t)0);
SetNdpara(1, 0, 0);
#endif
MlaTilingData mlaTilingData;
__gm__ MlaTilingData *tilingData = reinterpret_cast<__gm__ MlaTilingData *>(tiling);
mlaTilingData.tilingKey = tilingData->tilingKey;
mlaTilingData.n = tilingData->n;
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
mlaTilingData.mm1.m = tilingData->mm1.m;
mlaTilingData.mm1.k = tilingData->mm1.k;
mlaTilingData.mm1.n = tilingData->mm1.n;
mlaTilingData.mm1.m0 = tilingData->mm1.m0;
mlaTilingData.mm1.k0 = tilingData->mm1.k0;
mlaTilingData.mm1.n0 = tilingData->mm1.n0;
mlaTilingData.mm1.mLoop = tilingData->mm1.mLoop;
mlaTilingData.mm1.kLoop = tilingData->mm1.kLoop;
mlaTilingData.mm1.nLoop = tilingData->mm1.nLoop;
mlaTilingData.mm1.coreLoop = tilingData->mm1.coreLoop;
mlaTilingData.mm1.swizzleCount = tilingData->mm1.swizzleCount;
mlaTilingData.mm1.enShuffleK = tilingData->mm1.enShuffleK;
mlaTilingData.mm1.blockDim = tilingData->mm1.blockDim;
mlaTilingData.mm1.enLoadAllAmat = tilingData->mm1.enLoadAllAmat;
mlaTilingData.mm1.b0matPingPongBufferLen = tilingData->mm1.b0matPingPongBufferLen;
mlaTilingData.mm2.numBatch = tilingData->mm2.numBatch;
mlaTilingData.mm2.m = tilingData->mm2.m;
mlaTilingData.mm2.k = tilingData->mm2.k;
mlaTilingData.mm2.n = tilingData->mm2.n;
mlaTilingData.mm2.m0 = tilingData->mm2.m0;
mlaTilingData.mm2.k0 = tilingData->mm2.k0;
mlaTilingData.mm2.n0 = tilingData->mm2.n0;
mlaTilingData.mm2.mLoop = tilingData->mm2.mLoop;
mlaTilingData.mm2.kLoop = tilingData->mm2.kLoop;
mlaTilingData.mm2.nLoop = tilingData->mm2.nLoop;
mlaTilingData.mm2.coreLoop = tilingData->mm2.coreLoop;
mlaTilingData.mm2.swizzleCount = tilingData->mm2.swizzleCount;
mlaTilingData.mm2.enShuffleK = tilingData->mm2.enShuffleK;
mlaTilingData.mm2.blockDim = tilingData->mm2.blockDim;
mlaTilingData.mm2.enLoadAllAmat = tilingData->mm2.enLoadAllAmat;
mlaTilingData.mm2.b0matPingPongBufferLen = tilingData->mm2.b0matPingPongBufferLen;
mlaTilingData.mm3.numBatch = tilingData->mm3.numBatch;
mlaTilingData.mm3.m = tilingData->mm3.m;
mlaTilingData.mm3.k = tilingData->mm3.k;
mlaTilingData.mm3.n = tilingData->mm3.n;
mlaTilingData.mm3.m0 = tilingData->mm3.m0;
mlaTilingData.mm3.k0 = tilingData->mm3.k0;
mlaTilingData.mm3.n0 = tilingData->mm3.n0;
mlaTilingData.mm3.mLoop = tilingData->mm3.mLoop;
mlaTilingData.mm3.kLoop = tilingData->mm3.kLoop;
mlaTilingData.mm3.nLoop = tilingData->mm3.nLoop;
mlaTilingData.mm3.coreLoop = tilingData->mm3.coreLoop;
mlaTilingData.mm3.swizzleCount = tilingData->mm3.swizzleCount;
mlaTilingData.mm3.enShuffleK = tilingData->mm3.enShuffleK;
mlaTilingData.mm3.blockDim = tilingData->mm3.blockDim;
mlaTilingData.perTaskNum = tilingData->perTaskNum;
mlaTilingData.resTaskNum = tilingData->resTaskNum;
mlaTilingData.numCore = tilingData->numCore;
mlaTilingData.rmsNumCore1 = tilingData->rmsNumCore1;
mlaTilingData.rmsNumCol1 = tilingData->rmsNumCol1;
mlaTilingData.rmsNumCore2 = tilingData->rmsNumCore2;
mlaTilingData.rmsNumCol2 = tilingData->rmsNumCol2;
mlaTilingData.hiddenSizeQ = tilingData->hiddenSizeQ;
mlaTilingData.headNumQ = tilingData->headNumQ;
mlaTilingData.headDim = tilingData->headDim;
mlaTilingData.concatSize = tilingData->concatSize;
mlaTilingData.rotaryCoeff = tilingData->rotaryCoeff;
mlaTilingData.ntokens = tilingData->ntokens;
mlaTilingData.realCore = tilingData->realCore;
mlaTilingData.nlCoreRun = tilingData->nlCoreRun;
mlaTilingData.lCoreRun = tilingData->lCoreRun;
mlaTilingData.maxNPerLoopForUb = tilingData->maxNPerLoopForUb;
mlaTilingData.preCoreLoopTime = tilingData->preCoreLoopTime;
mlaTilingData.preCoreLoopNLast = tilingData->preCoreLoopNLast;
mlaTilingData.lastCoreLoopTime = tilingData->lastCoreLoopTime;
mlaTilingData.lastCoreLoopNLast = tilingData->lastCoreLoopNLast;
mlaTilingData.esqFrontCore = tilingData->esqFrontCore;
mlaTilingData.esqTailCore = tilingData->esqTailCore;
mlaTilingData.esqFrontCoreBatch = tilingData->esqFrontCoreBatch;
mlaTilingData.esqTailCoreBatch = tilingData->esqTailCoreBatch;
mlaTilingData.esqHeadNum = tilingData->esqHeadNum;
mlaTilingData.esqColNum = tilingData->esqColNum;
mlaTilingData.esqUbHeadLoop = tilingData->esqUbHeadLoop;
mlaTilingData.esqHeadPerLoop = tilingData->esqHeadPerLoop;
mlaTilingData.esqHeadTail = tilingData->esqHeadTail;
mlaTilingData.esqColLoop = tilingData->esqColLoop;
mlaTilingData.esqColTail = tilingData->esqColTail;
mlaTilingData.s1Offset = tilingData->s1Offset;
mlaTilingData.s2Offset = tilingData->s2Offset;
mlaTilingData.s3Offset = tilingData->s3Offset;
mlaTilingData.s4Offset = tilingData->s4Offset;
mlaTilingData.s5Offset = tilingData->s5Offset;
GM_ADDR s1 = workspace + static_cast<uint64_t>(mlaTilingData.s1Offset);
GM_ADDR s2 = workspace + static_cast<uint64_t>(mlaTilingData.s2Offset);
GM_ADDR s3 = workspace + static_cast<uint64_t>(mlaTilingData.s3Offset);
GM_ADDR s4 = workspace + static_cast<uint64_t>(mlaTilingData.s4Offset);
GM_ADDR s5 = workspace + static_cast<uint64_t>(mlaTilingData.s5Offset);
switch (mlaTilingData.tilingKey) {
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
mlaTilingData, tiling);
opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
if ASCEND_IS_AIC {
opFp16Cm0Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opFp16Cm0Qm0.ProcessVector();
}
break;
}
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
opFp16Cm1Qm0(mlaTilingData, tiling);
opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
if ASCEND_IS_AIC {
opFp16Cm1Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opFp16Cm1Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0(mlaTilingData, tiling);
opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm0Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm0Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm1Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm1Qm0.ProcessVector();
}
break;
}
default: {
break;
}
}
return;
}