Skip to content

Commit e8f1d28

Browse files
committed
Add coord check data collection to matmul fwd
1 parent b1014b4 commit e8f1d28

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

llmc/matmul.cuh

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s
109109
void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* bias,
110110
int m, int n, int k, cudaStream_t stream=0, bool transA=true, bool transB=false,
111111
int batch_count=0, size_t strideA=0, size_t strideB=0, size_t strideOut=0,
112-
bool accumulate=false, floatX* pre_gelu=NULL, bool backward=false)
112+
bool accumulate=false, floatX* pre_gelu=NULL, bool backward=false, float* coord_check_data=NULL, int cc_cnt=0)
113113
{
114114
NVTX_RANGE_FN();
115115
bool has_bias = (bias != NULL);
@@ -225,19 +225,32 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX*
225225
cublasCheck(cublasLtMatrixLayoutDestroy(CLayout));
226226
cublasCheck(cublasLtMatrixLayoutDestroy(DLayout));
227227
cudaCheck(cudaGetLastError());
228+
229+
// data collection
230+
if (coord_check_data != NULL) {
231+
float sum = 0.0;
232+
float* sum_d;
233+
cudaMalloc(&sum_d, sizeof(float));
234+
cudaCheck(cudaMemsetAsync(sum_d, 0, sizeof(float), stream));
235+
abs_sum_kernel<<<n, WARP_SIZE, 0, stream>>>(d, n, m, sum_d);
236+
cudaCheck(cudaGetLastError());
237+
cudaCheck(cudaMemcpy(&sum, sum_d, sizeof(float), cudaMemcpyDeviceToHost));
238+
cudaCheck(cudaFree(sum_d));
239+
coord_check_data[cc_cnt] = sum / (n*m);
240+
}
228241
}
229242

230243
// small wrapper around matmul_cublaslt for the forward pass (keeping historical order of arguments)
231244
void matmul_forward_cublaslt(floatX* out,
232245
floatX* inp, floatX* weight, floatX* bias,
233-
int B, int T, int C, int OC, cudaStream_t stream,
246+
int B, int T, int C, int OC, cudaStream_t stream, float* coord_check_data=NULL, int* cc_cnt=NULL,
234247
floatX* pre_gelu=NULL, int gelu_fusion=1) {
235248
// By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?)
236249
if (gelu_fusion < 1 && pre_gelu) {
237-
matmul_cublaslt(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false);
238-
gelu_forward(out, pre_gelu, B*T*OC, NULL, 0, stream);
250+
matmul_cublaslt(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false, coord_check_data, (*cc_cnt)++);
251+
gelu_forward(out, pre_gelu, B*T*OC, coord_check_data, (*cc_cnt)++, stream);
239252
} else {
240-
matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false);
253+
matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false, coord_check_data, (*cc_cnt)++);
241254
}
242255
}
243256

train_gpt2.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -709,20 +709,20 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T, int step,
709709
// now do the forward pass
710710
#ifdef ENABLE_CUDNN
711711
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
712-
matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, NULL, 0., main_stream);
712+
matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);
713713
attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, model->use_mup, model->mup_base_attn_mult, main_stream);
714714
#else
715715
floatX* l_att = acts.att + l * B * NH * T * T;
716716
// these are only needed as scratchpads for the forward pass, but
717717
// need not be stored for backward
718-
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);
718+
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream, coord_check_data, &cc_cnt);
719719
attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, model->use_mup, model->mup_base_attn_mult, coord_check_data, cc_cnt++, main_stream);
720720
#endif
721721

722-
matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream);
722+
matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream, coord_check_data, &cc_cnt);
723723
fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, 0., 0., coord_check_data, cc_cnt++, main_stream);
724-
matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion);
725-
matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream);
724+
matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, coord_check_data, &cc_cnt, l_fch, model->gelu_fusion);
725+
matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream, coord_check_data, &cc_cnt);
726726
// OK, fusion across blocks.
727727
if(l+1 != L) {
728728
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf;
@@ -739,7 +739,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T, int step,
739739
}
740740
}
741741

742-
matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);
742+
matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream, coord_check_data, &cc_cnt);
743743
cudaCheck(cudaDeviceSynchronize());
744744
}
745745

0 commit comments

Comments
 (0)