@@ -145,6 +145,135 @@ def fused_head_and_loss_fn(
145145class FusedHeadAndCrossEntropy (PyLayer ):
146146 """Fuse LM Head and CrossEntropyLoss into one module."""
147147
148+ @staticmethod
149+ def _fused_head_and_loss_fn_npu (
150+ hidden_states ,
151+ labels ,
152+ loss_mask ,
153+ lm_head_bias ,
154+ lm_head_weight_cast ,
155+ lm_head_bias_cast ,
156+ grad_lm_head_weight ,
157+ grad_lm_head_bias ,
158+ indices ,
159+ divisor ,
160+ n_tokens ,
161+ loop_chunk_size ,
162+ tensor_parallel_degree ,
163+ tensor_parallel_output ,
164+ model_parallel_group ,
165+ return_token_loss ,
166+ transpose_y ,
167+ dtype ,
168+ ):
169+ token_loss_list = []
170+ grad_hidden_states_list = []
171+
172+ token_idx_section = [loop_chunk_size for _ in range (0 , n_tokens , loop_chunk_size )]
173+ token_idx_section [- 1 ] = - 1
174+ hidden_states_chunk = hidden_states .split (token_idx_section , axis = 0 )
175+ labels_chunk = labels .split (token_idx_section , axis = 0 )
176+ loss_mask_chunk = loss_mask .split (token_idx_section , axis = 0 )
177+
178+ for i in range (len (token_idx_section )):
179+ # logits calculations
180+ logits_chunk_cast = paddle .matmul (
181+ hidden_states_chunk [i ],
182+ lm_head_weight_cast ,
183+ transpose_y = transpose_y ,
184+ )
185+ if lm_head_bias is not None :
186+ logits_chunk_cast += lm_head_bias_cast
187+ if tensor_parallel_degree > 1 and not tensor_parallel_output :
188+ logits_chunk_cast_lst = []
189+ dist .all_gather (
190+ logits_chunk_cast_lst ,
191+ logits_chunk_cast ,
192+ group = model_parallel_group ,
193+ )
194+ logits_chunk_cast = paddle .concat (logits_chunk_cast_lst , axis = - 1 )
195+ logits_chunk = logits_chunk_cast .astype ("float32" )
196+
197+ # log softmax
198+ max_logits = paddle .max (logits_chunk , axis = - 1 , keepdim = True )
199+ if tensor_parallel_degree > 1 and tensor_parallel_output :
200+ dist .all_reduce (max_logits , op = dist .ReduceOp .MAX , group = model_parallel_group )
201+ normalized_logits = logits_chunk - max_logits
202+ exp_logits = paddle .exp (normalized_logits )
203+ sum_exp_logits = paddle .sum (exp_logits , axis = - 1 , keepdim = True )
204+ if tensor_parallel_degree > 1 and tensor_parallel_output :
205+ dist .all_reduce (
206+ sum_exp_logits ,
207+ op = dist .ReduceOp .SUM ,
208+ group = model_parallel_group ,
209+ )
210+ log_sum_exp_logits = paddle .log (sum_exp_logits )
211+
212+ # cross entropy
213+ labels_one_hot = labels_chunk [i ].unsqueeze (1 ) == indices
214+ label_logits = paddle .sum (
215+ paddle .where (
216+ labels_one_hot ,
217+ normalized_logits ,
218+ paddle .zeros_like (normalized_logits ),
219+ ),
220+ axis = - 1 ,
221+ keepdim = True ,
222+ )
223+ if tensor_parallel_degree > 1 and tensor_parallel_output :
224+ dist .all_reduce (
225+ label_logits ,
226+ op = dist .ReduceOp .SUM ,
227+ group = model_parallel_group ,
228+ )
229+ token_loss_chunk = (log_sum_exp_logits - label_logits ).squeeze (1 ) / divisor
230+ cond = loss_mask_chunk [i ].astype ("bool" )
231+ token_loss_chunk = paddle .where (cond , token_loss_chunk , paddle .zeros_like (token_loss_chunk ))
232+ token_loss_list .append ((token_loss_chunk * loss_mask_chunk [i ]))
233+
234+ # gradients calculations
235+ if not return_token_loss :
236+ if tensor_parallel_degree > 1 and not tensor_parallel_output :
237+ exp_logits = exp_logits .split (model_parallel_group .nranks , axis = - 1 )[model_parallel_group .rank ]
238+ labels_one_hot = labels_one_hot .split (model_parallel_group .nranks , axis = - 1 )[
239+ model_parallel_group .rank
240+ ]
241+ grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot .astype ("float32" )) / divisor
242+ grad_logits_chunk = grad_logits_chunk .astype (dtype )
243+ grad_logits_chunk = paddle .where (
244+ cond .unsqueeze (1 ),
245+ grad_logits_chunk ,
246+ paddle .zeros_like (grad_logits_chunk ),
247+ )
248+ if hidden_states .stop_gradient :
249+ grad_hidden_states_list .append (
250+ paddle .matmul (
251+ grad_logits_chunk ,
252+ lm_head_weight_cast ,
253+ transpose_y = not transpose_y ,
254+ )
255+ )
256+ if grad_lm_head_weight is not None :
257+ if transpose_y :
258+ grad_lm_head_weight += paddle .matmul (
259+ grad_logits_chunk ,
260+ hidden_states_chunk [i ],
261+ transpose_x = True ,
262+ )
263+ else :
264+ grad_lm_head_weight += paddle .matmul (
265+ hidden_states_chunk [i ],
266+ grad_logits_chunk ,
267+ transpose_x = True ,
268+ )
269+ if grad_lm_head_bias is not None :
270+ grad_lm_head_bias += grad_logits_chunk .astype ("float32" ).sum (axis = 0 ).astype (dtype )
271+
272+ token_loss = paddle .concat (token_loss_list , axis = 0 )
273+ if hidden_states .stop_gradient :
274+ grad_hidden_states = paddle .concat (grad_hidden_states_list , axis = 0 )
275+ return token_loss , grad_lm_head_weight , grad_lm_head_bias , grad_hidden_states
276+
148277 @staticmethod
149278 def forward (
150279 ctx ,
@@ -239,111 +368,141 @@ def forward(
239368 else :
240369 grad_lm_head_bias = None
241370 if hidden_states .stop_gradient :
242- grad_hidden_states = paddle .zeros_like (hidden_states )
371+ if get_env_device () != "npu" :
372+ grad_hidden_states = paddle .zeros_like (hidden_states )
373+ else :
374+ grad_hidden_states = None
243375 else :
244376 grad_hidden_states = None
245377
246- # initialize outputs
247- token_loss = paddle .empty ((n_tokens ,), dtype = hidden_states .dtype )
248-
249- # blockwise calculations
250- for i in range (0 , n_tokens , loop_chunk_size ):
251- token_start_idx = i
252- token_end_idx = min (i + loop_chunk_size , n_tokens )
253- hidden_states_chunk = hidden_states [token_start_idx :token_end_idx ]
254- labels_chunk = labels [token_start_idx :token_end_idx ]
255-
256- # logits calculations
257- logits_chunk_cast = paddle .matmul (
258- hidden_states_chunk ,
378+ if get_env_device () == "npu" :
379+ (
380+ token_loss ,
381+ grad_lm_head_weight ,
382+ grad_lm_head_bias ,
383+ grad_hidden_states ,
384+ ) = FusedHeadAndCrossEntropy ._fused_head_and_loss_fn_npu (
385+ hidden_states ,
386+ labels ,
387+ loss_mask ,
388+ lm_head_bias ,
259389 lm_head_weight_cast ,
260- transpose_y = transpose_y ,
261- )
262- if lm_head_bias is not None :
263- logits_chunk_cast += lm_head_bias_cast
264- if tensor_parallel_degree > 1 and not tensor_parallel_output :
265- logits_chunk_cast_lst = []
266- dist .all_gather (
267- logits_chunk_cast_lst ,
268- logits_chunk_cast ,
269- group = model_parallel_group ,
270- )
271- logits_chunk_cast = paddle .concat (logits_chunk_cast_lst , axis = - 1 )
272- logits_chunk = logits_chunk_cast .astype ("float32" )
273-
274- # log softmax
275- max_logits = paddle .max (logits_chunk , axis = - 1 , keepdim = True )
276- if tensor_parallel_degree > 1 and tensor_parallel_output :
277- dist .all_reduce (max_logits , op = dist .ReduceOp .MAX , group = model_parallel_group )
278- normalized_logits = logits_chunk - max_logits
279- exp_logits = paddle .exp (normalized_logits )
280- sum_exp_logits = paddle .sum (exp_logits , axis = - 1 , keepdim = True )
281- if tensor_parallel_degree > 1 and tensor_parallel_output :
282- dist .all_reduce (
283- sum_exp_logits ,
284- op = dist .ReduceOp .SUM ,
285- group = model_parallel_group ,
286- )
287- log_sum_exp_logits = paddle .log (sum_exp_logits )
288-
289- # cross entropy
290- labels_one_hot = labels_chunk .unsqueeze (1 ) == indices
291- label_logits = paddle .sum (
292- paddle .where (
293- labels_one_hot ,
294- normalized_logits ,
295- paddle .zeros_like (normalized_logits ),
296- ),
297- axis = - 1 ,
298- keepdim = True ,
390+ lm_head_bias_cast ,
391+ grad_lm_head_weight ,
392+ grad_lm_head_bias ,
393+ indices ,
394+ divisor ,
395+ n_tokens ,
396+ loop_chunk_size ,
397+ tensor_parallel_degree ,
398+ tensor_parallel_output ,
399+ model_parallel_group ,
400+ return_token_loss ,
401+ transpose_y ,
402+ dtype ,
299403 )
300- if tensor_parallel_degree > 1 and tensor_parallel_output :
301- dist .all_reduce (
302- label_logits ,
303- op = dist .ReduceOp .SUM ,
304- group = model_parallel_group ,
404+ else :
405+ # initialize outputs
406+ token_loss = paddle .empty ((n_tokens ,), dtype = hidden_states .dtype )
407+
408+ # blockwise calculations
409+ for i in range (0 , n_tokens , loop_chunk_size ):
410+ token_start_idx = i
411+ token_end_idx = min (i + loop_chunk_size , n_tokens )
412+ hidden_states_chunk = hidden_states [token_start_idx :token_end_idx ]
413+ labels_chunk = labels [token_start_idx :token_end_idx ]
414+
415+ # logits calculations
416+ logits_chunk_cast = paddle .matmul (
417+ hidden_states_chunk ,
418+ lm_head_weight_cast ,
419+ transpose_y = transpose_y ,
305420 )
306- token_loss_chunk = (log_sum_exp_logits - label_logits ).squeeze (1 ) / divisor
307- cond = loss_mask [token_start_idx :token_end_idx ].astype ("bool" )
308- token_loss_chunk = paddle .where (cond , token_loss_chunk , paddle .zeros_like (token_loss_chunk ))
309- token_loss [token_start_idx :token_end_idx ] = token_loss_chunk * loss_mask [token_start_idx :token_end_idx ]
310-
311- # gradients calculations
312- if not return_token_loss :
421+ if lm_head_bias is not None :
422+ logits_chunk_cast += lm_head_bias_cast
313423 if tensor_parallel_degree > 1 and not tensor_parallel_output :
314- exp_logits = exp_logits .split (model_parallel_group .nranks , axis = - 1 )[model_parallel_group .rank ]
315- labels_one_hot = labels_one_hot .split (model_parallel_group .nranks , axis = - 1 )[
316- model_parallel_group .rank
317- ]
318- grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot .astype ("float32" )) / divisor
319- grad_logits_chunk = grad_logits_chunk .astype (dtype )
320- grad_logits_chunk = paddle .where (
321- cond .unsqueeze (1 ),
322- grad_logits_chunk ,
323- paddle .zeros_like (grad_logits_chunk ),
424+ logits_chunk_cast_lst = []
425+ dist .all_gather (
426+ logits_chunk_cast_lst ,
427+ logits_chunk_cast ,
428+ group = model_parallel_group ,
429+ )
430+ logits_chunk_cast = paddle .concat (logits_chunk_cast_lst , axis = - 1 )
431+ logits_chunk = logits_chunk_cast .astype ("float32" )
432+
433+ # log softmax
434+ max_logits = paddle .max (logits_chunk , axis = - 1 , keepdim = True )
435+ if tensor_parallel_degree > 1 and tensor_parallel_output :
436+ dist .all_reduce (max_logits , op = dist .ReduceOp .MAX , group = model_parallel_group )
437+ normalized_logits = logits_chunk - max_logits
438+ exp_logits = paddle .exp (normalized_logits )
439+ sum_exp_logits = paddle .sum (exp_logits , axis = - 1 , keepdim = True )
440+ if tensor_parallel_degree > 1 and tensor_parallel_output :
441+ dist .all_reduce (
442+ sum_exp_logits ,
443+ op = dist .ReduceOp .SUM ,
444+ group = model_parallel_group ,
445+ )
446+ log_sum_exp_logits = paddle .log (sum_exp_logits )
447+
448+ # cross entropy
449+ labels_one_hot = labels_chunk .unsqueeze (1 ) == indices
450+ label_logits = paddle .sum (
451+ paddle .where (
452+ labels_one_hot ,
453+ normalized_logits ,
454+ paddle .zeros_like (normalized_logits ),
455+ ),
456+ axis = - 1 ,
457+ keepdim = True ,
324458 )
325-
326- if grad_hidden_states is not None :
327- grad_hidden_states [token_start_idx :token_end_idx ] = paddle .matmul (
459+ if tensor_parallel_degree > 1 and tensor_parallel_output :
460+ dist .all_reduce (
461+ label_logits ,
462+ op = dist .ReduceOp .SUM ,
463+ group = model_parallel_group ,
464+ )
465+ token_loss_chunk = (log_sum_exp_logits - label_logits ).squeeze (1 ) / divisor
466+ cond = loss_mask [token_start_idx :token_end_idx ].astype ("bool" )
467+ token_loss_chunk = paddle .where (cond , token_loss_chunk , paddle .zeros_like (token_loss_chunk ))
468+ token_loss [token_start_idx :token_end_idx ] = token_loss_chunk * loss_mask [token_start_idx :token_end_idx ]
469+
470+ # gradients calculations
471+ if not return_token_loss :
472+ if tensor_parallel_degree > 1 and not tensor_parallel_output :
473+ exp_logits = exp_logits .split (model_parallel_group .nranks , axis = - 1 )[model_parallel_group .rank ]
474+ labels_one_hot = labels_one_hot .split (model_parallel_group .nranks , axis = - 1 )[
475+ model_parallel_group .rank
476+ ]
477+ grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot .astype ("float32" )) / divisor
478+ grad_logits_chunk = grad_logits_chunk .astype (dtype )
479+ grad_logits_chunk = paddle .where (
480+ cond .unsqueeze (1 ),
328481 grad_logits_chunk ,
329- lm_head_weight_cast ,
330- transpose_y = not transpose_y ,
482+ paddle .zeros_like (grad_logits_chunk ),
331483 )
332- if grad_lm_head_weight is not None :
333- if transpose_y :
334- grad_lm_head_weight += paddle .matmul (
335- grad_logits_chunk ,
336- hidden_states_chunk ,
337- transpose_x = True ,
338- )
339- else :
340- grad_lm_head_weight += paddle .matmul (
341- hidden_states_chunk ,
484+
485+ if hidden_states .stop_gradient :
486+ grad_hidden_states [token_start_idx :token_end_idx ] = paddle .matmul (
342487 grad_logits_chunk ,
343- transpose_x = True ,
488+ lm_head_weight_cast ,
489+ transpose_y = not transpose_y ,
344490 )
345- if grad_lm_head_bias is not None :
346- grad_lm_head_bias += grad_logits_chunk .astype ("float32" ).sum (axis = 0 ).astype (dtype )
491+ if grad_lm_head_weight is not None :
492+ if transpose_y :
493+ grad_lm_head_weight += paddle .matmul (
494+ grad_logits_chunk ,
495+ hidden_states_chunk ,
496+ transpose_x = True ,
497+ )
498+ else :
499+ grad_lm_head_weight += paddle .matmul (
500+ hidden_states_chunk ,
501+ grad_logits_chunk ,
502+ transpose_x = True ,
503+ )
504+ if grad_lm_head_bias is not None :
505+ grad_lm_head_bias += grad_logits_chunk .astype ("float32" ).sum (axis = 0 ).astype (dtype )
347506
348507 if return_token_loss :
349508 loss = token_loss .reshape (original_shape [:- 1 ])
0 commit comments