Skip to content

Commit 7544b94

Browse files
committed
Adapting npu for FusedHeadAndCrossEntropy (PaddlePaddle#9499)
* Adapting npu for FusedHeadAndCrossEntropy * wrapper npu function
1 parent 5861dfc commit 7544b94

File tree

1 file changed

+252
-93
lines changed

1 file changed

+252
-93
lines changed

paddlenlp/transformers/tensor_parallel_utils.py

Lines changed: 252 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,135 @@ def fused_head_and_loss_fn(
145145
class FusedHeadAndCrossEntropy(PyLayer):
146146
"""Fuse LM Head and CrossEntropyLoss into one module."""
147147

148+
@staticmethod
149+
def _fused_head_and_loss_fn_npu(
150+
hidden_states,
151+
labels,
152+
loss_mask,
153+
lm_head_bias,
154+
lm_head_weight_cast,
155+
lm_head_bias_cast,
156+
grad_lm_head_weight,
157+
grad_lm_head_bias,
158+
indices,
159+
divisor,
160+
n_tokens,
161+
loop_chunk_size,
162+
tensor_parallel_degree,
163+
tensor_parallel_output,
164+
model_parallel_group,
165+
return_token_loss,
166+
transpose_y,
167+
dtype,
168+
):
169+
token_loss_list = []
170+
grad_hidden_states_list = []
171+
172+
token_idx_section = [loop_chunk_size for _ in range(0, n_tokens, loop_chunk_size)]
173+
token_idx_section[-1] = -1
174+
hidden_states_chunk = hidden_states.split(token_idx_section, axis=0)
175+
labels_chunk = labels.split(token_idx_section, axis=0)
176+
loss_mask_chunk = loss_mask.split(token_idx_section, axis=0)
177+
178+
for i in range(len(token_idx_section)):
179+
# logits calculations
180+
logits_chunk_cast = paddle.matmul(
181+
hidden_states_chunk[i],
182+
lm_head_weight_cast,
183+
transpose_y=transpose_y,
184+
)
185+
if lm_head_bias is not None:
186+
logits_chunk_cast += lm_head_bias_cast
187+
if tensor_parallel_degree > 1 and not tensor_parallel_output:
188+
logits_chunk_cast_lst = []
189+
dist.all_gather(
190+
logits_chunk_cast_lst,
191+
logits_chunk_cast,
192+
group=model_parallel_group,
193+
)
194+
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
195+
logits_chunk = logits_chunk_cast.astype("float32")
196+
197+
# log softmax
198+
max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True)
199+
if tensor_parallel_degree > 1 and tensor_parallel_output:
200+
dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group)
201+
normalized_logits = logits_chunk - max_logits
202+
exp_logits = paddle.exp(normalized_logits)
203+
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
204+
if tensor_parallel_degree > 1 and tensor_parallel_output:
205+
dist.all_reduce(
206+
sum_exp_logits,
207+
op=dist.ReduceOp.SUM,
208+
group=model_parallel_group,
209+
)
210+
log_sum_exp_logits = paddle.log(sum_exp_logits)
211+
212+
# cross entropy
213+
labels_one_hot = labels_chunk[i].unsqueeze(1) == indices
214+
label_logits = paddle.sum(
215+
paddle.where(
216+
labels_one_hot,
217+
normalized_logits,
218+
paddle.zeros_like(normalized_logits),
219+
),
220+
axis=-1,
221+
keepdim=True,
222+
)
223+
if tensor_parallel_degree > 1 and tensor_parallel_output:
224+
dist.all_reduce(
225+
label_logits,
226+
op=dist.ReduceOp.SUM,
227+
group=model_parallel_group,
228+
)
229+
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
230+
cond = loss_mask_chunk[i].astype("bool")
231+
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
232+
token_loss_list.append((token_loss_chunk * loss_mask_chunk[i]))
233+
234+
# gradients calculations
235+
if not return_token_loss:
236+
if tensor_parallel_degree > 1 and not tensor_parallel_output:
237+
exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank]
238+
labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[
239+
model_parallel_group.rank
240+
]
241+
grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor
242+
grad_logits_chunk = grad_logits_chunk.astype(dtype)
243+
grad_logits_chunk = paddle.where(
244+
cond.unsqueeze(1),
245+
grad_logits_chunk,
246+
paddle.zeros_like(grad_logits_chunk),
247+
)
248+
if hidden_states.stop_gradient:
249+
grad_hidden_states_list.append(
250+
paddle.matmul(
251+
grad_logits_chunk,
252+
lm_head_weight_cast,
253+
transpose_y=not transpose_y,
254+
)
255+
)
256+
if grad_lm_head_weight is not None:
257+
if transpose_y:
258+
grad_lm_head_weight += paddle.matmul(
259+
grad_logits_chunk,
260+
hidden_states_chunk[i],
261+
transpose_x=True,
262+
)
263+
else:
264+
grad_lm_head_weight += paddle.matmul(
265+
hidden_states_chunk[i],
266+
grad_logits_chunk,
267+
transpose_x=True,
268+
)
269+
if grad_lm_head_bias is not None:
270+
grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype)
271+
272+
token_loss = paddle.concat(token_loss_list, axis=0)
273+
if hidden_states.stop_gradient:
274+
grad_hidden_states = paddle.concat(grad_hidden_states_list, axis=0)
275+
return token_loss, grad_lm_head_weight, grad_lm_head_bias, grad_hidden_states
276+
148277
@staticmethod
149278
def forward(
150279
ctx,
@@ -239,111 +368,141 @@ def forward(
239368
else:
240369
grad_lm_head_bias = None
241370
if hidden_states.stop_gradient:
242-
grad_hidden_states = paddle.zeros_like(hidden_states)
371+
if get_env_device() != "npu":
372+
grad_hidden_states = paddle.zeros_like(hidden_states)
373+
else:
374+
grad_hidden_states = None
243375
else:
244376
grad_hidden_states = None
245377

246-
# initialize outputs
247-
token_loss = paddle.empty((n_tokens,), dtype=hidden_states.dtype)
248-
249-
# blockwise calculations
250-
for i in range(0, n_tokens, loop_chunk_size):
251-
token_start_idx = i
252-
token_end_idx = min(i + loop_chunk_size, n_tokens)
253-
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
254-
labels_chunk = labels[token_start_idx:token_end_idx]
255-
256-
# logits calculations
257-
logits_chunk_cast = paddle.matmul(
258-
hidden_states_chunk,
378+
if get_env_device() == "npu":
379+
(
380+
token_loss,
381+
grad_lm_head_weight,
382+
grad_lm_head_bias,
383+
grad_hidden_states,
384+
) = FusedHeadAndCrossEntropy._fused_head_and_loss_fn_npu(
385+
hidden_states,
386+
labels,
387+
loss_mask,
388+
lm_head_bias,
259389
lm_head_weight_cast,
260-
transpose_y=transpose_y,
261-
)
262-
if lm_head_bias is not None:
263-
logits_chunk_cast += lm_head_bias_cast
264-
if tensor_parallel_degree > 1 and not tensor_parallel_output:
265-
logits_chunk_cast_lst = []
266-
dist.all_gather(
267-
logits_chunk_cast_lst,
268-
logits_chunk_cast,
269-
group=model_parallel_group,
270-
)
271-
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
272-
logits_chunk = logits_chunk_cast.astype("float32")
273-
274-
# log softmax
275-
max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True)
276-
if tensor_parallel_degree > 1 and tensor_parallel_output:
277-
dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group)
278-
normalized_logits = logits_chunk - max_logits
279-
exp_logits = paddle.exp(normalized_logits)
280-
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
281-
if tensor_parallel_degree > 1 and tensor_parallel_output:
282-
dist.all_reduce(
283-
sum_exp_logits,
284-
op=dist.ReduceOp.SUM,
285-
group=model_parallel_group,
286-
)
287-
log_sum_exp_logits = paddle.log(sum_exp_logits)
288-
289-
# cross entropy
290-
labels_one_hot = labels_chunk.unsqueeze(1) == indices
291-
label_logits = paddle.sum(
292-
paddle.where(
293-
labels_one_hot,
294-
normalized_logits,
295-
paddle.zeros_like(normalized_logits),
296-
),
297-
axis=-1,
298-
keepdim=True,
390+
lm_head_bias_cast,
391+
grad_lm_head_weight,
392+
grad_lm_head_bias,
393+
indices,
394+
divisor,
395+
n_tokens,
396+
loop_chunk_size,
397+
tensor_parallel_degree,
398+
tensor_parallel_output,
399+
model_parallel_group,
400+
return_token_loss,
401+
transpose_y,
402+
dtype,
299403
)
300-
if tensor_parallel_degree > 1 and tensor_parallel_output:
301-
dist.all_reduce(
302-
label_logits,
303-
op=dist.ReduceOp.SUM,
304-
group=model_parallel_group,
404+
else:
405+
# initialize outputs
406+
token_loss = paddle.empty((n_tokens,), dtype=hidden_states.dtype)
407+
408+
# blockwise calculations
409+
for i in range(0, n_tokens, loop_chunk_size):
410+
token_start_idx = i
411+
token_end_idx = min(i + loop_chunk_size, n_tokens)
412+
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
413+
labels_chunk = labels[token_start_idx:token_end_idx]
414+
415+
# logits calculations
416+
logits_chunk_cast = paddle.matmul(
417+
hidden_states_chunk,
418+
lm_head_weight_cast,
419+
transpose_y=transpose_y,
305420
)
306-
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
307-
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
308-
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
309-
token_loss[token_start_idx:token_end_idx] = token_loss_chunk * loss_mask[token_start_idx:token_end_idx]
310-
311-
# gradients calculations
312-
if not return_token_loss:
421+
if lm_head_bias is not None:
422+
logits_chunk_cast += lm_head_bias_cast
313423
if tensor_parallel_degree > 1 and not tensor_parallel_output:
314-
exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank]
315-
labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[
316-
model_parallel_group.rank
317-
]
318-
grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor
319-
grad_logits_chunk = grad_logits_chunk.astype(dtype)
320-
grad_logits_chunk = paddle.where(
321-
cond.unsqueeze(1),
322-
grad_logits_chunk,
323-
paddle.zeros_like(grad_logits_chunk),
424+
logits_chunk_cast_lst = []
425+
dist.all_gather(
426+
logits_chunk_cast_lst,
427+
logits_chunk_cast,
428+
group=model_parallel_group,
429+
)
430+
logits_chunk_cast = paddle.concat(logits_chunk_cast_lst, axis=-1)
431+
logits_chunk = logits_chunk_cast.astype("float32")
432+
433+
# log softmax
434+
max_logits = paddle.max(logits_chunk, axis=-1, keepdim=True)
435+
if tensor_parallel_degree > 1 and tensor_parallel_output:
436+
dist.all_reduce(max_logits, op=dist.ReduceOp.MAX, group=model_parallel_group)
437+
normalized_logits = logits_chunk - max_logits
438+
exp_logits = paddle.exp(normalized_logits)
439+
sum_exp_logits = paddle.sum(exp_logits, axis=-1, keepdim=True)
440+
if tensor_parallel_degree > 1 and tensor_parallel_output:
441+
dist.all_reduce(
442+
sum_exp_logits,
443+
op=dist.ReduceOp.SUM,
444+
group=model_parallel_group,
445+
)
446+
log_sum_exp_logits = paddle.log(sum_exp_logits)
447+
448+
# cross entropy
449+
labels_one_hot = labels_chunk.unsqueeze(1) == indices
450+
label_logits = paddle.sum(
451+
paddle.where(
452+
labels_one_hot,
453+
normalized_logits,
454+
paddle.zeros_like(normalized_logits),
455+
),
456+
axis=-1,
457+
keepdim=True,
324458
)
325-
326-
if grad_hidden_states is not None:
327-
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
459+
if tensor_parallel_degree > 1 and tensor_parallel_output:
460+
dist.all_reduce(
461+
label_logits,
462+
op=dist.ReduceOp.SUM,
463+
group=model_parallel_group,
464+
)
465+
token_loss_chunk = (log_sum_exp_logits - label_logits).squeeze(1) / divisor
466+
cond = loss_mask[token_start_idx:token_end_idx].astype("bool")
467+
token_loss_chunk = paddle.where(cond, token_loss_chunk, paddle.zeros_like(token_loss_chunk))
468+
token_loss[token_start_idx:token_end_idx] = token_loss_chunk * loss_mask[token_start_idx:token_end_idx]
469+
470+
# gradients calculations
471+
if not return_token_loss:
472+
if tensor_parallel_degree > 1 and not tensor_parallel_output:
473+
exp_logits = exp_logits.split(model_parallel_group.nranks, axis=-1)[model_parallel_group.rank]
474+
labels_one_hot = labels_one_hot.split(model_parallel_group.nranks, axis=-1)[
475+
model_parallel_group.rank
476+
]
477+
grad_logits_chunk = (exp_logits / sum_exp_logits - labels_one_hot.astype("float32")) / divisor
478+
grad_logits_chunk = grad_logits_chunk.astype(dtype)
479+
grad_logits_chunk = paddle.where(
480+
cond.unsqueeze(1),
328481
grad_logits_chunk,
329-
lm_head_weight_cast,
330-
transpose_y=not transpose_y,
482+
paddle.zeros_like(grad_logits_chunk),
331483
)
332-
if grad_lm_head_weight is not None:
333-
if transpose_y:
334-
grad_lm_head_weight += paddle.matmul(
335-
grad_logits_chunk,
336-
hidden_states_chunk,
337-
transpose_x=True,
338-
)
339-
else:
340-
grad_lm_head_weight += paddle.matmul(
341-
hidden_states_chunk,
484+
485+
if hidden_states.stop_gradient:
486+
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(
342487
grad_logits_chunk,
343-
transpose_x=True,
488+
lm_head_weight_cast,
489+
transpose_y=not transpose_y,
344490
)
345-
if grad_lm_head_bias is not None:
346-
grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype)
491+
if grad_lm_head_weight is not None:
492+
if transpose_y:
493+
grad_lm_head_weight += paddle.matmul(
494+
grad_logits_chunk,
495+
hidden_states_chunk,
496+
transpose_x=True,
497+
)
498+
else:
499+
grad_lm_head_weight += paddle.matmul(
500+
hidden_states_chunk,
501+
grad_logits_chunk,
502+
transpose_x=True,
503+
)
504+
if grad_lm_head_bias is not None:
505+
grad_lm_head_bias += grad_logits_chunk.astype("float32").sum(axis=0).astype(dtype)
347506

348507
if return_token_loss:
349508
loss = token_loss.reshape(original_shape[:-1])

0 commit comments

Comments
 (0)