Skip to content
200 changes: 197 additions & 3 deletions python_coreml_stable_diffusion/torch2coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from copy import deepcopy
import coremltools as ct
from diffusers import StableDiffusionPipeline
from diffusers.models.vae import DiagonalGaussianDistribution
import gc

import logging
Expand All @@ -29,11 +30,22 @@
import torch.nn as nn
import torch.nn.functional as F

#from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op
#from coremltools.converters.mil.frontend.torch.ops import _get_inputs
#from coremltools.converters.mil import Builder as mb
#
#@register_torch_op
#def randn(context, node):
# inputs = _get_inputs(context, node, expected=5)
# shape = inputs[0]
#
# x = mb.random_normal(shape=shape, mean=0., stddev=1.)
# context.add(x, node.name)

torch.set_grad_enabled(False)

from types import MethodType


def _get_coreml_inputs(sample_inputs, args):
return [
ct.TensorType(
Expand All @@ -43,6 +55,23 @@ def _get_coreml_inputs(sample_inputs, args):
) for k, v in sample_inputs.items()
]

# Simpler version of `DiagonalGaussianDistribution` with only needed calculations
# as implemented in vae.py as part of the AutoencoderKL class
# This is because coreml tools does not support the `randn` operation, so we pass in a random tensor.
class CoreMLDiagonalGaussianDistribution(object):
def __init__(self, parameters, noise):
self.parameters = parameters
self.noise = noise
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)

def sample(self) -> torch.FloatTensor:
device = self.parameters.device
# make sure sample is on the same device as the parameters and has same dtype
sample = self.noise.to(device=device, dtype=self.parameters.dtype)
x = self.mean + self.std * sample
return x

def compute_psnr(a, b):
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
Expand Down Expand Up @@ -140,7 +169,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,

def quantize_weights_to_8bits(args):
for model_name in [
"text_encoder", "vae_decoder", "unet", "unet_chunk1",
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1",
"unet_chunk2", "safety_checker"
]:
out_path = _get_out_path(args, model_name)
Expand Down Expand Up @@ -190,6 +219,7 @@ def bundle_resources_for_swift_cli(args):
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
for source_name, target_name in [("text_encoder", "TextEncoder"),
("vae_decoder", "VAEDecoder"),
("vae_encoder", "VAEEncoder"),
("unet", "Unet"),
("unet_chunk1", "UnetChunk1"),
("unet_chunk2", "UnetChunk2"),
Expand Down Expand Up @@ -453,6 +483,164 @@ def forward(self, z):
gc.collect()


def convert_vae_encoder(pipe, args):
""" Converts the VAE Encoder component of Stable Diffusion
"""
out_path = _get_out_path(args, "vae_encoder")
if os.path.exists(out_path):
logger.info(
f"`vae_encoder` already exists at {out_path}, skipping conversion."
)
return

if not hasattr(pipe, "unet"):
raise RuntimeError(
"convert_unet() deletes pipe.unet to save RAM. "
"Please use convert_vae_encoder() before convert_unet()")

sample_shape = (
1, # B
3, # C (RGB range from -1 to 1)
args.latent_h or pipe.unet.config.sample_size * 8, # H
args.latent_w or pipe.unet.config.sample_size * 8, # w
)

noise_shape = (
1, # B
4, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # w
)

float_value_shape = (
1,
1,
)

sqrtAlphasCumprodTorchShape = torch.tensor([[0.2,]])
sqrtOneMinusAlphasCumprodTorchShape = torch.tensor([[0.8,]])

sample_vae_encoder_inputs = {
"sample": torch.rand(*sample_shape, dtype=torch.float16),
"diagonalNoise": torch.rand(*noise_shape, dtype=torch.float16),
"noise": torch.rand(*noise_shape, dtype=torch.float16),
"sqrtAlphasCumprod": torch.rand(*float_value_shape, dtype=torch.float16),
"sqrtOneMinusAlphasCumprod": torch.rand(*float_value_shape, dtype=torch.float16),
}

class VAEEncoder(nn.Module):
""" Wrapper nn.Module wrapper for pipe.encode() method
"""

def __init__(self):
super().__init__()
self.quant_conv = pipe.vae.quant_conv
self.alphas_cumprod = pipe.scheduler.alphas_cumprod
self.encoder = pipe.vae.encoder

# Because CoreMLTools does not support the torch.randn op, we pass in both
# the diagonal Noise for the `DiagonalGaussianDistribution` operation and
# the noise tensor combined with precalculated `sqrtAlphasCumprod` and `sqrtOneMinusAlphasCumprod`
# for faster computation.
def forward(self, sample, diagonalNoise, noise, sqrtAlphasCumprod, sqrtOneMinusAlphasCumprod):
h = self.encoder(sample)
moments = self.quant_conv(h)
diagonalNoise = diagonalNoise.to(sample.device)
# posterior = DiagonalGaussianDistribution(moments)
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonalNoise)
posteriorSample = posterior.sample()

# Add the scaling operation and the latent noise for faster computation
init_latents = 0.18215 * posteriorSample
result = self.add_noise(init_latents, noise, sqrtAlphasCumprod, sqrtOneMinusAlphasCumprod)
return result

def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
sqrtAlphasCumprod: torch.FloatTensor,
sqrtOneMinusAlphasCumprod: torch.FloatTensor
) -> torch.FloatTensor:
noise = noise.to(original_samples.device)
sqrtAlphasCumprod = sqrtAlphasCumprod.to(original_samples.device)
sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod.to(original_samples.device)
noisy_samples = sqrtAlphasCumprod * original_samples + sqrtOneMinusAlphasCumprod * noise
return noisy_samples


baseline_encoder = VAEEncoder().eval()

# No optimization needed for the VAE Encoder as it is a pure ConvNet
traced_vae_encoder = torch.jit.trace(
baseline_encoder, (
sample_vae_encoder_inputs["sample"].to(torch.float32),
sample_vae_encoder_inputs["diagonalNoise"].to(torch.float32),
sample_vae_encoder_inputs["noise"].to(torch.float32),
sqrtAlphasCumprodTorchShape.to(torch.float32),
sqrtOneMinusAlphasCumprodTorchShape.to(torch.float32)
))

modify_coremltools_torch_frontend_badbmm()
coreml_vae_encoder, out_path = _convert_to_coreml(
"vae_encoder", traced_vae_encoder, sample_vae_encoder_inputs,
["latent_dist"], args)

# Set model metadata
coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_vae_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_vae_encoder.version = args.model_version
coreml_vae_encoder.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."

# Set the input descriptions
coreml_vae_encoder.input_description["sample"] = \
"An image of the correct size to create the latent space with, image2image and in-painting."
coreml_vae_encoder.input_description["diagonalNoise"] = \
"Latent noise for `DiagonalGaussianDistribution` operation."
coreml_vae_encoder.input_description["noise"] = \
"Latent noise for use with strength parameter of image2image"
coreml_vae_encoder.input_description["sqrtAlphasCumprod"] = \
"Precalculated `sqrtAlphasCumprod` value based on strength and the current schedular's alphasCumprod values"
coreml_vae_encoder.input_description["sqrtOneMinusAlphasCumprod"] = \
"Precalculated `sqrtOneMinusAlphasCumprod` value based on strength and the current schedular's alphasCumprod values"

# Set the output descriptions
coreml_vae_encoder.output_description[
"latent_dist"] = "The latent embeddings from the unet model from the input image for image2image."

_save_mlpackage(coreml_vae_encoder, out_path)

logger.info(f"Saved vae_encoder into {out_path}")

# Parity check PyTorch vs CoreML
if args.check_output_correctness:
baseline_out = baseline_encoder(
sample=sample_vae_encoder_inputs["sample"].to(torch.float32),
diagonalNoise=sample_vae_encoder_inputs["diagonalNoise"].to(torch.float32),
noise=sample_vae_encoder_inputs["noise"].to(torch.float32),
sqrtAlphasCumprod=sqrtAlphasCumprodTorchShape,
sqrtOneMinusAlphasCumprod=sqrtOneMinusAlphasCumprodTorchShape,
).numpy(),

coreml_out = list(
coreml_vae_encoder.predict(
{
"sample": sample_vae_encoder_inputs["sample"].numpy(),
"diagonalNoise": sample_vae_encoder_inputs["diagonalNoise"].numpy(),
"noise": sample_vae_encoder_inputs["noise"].numpy(),
"sqrtAlphasCumprod": sqrtAlphasCumprodTorchShape.numpy(),
"sqrtOneMinusAlphasCumprod": sqrtOneMinusAlphasCumprodTorchShape.numpy()
}).values())

report_correctness(baseline_out[0], coreml_out[0],
"vae_encoder baseline PyTorch to baseline CoreML")

del traced_vae_encoder, pipe.vae.encoder, coreml_vae_encoder
gc.collect()


def convert_unet(pipe, args):
""" Converts the UNet component of Stable Diffusion
"""
Expand Down Expand Up @@ -801,7 +989,12 @@ def main(args):
logger.info("Converting vae_decoder")
convert_vae_decoder(pipe, args)
logger.info("Converted vae_decoder")


if args.convert_vae_encoder:
logger.info("Converting vae_encoder")
convert_vae_encoder(pipe, args)
logger.info("Converted vae_encoder")

if args.convert_unet:
logger.info("Converting unet")
convert_unet(pipe, args)
Expand Down Expand Up @@ -835,6 +1028,7 @@ def parser_spec():
# Select which models to export (All are needed for text-to-image pipeline to function)
parser.add_argument("--convert-text-encoder", action="store_true")
parser.add_argument("--convert-vae-decoder", action="store_true")
parser.add_argument("--convert-vae-encoder", action="store_true")
parser.add_argument("--convert-unet", action="store_true")
parser.add_argument("--convert-safety-checker", action="store_true")
parser.add_argument(
Expand Down
29 changes: 29 additions & 0 deletions swift/StableDiffusion/pipeline/AlphasCumprodCalculation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.

import Foundation

public struct AlphasCumprodCalculation {
public var sqrtAlphasCumprod: Float
public var sqrtOneMinusAlphasCumprod: Float

public init(
sqrtAlphasCumprod: Float,
sqrtOneMinusAlphasCumprod: Float
) {
self.sqrtAlphasCumprod = sqrtAlphasCumprod
self.sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod
}

public init(
alphasCumprod: [Float],
timesteps: Int = 1_000,
steps: Int,
strength: Float
) {
let tEnc = Int(strength * Float(steps))
let initTimestep = min(max(0, timesteps - timesteps / steps * (steps - tEnc) + 1), timesteps - 1)
self.sqrtAlphasCumprod = alphasCumprod[initTimestep].squareRoot()
self.sqrtOneMinusAlphasCumprod = (1 - alphasCumprod[initTimestep]).squareRoot()
}
}
Loading