Skip to content

Commit 05fdf47

Browse files
author
Yuwei Guo
committed
optimize memory cost
1 parent 41a698a commit 05fdf47

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

animatediff/pipelines/pipeline_animation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import torch
9+
from tqdm import tqdm
910

1011
from diffusers.utils import is_accelerate_available
1112
from packaging import version
@@ -239,7 +240,11 @@ def decode_latents(self, latents):
239240
video_length = latents.shape[2]
240241
latents = 1 / 0.18215 * latents
241242
latents = rearrange(latents, "b c f h w -> (b f) c h w")
242-
video = self.vae.decode(latents).sample
243+
# video = self.vae.decode(latents).sample
244+
video = []
245+
for frame_idx in tqdm(range(latents.shape[0])):
246+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
247+
video = torch.cat(video)
243248
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
244249
video = (video / 2 + 0.5).clamp(0, 1)
245250
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

scripts/animate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from animatediff.utils.util import save_videos_grid
1818
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
1919
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
20+
from diffusers.utils.import_utils import is_xformers_available
2021

2122
from einops import rearrange, repeat
2223

@@ -51,6 +52,9 @@ def main(args):
5152
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
5253
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
5354

55+
if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
56+
else: assert False
57+
5458
pipeline = AnimationPipeline(
5559
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
5660
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),

0 commit comments

Comments
 (0)