Skip to content

Commit d09e5cf

Browse files
author
Yuwei Guo
committed
update infer script
1 parent 4498f76 commit d09e5cf

File tree

1 file changed

+11
-47
lines changed

1 file changed

+11
-47
lines changed

scripts/animate.py

Lines changed: 11 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
from animatediff.models.unet import UNet3DConditionModel
1616
from animatediff.pipelines.pipeline_animation import AnimationPipeline
1717
from animatediff.utils.util import save_videos_grid
18-
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
19-
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
18+
from animatediff.utils.util import load_weights
2019
from diffusers.utils.import_utils import is_xformers_available
2120

2221
from einops import rearrange, repeat
2322

2423
import csv, pdb, glob
25-
from safetensors import safe_open
2624
import math
2725
from pathlib import Path
2826

@@ -60,50 +58,16 @@ def main(args):
6058
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
6159
).to("cuda")
6260

63-
# 1. unet ckpt
64-
# 1.1 motion module
65-
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
66-
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
67-
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
68-
assert len(unexpected) == 0
69-
70-
# 1.2 T2I
71-
if model_config.path != "":
72-
if model_config.path.endswith(".ckpt"):
73-
state_dict = torch.load(model_config.path)
74-
pipeline.unet.load_state_dict(state_dict)
75-
76-
elif model_config.path.endswith(".safetensors"):
77-
state_dict = {}
78-
with safe_open(model_config.path, framework="pt", device="cpu") as f:
79-
for key in f.keys():
80-
state_dict[key] = f.get_tensor(key)
81-
82-
is_lora = all("lora" in k for k in state_dict.keys())
83-
if not is_lora:
84-
base_state_dict = state_dict
85-
else:
86-
base_state_dict = {}
87-
with safe_open(model_config.base, framework="pt", device="cpu") as f:
88-
for key in f.keys():
89-
base_state_dict[key] = f.get_tensor(key)
90-
91-
# vae
92-
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
93-
pipeline.vae.load_state_dict(converted_vae_checkpoint)
94-
# unet
95-
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
96-
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
97-
# text_model
98-
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
99-
100-
# import pdb
101-
# pdb.set_trace()
102-
if is_lora:
103-
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
104-
105-
pipeline.to("cuda")
106-
### <<< create validation pipeline <<< ###
61+
pipeline = load_weights(
62+
pipeline,
63+
# motion module
64+
motion_module_path = motion_module,
65+
motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
66+
# image layers
67+
dreambooth_model_path = model_config.get("dreambooth_path", ""),
68+
lora_model_path = model_config.get("lora_model_path", ""),
69+
lora_alpha = model_config.get("lora_alpha", 0.8),
70+
).to("cuda")
10771

10872
prompts = model_config.prompt
10973
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt

0 commit comments

Comments
 (0)