@@ -709,20 +709,20 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T, int step,
709709 // now do the forward pass
710710 #ifdef ENABLE_CUDNN
711711 float * l_att = (float *)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
712- matmul_forward_cublaslt (l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3 *C, NULL , 0 ., main_stream);
712+ matmul_forward_cublaslt (l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3 *C, main_stream);
713713 attention_forward_cudnn (l_atty, (float *)l_att, l_qkvr, B, T, NH, C, model->use_mup , model->mup_base_attn_mult , main_stream);
714714 #else
715715 floatX* l_att = acts.att + l * B * NH * T * T;
716716 // these are only needed as scratchpads for the forward pass, but
717717 // need not be stored for backward
718- matmul_forward_cublaslt (scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3 *C, main_stream);
718+ matmul_forward_cublaslt (scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3 *C, main_stream, coord_check_data, &cc_cnt );
719719 attention_forward (l_atty, l_qkvr, l_att, scratch, B, T, C, NH, model->use_mup , model->mup_base_attn_mult , coord_check_data, cc_cnt++, main_stream);
720720 #endif
721721
722- matmul_forward_cublaslt (scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream);
722+ matmul_forward_cublaslt (scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream, coord_check_data, &cc_cnt );
723723 fused_residual_forward5 (l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, 0 ., 0 ., coord_check_data, cc_cnt++, main_stream);
724- matmul_forward_cublaslt (l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4 *C, main_stream, l_fch, model->gelu_fusion );
725- matmul_forward_cublaslt (scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4 *C, C, main_stream);
724+ matmul_forward_cublaslt (l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4 *C, main_stream, coord_check_data, &cc_cnt, l_fch, model->gelu_fusion );
725+ matmul_forward_cublaslt (scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4 *C, C, main_stream, coord_check_data, &cc_cnt );
726726 // OK, fusion across blocks.
727727 if (l+1 != L) {
728728 floatX* l_ln1 = (model->recompute < 2 ) ? acts.ln1 + (l + 1 ) * B * T * C : acts.lnf ;
@@ -739,7 +739,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T, int step,
739739 }
740740 }
741741
742- matmul_forward_cublaslt (acts.output , acts.lnf , params.wte , NULL , B, T, C, Vp, main_stream);
742+ matmul_forward_cublaslt (acts.output , acts.lnf , params.wte , NULL , B, T, C, Vp, main_stream, coord_check_data, &cc_cnt );
743743 cudaCheck (cudaDeviceSynchronize ());
744744}
745745
0 commit comments