Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ namespace turbomind {

using namespace attention;

template void invokeAttention<
typename AttentionConfig<arch::Sm80, nv_bfloat16, 576, CacheType::kLinear>::Kernel>(
template void invokeAttention<typename AttentionConfig<arch::Sm80, nv_bfloat16, 576, CacheType::kLinear>::Kernel>(
const AttentionParams<nv_bfloat16>& params);

} // namespace turbomind
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 1, 576>>(
const AttentionParams<nv_bfloat16>& params);
template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 1, 576>>(const AttentionParams<nv_bfloat16>& params);

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 2, 576>>(
const AttentionParams<nv_bfloat16>& params);
template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 2, 576>>(const AttentionParams<nv_bfloat16>& params);

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 8, 576>>(
const AttentionParams<nv_bfloat16>& params);
template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 8, 576>>(const AttentionParams<nv_bfloat16>& params);

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 16, 576>>(
const AttentionParams<nv_bfloat16>& params);
template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 16, 576>>(const AttentionParams<nv_bfloat16>& params);

} // namespace turbomind
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint4_t, 8, 576>>(
const AttentionParams<nv_bfloat16>&);
template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint4_t, 8, 576>>(const AttentionParams<nv_bfloat16>&);

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint4_t, 16, 576>>(
const AttentionParams<nv_bfloat16>&);
template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint4_t, 16, 576>>(const AttentionParams<nv_bfloat16>&);

} // namespace turbomind
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint8_t, 8, 576>>(
const AttentionParams<nv_bfloat16>&);
template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint8_t, 8, 576>>(const AttentionParams<nv_bfloat16>&);

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint8_t, 16, 576>>(
const AttentionParams<nv_bfloat16>&);
template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint8_t, 16, 576>>(const AttentionParams<nv_bfloat16>&);

} // namespace turbomind
95 changes: 43 additions & 52 deletions src/turbomind/kernels/gemm/moe_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -691,37 +691,37 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
// noaux_tc: scores = scoring_func(logits), scores_for_choice = scores + correction_bias,
// top-k on scores_for_choice, weights from scores; renormalize; apply routed_scale.
// Threading: one token per block, threads cooperate over expert dimension.
__global__ void MoeGateNoAuxTCKernel(float* scales, // [top_k, tokens]
int8_t* masks, // [experts, tokens_padded]
int* accum, // [experts, tiles]
const float* logits, // [tokens, experts]
const float* bias, // [experts] or nullptr
int tokens,
int tokens_padded,
int experts,
int top_k,
bool norm_topk,
float routed_scale,
int log_tile,
int tiles,
bool use_sigmoid)
__global__ void MoeGateNoAuxTCKernel(float* scales, // [top_k, tokens]
int8_t* masks, // [experts, tokens_padded]
int* accum, // [experts, tiles]
const float* logits, // [tokens, experts]
const float* bias, // [experts] or nullptr
int tokens,
int tokens_padded,
int experts,
int top_k,
bool norm_topk,
float routed_scale,
int log_tile,
int tiles,
bool use_sigmoid)
{
const int ti = blockIdx.x; // one token per block
if (ti >= tokens) {
return;
}

extern __shared__ char smem[];
float* scores = (float*)smem;
float* scores_for_choice = scores + experts;
float* scores = (float*)smem;
float* scores_for_choice = scores + experts;

const float* row = logits + ti * experts;

if (use_sigmoid) {
// Sigmoid scoring: scores[e] = 1 / (1 + exp(-logit[e]))
for (int e = threadIdx.x; e < experts; e += blockDim.x) {
float s = 1.0f / (1.0f + expf(-row[e]));
scores[e] = s;
float s = 1.0f / (1.0f + expf(-row[e]));
scores[e] = s;
scores_for_choice[e] = s + (bias ? bias[e] : 0.f);
}
}
Expand All @@ -739,16 +739,16 @@ __global__ void MoeGateNoAuxTCKernel(float* scales, // [top_k, tokens]

float sum_exp = 0.f;
for (int e = threadIdx.x; e < experts; e += blockDim.x) {
float s = expf(row[e] - max_logit);
float s = expf(row[e] - max_logit);
scores[e] = s;
sum_exp += s;
}
sum_exp = blockReduceSum<float>(sum_exp);
__syncthreads();

for (int e = threadIdx.x; e < experts; e += blockDim.x) {
float s = scores[e] / (sum_exp + 1e-20f);
scores[e] = s;
float s = scores[e] / (sum_exp + 1e-20f);
scores[e] = s;
scores_for_choice[e] = s + (bias ? bias[e] : 0.f);
}
}
Expand Down Expand Up @@ -784,7 +784,7 @@ __global__ void MoeGateNoAuxTCKernel(float* scales, // [top_k, tokens]
}
}
if (best_e < 0) {
best_e = 0;
best_e = 0;
topk_val[k] = 0.f;
}
else {
Expand Down Expand Up @@ -819,29 +819,29 @@ __global__ void MoeGateNoAuxTCKernel(float* scales, // [top_k, tokens]
}

void invokeMoeGate_NoAuxTC(int* f2n,
int* f2E,
int* en2f,
int* offsets,
float* scales,
void* masks,
int* accum,
const float* logits,
const float* correction_bias,
int tokens,
int tokens_padded,
int experts,
int exp_per_tok,
bool norm_topk_prob,
float routed_scale,
bool use_sigmoid,
cudaStream_t st)
int* f2E,
int* en2f,
int* offsets,
float* scales,
void* masks,
int* accum,
const float* logits,
const float* correction_bias,
int tokens,
int tokens_padded,
int experts,
int exp_per_tok,
bool norm_topk_prob,
float routed_scale,
bool use_sigmoid,
cudaStream_t st)
{
TM_CHECK(exp_per_tok > 0);
TM_CHECK_LE(exp_per_tok, 32);
TM_CHECK_LE(exp_per_tok, experts);

constexpr int base_log_tile = 9;
int log_tile = base_log_tile;
int log_tile = base_log_tile;
while (((tokens_padded + (1 << log_tile) - 1) >> log_tile) > kMoeGateMaxTiles) {
++log_tile;
}
Expand All @@ -855,8 +855,8 @@ void invokeMoeGate_NoAuxTC(int* f2n,
while (block_dim < experts && block_dim < 256) {
block_dim *= 2; // next power of 2
}
const int blocks = tokens;
const size_t smem = sizeof(float) * experts * 2;
const int blocks = tokens;
const size_t smem = sizeof(float) * experts * 2;

MoeGateNoAuxTCKernel<<<blocks, block_dim, smem, st>>>(scales,
(int8_t*)masks,
Expand All @@ -875,17 +875,8 @@ void invokeMoeGate_NoAuxTC(int* f2n,

constexpr int scan_threads = (1 << base_log_tile) / kMoeGateVecSize;
const dim3 scan_blocks(tiles, experts + 1);
MoeScanKernel_v2<scan_threads><<<scan_blocks, scan_threads, 0, st>>>(f2n,
f2E,
en2f,
offsets,
(int8_t*)masks,
accum,
log_tile,
tiles,
tokens,
tokens_padded,
experts);
MoeScanKernel_v2<scan_threads><<<scan_blocks, scan_threads, 0, st>>>(
f2n, f2E, en2f, offsets, (int8_t*)masks, accum, log_tile, tiles, tokens, tokens_padded, experts);
}

template<int vec_size, int block_dim, class T>
Expand Down
5 changes: 2 additions & 3 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(
// ffn_weight_type for their shared experts (int4 for mixed AWQ,
// bfloat16 for GptOss mxfp4, same as weight_type otherwise).
if (inter_size_) {
const bool is_moe_layer = layer_id < (int)moe_param.expert_num.size()
&& moe_param.expert_num[layer_id];
const DataType ffn_wtype = is_moe_layer ? model.ffn_weight_type : weight_type_;
const bool is_moe_layer = layer_id < (int)moe_param.expert_num.size() && moe_param.expert_num[layer_id];
const DataType ffn_wtype = is_moe_layer ? model.ffn_weight_type : weight_type_;
const bool is_cublas_gemm = byte_size(ffn_wtype, 8) == 16;
ffn_weights.reset(new LlamaFfnWeight{
hidden_units_,
Expand Down
6 changes: 3 additions & 3 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ struct ModelParam {
// Full AWQ int4 int4 int4
// Mixed AWQ float16 int4 int4
// GptOss mxfp4 bfloat16 bfloat16 e2m1
DataType weight_type; // attention weights
DataType expert_weight_type; // MoE routed expert weights
DataType ffn_weight_type; // dense FFN / shared expert weights
DataType weight_type; // attention weights
DataType expert_weight_type; // MoE routed expert weights
DataType ffn_weight_type; // dense FFN / shared expert weights

int group_size;
MLAParam mla;
Expand Down
30 changes: 15 additions & 15 deletions src/turbomind/models/llama/moe_ffn_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,21 @@ void MoeFfnLayer::Forward(ForwardParam& p)

/// TODO: fix illegal memory access even if NaN are present in logits
invokeMoeGate_V2(f2n_.data(),
f2E_.data(),
en2f_.data(),
offsets_.data(),
scales_.data(),
masks_.data(),
accum_.data(),
logits.data(),
tokens,
padded,
expert_num,
param_.experts_per_token,
softmax,
param_.norm_topk_prob,
param_.routed_scale,
st);
f2E_.data(),
en2f_.data(),
offsets_.data(),
scales_.data(),
masks_.data(),
accum_.data(),
logits.data(),
tokens,
padded,
expert_num,
param_.experts_per_token,
softmax,
param_.norm_topk_prob,
param_.routed_scale,
st);
}
sync_check_cuda_error();

Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/turbomind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ TurboMind::Impl::Impl(string model_dir, string config, FFICtxFactory ffi_ctx_fac

model_param_.weight_type = data_type_from_string(model["weight_type"].as<std::string>());
model_param_.expert_weight_type = data_type_from_string(model["expert_weight_type"].as<std::string>());
model_param_.ffn_weight_type = data_type_from_string(
model["ffn_weight_type"].as<std::string>(model["weight_type"].as<std::string>()));
model_param_.ffn_weight_type =
data_type_from_string(model["ffn_weight_type"].as<std::string>(model["weight_type"].as<std::string>()));

if (auto method = get_moe_method()) {
moe_param_.method = *method;
Expand Down