-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmerge_lora.py
More file actions
166 lines (129 loc) · 5.2 KB
/
merge_lora.py
File metadata and controls
166 lines (129 loc) · 5.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python3
"""
Merge MLX LoRA adapters with base model and export to HuggingFace format.
This script:
1. Loads the base Qwen3 model (MLX 4-bit)
2. Loads the LoRA adapters
3. Fuses them together
4. Exports to a format compatible with HuggingFace/Bumblebee
Usage:
python merge_lora.py
"""
import argparse
import json
import shutil
from pathlib import Path
def fuse_with_mlx_lm():
"""Use mlx_lm.fuse to merge LoRA adapters with base model."""
import subprocess
base_model = "lmstudio-community/Qwen3-8B-MLX-4bit"
adapter_path = Path(__file__).parent / "adapters_qwen3_4bit"
output_path = Path(__file__).parent / "fused_model"
print(f"Fusing LoRA adapters from: {adapter_path}")
print(f"Base model: {base_model}")
print(f"Output: {output_path}")
# Run mlx_lm.fuse
cmd = [
"python", "-m", "mlx_lm.fuse",
"--model", base_model,
"--adapter-path", str(adapter_path),
"--save-path", str(output_path),
]
print(f"\nRunning: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error: {result.stderr}")
return None
print(result.stdout)
print(f"\nFused model saved to: {output_path}")
return output_path
def export_to_huggingface(mlx_model_path: Path, output_path: Path):
"""
Export MLX model to HuggingFace safetensors format.
MLX models use the same safetensors format, but the weight names
might differ. We need to rename them to match HF conventions.
"""
import mlx.core as mx
from safetensors.numpy import save_file
import numpy as np
print(f"\nConverting MLX model to HuggingFace format...")
print(f"Input: {mlx_model_path}")
print(f"Output: {output_path}")
output_path.mkdir(parents=True, exist_ok=True)
# Load MLX weights
weights_file = mlx_model_path / "model.safetensors"
if not weights_file.exists():
# Try loading sharded weights
weights_files = list(mlx_model_path.glob("model-*.safetensors"))
if not weights_files:
print(f"Error: No model weights found in {mlx_model_path}")
return None
# MLX to HF weight name mapping (if needed)
# Generally MLX uses the same names as HF for Qwen models
# Copy config.json
config_src = mlx_model_path / "config.json"
if config_src.exists():
# Modify config for standard HF format
with open(config_src) as f:
config = json.load(f)
# Remove MLX-specific fields
config.pop("quantization", None)
with open(output_path / "config.json", "w") as f:
json.dump(config, f, indent=2)
print("Copied and cleaned config.json")
# Copy tokenizer files
for tok_file in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]:
src = mlx_model_path / tok_file
if src.exists():
shutil.copy(src, output_path / tok_file)
print(f"Copied {tok_file}")
# Copy model weights (safetensors are compatible)
for weights_file in mlx_model_path.glob("*.safetensors"):
shutil.copy(weights_file, output_path / weights_file.name)
print(f"Copied {weights_file.name}")
print(f"\nHuggingFace model exported to: {output_path}")
return output_path
def dequantize_and_export(mlx_model_path: Path, output_path: Path):
"""
For 4-bit quantized models, we need to dequantize before exporting
to standard HF format (since Bumblebee doesn't support 4-bit yet).
This increases model size but ensures compatibility.
"""
print("\nNote: The model is 4-bit quantized.")
print("For full Bumblebee compatibility, we'd need to dequantize.")
print("This would increase model size from ~4GB to ~16GB.")
print("\nAlternative: Use the quantized model with custom loading code.")
def main():
parser = argparse.ArgumentParser(description="Merge LoRA adapters and export to HF format")
parser.add_argument("--skip-fuse", action="store_true", help="Skip fusing, use existing fused model")
parser.add_argument("--output", type=Path, default=Path("qwen3_finetuned_hf"), help="Output directory")
args = parser.parse_args()
project_dir = Path(__file__).parent
fused_path = project_dir / "fused_model"
if not args.skip_fuse:
fused_path = fuse_with_mlx_lm()
if fused_path is None:
print("Fusing failed!")
return 1
if fused_path and fused_path.exists():
export_to_huggingface(fused_path, project_dir / args.output)
dequantize_and_export(fused_path, project_dir / args.output)
print("\n" + "="*60)
print("NEXT STEPS:")
print("="*60)
print("""
1. The fused model is in MLX 4-bit quantized format
2. Bumblebee currently doesn't support 4-bit quantization
3. Options:
a) Implement quantized model loading in Bumblebee (contribute!)
b) Dequantize the model (increases size to ~16GB)
c) Use Ollama for inference, Elixir for bot logic
To test with Ollama:
ollama create qwen3-custom -f Modelfile
To proceed with Bumblebee, we need to either:
- Implement quantization support
- Or use the full-precision HF Qwen3 model
""")
return 0
if __name__ == "__main__":
exit(main())