Skip to content

Commit 140a69b

Browse files
authored
Support merge lora module for 4bit and 8bit linear (#851)
* support merge lora module for 4bit and 8bit linear * add tests for merging lora module to 8bit and 4bit model * state shoule reset grad * add prepare output before and after merge lora * fix format * fix format 2 * fix format 3 * add warning * fix parameter format * remove 8bit merge * remove 8bit linear merge * add comment for 4bit merge
1 parent 8c17d55 commit 140a69b

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

src/peft/tuners/lora.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def _prepare_adapter_config(peft_config, model_config):
549549

550550
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
551551
if merge:
552-
if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False):
552+
if getattr(self.model, "is_loaded_in_8bit", False):
553553
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")
554554
if getattr(self.model, "quantization_method", None) == "gptq":
555555
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
@@ -573,6 +573,17 @@ def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
573573
padding=target.padding,
574574
dilation=target.dilation,
575575
)
576+
elif is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
577+
bias = target.bias is not None
578+
new_module = bnb.nn.Linear4bit(
579+
target.in_features,
580+
target.out_features,
581+
bias=bias,
582+
compute_dtype=target.compute_dtype,
583+
compress_statistics=target.weight.compress_statistics,
584+
quant_type=target.weight.quant_type,
585+
device=target.weight.device,
586+
)
576587
else:
577588
bias = target.bias is not None
578589
if getattr(target, "is_target_conv_1d_layer", False):
@@ -1193,8 +1204,49 @@ def __init__(
11931204
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
11941205
self.active_adapter = adapter_name
11951206

1207+
def merge(self):
1208+
if self.active_adapter not in self.lora_A.keys():
1209+
return
1210+
if self.merged:
1211+
warnings.warn("Already merged. Nothing to do.")
1212+
return
1213+
if self.r[self.active_adapter] > 0:
1214+
warnings.warn(
1215+
"Merge lora module to 4-bit linear may get different generations due to rounding errors."
1216+
)
1217+
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
1218+
kwargs = self.weight.__dict__
1219+
lora_data = self.get_delta_weight(self.active_adapter)
1220+
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data
1221+
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device)
1222+
self.merged = True
1223+
1224+
def unmerge(self):
1225+
if self.active_adapter not in self.lora_A.keys():
1226+
return
1227+
if not self.merged:
1228+
warnings.warn("Already unmerged. Nothing to do.")
1229+
return
1230+
if self.r[self.active_adapter] > 0:
1231+
warnings.warn(
1232+
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
1233+
)
1234+
kwargs = self.weight.__dict__
1235+
lora_data = self.get_delta_weight(self.active_adapter)
1236+
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) - lora_data
1237+
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device)
1238+
self.merged = False
1239+
1240+
def get_delta_weight(self, adapter):
1241+
return (
1242+
transpose(
1243+
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
1244+
False,
1245+
)
1246+
* self.scaling[adapter]
1247+
)
1248+
11961249
def forward(self, x: torch.Tensor) -> torch.Tensor:
1197-
# note: logic differs from default Linear because merging is not supported
11981250
result = super().forward(x)
11991251

12001252
if (

tests/test_common_gpu.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535

3636
if is_bnb_available():
37+
import bitsandbytes as bnb
38+
3739
from peft.tuners.lora import Linear8bitLt
3840

3941
if is_bnb_4bit_available():
@@ -356,3 +358,37 @@ def test_modules_to_save_grad(self):
356358
self.assertTrue(modules_to_save.weight.requires_grad is True)
357359
self.assertTrue(original_module.weight.grad is None)
358360
self.assertTrue(modules_to_save.weight.grad is not None)
361+
362+
@require_torch_gpu
363+
@pytest.mark.single_gpu_tests
364+
@require_bitsandbytes
365+
def test_4bit_merge_lora(self):
366+
torch.manual_seed(3000)
367+
bnb_config = BitsAndBytesConfig(
368+
load_in_4bit=True,
369+
bnb_4bit_use_double_quant=False,
370+
bnb_4bit_compute_type=torch.float32,
371+
)
372+
model = AutoModelForCausalLM.from_pretrained(
373+
"facebook/opt-125m",
374+
quantization_config=bnb_config,
375+
torch_dtype=torch.float32,
376+
)
377+
config = LoraConfig(
378+
r=8,
379+
init_lora_weights=False,
380+
)
381+
model = get_peft_model(model, config)
382+
383+
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
384+
with torch.inference_mode():
385+
out_before_merge = model.generate(random_input, max_new_tokens=1)
386+
387+
model.merge_and_unload("default")
388+
with torch.inference_mode():
389+
out_after_merge = model.generate(random_input, max_new_tokens=1)
390+
391+
self.assertTrue(torch.equal(out_before_merge, out_after_merge))
392+
self.assertTrue(isinstance(model, PeftModel))
393+
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit))
394+
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit))

0 commit comments

Comments
 (0)