diff --git a/diffsynth_engine/distributed/parallel_state.py b/diffsynth_engine/distributed/parallel_state.py index d404d75b..91b31cce 100644 --- a/diffsynth_engine/distributed/parallel_state.py +++ b/diffsynth_engine/distributed/parallel_state.py @@ -214,12 +214,29 @@ def get_ranks(self, token): return ranks +# WORLD +def is_world_group_initialized() -> bool: + return _WORLD is not None + + def get_world_group() -> GroupCoordinator: assert _WORLD is not None, "world group is not initialized" return _WORLD +def get_global_world_size(): + return get_world_group().world_size + + +def get_global_rank(): + return get_world_group().rank_in_group + + # TP +def is_tp_group_initialized() -> bool: + return _TP is not None + + def get_tp_group() -> GroupCoordinator: assert _TP is not None, "tensor model parallel group is not initialized" return _TP @@ -236,6 +253,10 @@ def get_tensor_model_parallel_rank(): # SP +def is_sp_group_initialized() -> bool: + return _SP is not None + + def get_sp_group() -> SequenceParallelGroupCoordinator: assert _SP is not None, "pipeline model parallel group is not initialized" return _SP @@ -268,6 +289,10 @@ def get_ring_parallel_rank(): # PP +def is_pp_group_initialized() -> bool: + return _PP is not None + + def get_pp_group() -> PipelineGroupCoordinator: assert _PP is not None, "pipeline model parallel group is not initialized" return _PP @@ -294,6 +319,10 @@ def is_pipeline_last_stage(): # CFG +def is_cfg_group_initialized() -> bool: + return _CFG is not None + + def get_cfg_group() -> GroupCoordinator: assert _CFG is not None, "classifier_free_guidance parallel group is not initialized" return _CFG @@ -310,6 +339,10 @@ def get_classifier_free_guidance_rank(): # DP +def is_dp_group_initialized() -> bool: + return _DP is not None + + def get_dp_group() -> GroupCoordinator: assert _DP is not None, "pipeline model parallel group is not initialized" return _DP @@ -346,6 +379,10 @@ def get_dit_world_size(): # VAE +def is_vae_group_initialized() -> bool: + return _VAE is not None + + def get_vae_parallel_group() -> GroupCoordinator: assert _VAE is not None, "VAE parallel group is not initialized" return _VAE @@ -491,6 +528,10 @@ def init_dit_group( _DIT = torch.distributed.new_group(ranks=list(range(dit_parallel_size)), backend=backend) +def is_dit_group_initialized() -> bool: + return _DIT is not None + + def get_dit_group(): assert _DIT is not None, "DIT group is not initialized" return _DIT diff --git a/diffsynth_engine/engine.py b/diffsynth_engine/engine.py index 97eb7764..60cf6f18 100644 --- a/diffsynth_engine/engine.py +++ b/diffsynth_engine/engine.py @@ -139,6 +139,10 @@ def shutdown(self): self.workers = None self.conns = None + + if self.pipeline is not None: + del self.pipeline + self.pipeline = None def start_profile(self, path: str = ".", profile_rank0_only: bool = True): if self.workers is not None: diff --git a/diffsynth_engine/layers/attention/layer.py b/diffsynth_engine/layers/attention/layer.py index 4df7bc3d..f98678b2 100644 --- a/diffsynth_engine/layers/attention/layer.py +++ b/diffsynth_engine/layers/attention/layer.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import torch.distributed as dist import torch.nn as nn from diffsynth_engine.distributed.comm import SeqAllToAll4D @@ -11,6 +10,7 @@ get_ring_parallel_world_size, get_sp_group, get_ulysses_parallel_world_size, + is_sp_group_initialized, ) from diffsynth_engine.forward_context import ForwardContext, get_forward_context from diffsynth_engine.layers.attention.backends.abstract import AttentionType @@ -139,8 +139,8 @@ def forward( attn_kwargs = {"attn_metadata": attn_metadata} attn_kwargs.update(kwargs) - ulysses_parallel_world_size = get_ulysses_parallel_world_size() if dist.is_initialized() else 1 - ring_parallel_world_size = get_ring_parallel_world_size() if dist.is_initialized() else 1 + ulysses_parallel_world_size = get_ulysses_parallel_world_size() if is_sp_group_initialized() else 1 + ring_parallel_world_size = get_ring_parallel_world_size() if is_sp_group_initialized() else 1 if ulysses_parallel_world_size > 1: q = SeqAllToAll4D.apply(get_sp_group().ulysses_group, q, self.scatter_idx, self.gather_idx) diff --git a/diffsynth_engine/models/base.py b/diffsynth_engine/models/base.py index 43b90672..a21405a7 100644 --- a/diffsynth_engine/models/base.py +++ b/diffsynth_engine/models/base.py @@ -15,6 +15,13 @@ class DiffusionModel(nn.Module, ConfigMixin): config_name = CONFIG_NAME + @property + def dtype(self) -> torch.dtype: + param = next(self.parameters(), None) + if param is None: + raise RuntimeError(f"{type(self).__name__} has no parameters, cannot determine dtype") + return param.dtype + @classmethod def from_pretrained( cls, diff --git a/diffsynth_engine/models/qwen_image/autoencoder_kl_qwenimage.py b/diffsynth_engine/models/qwen_image/autoencoder_kl_qwenimage.py index 554986b4..632f365e 100644 --- a/diffsynth_engine/models/qwen_image/autoencoder_kl_qwenimage.py +++ b/diffsynth_engine/models/qwen_image/autoencoder_kl_qwenimage.py @@ -21,7 +21,6 @@ # - Paper: https://huggingface.co/papers/2503.20314 import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import register_to_config @@ -1167,8 +1166,8 @@ def parallel_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: self.clear_cache() - dist.all_reduce(values, group=group) - dist.all_reduce(weight, group=group) + group.all_reduce(values) + group.all_reduce(weight) enc = values / weight return enc @@ -1247,8 +1246,8 @@ def parallel_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> De self.clear_cache() - dist.all_reduce(values, group=group) - dist.all_reduce(weight, group=group) + group.all_reduce(values) + group.all_reduce(weight) dec = values / weight dec = torch.clamp(dec, min=-1.0, max=1.0) diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index b409b00b..55d48f2b 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -1,8 +1,19 @@ +from typing import Type + import torch -import torch.distributed as dist +import torch.nn as nn +from accelerate import init_empty_weights +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict from tqdm import tqdm from diffsynth_engine.configs import PipelineConfig +from diffsynth_engine.distributed.parallel_state import get_global_rank, is_world_group_initialized +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.utils import logging +from diffsynth_engine.utils.load_utils import fix_state_dict_key, load_model_weights + +logger = logging.get_logger(__name__) class Pipeline: @@ -17,6 +28,122 @@ def from_pretrained(cls, model_path_or_config: str | PipelineConfig): def __call__(self, *args, **kwargs): raise NotImplementedError() + @staticmethod + def init_transformer(model_cls: Type[nn.Module], pipeline_config: PipelineConfig, empty_weights: bool = False): + use_fsdp = pipeline_config.use_fsdp and is_world_group_initialized() + + with set_forward_context(attn_type=pipeline_config.attn_type): + with init_empty_weights(): + config = model_cls.load_config( + pipeline_config.model_path, + subfolder="transformer", + local_files_only=True, + ) + model = model_cls.from_config(config) + + if empty_weights: + return model + + if use_fsdp: + for block in model.transformer_blocks: + fully_shard(block) + fully_shard(model) + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="transformer", + device="cpu" if use_fsdp else pipeline_config.device, + dtype=pipeline_config.model_dtype, + broadcast_from_rank0=not use_fsdp, + ) + + if use_fsdp: + set_model_state_dict( + model, + state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + else: + model.load_state_dict(state_dict, strict=True, assign=True) + model.to(device=pipeline_config.device) + + del state_dict + return model + + @staticmethod + def init_text_encoder( + model_cls: Type[nn.Module], + pipeline_config: PipelineConfig, + key_mapping: dict = None, + empty_weights: bool = False, + ): + use_fsdp = pipeline_config.use_fsdp and is_world_group_initialized() + + with init_empty_weights(): + config = model_cls.config_class.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + model = model_cls(config) + + if empty_weights: + return model + + if use_fsdp: + for layer in model.model.language_model.layers: + fully_shard(layer) + fully_shard(model) + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="text_encoder", + device="cpu" if use_fsdp else pipeline_config.device, + dtype=pipeline_config.text_encoder_dtype, + broadcast_from_rank0=not use_fsdp, + ) + + if key_mapping: + state_dict = fix_state_dict_key(state_dict, key_mapping) + + if use_fsdp: + set_model_state_dict( + model, + state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + else: + model.load_state_dict(state_dict, strict=True, assign=True) + model.to(device=pipeline_config.device) + + del state_dict + return model + + @staticmethod + def init_vae(model_cls: Type[nn.Module], pipeline_config: PipelineConfig, empty_weights: bool = False): + with init_empty_weights(): + config = model_cls.load_config( + pipeline_config.model_path, + subfolder="vae", + local_files_only=True, + ) + model = model_cls.from_config(config) + + if empty_weights: + return model + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="vae", + device=pipeline_config.device, + dtype=pipeline_config.vae_dtype, + ) + model.load_state_dict(state_dict, strict=True, assign=True) + model.to(device=pipeline_config.device) + + del state_dict + return model + @torch.compiler.disable def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): @@ -28,7 +155,7 @@ def progress_bar(self, iterable=None, total=None): progress_bar_config = dict(self._progress_bar_config) if "disable" not in progress_bar_config: - is_rank_zero = not dist.is_initialized() or dist.get_rank() == 0 + is_rank_zero = not is_world_group_initialized() or get_global_rank() == 0 progress_bar_config["disable"] = not is_rank_zero if iterable is not None: diff --git a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage.py b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage.py index feec5281..ec2498a9 100644 --- a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage.py +++ b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage.py @@ -20,22 +20,19 @@ import numpy as np import torch -from accelerate import init_empty_weights from diffusers.image_processor import VaeImageProcessor -from diffusers.models import AutoencoderKLQwenImage from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor -from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer from diffsynth_engine.configs import QwenImagePipelineConfig from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized from diffsynth_engine.forward_context import set_forward_context from diffsynth_engine.layers.attention import get_attn_backend -from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel +from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel, AutoencoderKLQwenImage from diffsynth_engine.pipelines.base import Pipeline from diffsynth_engine.utils import logging -from diffsynth_engine.utils.load_utils import fix_state_dict_key, load_model_weights logger = logging.get_logger(__name__) @@ -184,7 +181,7 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") # Load transformer - transformer = cls.init_transformer(pipeline_config) + transformer = cls.init_transformer(QwenImageTransformer2DModel, pipeline_config) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( @@ -193,10 +190,14 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): ) # Load VAE - vae = cls.init_vae(pipeline_config) + vae = cls.init_vae(AutoencoderKLQwenImage, pipeline_config) # Load text encoder - text_encoder = cls.init_text_encoder(pipeline_config) + text_encoder = cls.init_text_encoder( + Qwen2_5_VLForConditionalGeneration, + pipeline_config, + key_mapping=getattr(Qwen2_5_VLForConditionalGeneration, "_checkpoint_conversion_mapping", None), + ) # Load tokenizer tokenizer = Qwen2Tokenizer.from_pretrained( @@ -214,77 +215,6 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): transformer=transformer, ) - @staticmethod - def init_transformer(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing transformer...") - with set_forward_context(attn_type=pipeline_config.attn_type): - if empty_weights: - with init_empty_weights(): - config_dict = QwenImageTransformer2DModel.load_config( - pipeline_config.model_path, - subfolder="transformer", - local_files_only=True, - ) - model = QwenImageTransformer2DModel.from_config(config_dict) - else: - model = QwenImageTransformer2DModel.from_pretrained( - pipeline_config.model_path, - subfolder="transformer", - device=pipeline_config.device, - dtype=pipeline_config.model_dtype, - ) - return model - - @staticmethod - def init_text_encoder(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing text encoder...") - with init_empty_weights(): - config = Qwen2_5_VLConfig.from_pretrained( - pipeline_config.model_path, - subfolder="text_encoder", - local_files_only=True, - ) - model = Qwen2_5_VLForConditionalGeneration(config) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="text_encoder", - device=pipeline_config.device, - dtype=pipeline_config.text_encoder_dtype, - ) - if key_mapping := getattr(model, "_checkpoint_conversion_mapping", None): - state_dict = fix_state_dict_key(state_dict, key_mapping) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - - @staticmethod - def init_vae(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing VAE...") - with init_empty_weights(): - config_dict = AutoencoderKLQwenImage.load_config( - pipeline_config.model_path, - subfolder="vae", - local_files_only=True, - ) - model = AutoencoderKLQwenImage.from_config(config_dict) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="vae", - device=pipeline_config.device, - dtype=pipeline_config.vae_dtype, - ) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) @@ -300,7 +230,7 @@ def _get_qwen_prompt_embeds( dtype: Optional[torch.dtype] = None, ): device = device or self.device - dtype = dtype or self.text_encoder.dtype + dtype = dtype or self.pipeline_config.text_encoder_dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -838,7 +768,7 @@ def __call__( image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = latents.to(self.vae.dtype) + latents = latents.to(self.pipeline_config.vae_dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) diff --git a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit.py b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit.py index 8520885f..597b2e4f 100644 --- a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit.py +++ b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit.py @@ -21,22 +21,22 @@ import numpy as np import torch -from accelerate import init_empty_weights from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.models import AutoencoderKLQwenImage from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor -from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from diffsynth_engine.configs.qwen_image import QwenImagePipelineConfig -from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.distributed.parallel_state import ( + get_cfg_group, + model_parallel_is_initialized, +) from diffsynth_engine.forward_context import set_forward_context from diffsynth_engine.layers.attention import get_attn_backend -from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel +from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel, AutoencoderKLQwenImage from diffsynth_engine.pipelines.base import Pipeline from diffsynth_engine.utils import logging -from diffsynth_engine.utils.load_utils import fix_state_dict_key, load_model_weights logger = logging.get_logger(__name__) @@ -218,7 +218,7 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") # Load transformer - transformer = cls.init_transformer(pipeline_config) + transformer = cls.init_transformer(QwenImageTransformer2DModel, pipeline_config) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( @@ -227,10 +227,14 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): ) # Load VAE - vae = cls.init_vae(pipeline_config) + vae = cls.init_vae(AutoencoderKLQwenImage, pipeline_config) # Load text encoder - text_encoder = cls.init_text_encoder(pipeline_config) + text_encoder = cls.init_text_encoder( + Qwen2_5_VLForConditionalGeneration, + pipeline_config, + key_mapping=getattr(Qwen2_5_VLForConditionalGeneration, "_checkpoint_conversion_mapping", None), + ) # Load tokenizer tokenizer = Qwen2Tokenizer.from_pretrained( @@ -254,76 +258,6 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): transformer=transformer, ) - @staticmethod - def init_transformer(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - with set_forward_context(attn_type=pipeline_config.attn_type): - if empty_weights: - with init_empty_weights(): - config_dict = QwenImageTransformer2DModel.load_config( - pipeline_config.model_path, - subfolder="transformer", - local_files_only=True, - ) - model = QwenImageTransformer2DModel.from_config(config_dict) - else: - model = QwenImageTransformer2DModel.from_pretrained( - pipeline_config.model_path, - subfolder="transformer", - device=pipeline_config.device, - dtype=pipeline_config.model_dtype, - ) - return model - - @staticmethod - def init_text_encoder(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing text encoder...") - with init_empty_weights(): - config = Qwen2_5_VLConfig.from_pretrained( - pipeline_config.model_path, - subfolder="text_encoder", - local_files_only=True, - ) - model = Qwen2_5_VLForConditionalGeneration(config) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="text_encoder", - device=pipeline_config.device, - dtype=pipeline_config.text_encoder_dtype, - ) - if key_mapping := getattr(model, "_checkpoint_conversion_mapping", None): - state_dict = fix_state_dict_key(state_dict, key_mapping) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - - @staticmethod - def init_vae(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing VAE...") - with init_empty_weights(): - config_dict = AutoencoderKLQwenImage.load_config( - pipeline_config.model_path, - subfolder="vae", - local_files_only=True, - ) - model = AutoencoderKLQwenImage.from_config(config_dict) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="vae", - device=pipeline_config.device, - dtype=pipeline_config.vae_dtype, - ) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) @@ -339,7 +273,7 @@ def _get_qwen_prompt_embeds( dtype: Optional[torch.dtype] = None, ): device = device or self.device - dtype = dtype or self.text_encoder.dtype + dtype = dtype or self.pipeline_config.text_encoder_dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -499,7 +433,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - image = image.to(dtype=self.vae.dtype) + image = image.to(dtype=self.pipeline_config.vae_dtype) if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") @@ -951,7 +885,7 @@ def __call__( image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = latents.to(self.vae.dtype) + latents = latents.to(self.pipeline_config.vae_dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) diff --git a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit_plus.py b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit_plus.py index fd53084e..1e4482eb 100644 --- a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit_plus.py +++ b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_edit_plus.py @@ -21,22 +21,22 @@ import numpy as np import torch -from accelerate import init_empty_weights from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.models import AutoencoderKLQwenImage from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor -from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from diffsynth_engine.configs.qwen_image import QwenImagePipelineConfig -from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.distributed.parallel_state import ( + get_cfg_group, + model_parallel_is_initialized, +) from diffsynth_engine.forward_context import set_forward_context from diffsynth_engine.layers.attention import get_attn_backend -from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel +from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel, AutoencoderKLQwenImage from diffsynth_engine.pipelines.base import Pipeline from diffsynth_engine.utils import logging -from diffsynth_engine.utils.load_utils import fix_state_dict_key, load_model_weights logger = logging.get_logger(__name__) @@ -224,7 +224,7 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") # Load transformer - transformer = cls.init_transformer(pipeline_config) + transformer = cls.init_transformer(QwenImageTransformer2DModel, pipeline_config) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( @@ -233,10 +233,14 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): ) # Load VAE - vae = cls.init_vae(pipeline_config) + vae = cls.init_vae(AutoencoderKLQwenImage, pipeline_config) # Load text encoder - text_encoder = cls.init_text_encoder(pipeline_config) + text_encoder = cls.init_text_encoder( + Qwen2_5_VLForConditionalGeneration, + pipeline_config, + key_mapping=getattr(Qwen2_5_VLForConditionalGeneration, "_checkpoint_conversion_mapping", None), + ) # Load tokenizer tokenizer = Qwen2Tokenizer.from_pretrained( @@ -261,76 +265,6 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): transformer=transformer, ) - @staticmethod - def init_transformer(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - with set_forward_context(attn_type=pipeline_config.attn_type): - if empty_weights: - with init_empty_weights(): - config_dict = QwenImageTransformer2DModel.load_config( - pipeline_config.model_path, - subfolder="transformer", - local_files_only=True, - ) - model = QwenImageTransformer2DModel.from_config(config_dict) - else: - model = QwenImageTransformer2DModel.from_pretrained( - pipeline_config.model_path, - subfolder="transformer", - device=pipeline_config.device, - dtype=pipeline_config.model_dtype, - ) - return model - - @staticmethod - def init_text_encoder(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing text encoder...") - with init_empty_weights(): - config = Qwen2_5_VLConfig.from_pretrained( - pipeline_config.model_path, - subfolder="text_encoder", - local_files_only=True, - ) - model = Qwen2_5_VLForConditionalGeneration(config) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="text_encoder", - device=pipeline_config.device, - dtype=pipeline_config.text_encoder_dtype, - ) - if key_mapping := getattr(model, "_checkpoint_conversion_mapping", None): - state_dict = fix_state_dict_key(state_dict, key_mapping) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - - @staticmethod - def init_vae(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing VAE...") - with init_empty_weights(): - config_dict = AutoencoderKLQwenImage.load_config( - pipeline_config.model_path, - subfolder="vae", - local_files_only=True, - ) - model = AutoencoderKLQwenImage.from_config(config_dict) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="vae", - device=pipeline_config.device, - dtype=pipeline_config.vae_dtype, - ) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) @@ -358,7 +292,7 @@ def _get_qwen_prompt_embeds( dtype: Target dtype """ device = device or self.device - dtype = dtype or self.text_encoder.dtype + dtype = dtype or self.pipeline_config.text_encoder_dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -538,7 +472,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - image = image.to(self.vae.dtype) + image = image.to(self.pipeline_config.vae_dtype) if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") @@ -1040,7 +974,7 @@ def __call__( image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = latents.to(self.vae.dtype) + latents = latents.to(self.pipeline_config.vae_dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) diff --git a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_layered.py b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_layered.py index 3e98748f..212e3ea5 100644 --- a/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_layered.py +++ b/diffsynth_engine/pipelines/qwen_image/pipeline_qwenimage_layered.py @@ -21,22 +21,22 @@ import numpy as np import torch -from accelerate import init_empty_weights from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.models import AutoencoderKLQwenImage from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor -from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from diffsynth_engine.configs.qwen_image import QwenImagePipelineConfig -from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.distributed.parallel_state import ( + get_cfg_group, + model_parallel_is_initialized, +) from diffsynth_engine.forward_context import set_forward_context from diffsynth_engine.layers.attention import get_attn_backend -from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel +from diffsynth_engine.models.qwen_image import QwenImageTransformer2DModel, AutoencoderKLQwenImage from diffsynth_engine.pipelines.base import Pipeline from diffsynth_engine.utils import logging -from diffsynth_engine.utils.load_utils import fix_state_dict_key, load_model_weights logger = logging.get_logger(__name__) @@ -240,7 +240,7 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") # Load transformer - transformer = cls.init_transformer(pipeline_config) + transformer = cls.init_transformer(QwenImageTransformer2DModel, pipeline_config) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( @@ -249,10 +249,14 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): ) # Load VAE - vae = cls.init_vae(pipeline_config) + vae = cls.init_vae(AutoencoderKLQwenImage, pipeline_config) # Load text encoder - text_encoder = cls.init_text_encoder(pipeline_config) + text_encoder = cls.init_text_encoder( + Qwen2_5_VLForConditionalGeneration, + pipeline_config, + key_mapping=getattr(Qwen2_5_VLForConditionalGeneration, "_checkpoint_conversion_mapping", None), + ) # Load tokenizer tokenizer = Qwen2Tokenizer.from_pretrained( @@ -277,76 +281,6 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig): transformer=transformer, ) - @staticmethod - def init_transformer(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - with set_forward_context(attn_type=pipeline_config.attn_type): - if empty_weights: - with init_empty_weights(): - config_dict = QwenImageTransformer2DModel.load_config( - pipeline_config.model_path, - subfolder="transformer", - local_files_only=True, - ) - model = QwenImageTransformer2DModel.from_config(config_dict) - else: - model = QwenImageTransformer2DModel.from_pretrained( - pipeline_config.model_path, - subfolder="transformer", - device=pipeline_config.device, - dtype=pipeline_config.model_dtype, - ) - return model - - @staticmethod - def init_text_encoder(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing text encoder...") - with init_empty_weights(): - config = Qwen2_5_VLConfig.from_pretrained( - pipeline_config.model_path, - subfolder="text_encoder", - local_files_only=True, - ) - model = Qwen2_5_VLForConditionalGeneration(config) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="text_encoder", - device=pipeline_config.device, - dtype=pipeline_config.text_encoder_dtype, - ) - if key_mapping := getattr(model, "_checkpoint_conversion_mapping", None): - state_dict = fix_state_dict_key(state_dict, key_mapping) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - - @staticmethod - def init_vae(pipeline_config: QwenImagePipelineConfig, empty_weights: bool = False): - logger.info("Initializing VAE...") - with init_empty_weights(): - config_dict = AutoencoderKLQwenImage.load_config( - pipeline_config.model_path, - subfolder="vae", - local_files_only=True, - ) - model = AutoencoderKLQwenImage.from_config(config_dict) - - if empty_weights: - return model - - state_dict = load_model_weights( - pipeline_config.model_path, - subfolder="vae", - device=pipeline_config.device, - dtype=pipeline_config.vae_dtype, - ) - model.load_state_dict(state_dict, strict=True, assign=True) - model.to(device=pipeline_config.device) - return model - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) @@ -362,7 +296,7 @@ def _get_qwen_prompt_embeds( dtype: Optional[torch.dtype] = None, ): device = device or self.device - dtype = dtype or self.text_encoder.dtype + dtype = dtype or self.pipeline_config.text_encoder_dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -536,7 +470,7 @@ def _unpack_latents(latents, height, width, layers, vae_scale_factor): return latents def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - image = image.to(self.vae.dtype) + image = image.to(self.pipeline_config.vae_dtype) if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") @@ -906,7 +840,7 @@ def __call__( prompt_image = image image = self.image_processor.preprocess(image, calculated_height, calculated_width) image = image.unsqueeze(2) - image = image.to(dtype=self.text_encoder.dtype) + image = image.to(dtype=self.pipeline_config.text_encoder_dtype) if prompt is None or prompt == "" or prompt == " ": prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device) @@ -1053,7 +987,7 @@ def __call__( image = latents else: latents = self._unpack_latents(latents, height, width, layers, self.vae_scale_factor) - latents = latents.to(self.vae.dtype) + latents = latents.to(self.pipeline_config.vae_dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) diff --git a/diffsynth_engine/utils/load_utils.py b/diffsynth_engine/utils/load_utils.py index 20d2d8e6..9c2ab043 100644 --- a/diffsynth_engine/utils/load_utils.py +++ b/diffsynth_engine/utils/load_utils.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Optional import torch -import torch.distributed as dist +from diffsynth_engine.distributed.parallel_state import get_global_rank, get_world_group, is_world_group_initialized from diffsynth_engine.utils import logging from diffsynth_engine.utils.constants import ( DIFFUSION_SAFETENSORS_INDEX_NAME, @@ -28,7 +28,7 @@ def load_safetensors(path: str, device: str = "cpu") -> Dict[str, Any]: - is_rank_zero = not dist.is_initialized() or dist.get_rank() == 0 + is_rank_zero = not is_world_group_initialized() or get_global_rank() == 0 start_time = time.perf_counter() if FAST_SAFETENSORS_AVAILABLE: if is_rank_zero: @@ -36,7 +36,8 @@ def load_safetensors(path: str, device: str = "cpu") -> Dict[str, Any]: num_threads = int(os.environ.get("FAST_SAFETENSORS_NUM_THREADS", 16)) direct_io = os.environ.get("FAST_SAFETENSORS_DIRECT_IO", "False").upper() == "TRUE" state_dict = load_file(path, num_threads=num_threads, direct_io=direct_io) - state_dict = {k: v.to(device=device) for k, v in state_dict.items()} + for k, v in state_dict.items(): + state_dict[k] = v.to(device=device, non_blocking=True) else: if is_rank_zero: logger.info(f"Safetensors loading model from {path}...") @@ -52,44 +53,58 @@ def load_model_weights( subfolder: Optional[str] = None, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, + broadcast_from_rank0: bool = True, ) -> Dict[str, Any]: - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - - if not os.path.exists(model_path): - raise FileNotFoundError(f"Model path not found: {model_path}") - - _diffusion_index_file = os.path.join(model_path, DIFFUSION_SAFETENSORS_INDEX_NAME) - _diffusion_weights_file = os.path.join(model_path, DIFFUSION_SAFETENSORS_WEIGHTS_NAME) - _index_file = os.path.join(model_path, SAFETENSORS_INDEX_NAME) - _weights_file = os.path.join(model_path, SAFETENSORS_WEIGHTS_NAME) - - index_file, weights_file = None, None - - if os.path.exists(_diffusion_index_file): - index_file = _diffusion_index_file - elif os.path.exists(_diffusion_weights_file): - weights_file = _diffusion_weights_file - elif os.path.exists(_index_file): - index_file = _index_file - elif os.path.exists(_weights_file): - weights_file = _weights_file - else: - raise FileNotFoundError(f"Safetensors index or weights file not found in {model_path}") - - if index_file is not None: - with open(index_file, "r", encoding="utf-8") as f: - index_dict = json.load(f) - weight_map = index_dict["weight_map"] - shard_files = sorted(set(weight_map.values())) - state_dict = {} - for shard_file in shard_files: - shard_file = os.path.join(model_path, shard_file) - state_dict.update(load_safetensors(shard_file)) - else: - state_dict = load_safetensors(weights_file) + world_group = get_world_group() if is_world_group_initialized() else None + is_rank_zero = world_group is None or get_global_rank() == 0 + + # rank 0 reads all shards + state_dict = {} + if is_rank_zero: + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model path not found: {model_path}") + + _diffusion_index_file = os.path.join(model_path, DIFFUSION_SAFETENSORS_INDEX_NAME) + _diffusion_weights_file = os.path.join(model_path, DIFFUSION_SAFETENSORS_WEIGHTS_NAME) + _index_file = os.path.join(model_path, SAFETENSORS_INDEX_NAME) + _weights_file = os.path.join(model_path, SAFETENSORS_WEIGHTS_NAME) + + index_file, weights_file = None, None + + if os.path.exists(_diffusion_index_file): + index_file = _diffusion_index_file + elif os.path.exists(_diffusion_weights_file): + weights_file = _diffusion_weights_file + elif os.path.exists(_index_file): + index_file = _index_file + elif os.path.exists(_weights_file): + weights_file = _weights_file + else: + raise FileNotFoundError(f"Safetensors index or weights file not found in {model_path}") + + if index_file is not None: + with open(index_file, "r", encoding="utf-8") as f: + index_dict = json.load(f) + weight_map = index_dict["weight_map"] + shard_files = sorted(set(weight_map.values())) + for shard_file in shard_files: + path = os.path.join(model_path, shard_file) + shard_dict = load_safetensors(path) + for k, v in shard_dict.items(): + shard_dict[k] = v.to(device=device, dtype=dtype, non_blocking=True) + state_dict.update(shard_dict) + else: + state_dict = load_safetensors(weights_file) + for k, v in state_dict.items(): + state_dict[k] = v.to(device=device, dtype=dtype, non_blocking=True) + + # rank 0 broadcasts full state dict to all other ranks + if broadcast_from_rank0 and world_group is not None: + state_dict = world_group.broadcast_tensor_dict(state_dict, src=0) - state_dict = {k: v.to(device=device, dtype=dtype, non_blocking=True) for k, v in state_dict.items()} return state_dict