StableDiffusion/scripts/streamlit/stableunclip.py

392 lines
15 KiB
Python
Raw Normal View History

2023-01-14 12:48:28 +00:00
import importlib
import streamlit as st
import torch
import cv2
import numpy as np
import PIL
from omegaconf import OmegaConf
from PIL import Image
from tqdm import trange
import io, os
from torch import autocast
from einops import rearrange, repeat
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False)
PROMPTS_ROOT = "scripts/prompts/"
2023-01-27 13:59:02 +00:00
SAVE_PATH = "outputs/demo/stable-unclip/"
2023-01-14 12:48:28 +00:00
VERSION2SPECS = {
2023-01-27 13:59:02 +00:00
"Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
"Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
2023-01-14 12:48:28 +00:00
"Full Karlo": {}
}
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_interactive_image():
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def load_img(display=True):
image = get_interactive_image()
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
w, h = map(lambda x: x - x % 64, (w, h))
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2. * image - 1.
def get_init_img(batch_size=1):
init_image = load_img().cuda()
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
return init_image
def sample(
model,
prompt,
n_runs=3,
n_samples=2,
H=512,
W=512,
C=4,
f=8,
scale=10.0,
ddim_steps=50,
ddim_eta=0.0,
callback=None,
skip_single_save=False,
save_grid=True,
ucg_schedule=None,
negative_prompt="",
adm_cond=None,
adm_uc=None,
use_full_precision=False,
only_adm_cond=False
):
batch_size = n_samples
precision_scope = autocast if not use_full_precision else nullcontext
2023-01-30 10:13:15 +00:00
#decoderscope = autocast if not use_full_precision else nullcontext
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
2023-01-14 12:48:28 +00:00
if isinstance(prompt, str):
prompt = [prompt]
prompts = batch_size * prompt
outputs = st.empty()
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(n_runs, desc="Sampling"):
shape = [C, H // f, W // f]
if not only_adm_cond:
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
if adm_cond is not None:
if adm_cond.shape[0] == 1:
adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size)
if adm_uc is None:
st.warning("Not guiding via c_adm")
adm_uc = adm_cond
else:
if adm_uc.shape[0] == 1:
adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size)
if not only_adm_cond:
c = {"c_crossattn": [c], "c_adm": adm_cond}
uc = {"c_crossattn": [uc], "c_adm": adm_uc}
else:
c = adm_cond
uc = adm_uc
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=None,
callback=callback,
ucg_schedule=ucg_schedule
)
x_samples = model.decode_first_stage(samples_ddim)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not skip_single_save:
base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples")))
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png"))
base_count += 1
all_samples.append(x_samples)
# get grid of all samples
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
outputs.image(grid.cpu().numpy())
# additionally, save grid
grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8))
if save_grid:
grid_count = len(os.listdir(SAVE_PATH)) - 1
grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png'))
return x_samples
def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.):
schedule = list()
for i in range(num_steps):
if float(i / num_steps) < 0.1:
schedule.append(max_weight)
elif i % 2 == 0:
schedule.append(min_weight)
else:
schedule.append(max_weight)
print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}")
return schedule
def torch2np(x):
x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8)
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
return x
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
2023-01-27 13:59:02 +00:00
def init(version="Stable unCLIP-L", load_karlo_prior=False):
2023-01-14 12:48:28 +00:00
state = dict()
if not "model" in state:
2023-01-27 13:59:02 +00:00
if version == "Stable unCLIP-L":
config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
ckpt = "checkpoints/v2-1-stable-unclip-l-ft.ckpt"
elif version == "Stable unOpenCLIP-H":
config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
ckpt = "checkpoints/v2-1-stable-unclip-h-ft.ckpt"
2023-01-14 12:48:28 +00:00
elif version == "Full Karlo":
from ldm.modules.karlo.kakao.sampler import T2ISampler
st.info("Loading full KARLO..")
karlo = T2ISampler.from_pretrained(
root_dir="checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th",
sampling_type="default",
)
state["karlo_prior"] = karlo
state["msg"] = "loaded full Karlo"
return state
else:
raise ValueError(f"version {version} unknown!")
config = OmegaConf.load(config)
model, msg = load_model_from_config(config, ckpt, vae_sd=None)
state["msg"] = msg
if load_karlo_prior:
from ldm.modules.karlo.kakao.sampler import PriorSampler
st.info("Loading KARLO CLIP prior...")
karlo_prior = PriorSampler.from_pretrained(
2023-01-27 13:59:02 +00:00
root_dir="checkpoints/karlo_models",
2023-01-14 12:48:28 +00:00
clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th",
sampling_type="default",
)
state["karlo_prior"] = karlo_prior
state["model"] = model
state["ckpt"] = ckpt
state["config"] = config
return state
def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
msg = None
if "global_step" in pl_sd:
msg = f"This is global step {pl_sd['global_step']}. "
if "model_ema.num_updates" in pl_sd["state_dict"]:
msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates."
global_step = pl_sd.get("global_step", "?")
sd = pl_sd["state_dict"]
if vae_sd is not None:
for k in sd.keys():
if "first_stage" in k:
sd[k] = vae_sd[k[len("first_stage_model."):]]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
print(f"Loaded global step {global_step}")
return model, msg
if __name__ == "__main__":
2023-01-27 13:59:02 +00:00
st.title("Stable unCLIP")
2023-01-14 12:48:28 +00:00
mode = "txt2img"
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
2023-01-29 22:38:31 +00:00
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
2023-01-14 13:05:53 +00:00
state = init(version=version, load_karlo_prior=use_karlo)
2023-01-14 12:48:28 +00:00
st.info(state["msg"])
2023-01-30 10:13:15 +00:00
prompt = st.text_input("Prompt", "a professional photograph")
2023-01-14 12:48:28 +00:00
negative_prompt = st.text_input("Negative Prompt", "")
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
2023-01-30 10:13:15 +00:00
steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
2023-01-14 12:48:28 +00:00
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
2023-01-30 10:13:15 +00:00
force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
2023-01-14 12:48:28 +00:00
if version != "Full Karlo":
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
C = VERSION2SPECS[version]["C"]
f = VERSION2SPECS[version]["f"]
2023-01-14 13:05:53 +00:00
SAVE_PATH = os.path.join(SAVE_PATH, version)
2023-01-14 12:48:28 +00:00
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed)
ucg_schedule = None
2023-01-29 22:21:30 +00:00
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "DPM"], 0)
2023-01-14 12:48:28 +00:00
if version == "Full Karlo":
pass
else:
2023-01-29 22:21:30 +00:00
if sampler == "DPM":
2023-01-14 12:48:28 +00:00
sampler = DPMSolverSampler(state["model"])
elif sampler == "DDIM":
sampler = DDIMSampler(state["model"])
else:
raise ValueError(f"unknown sampler {sampler}!")
adm_cond, adm_uc = None, None
if use_karlo:
# uses the prior
karlo_sampler = state["karlo_prior"]
2023-01-18 09:24:38 +00:00
noise_level = None
if state["model"].noise_augmentor is not None:
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
2023-01-14 12:48:28 +00:00
with torch.no_grad():
karlo_prediction = iter(
karlo_sampler(
prompt=prompt,
bsz=number_cols,
progressive_mode="final",
)
).__next__()
2023-01-18 09:24:38 +00:00
adm_cond = karlo_prediction
if noise_level is not None:
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
2023-01-30 10:13:15 +00:00
adm_uc = torch.zeros_like(adm_cond)
2023-01-14 12:48:28 +00:00
else:
init_img = get_init_img(batch_size=number_cols)
with torch.no_grad():
adm_cond = state["model"].embedder(init_img)
2023-01-18 09:24:38 +00:00
if state["model"].noise_augmentor is not None:
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
# assume this gives embeddings of noise levels
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
2023-01-29 22:21:30 +00:00
adm_uc = torch.zeros_like(adm_cond)
2023-01-14 12:48:28 +00:00
if st.button("Sample"):
print("running prompt:", prompt)
st.text("Sampling")
t_progress = st.progress(0)
result = st.empty()
def t_callback(t):
t_progress.progress(min((t + 1) / steps, 1.))
if version == "KARLO":
outputs = st.empty()
karlo_sampler = state["karlo_prior"]
all_samples = list()
with torch.no_grad():
for _ in range(number_rows):
karlo_prediction = iter(
karlo_sampler(
prompt=prompt,
bsz=number_cols,
progressive_mode="final",
)
).__next__()
all_samples.append(karlo_prediction)
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
outputs.image(grid.cpu().numpy())
else:
samples = sample(
state["model"],
prompt,
n_runs=number_rows,
n_samples=number_cols,
H=H, W=W, C=C, f=f,
scale=scale,
ddim_steps=steps,
ddim_eta=eta,
callback=t_callback,
ucg_schedule=ucg_schedule,
negative_prompt=negative_prompt,
adm_cond=adm_cond, adm_uc=adm_uc,
use_full_precision=force_full_precision,
only_adm_cond=False
)