diff --git a/.gitignore b/.gitignore index bf707424f..117c6c902 100644 --- a/.gitignore +++ b/.gitignore @@ -22,5 +22,8 @@ save* .log *.pid *.ipynb* +model/ +output_* +datasets/ .venv/ -*.sh \ No newline at end of file +*.sh diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml new file mode 100644 index 000000000..48218b91e --- /dev/null +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -0,0 +1,54 @@ +base: + seed: &seed 42 +model: + type: Wan2T2V + path: /path/to/wan_t2v + torch_dtype: auto + use_cpu_to_save_cuda_mem_for_catcher: True +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 4.0 + guidance_scale_2: 3.0 + seed: *seed +eval: + eval_pos: [] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 4.0 + guidance_scale_2: 3.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + quant_type: int-quant + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + quant_type: int-quant + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml index 680fab43b..262d68520 100755 --- a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: /path/to/x2v/ \ No newline at end of file diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml index 14d05479d..59e35dd4e 100755 --- a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: /path/to/x2v/ \ No newline at end of file diff --git a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml old mode 100755 new mode 100644 index b6a53b0e0..844b62214 --- a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml @@ -29,4 +29,4 @@ quant: granularity: per_token save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: /path/to/x2v/ \ No newline at end of file diff --git a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml index 7d65f31fc..122d31f79 100755 --- a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml @@ -42,4 +42,4 @@ quant: alpha: 0.7 save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: /path/to/x2v/ \ No newline at end of file diff --git a/docs/wan2.1_quantization_guide.md b/docs/wan2.1_quantization_guide.md new file mode 100644 index 000000000..eeef5ac63 --- /dev/null +++ b/docs/wan2.1_quantization_guide.md @@ -0,0 +1,288 @@ +# Wan2.1 视频生成模型量化指南 + +## 概述 + +llmc 框架现已全面支持 Wan2.1 系列视频生成模型的量化,并提供真正量化的 INT8/FP8 权重导出,与 lightx2v 推理框架兼容。 + +## 支持的模型类型 + +- **WanI2V**: Image-to-Video (图像到视频) +- **WanT2V**: Text-to-Video (文本到视频) + +## 支持的量化方法 + +### FP8 量化 (推荐) + +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml` + +**特点**: +- 使用 E4M3 FP8 格式 (8-bit 浮点数,4位指数,3位尾数) +- SmoothQuant 算法,平衡权重和激活的量化难度 +- 适合 GPU 推理,性能损失小 + +**量化配置**: +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数 +``` + +### INT8 量化 + +#### 1. RTN (Round-to-Nearest) +**配置文件**: `configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml` + +**特点**: +- 最简单的量化方法 +- 直接四舍五入到最近的量化级别 +- 速度快,精度略低 + +#### 2. AWQ (Activation-aware Weight Quantization) +**配置文件**: `configs/quantization/video_gen/wan_i2v/awq_w_a.yaml` + +**特点**: +- 基于激活分布优化权重量化 +- 保护重要通道,减少精度损失 +- 需要校准数据 + +#### 3. SmoothQuant +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml` + +**特点**: +- 平衡权重和激活的量化难度 +- 数学上等价于平滑激活异常值 +- 通常提供最佳精度 + +### LoRA 模型量化 + +支持对 LoRA 适配器模型的量化: +- `smoothquant_w_a_int8_lora.yaml` +- `rtn_w_a_lora.yaml` + +## 运行步骤 + +### 1. 准备环境 + +```bash +# 设置 llmc 路径 +export llmc=/path/to/llmc +export PYTHONPATH=$llmc:$PYTHONPATH + +# 设置 GPU +export CUDA_VISIBLE_DEVICES=0 +``` + +### 2. 准备校准数据 + +为 I2V 模型准备校准数据: +``` +assets/wan_i2v/calib/ +├── image_1.jpg +├── image_2.jpg +└── ... +``` + +为 T2V 模型准备校准数据: +``` +assets/wan_t2v/calib/ +├── prompt_1.txt +├── prompt_2.txt +└── ... +``` + +### 3. 修改配置文件 + +编辑对应的 YAML 配置文件,设置: +- `model.path`: Wan2.1 模型路径 +- `calib.path`: 校准数据路径 +- `save.save_path`: 量化模型保存路径 + +**示例 (FP8 量化)**: +```yaml +base: + seed: 42 +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的模型路径 + torch_dtype: auto +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 +save: + save_lightx2v: True + save_path: /path/to/save/quantized/model # 修改为保存路径 +``` + +### 4. 运行量化 + +#### 使用脚本运行 (推荐) + +```bash +# 运行 FP8 量化 (I2V) +./run_llmc.sh wan_i2v_fp8 + +# 运行 INT8 RTN 量化 (I2V) +./run_llmc.sh wan_i2v_int8_rtn + +# 运行 INT8 AWQ 量化 (I2V) +./run_llmc.sh wan_i2v_int8_awq + +# 运行 INT8 SmoothQuant 量化 (I2V) +./run_llmc.sh wan_i2v_int8_smoothquant + +# 运行 T2V 模型量化 +./run_llmc.sh wan_t2v_int8_rtn +./run_llmc.sh wan_t2v_int8_awq +./run_llmc.sh wan_t2v_int8_smoothquant +``` + +#### 直接运行命令 + +```bash +torchrun \ +--nnodes 1 \ +--nproc_per_node 1 \ +--rdzv_id $RANDOM \ +--rdzv_backend c10d \ +--rdzv_endpoint 127.0.0.1:29500 \ +${llmc}/llmc/__main__.py \ +--config configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml \ +--task_id my_quant_task +``` + +### 5. 监控进度 + +```bash +# 查看日志 +tail -f wan_i2v_fp8.log + +# 查看进程 +ps aux | grep __main__.py +``` + +### 6. 停止任务 + +```bash +# 使用保存的 PID 文件 +xargs kill -9 < wan_i2v_fp8.pid +``` + +## 配置参数说明 + +### 模型配置 +- `type`: 模型类型 (`WanI2V` 或 `WanT2V`) +- `path`: 模型权重路径 +- `torch_dtype`: 数据类型 (`auto`, `bfloat16`, `float32`) + +### 校准配置 +- `sample_steps`: 采样步数 (通常 20-40) +- `bs`: 批大小 (通常 1,视频生成显存占用大) +- `target_height`: 目标视频高度 (默认 480) +- `target_width`: 目标视频宽度 (默认 832) +- `num_frames`: 视频帧数 (默认 81) +- `guidance_scale`: CFG 引导强度 (默认 5.0) + +### 量化配置 +- `method`: 量化方法 (`RTN`, `Awq`, `SmoothQuant`) +- `weight.bit`: 权重位宽 (8, e4m3) +- `act.bit`: 激活位宽 (8, e4m3) +- `granularity`: 量化粒度 (`per_channel`, `per_token`) +- `special.alpha`: SmoothQuant 平衡参数 (0.5-1.0) + +## 在 lightx2v 中使用量化模型 + +### 1. 配置 lightx2v + +编辑 `lightx2v/configs/quantization/wan_i2v.json`: +```json +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "dit_quantized_ckpt": "/path/to/quantized/model", + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm" +} +``` + +对于 FP8 模型,设置 `"dit_quant_scheme": "fp8"`。 + +### 2. 运行推理 + +```bash +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path /path/to/original/model \ +--config_json configs/quantization/wan_i2v.json \ +--prompt "Your prompt here" \ +--image_path /path/to/input/image.jpg \ +--save_result_path output.mp4 +``` + +## 性能建议 + +1. **FP8 vs INT8**: + - FP8: 精度更高,适合对质量要求高的场景 + - INT8: 压缩率更高,适合对速度要求高的场景 + +2. **量化方法选择**: + - 快速原型: RTN + - 平衡精度和速度: SmoothQuant + - 最高精度: AWQ + +3. **校准数据**: + - 使用 10-50 个样本 + - 覆盖典型使用场景 + - I2V: 使用多样化图像 + - T2V: 使用多样化文本描述 + +4. **资源需求**: + - GPU: 建议 24GB+ 显存 + - 校准时间: 30分钟 - 2小时 (取决于数据量) + - 存储空间: 量化后模型约原模型 25-50% 大小 + +## 故障排除 + +### 显存不足 +- 减小 `bs` 到 1 +- 减小 `num_frames` +- 减小 `target_height` 和 `target_width` + +### 量化精度损失过大 +- 尝试 SmoothQuant 方法 +- 增加校准数据数量 +- 调整 `alpha` 参数 (0.5-1.0) + +### lightx2v 兼容性问题 +- 确保使用 `save_lightx2v: True` +- 检查 `dit_quant_scheme` 设置 +- 确认量化模型路径正确 + +## 参考 + +- lightx2v 文档: [lightx2v 项目地址] +- llmc 框架: [llmc 项目地址] +- Wan2.1 模型: [模型地址] diff --git a/docs/wan2.2_quantization_guide.md b/docs/wan2.2_quantization_guide.md new file mode 100644 index 000000000..8a5633275 --- /dev/null +++ b/docs/wan2.2_quantization_guide.md @@ -0,0 +1,136 @@ +# Wan2.2 视频生成模型量化指南 + +## 概述 + +本仓库为 **Wan2.2-T2V** 提供的现成示例是 **4-bit AWQ 模拟量化**(`configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml`)。 + +Wan2.2 为 **MoE 双专家**:高噪声(`transformer`)与低噪声(`transformer_2`),校准与块级量化会覆盖两条支路。保存侧默认示例为 `save_fake`,推理对接需按你的推理栈自行对齐。 + +**模型示例(原生 checkpoint 布局)**:[Wan-AI/Wan2.2-T2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) + +## Wan2.2 相对 Wan2.1 的要点 + +| 项目 | 说明 | +|------|------| +| 注册名 | `Wan2T2V` | +| 结构 | 双专家 MoE,非单路 DiT | +| 推理后端 | 优先官方 `wan` + 原生目录;可按 YAML 注释回退 Diffusers | +| CFG | `guidance_scale`(高噪声)与 `guidance_scale_2`(低噪声),与官方双引导一致 | + +## 量化配置示例 + +`awq_w_a.yaml` 中 `quant` 段与仓库一致,例如: + +```yaml +quant: + video_gen: + method: Awq + weight: + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +``` + +## 运行步骤 + +### 1. 环境 + +```bash +export llmc=/path/to/LightCompress +export PYTHONPATH=$llmc:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +``` + +原生布局需要能 `import wan`,通常: + +```bash +pip install -e /path/to/Wan2.2 +``` + +或在 YAML 里设置 `wan2_repo_path: /path/to/Wan2.2`。 + +### 2. 校准数据 + +与 Wan2.1 T2V 相同,文本 prompt 文件目录,例如: + +``` +assets/wan_t2v/calib/ +├── prompt_1.txt +├── prompt_2.txt +└── ... +``` + +配置中 `calib.name: t2v`,`calib.path` 指向该目录。 + +### 3. 修改 `awq_w_a.yaml` + +必改: + +- `model.path`:Wan2.2 权重路径 +- `calib.path` / `eval.path`:校准与评估数据 +- `save.save_path`:输出目录 + +可选(见 YAML 注释): + +- `use_cpu_to_save_cuda_mem_for_catcher: True`:校准显存紧张时减轻峰值 +- `allow_diffusers_fallback: True`:无法用官方后端时回退 Diffusers + +双引导示例: + +```yaml +calib: + guidance_scale: 4.0 # high_noise + guidance_scale_2: 3.0 # low_noise +eval: + guidance_scale: 4.0 + guidance_scale_2: 3.0 +``` + +### 4. 启动量化 + +```bash +torchrun \ + --nnodes 1 \ + --nproc_per_node 1 \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint 127.0.0.1:29500 \ + ${llmc}/llmc/__main__.py \ + --config ${llmc}/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml \ + --task_id wan22_awq_int4 +``` + +`scripts/run_llmc.sh` 中把 `model_name=wan2_2_t2v`、`task_name=awq_w_a` 等与上述 YAML 对齐即可(需按本机修改脚本里的 Python 路径等)。 + +## 参数速查 + +| 区域 | 说明 | +|------|------| +| `model.type` | `Wan2T2V` | +| `quant.video_gen.method` | `Awq` | +| `weight` / `act` | `bit: 4`(具体 `quant_type` 以 YAML 为准) | +| `save` | 示例 `save_fake: True` 与 `save_path` | + +## 常见问题 + +- **OOM**:减小 `sample_steps`、`num_frames`、分辨率;`bs: 1`;可开 `use_cpu_to_save_cuda_mem_for_catcher`。 +- **无法 `import wan`**:安装官方仓库或配置 `wan2_repo_path`。 +- **画质下降**:增加/多样化校准 prompt;在支持范围内微调 `special` 与校准规模。 + +## 参考 + +- `configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml` +- `llmc/models/wan2_2_t2v.py` +- 其它精度(如 FP8、INT8)可参考 `docs/wan2.1_quantization_guide.md` 的思路,自行新增 `wan2_2_t2v` 下 YAML 并替换 `model.type` 与路径。 diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index 72823d1bd..380e8f42c 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -31,11 +31,15 @@ def __init__(self, model, compress_config, input, padding_mask, config): def run_block_loop(self): for i in range(len(self.blocks)): self.block_idx = i + if self.input and hasattr(self.model, 'get_blockwise_input'): + self.input = self.model.get_blockwise_input(self.block_idx, self.input) logger.info( f'\nblock index: {self.block_idx}/{len(self.blocks)} ' f'\nblock: {self.blocks[self.block_idx]}' ) self.block_opt(self.blocks[self.block_idx]) + if self.input and hasattr(self.model, 'set_blockwise_input'): + self.model.set_blockwise_input(self.block_idx, self.input) if hasattr(self, 'save_scale') and self.save_scale: os.makedirs(self.scale_path, exist_ok=True) diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index cfcebd4e1..2df4f8c93 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -3,6 +3,7 @@ import gc import os import re +import shutil from collections import defaultdict from functools import partial @@ -34,7 +35,11 @@ _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) -from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer +from .quant import ( + FloatQuantizer, + IntegerQuantizer, + Weight48IntegerQuantizer, +) class BaseBlockwiseQuantization(BlockwiseOpt): @@ -450,9 +455,21 @@ def run(self, block, input_feat, handles): h.remove() torch.cuda.empty_cache() - self.block_transform(block, input_feat, self.input['kwargs']) + if not self._is_ignored_block(self.block_idx): + self.block_transform(block, input_feat, self.input['kwargs']) + else: + logger.info( + f'Block {self.block_idx} is in ignored_block_ids, ' + f'skipping block_transform.' + ) else: - self.block_transform(block) + if not self._is_ignored_block(self.block_idx): + self.block_transform(block) + else: + logger.info( + f'Block {self.block_idx} is in ignored_block_ids, ' + f'skipping block_transform.' + ) if not self.data_free and self.quant_out: self.model.replace_module_block( @@ -913,27 +930,45 @@ def set_non_linear_mode(self, quant_format, module, mode): if getattr(m, 'calib', None) is not None: m.calib = mode + def _get_ignored_block_ids_set(self): + if not hasattr(self, '_ignored_block_ids_set_cache'): + expanded = [] + for item in self.ignored_block_ids: + match = re.match(r'(\d+)-(\d+)', str(item)) + if match: + start, end = int(match.group(1)), int(match.group(2)) + expanded.extend(range(start, end + 1)) + else: + expanded.append(int(item)) + self._ignored_block_ids_set_cache = set(expanded) + return self._ignored_block_ids_set_cache + + def _is_ignored_block(self, block_idx): + if not self.mixed_precision or not self.ignored_block_ids: + return False + return block_idx in self._get_ignored_block_ids_set() + def set_no_quant_layer(self): if self.ignored_speical_names: assert hasattr(self.model, 'block_name_prefix'), \ 'block_name_prefix missing in model' - ignored_block_ids = [] - for item in self.ignored_block_ids: - match = re.match(r'(\d+)-(\d+)', str(item)) - if match: - start, end = int(match.group(1)), int(match.group(2)) - ignored_block_ids.extend(range(start, end + 1)) - else: - ignored_block_ids.append(int(item)) + ignored_block_ids = self._get_ignored_block_ids_set() + # If no layer_names specified, skip all linear layers in the ignored blocks + skip_all_linears = not self.ignored_layer_names for idx, block in enumerate(self.blocks): for n, m in block.named_modules(): - if idx in ignored_block_ids and n in self.ignored_layer_names: - m.register_buffer('no_quant', torch.tensor(True)) - else: - layer_name = f'{self.model.block_name_prefix}.{idx}.{n}' - if layer_name in self.ignored_speical_names: + if idx in ignored_block_ids: + if skip_all_linears: + if isinstance(m, tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)): + m.register_buffer('no_quant', torch.tensor(True)) + elif n in self.ignored_layer_names: m.register_buffer('no_quant', torch.tensor(True)) + else: + if self.ignored_speical_names: + layer_name = f'{self.model.block_name_prefix}.{idx}.{n}' + if layer_name in self.ignored_speical_names: + m.register_buffer('no_quant', torch.tensor(True)) @torch.no_grad() def deploy(self, quant_format, keep_device=False): @@ -1009,6 +1044,18 @@ def contiguous_params(self): if not param.is_contiguous(): param.data = param.data.contiguous() + if ( + self.config.model.type in ['Wan2T2V'] + and hasattr(self.model.Pipeline, 'transformer_2') + and self.model.Pipeline.transformer_2 is not None + ): + for name, param in self.model.Pipeline.transformer_2.named_parameters(): + if not param.is_contiguous(): + param.data = param.data.contiguous() + for name, param in self.model.Pipeline.transformer_2.named_buffers(): + if not param.is_contiguous(): + param.data = param.data.contiguous() + @torch.no_grad() def save_model(self, path): if int(os.environ['RANK']) != 0: @@ -1029,6 +1076,8 @@ def save_model(self, path): self.model.avlm_model.save_pretrained(path) logger.info('save model done --') self.copy_tokenizer(path) + elif self.config.model.type in ['Wan2T2V']: + self.model.save_wan2_2_pretrained(path) else: self.model.get_model().save_pretrained(path) logger.info('save model done --') diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index d06fd7479..b2d2f2aea 100755 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1,4 +1,6 @@ import gc +import os +import sys import torch from loguru import logger @@ -1225,7 +1227,6 @@ def __repr__(self): f'kwargs={self.kwargs}, qmin={self.qmin}, qmax={self.qmax})' ) - class Weight48IntegerQuantizer(BaseQuantizer): # flake8: noqa def __init__(self, bit, bit4, bit8, **kwargs): diff --git a/llmc/eval/eval_video_generate.py b/llmc/eval/eval_video_generate.py index 0f99ff6c9..726187c0b 100755 --- a/llmc/eval/eval_video_generate.py +++ b/llmc/eval/eval_video_generate.py @@ -23,6 +23,7 @@ def __init__(self, model, config): self.target_width = self.eval_cfg.get('target_width', 832) self.num_frames = self.eval_cfg.get('num_frames', 81) self.guidance_scale = self.eval_cfg.get('guidance_scale', 5.0) + self.guidance_scale_2 = self.eval_cfg.get('guidance_scale_2', None) self.fps = self.eval_cfg.get('fps', 15) @torch.no_grad() @@ -56,14 +57,17 @@ def t2v_eval(self, model, testenc, bs, eval_pos): assert bs == 1, 'Only support eval bs=1' for i, data in enumerate(testenc): - output = model.Pipeline( - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=self.target_height, - width=self.target_width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, os.path.join(self.output_video_path, f'{eval_pos}_output_{i}.mp4'), @@ -77,15 +81,18 @@ def i2v_eval(self, model, testenc, bs, eval_pos): for i, data in enumerate(testenc): image, width, height = self.pre_process(model, data['image']) - output = model.Pipeline( - image=image, - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=height, - width=width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'image': image, + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': height, + 'width': width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, @@ -98,9 +105,9 @@ def i2v_eval(self, model, testenc, bs, eval_pos): @torch.no_grad() def eval_func(self, model, testenc, bs, eval_pos): assert bs == 1, 'Evaluation only supports batch size = 1.' - assert self.model_type in ['WanT2V', 'WanI2V'], ( + assert self.model_type in ['WanT2V', 'WanI2V', 'Wan2T2V'], ( f"Unsupported model type '{self.model_type}'.\n" - 'Only Wan2.1 video generation models (WanT2V, WanI2V) are supported.' + 'Only Wan video generation models (WanT2V, WanI2V, Wan2T2V) are supported.' ) if self.eval_dataset_name == 't2v': return self.t2v_eval(model, testenc, bs, eval_pos) diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 83d746254..7351995df 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -37,3 +37,4 @@ from .vit import Vit from .wan_i2v import WanI2V from .wan_t2v import WanT2V +from .wan2_2_t2v import Wan2T2V diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 4d7dda2ae..25393a871 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -119,7 +119,7 @@ def has_bias(self): pass def build_tokenizer(self): - if self.model_type not in ['Vit', 'WanT2V', 'WanI2V']: + if self.model_type not in ['Vit', 'WanT2V', 'WanI2V', 'Wan2T2V']: assert self.tokenizer_mode in ['fast', 'slow'] self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, use_fast=self.tokenizer_mode, trust_remote_code=True @@ -129,7 +129,7 @@ def build_tokenizer(self): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token else: - self.tokenizer = None + self.tokenizer = None def get_tokenizer(self): return self.tokenizer diff --git a/llmc/models/wan2_2_t2v.py b/llmc/models/wan2_2_t2v.py new file mode 100755 index 000000000..c3db088da --- /dev/null +++ b/llmc/models/wan2_2_t2v.py @@ -0,0 +1,744 @@ +import gc +import copy +import inspect +import os +import shutil +import sys +from collections import defaultdict +from types import SimpleNamespace + +import torch +import torch.nn as nn +from diffusers import AutoencoderKLWan, WanPipeline +from loguru import logger + +from llmc.compression.quantization.module_utils import LlmcWanTransformerBlock +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +class WanOfficialPipelineAdapter: + """Adapter that exposes Wan-Video/Wan2.2 official t2v runtime as a Pipeline-like interface.""" + + def __init__( + self, + runner, + sample_solver='unipc', + sampling_steps=40, + sample_shift=12.0, + offload_model=True, + ): + self.runner = runner + # Keep the same expert naming semantics as existing LLMC Wan2.2 flow: + # transformer -> high-noise expert, transformer_2 -> low-noise expert. + self.transformer = runner.high_noise_model + self.transformer_2 = runner.low_noise_model + self.sample_solver = sample_solver + self.sampling_steps = sampling_steps + self.sample_shift = sample_shift + self.offload_model = offload_model + self._is_wan_official = True + + @staticmethod + def _tensor_to_frames(video): + if video is None: + return [] + if not torch.is_tensor(video): + return video + + video = video.detach().cpu() + if video.dim() != 4: + raise ValueError(f'Unexpected official Wan video shape: {tuple(video.shape)}') + + # Accept [C, F, H, W] and convert to [F, C, H, W]. + if video.shape[0] in (1, 3): + video = video.permute(1, 0, 2, 3) + + if video.dtype.is_floating_point: + if video.min().item() < 0: + video = (video.clamp(-1, 1) + 1.0) / 2.0 + else: + video = video.clamp(0, 1) + video = (video * 255).round().to(torch.uint8) + elif video.dtype != torch.uint8: + video = video.to(torch.uint8) + + return [frame.permute(1, 2, 0).contiguous().numpy() for frame in video] + + def to(self, device): # noqa: ARG002 + # Keep the same API as diffusers pipeline; official runner manages model movement itself. + return self + + def __call__( + self, + prompt, + negative_prompt='', + height=480, + width=832, + num_frames=81, + guidance_scale=5.0, + guidance_scale_2=None, + **kwargs, + ): + if isinstance(prompt, (list, tuple)): + prompt = prompt[0] + if isinstance(negative_prompt, (list, tuple)): + negative_prompt = negative_prompt[0] + + # Official Wan2.2 guide_scale order: (low_noise, high_noise). + guide_scale_low = guidance_scale if guidance_scale_2 is None else guidance_scale_2 + guide_scale_high = guidance_scale + + sampling_steps = kwargs.get( + 'num_inference_steps', + kwargs.get('sampling_steps', self.sampling_steps) + ) + sample_shift = kwargs.get('sample_shift', self.sample_shift) + sample_solver = kwargs.get('sample_solver', self.sample_solver) + seed = kwargs.get('seed', -1) + offload_model = kwargs.get('offload_model', self.offload_model) + + video = self.runner.generate( + input_prompt=prompt, + size=(width, height), + frame_num=num_frames, + shift=sample_shift, + sample_solver=sample_solver, + sampling_steps=sampling_steps, + guide_scale=(guide_scale_low, guide_scale_high), + n_prompt=negative_prompt if negative_prompt is not None else '', + seed=seed, + offload_model=offload_model, + ) + return SimpleNamespace(frames=[self._tensor_to_frames(video)]) + + +@MODEL_REGISTRY +class Wan2T2V(BaseModel): + """Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1.""" + + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + if 'calib' in config: + self.calib_bs = config.calib.bs + self.sample_steps = config.calib.sample_steps + self.target_height = config.calib.get('target_height', 480) + self.target_width = config.calib.get('target_width', 832) + self.num_frames = config.calib.get('num_frames', 81) + self.guidance_scale = config.calib.get('guidance_scale', 5.0) + self.guidance_scale_2 = config.calib.get('guidance_scale_2', 3.0) + else: + self.sample_steps = None + + @staticmethod + def _normalize_hf_repo_path(model_path): + hf_prefix = 'https://huggingface.co/' + if not isinstance(model_path, str) or not model_path.startswith(hf_prefix): + return model_path + repo_path = model_path[len(hf_prefix):].strip('/') + for marker in ['/tree/', '/blob/', '/resolve/']: + if marker in repo_path: + repo_path = repo_path.split(marker, maxsplit=1)[0] + return repo_path + + @staticmethod + def _has_diffusers_layout(model_path): + if not isinstance(model_path, str): + return False + return ( + os.path.isdir(model_path) + and os.path.isfile(os.path.join(model_path, 'model_index.json')) + and os.path.isdir(os.path.join(model_path, 'transformer')) + and os.path.isdir(os.path.join(model_path, 'vae')) + ) + + @staticmethod + def _has_wan22_native_layout(model_path): + if not isinstance(model_path, str): + return False + return ( + os.path.isdir(model_path) + and os.path.isfile(os.path.join(model_path, 'configuration.json')) + and os.path.isdir(os.path.join(model_path, 'high_noise_model')) + and os.path.isdir(os.path.join(model_path, 'low_noise_model')) + ) + + @staticmethod + def _is_wan22_native_repo_id(model_path): + if not isinstance(model_path, str): + return False + return model_path.rstrip('/\\') == 'Wan-AI/Wan2.2-T2V-A14B' + + def _should_require_official_backend(self, normalized_model_path): + if self.config.model.get('force_diffusers', False): + return False + if self.config.model.get('diffusers_path', None): + return False + if self.config.model.get('allow_diffusers_fallback', False): + return False + return ( + self._has_wan22_native_layout(normalized_model_path) + or self._is_wan22_native_repo_id(normalized_model_path) + ) + + def _import_official_wan(self): + def _import_impl(): + from wan.configs import t2v_A14B + from wan.text2video import WanT2V as WanOfficialT2V + + return t2v_A14B, WanOfficialT2V + + try: + return _import_impl() + except Exception as e: + repo_path = self.config.model.get('wan2_repo_path', None) + if repo_path and os.path.isdir(repo_path): + if repo_path not in sys.path: + sys.path.insert(0, repo_path) + try: + return _import_impl() + except Exception as e2: + logger.warning( + f'Failed to import official Wan2.2 from wan2_repo_path={repo_path}: {e2}' + ) + logger.warning( + 'Failed to import official Wan2.2 runtime (wan package). ' + 'Diffusers fallback depends on model.allow_diffusers_fallback/model.force_diffusers. ' + f'import_error={e}' + ) + return None, None + + def _try_build_official_wan_pipeline(self): + normalized_model_path = self._normalize_hf_repo_path(self.model_path) + if not self._has_wan22_native_layout(normalized_model_path): + return False + if self.config.model.get('force_diffusers', False): + logger.info('force_diffusers=True, skip official Wan2.2 import backend.') + return False + + t2v_A14B, WanOfficialT2V = self._import_official_wan() + if t2v_A14B is None or WanOfficialT2V is None: + return False + + wan_config = copy.deepcopy(t2v_A14B) + # Keep official defaults unless explicitly overridden by llmc config. + if self.config.model.get('sample_steps', None) is not None: + wan_config.sample_steps = self.config.model.sample_steps + if self.config.model.get('sample_shift', None) is not None: + wan_config.sample_shift = self.config.model.sample_shift + if self.config.model.get('boundary', None) is not None: + wan_config.boundary = self.config.model.boundary + + runner = WanOfficialT2V( + config=wan_config, + checkpoint_dir=normalized_model_path, + device_id=int(os.environ.get('LOCAL_RANK', 0)), + rank=int(os.environ.get('RANK', 0)), + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=self.config.model.get('t5_cpu', False), + init_on_cpu=self.config.model.get('init_on_cpu', True), + convert_model_dtype=self.config.model.get('convert_model_dtype', False), + ) + self.Pipeline = WanOfficialPipelineAdapter( + runner=runner, + sample_solver=self.config.model.get('sample_solver', 'unipc'), + sampling_steps=self.config.model.get( + 'sampling_steps', getattr(wan_config, 'sample_steps', 40) + ), + sample_shift=self.config.model.get( + 'sample_shift', getattr(wan_config, 'sample_shift', 12.0) + ), + offload_model=self.config.model.get('offload_model', True), + ) + self.pipeline_model_path = normalized_model_path + self.pipeline_source = 'wan_official' + self.use_official_wan = True + logger.info( + f'Loaded Wan2.2 via official Wan runtime from native checkpoint: {normalized_model_path}' + ) + return True + + def _resolve_pipeline_model_path(self): + explicit_diffusers_path = self.config.model.get('diffusers_path', None) + if explicit_diffusers_path is not None: + resolved_path = self._normalize_hf_repo_path(explicit_diffusers_path) + logger.info(f'Use explicit Wan2.2 diffusers_path: {resolved_path}') + return resolved_path + + raw_model_path = self.model_path + normalized_path = self._normalize_hf_repo_path(raw_model_path) + + if normalized_path != raw_model_path: + logger.info( + f'Normalize Wan2.2 model path from URL to repo id: {normalized_path}' + ) + + if self._has_diffusers_layout(normalized_path): + return normalized_path + + if self._has_wan22_native_layout(normalized_path): + local_diffusers_candidate = normalized_path.rstrip('/\\') + '-Diffusers' + if self._has_diffusers_layout(local_diffusers_candidate): + logger.info( + 'Detected native Wan2.2 checkpoint. ' + f'Use local diffusers directory: {local_diffusers_candidate}' + ) + return local_diffusers_candidate + logger.warning( + 'Detected native Wan2.2 checkpoint layout ' + f'({normalized_path}) but no local diffusers export found. ' + 'Fallback to official diffusers repo: Wan-AI/Wan2.2-T2V-A14B-Diffusers. ' + 'You can set model.diffusers_path to override this behavior.' + ) + return 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' + + if normalized_path.rstrip('/\\').endswith('Wan2.2-T2V-A14B'): + mapped_path = normalized_path.rstrip('/\\') + '-Diffusers' + logger.info( + f'Map Wan2.2 native repo/path to diffusers pipeline source: {mapped_path}' + ) + return mapped_path + + return normalized_path + + def build_model(self): + self.use_official_wan = False + normalized_model_path = self._normalize_hf_repo_path(self.model_path) + require_official_backend = self._should_require_official_backend(normalized_model_path) + + if self._try_build_official_wan_pipeline(): + self.find_llmc_model() + self.find_blocks() + logger.info( + 'Wan2.2 MoE official backend loaded: blocks=%s(+%s)', + len(self.Pipeline.transformer.blocks), + ( + len(self.Pipeline.transformer_2.blocks) + if hasattr(self.Pipeline, 'transformer_2') + and self.Pipeline.transformer_2 is not None + else 0 + ), + ) + logger.info('Model: %s', self.model) + return + + if require_official_backend: + raise RuntimeError( + 'Detected Wan2.2 native source ' + f'({normalized_model_path}) but official Wan runtime is unavailable. ' + 'Please install/prepare official Wan2.2 code (pip install -e /path/to/Wan2.2 ' + 'or set model.wan2_repo_path). ' + 'If you intentionally want Diffusers fallback, set ' + 'model.allow_diffusers_fallback=True or model.force_diffusers=True.' + ) + + self.pipeline_model_path = self._resolve_pipeline_model_path() + vae = AutoencoderKLWan.from_pretrained( + self.pipeline_model_path, + subfolder='vae', + torch_dtype=torch.float32, + use_safetensors=True, + ) + # Wan2.2: one pipeline, two transformer experts (transformer + transformer_2). + # Pipeline switches by SNR; both use WanTransformer3DModel with same block layout as Wan2.1. + self.Pipeline = WanPipeline.from_pretrained( + self.pipeline_model_path, + vae=vae, + torch_dtype=torch.bfloat16, + use_safetensors=True, + ) + self.find_llmc_model() + # Wrap both experts with LlmcWanTransformerBlock (same as Wan2.1 per-block layout). + for block_idx, block in enumerate(self.Pipeline.transformer.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer.blocks[block_idx] = new_block + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + for block_idx, block in enumerate(self.Pipeline.transformer_2.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer_2.blocks[block_idx] = new_block + self.num_transformer_blocks = len(self.Pipeline.transformer.blocks) + self.blocks = list(self.Pipeline.transformer.blocks) + list(self.Pipeline.transformer_2.blocks) + logger.info( + 'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).' + ) + else: + self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) + logger.info('Wan2.2: single transformer wrapped (40 blocks).') + logger.info('Model: %s', self.model) + + def find_llmc_model(self): + self.model = self.Pipeline.transformer + + def find_blocks(self): + self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + self.blocks += list(self.Pipeline.transformer_2.blocks) + + def _expert_name_from_block_idx(self, block_idx): + if block_idx < self.num_transformer_blocks: + return 'transformer' + return 'transformer_2' + + def get_blockwise_input(self, block_idx, fallback_input): + if not hasattr(self, 'blockwise_inputs'): + return fallback_input + return self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] + + def set_blockwise_input(self, block_idx, block_input): + if not hasattr(self, 'blockwise_inputs'): + return + self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] = block_input + + def get_catcher(self, first_block_input): + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.step = 0 + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + capture_kwargs = dict(kwargs) + for i, arg in enumerate(args): + if i > 0: + capture_kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + first_block_input['kwargs'].append(capture_kwargs) + self.step += 1 + if self.step == sample_steps: + raise ValueError + else: + return self.module(*args, **kwargs) + + return Catcher + + @torch.no_grad() + def collect_first_block_input(self, calib_data, padding_mask=None): + first_block_input = { + 'transformer': defaultdict(list), + 'transformer_2': defaultdict(list), + } + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module, expert_name): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.expert_name = expert_name + + def _to_cpu(self, x): + if torch.is_tensor(x): + return x.detach().cpu() + if isinstance(x, tuple): + return tuple(self._to_cpu(t) for t in x) + return x + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + capture_kwargs = dict(kwargs) + for i, arg in enumerate(args): + if i > 0: + capture_kwargs[params[i]] = arg + cur_num = len(first_block_input[self.expert_name]['data']) + if cur_num < sample_steps: + first_block_input[self.expert_name]['data'].append( + args[0].detach().cpu() if torch.is_tensor(args[0]) else args[0] + ) + first_block_input[self.expert_name]['kwargs'].append( + {k: self._to_cpu(v) for k, v in capture_kwargs.items()} + ) + if all(len(first_block_input[name]['data']) >= sample_steps for name in first_block_input): + raise ValueError + return self.module(*args, **kwargs) + + first_block = self.Pipeline.transformer.blocks[0] + self.Pipeline.transformer.blocks[0] = Catcher(first_block, 'transformer') + first_block_2 = None + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + first_block_2 = self.Pipeline.transformer_2.blocks[0] + self.Pipeline.transformer_2.blocks[0] = Catcher(first_block_2, 'transformer_2') + + self.Pipeline.to('cuda') + for data in calib_data: + try: + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if hasattr(self, 'guidance_scale_2'): + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + self.Pipeline(**pipe_kw) + except ValueError: + pass + gc.collect() + torch.cuda.empty_cache() + + self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module + if first_block_2 is not None: + self.Pipeline.transformer_2.blocks[0] = self.Pipeline.transformer_2.blocks[0].module + self.Pipeline.to('cpu') + + assert len(first_block_input['transformer']['data']) > 0, 'Catch transformer input data failed.' + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + assert len(first_block_input['transformer_2']['data']) > 0, \ + 'Catch transformer_2 input data failed.' + + self.blockwise_inputs = first_block_input + self.first_block_input = self.blockwise_inputs['transformer'] + self.n_samples = sum(len(v['data']) for v in self.blockwise_inputs.values()) + logger.info( + 'Retrieved Wan2.2 calibration samples: transformer=%s, transformer_2=%s.', + len(self.blockwise_inputs['transformer']['data']), + len(self.blockwise_inputs['transformer_2']['data']), + ) + + def get_padding_mask(self): + return None + + def has_bias(self): + return True + + def __str__(self): + return '\nWan2.2 MoE Model:\n%s\nTotal params: ~27B (14B active per step)' % ( + str(self.model), + ) + + def get_layernorms_in_block(self, block): + if hasattr(block, 'affine_norm1'): + return { + 'affine_norm1': block.affine_norm1, + 'norm2': block.norm2, + 'affine_norm3': block.affine_norm3, + } + return { + 'norm1': block.norm1, + 'norm3': block.norm3, + 'norm2': block.norm2, + } + + def get_subsets_in_block(self, block): + if not hasattr(block, 'attn1'): + # Official Wan2.2 native block layout: + # self_attn/qkv/o, cross_attn/qkv/o, ffn[0|2], modulation. + return [ + { + 'layers': { + 'self_attn.q': block.self_attn.q, + 'self_attn.k': block.self_attn.k, + 'self_attn.v': block.self_attn.v, + }, + # Official Wan2.2 uses non-affine norm1/norm2 by default. + # Skip trans-based scale folding to avoid invalid ln.weight operations. + 'prev_op': [None], + 'input': ['self_attn.q'], + 'inspect': block.self_attn, + 'has_kwargs': True, + 'do_trans': False, + 'sub_keys': { + 'seq_lens': 'seq_lens', + 'grid_sizes': 'grid_sizes', + 'freqs': 'freqs', + }, + }, + { + 'layers': { + 'cross_attn.q': block.cross_attn.q, + }, + 'prev_op': [None], + 'input': ['cross_attn.q'], + 'inspect': block.cross_attn, + 'has_kwargs': True, + 'do_trans': False, + 'sub_keys': { + 'context': 'context', + 'context_lens': 'context_lens', + }, + }, + { + 'layers': { + 'ffn.0': block.ffn[0], + }, + 'prev_op': [None], + 'input': ['ffn.0'], + 'inspect': block.ffn, + 'has_kwargs': False, + 'do_trans': False, + }, + ] + return [ + { + 'layers': { + 'attn1.to_q': block.attn1.to_q, + 'attn1.to_k': block.attn1.to_k, + 'attn1.to_v': block.attn1.to_v, + }, + 'prev_op': [block.affine_norm1], + 'input': ['attn1.to_q'], + 'inspect': block.attn1, + 'has_kwargs': True, + 'sub_keys': {'rotary_emb': 'rotary_emb'}, + }, + { + 'layers': { + 'attn2.to_q': block.attn2.to_q, + }, + 'prev_op': [block.norm2], + 'input': ['attn2.to_q'], + 'inspect': block.attn2, + 'has_kwargs': True, + 'sub_keys': {'encoder_hidden_states': 'encoder_hidden_states'}, + }, + { + 'layers': { + 'ffn.net.0.proj': block.ffn.net[0].proj, + }, + 'prev_op': [block.affine_norm3], + 'input': ['ffn.net.0.proj'], + 'inspect': block.ffn, + 'has_kwargs': True, + }, + ] + + def find_embed_layers(self): + pass + + def get_embed_layers(self): + pass + + def get_layers_except_blocks(self): + pass + + @staticmethod + def copy_native_checkpoint(src, dst): + """Copy full Wan2.2 native checkpoint tree before overwriting expert safetensors.""" + if not isinstance(src, str) or not os.path.isdir(src): + raise RuntimeError( + 'Wan2.2 official save expects a local native checkpoint directory, ' + f'but got src={src!r}.' + ) + if os.path.abspath(src) == os.path.abspath(dst): + raise RuntimeError( + 'Wan2.2 official save path must differ from source checkpoint path ' + f'(src=dst={src}).' + ) + if os.path.exists(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + logger.info(f'Copied original Wan2.2 native checkpoint from {src} to {dst}') + + @staticmethod + def validate_native_save_structure(save_path, source_path=None): + """Verify saved directory has Wan2.2 native layout (experts + copied non-expert assets).""" + if not os.path.isdir(save_path): + raise RuntimeError(f'Wan2.2 saved path is not a directory: {save_path}') + + required_entries = ['configuration.json', 'high_noise_model', 'low_noise_model'] + missing_required = [ + name for name in required_entries + if not os.path.exists(os.path.join(save_path, name)) + ] + if missing_required: + raise RuntimeError( + 'Wan2.2 saved structure is incomplete. Missing required entries: ' + f'{missing_required}. save_path={save_path}' + ) + + if isinstance(source_path, str) and os.path.isdir(source_path): + source_entries = set(os.listdir(source_path)) + source_non_expert_entries = sorted( + name for name in source_entries + if name not in {'high_noise_model', 'low_noise_model'} + ) + missing_non_expert = [ + name for name in source_non_expert_entries + if not os.path.exists(os.path.join(save_path, name)) + ] + if missing_non_expert: + raise RuntimeError( + 'Wan2.2 saved structure lost original non-expert files/directories: ' + f'{missing_non_expert}. source_path={source_path}, save_path={save_path}' + ) + + logger.info( + f'Wan2.2 native save structure verified. ' + f'top-level entries={sorted(os.listdir(save_path))}' + ) + + def save_wan2_2_pretrained(self, path): + """Wan2.2 专用保存:支持官方 native 与非官方 Pipeline 两种布局。 + + 该逻辑原本位于 llmc/compression/quantization/base_blockwise_quantization.py 的 Wan2T2V 分支。 + """ + if int(os.environ.get('RANK', '0')) != 0: + return + + if getattr(self.Pipeline, '_is_wan_official', False): + src = getattr(self, 'pipeline_model_path', self.model_path) + self.copy_native_checkpoint(src, path) + + self.Pipeline.transformer.save_pretrained( + os.path.join(path, 'high_noise_model') + ) + logger.info('save Wan2.2 high_noise_model done --') + if ( + hasattr(self.Pipeline, 'transformer_2') + and self.Pipeline.transformer_2 is not None + ): + self.Pipeline.transformer_2.save_pretrained( + os.path.join(path, 'low_noise_model') + ) + logger.info('save Wan2.2 low_noise_model done --') + + self.validate_native_save_structure(path, source_path=src) + return + + # Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.) + # so that non-quantized components are preserved. + src = getattr(self, 'pipeline_model_path', self.model_path) + copied_from_source = False + if isinstance(src, str) and os.path.isdir(src) and os.path.abspath(src) != os.path.abspath(path): + if os.path.exists(path): + shutil.rmtree(path) + shutil.copytree(src, path) + logger.info(f'Copied original pipeline from {src} to {path}') + copied_from_source = True + + if not copied_from_source: + if os.path.exists(path): + shutil.rmtree(path) + # Fallback for remote repo-id sources: materialize all non-quantized components first. + self.Pipeline.save_pretrained(path, safe_serialization=True) + logger.info( + 'save Wan2.2 full pipeline done via Pipeline.save_pretrained ' + f'(source={src}) --' + ) + + # Overwrite transformer subfolder with quantized weights. + self.Pipeline.transformer.save_pretrained( + os.path.join(path, 'transformer') + ) + logger.info('save Wan2.2 transformer done --') + if ( + hasattr(self.Pipeline, 'transformer_2') + and self.Pipeline.transformer_2 is not None + ): + self.Pipeline.transformer_2.save_pretrained( + os.path.join(path, 'transformer_2') + ) + logger.info('save Wan2.2 transformer_2 done --') + + def skip_layer_name(self): + pass diff --git a/llmc/models/wan_t2v.py b/llmc/models/wan_t2v.py index 885bccda3..59696686d 100755 --- a/llmc/models/wan_t2v.py +++ b/llmc/models/wan_t2v.py @@ -162,4 +162,4 @@ def get_layers_except_blocks(self): pass def skip_layer_name(self): - pass + pass \ No newline at end of file diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 5869fa8d0..8fd082be7 100755 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -6,6 +6,7 @@ loguru transformers>=4.45.2 lmms-eval==0.3.0 huggingface-hub +safetensors sentencepiece protobuf accelerate>=0.26.0