Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,27 @@ def forward(self, x, context=None, mask=None):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40)
del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
del mask

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
# attention, what we cannot get enough of, by halves
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
sim = einsum('b i j, b j d -> b i d', sim, v)
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)


class BasicTransformerBlock(nn.Module):
Expand Down Expand Up @@ -258,4 +262,4 @@ def forward(self, x, context=None):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in
246 changes: 124 additions & 122 deletions optimizedSD/inpaint_gradio.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import argparse
import os
import re
import time
from contextlib import nullcontext
from itertools import islice
from random import randint

import gradio as gr
import numpy as np
import torch
from torchvision.utils import make_grid
import os, re
from PIL import Image
import torch
import numpy as np
from random import randint
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from einops import rearrange, repeat
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from transformers import logging
import pandas as pd

from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger

logging.set_verbosity_error()
import mimetypes

mimetypes.init()
mimetypes.add_type("application/javascript", ".js")

Expand All @@ -43,7 +43,6 @@ def load_model_from_config(ckpt, verbose=False):


def load_img(image, h0, w0):

image = image.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
Expand All @@ -60,96 +59,59 @@ def load_img(image, h0, w0):
return 2.0 * image - 1.0


def load_mask(mask, h0, w0, invert=False):

def load_mask(mask, h0, w0, newH, newW, invert=False):
image = mask.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
if(h0 is not None and w0 is not None):
print(f"loaded input mask of size ({w}, {h})")
if h0 is not None and w0 is not None:
h, w = h0, w0

w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32

print(f"New image size ({w}, {h})")
image = image.resize((64, 64), resample = Image.LANCZOS)
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.LANCZOS)
# image = image.resize((64, 64), resample=Image.LANCZOS)
image = np.array(image)

if invert:
print("inverted")
where_0, where_1 = np.where(image == 0),np.where(image == 255)
where_0, where_1 = np.where(image == 0), np.where(image == 255)
image[where_0], image[where_1] = 255, 0
image = image.astype(np.float32)/255.0
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image


config = "optimizedSD/v1-inference.yaml"
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, v_ in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd

def generate(
image,
prompt,
strength,
ddim_steps,
n_iter,
batch_size,
Height,
Width,
scale,
ddim_eta,
unet_bs,
device,
seed,
outdir,
img_format,
turbo,
full_precision,
image,
prompt,
strength,
ddim_steps,
n_iter,
batch_size,
Height,
Width,
scale,
ddim_eta,
unet_bs,
device,
seed,
outdir,
img_format,
turbo,
full_precision,
):

if seed == "":
seed = randint(0, 1000000)
seed = int(seed)
seed_everything(seed)
sampler = "ddim"

# Logging
logger(locals(), log_csv = "logs/inpaint_gradio_logs.csv")
logger(locals(), log_csv="logs/inpaint_gradio_logs.csv")

init_image = load_img(image['image'], Height, Width).to(device)
mask = load_mask(image['mask'], Height, Width, True).to(device)

model.unet_bs = unet_bs
model.turbo = turbo
Expand All @@ -161,10 +123,7 @@ def generate(
modelCS.half()
modelFS.half()
init_image = init_image.half()
mask.half()

mask = mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
# mask.half()

tic = time.time()
os.makedirs(outdir, exist_ok=True)
Expand All @@ -182,6 +141,10 @@ def generate(
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
init_latent = repeat(init_latent, "1 ... -> b ...", b=batch_size)

mask = load_mask(image['mask'], Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)

if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
Expand Down Expand Up @@ -237,23 +200,22 @@ def generate(
z_enc = model.stochastic_encode(
init_latent, torch.tensor([t_enc] * batch_size).to(device),
seed, ddim_eta, ddim_steps)

# decode it
samples_ddim = model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
mask = mask,
x_T = init_latent,
sampler = sampler,
mask=mask,
x_T=init_latent,
sampler=sampler,
)

modelFS.to(device)
print("saving images")
for i in range(batch_size):

x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
all_samples.append(x_sample.to("cpu"))
Expand Down Expand Up @@ -284,37 +246,77 @@ def generate(
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()

txt = (
"Samples finished in "
+ str(round(time_taken, 3))
+ " minutes and exported to \n"
+ sample_path
+ "\nSeeds used = "
+ seeds[:-1]
"Samples finished in "
+ str(round(time_taken, 3))
+ " minutes and exported to \n"
+ sample_path
+ "\nSeeds used = "
+ seeds[:-1]
)
return Image.fromarray(grid.astype(np.uint8)), image['mask'], txt


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='txt2img using gradio')
parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path')
parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path')
args = parser.parse_args()
config = args.config_path
ckpt = args.ckpt_path
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, v_ in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd

demo = gr.Interface(
fn=generate,
inputs=[
gr.Image(tool="sketch", type="pil"),
"text",
gr.Slider(0, 0.99, value=0.99, step=0.01),
gr.Slider(1, 1000, value=50),
gr.Slider(1, 100, step=1),
gr.Slider(1, 100, step=1),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(0, 50, value=7.5, step=0.1),
gr.Slider(0, 1, step=0.01),
gr.Slider(1, 2, value=1, step=1),
gr.Text(value="cuda"),
"text",
gr.Text(value="outputs/inpaint-samples"),
gr.Radio(["png", "jpg"], value='png'),
"checkbox",
"checkbox",
],
outputs=["image", "image", "text"],
)
return Image.fromarray(grid.astype(np.uint8)), image['mask'],txt


demo = gr.Interface(
fn=generate,
inputs=[
gr.Image(tool="sketch", type="pil"),
"text",
gr.Slider(0, 0.99, value=0.99, step = 0.01),
gr.Slider(1, 1000, value=50),
gr.Slider(1, 100, step=1),
gr.Slider(1, 100, step=1),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(64, 4096, value=512, step=64),
gr.Slider(0, 50, value=7.5, step=0.1),
gr.Slider(0, 1, step=0.01),
gr.Slider(1, 2, value=1, step=1),
gr.Text(value="cuda"),
"text",
gr.Text(value="outputs/inpaint-samples"),
gr.Radio(["png", "jpg"], value='png'),
"checkbox",
"checkbox",
],
outputs=["image", "image", "text"],
)
demo.launch()
demo.launch()