@@ -247,6 +247,15 @@ def create_weights(
247247 def process_weights_after_loading (self , layer : Module ) -> None :
248248 # Block quant doesn't need to process weights after loading
249249 if self .block_quant :
250+ if current_platform .is_rocm ():
251+ weight , weight_scale , _ = \
252+ normalize_e4m3fn_to_e4m3fnuz (
253+ weight = layer .weight ,
254+ weight_scale = layer .weight_scale_inv ,
255+ input_scale = layer .input_scale )
256+ layer .weight = Parameter (weight , requires_grad = False )
257+ layer .weight_scale_inv = Parameter (weight_scale ,
258+ requires_grad = False )
250259 return
251260 layer .weight = torch .nn .Parameter (layer .weight .data ,
252261 requires_grad = False )
@@ -495,6 +504,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
495504 def process_weights_after_loading (self , layer : Module ) -> None :
496505 # Block quant doesn't need to process weights after loading
497506 if self .block_quant :
507+ if current_platform .is_rocm ():
508+ w13_weight , w13_weight_scale_inv , w13_input_scale = \
509+ normalize_e4m3fn_to_e4m3fnuz (
510+ layer .w13_weight , layer .w13_weight_scale_inv ,
511+ layer .w13_input_scale )
512+ w2_weight , w2_weight_scale_inv , w2_input_scale = \
513+ normalize_e4m3fn_to_e4m3fnuz (
514+ layer .w2_weight , layer .w2_weight_scale_inv ,
515+ layer .w2_input_scale )
516+ # Reset the parameter
517+ layer .w13_weight = torch .nn .Parameter (w13_weight ,
518+ requires_grad = False )
519+ layer .w13_weight_scale_inv = torch .nn .Parameter (
520+ w13_weight_scale_inv , requires_grad = False )
521+ if w13_input_scale is not None :
522+ layer .w13_input_scale = torch .nn .Parameter (
523+ w13_input_scale , requires_grad = False )
524+ layer .w2_weight = torch .nn .Parameter (w2_weight ,
525+ requires_grad = False )
526+ layer .w2_weight_scale_inv = torch .nn .Parameter (
527+ w2_weight_scale_inv , requires_grad = False )
528+ if w2_input_scale is not None :
529+ layer .w2_input_scale = torch .nn .Parameter (
530+ w2_input_scale , requires_grad = False )
498531 return
499532 # If checkpoint is fp16, quantize in place.
500533 if not self .quant_config .is_checkpoint_fp8_serialized :
0 commit comments