@@ -549,7 +549,7 @@ def _prepare_adapter_config(peft_config, model_config):
549549
550550 def _unload_and_optionally_merge (self , merge = True , progressbar : bool = False ):
551551 if merge :
552- if getattr (self .model , "is_loaded_in_8bit" , False ) or getattr ( self . model , "is_loaded_in_4bit" , False ) :
552+ if getattr (self .model , "is_loaded_in_8bit" , False ):
553553 raise ValueError ("Cannot merge LORA layers when the model is loaded in 8-bit mode" )
554554 if getattr (self .model , "quantization_method" , None ) == "gptq" :
555555 raise ValueError ("Cannot merge LORA layers when the model is gptq quantized" )
@@ -573,6 +573,17 @@ def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
573573 padding = target .padding ,
574574 dilation = target .dilation ,
575575 )
576+ elif is_bnb_4bit_available () and isinstance (target , bnb .nn .Linear4bit ):
577+ bias = target .bias is not None
578+ new_module = bnb .nn .Linear4bit (
579+ target .in_features ,
580+ target .out_features ,
581+ bias = bias ,
582+ compute_dtype = target .compute_dtype ,
583+ compress_statistics = target .weight .compress_statistics ,
584+ quant_type = target .weight .quant_type ,
585+ device = target .weight .device ,
586+ )
576587 else :
577588 bias = target .bias is not None
578589 if getattr (target , "is_target_conv_1d_layer" , False ):
@@ -1193,8 +1204,49 @@ def __init__(
11931204 self .update_layer (adapter_name , r , lora_alpha , lora_dropout , init_lora_weights )
11941205 self .active_adapter = adapter_name
11951206
1207+ def merge (self ):
1208+ if self .active_adapter not in self .lora_A .keys ():
1209+ return
1210+ if self .merged :
1211+ warnings .warn ("Already merged. Nothing to do." )
1212+ return
1213+ if self .r [self .active_adapter ] > 0 :
1214+ warnings .warn (
1215+ "Merge lora module to 4-bit linear may get different generations due to rounding errors."
1216+ )
1217+ # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
1218+ kwargs = self .weight .__dict__
1219+ lora_data = self .get_delta_weight (self .active_adapter )
1220+ w_data = bnb .functional .dequantize_4bit (self .weight .data , self .weight .quant_state ) + lora_data
1221+ self .weight = bnb .nn .Params4bit (w_data .to ("cpu" ), requires_grad = False , ** kwargs ).to (self .weight .device )
1222+ self .merged = True
1223+
1224+ def unmerge (self ):
1225+ if self .active_adapter not in self .lora_A .keys ():
1226+ return
1227+ if not self .merged :
1228+ warnings .warn ("Already unmerged. Nothing to do." )
1229+ return
1230+ if self .r [self .active_adapter ] > 0 :
1231+ warnings .warn (
1232+ "Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
1233+ )
1234+ kwargs = self .weight .__dict__
1235+ lora_data = self .get_delta_weight (self .active_adapter )
1236+ w_data = bnb .functional .dequantize_4bit (self .weight .data , self .weight .quant_state ) - lora_data
1237+ self .weight = bnb .nn .Params4bit (w_data .to ("cpu" ), requires_grad = False , ** kwargs ).to (self .weight .device )
1238+ self .merged = False
1239+
1240+ def get_delta_weight (self , adapter ):
1241+ return (
1242+ transpose (
1243+ self .lora_B [adapter ].weight @ self .lora_A [adapter ].weight ,
1244+ False ,
1245+ )
1246+ * self .scaling [adapter ]
1247+ )
1248+
11961249 def forward (self , x : torch .Tensor ) -> torch .Tensor :
1197- # note: logic differs from default Linear because merging is not supported
11981250 result = super ().forward (x )
11991251
12001252 if (
0 commit comments