|
15 | 15 | from animatediff.models.unet import UNet3DConditionModel |
16 | 16 | from animatediff.pipelines.pipeline_animation import AnimationPipeline |
17 | 17 | from animatediff.utils.util import save_videos_grid |
18 | | -from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint |
19 | | -from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora |
| 18 | +from animatediff.utils.util import load_weights |
20 | 19 | from diffusers.utils.import_utils import is_xformers_available |
21 | 20 |
|
22 | 21 | from einops import rearrange, repeat |
23 | 22 |
|
24 | 23 | import csv, pdb, glob |
25 | | -from safetensors import safe_open |
26 | 24 | import math |
27 | 25 | from pathlib import Path |
28 | 26 |
|
@@ -60,50 +58,16 @@ def main(args): |
60 | 58 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), |
61 | 59 | ).to("cuda") |
62 | 60 |
|
63 | | - # 1. unet ckpt |
64 | | - # 1.1 motion module |
65 | | - motion_module_state_dict = torch.load(motion_module, map_location="cpu") |
66 | | - if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) |
67 | | - missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) |
68 | | - assert len(unexpected) == 0 |
69 | | - |
70 | | - # 1.2 T2I |
71 | | - if model_config.path != "": |
72 | | - if model_config.path.endswith(".ckpt"): |
73 | | - state_dict = torch.load(model_config.path) |
74 | | - pipeline.unet.load_state_dict(state_dict) |
75 | | - |
76 | | - elif model_config.path.endswith(".safetensors"): |
77 | | - state_dict = {} |
78 | | - with safe_open(model_config.path, framework="pt", device="cpu") as f: |
79 | | - for key in f.keys(): |
80 | | - state_dict[key] = f.get_tensor(key) |
81 | | - |
82 | | - is_lora = all("lora" in k for k in state_dict.keys()) |
83 | | - if not is_lora: |
84 | | - base_state_dict = state_dict |
85 | | - else: |
86 | | - base_state_dict = {} |
87 | | - with safe_open(model_config.base, framework="pt", device="cpu") as f: |
88 | | - for key in f.keys(): |
89 | | - base_state_dict[key] = f.get_tensor(key) |
90 | | - |
91 | | - # vae |
92 | | - converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) |
93 | | - pipeline.vae.load_state_dict(converted_vae_checkpoint) |
94 | | - # unet |
95 | | - converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) |
96 | | - pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
97 | | - # text_model |
98 | | - pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict) |
99 | | - |
100 | | - # import pdb |
101 | | - # pdb.set_trace() |
102 | | - if is_lora: |
103 | | - pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) |
104 | | - |
105 | | - pipeline.to("cuda") |
106 | | - ### <<< create validation pipeline <<< ### |
| 61 | + pipeline = load_weights( |
| 62 | + pipeline, |
| 63 | + # motion module |
| 64 | + motion_module_path = motion_module, |
| 65 | + motion_module_lora_configs = model_config.get("motion_module_lora_configs", []), |
| 66 | + # image layers |
| 67 | + dreambooth_model_path = model_config.get("dreambooth_path", ""), |
| 68 | + lora_model_path = model_config.get("lora_model_path", ""), |
| 69 | + lora_alpha = model_config.get("lora_alpha", 0.8), |
| 70 | + ).to("cuda") |
107 | 71 |
|
108 | 72 | prompts = model_config.prompt |
109 | 73 | n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt |
|
0 commit comments