2023-01-14 12:48:28 +00:00
|
|
|
import importlib
|
|
|
|
import streamlit as st
|
|
|
|
import torch
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import PIL
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from PIL import Image
|
|
|
|
from tqdm import trange
|
|
|
|
import io, os
|
|
|
|
from torch import autocast
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
from torchvision.utils import make_grid
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
|
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
|
|
from ldm.models.diffusion.plms import PLMSSampler
|
|
|
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
|
|
|
|
|
|
|
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
PROMPTS_ROOT = "scripts/prompts/"
|
2023-01-27 13:59:02 +00:00
|
|
|
SAVE_PATH = "outputs/demo/stable-unclip/"
|
2023-01-14 12:48:28 +00:00
|
|
|
|
|
|
|
VERSION2SPECS = {
|
2023-01-27 13:59:02 +00:00
|
|
|
"Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
|
|
|
|
"Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
|
2023-01-14 12:48:28 +00:00
|
|
|
"Full Karlo": {}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def get_obj_from_str(string, reload=False):
|
|
|
|
module, cls = string.rsplit(".", 1)
|
|
|
|
importlib.invalidate_caches()
|
|
|
|
if reload:
|
|
|
|
module_imp = importlib.import_module(module)
|
|
|
|
importlib.reload(module_imp)
|
|
|
|
return getattr(importlib.import_module(module, package=None), cls)
|
|
|
|
|
|
|
|
|
|
|
|
def instantiate_from_config(config):
|
|
|
|
if not "target" in config:
|
|
|
|
raise KeyError("Expected key `target` to instantiate.")
|
|
|
|
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
|
|
|
|
|
|
|
|
|
|
|
def get_interactive_image():
|
|
|
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
|
|
|
if image is not None:
|
|
|
|
image = Image.open(image)
|
|
|
|
if not image.mode == "RGB":
|
|
|
|
image = image.convert("RGB")
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
def load_img(display=True):
|
|
|
|
image = get_interactive_image()
|
|
|
|
if display:
|
|
|
|
st.image(image)
|
|
|
|
w, h = image.size
|
|
|
|
print(f"loaded input image of size ({w}, {h})")
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h))
|
|
|
|
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
|
|
image = torch.from_numpy(image)
|
|
|
|
return 2. * image - 1.
|
|
|
|
|
|
|
|
|
|
|
|
def get_init_img(batch_size=1):
|
|
|
|
init_image = load_img().cuda()
|
|
|
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
|
|
|
return init_image
|
|
|
|
|
|
|
|
|
|
|
|
def sample(
|
|
|
|
model,
|
|
|
|
prompt,
|
|
|
|
n_runs=3,
|
|
|
|
n_samples=2,
|
|
|
|
H=512,
|
|
|
|
W=512,
|
|
|
|
C=4,
|
|
|
|
f=8,
|
|
|
|
scale=10.0,
|
|
|
|
ddim_steps=50,
|
|
|
|
ddim_eta=0.0,
|
|
|
|
callback=None,
|
|
|
|
skip_single_save=False,
|
|
|
|
save_grid=True,
|
|
|
|
ucg_schedule=None,
|
|
|
|
negative_prompt="",
|
|
|
|
adm_cond=None,
|
|
|
|
adm_uc=None,
|
|
|
|
use_full_precision=False,
|
|
|
|
only_adm_cond=False
|
|
|
|
):
|
|
|
|
batch_size = n_samples
|
|
|
|
precision_scope = autocast if not use_full_precision else nullcontext
|
2023-01-30 10:13:15 +00:00
|
|
|
#decoderscope = autocast if not use_full_precision else nullcontext
|
|
|
|
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
2023-01-14 12:48:28 +00:00
|
|
|
if isinstance(prompt, str):
|
|
|
|
prompt = [prompt]
|
|
|
|
prompts = batch_size * prompt
|
|
|
|
|
|
|
|
outputs = st.empty()
|
|
|
|
|
|
|
|
with precision_scope("cuda"):
|
|
|
|
with model.ema_scope():
|
|
|
|
all_samples = list()
|
|
|
|
for n in trange(n_runs, desc="Sampling"):
|
|
|
|
shape = [C, H // f, W // f]
|
|
|
|
if not only_adm_cond:
|
|
|
|
uc = None
|
|
|
|
if scale != 1.0:
|
|
|
|
uc = model.get_learned_conditioning(batch_size * [negative_prompt])
|
|
|
|
if isinstance(prompts, tuple):
|
|
|
|
prompts = list(prompts)
|
|
|
|
c = model.get_learned_conditioning(prompts)
|
|
|
|
|
|
|
|
if adm_cond is not None:
|
|
|
|
if adm_cond.shape[0] == 1:
|
|
|
|
adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size)
|
|
|
|
if adm_uc is None:
|
|
|
|
st.warning("Not guiding via c_adm")
|
|
|
|
adm_uc = adm_cond
|
|
|
|
else:
|
|
|
|
if adm_uc.shape[0] == 1:
|
|
|
|
adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size)
|
|
|
|
if not only_adm_cond:
|
|
|
|
c = {"c_crossattn": [c], "c_adm": adm_cond}
|
|
|
|
uc = {"c_crossattn": [uc], "c_adm": adm_uc}
|
|
|
|
else:
|
|
|
|
c = adm_cond
|
|
|
|
uc = adm_uc
|
|
|
|
samples_ddim, _ = sampler.sample(S=ddim_steps,
|
|
|
|
conditioning=c,
|
|
|
|
batch_size=batch_size,
|
|
|
|
shape=shape,
|
|
|
|
verbose=False,
|
|
|
|
unconditional_guidance_scale=scale,
|
|
|
|
unconditional_conditioning=uc,
|
|
|
|
eta=ddim_eta,
|
|
|
|
x_T=None,
|
|
|
|
callback=callback,
|
|
|
|
ucg_schedule=ucg_schedule
|
|
|
|
)
|
|
|
|
x_samples = model.decode_first_stage(samples_ddim)
|
|
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
|
|
|
if not skip_single_save:
|
|
|
|
base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples")))
|
|
|
|
for x_sample in x_samples:
|
|
|
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
|
|
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
|
|
os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png"))
|
|
|
|
base_count += 1
|
|
|
|
|
|
|
|
all_samples.append(x_samples)
|
|
|
|
|
|
|
|
# get grid of all samples
|
|
|
|
grid = torch.stack(all_samples, 0)
|
|
|
|
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
|
|
|
|
outputs.image(grid.cpu().numpy())
|
|
|
|
|
|
|
|
# additionally, save grid
|
|
|
|
grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8))
|
|
|
|
if save_grid:
|
|
|
|
grid_count = len(os.listdir(SAVE_PATH)) - 1
|
|
|
|
grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png'))
|
|
|
|
|
|
|
|
return x_samples
|
|
|
|
|
|
|
|
|
|
|
|
def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.):
|
|
|
|
schedule = list()
|
|
|
|
for i in range(num_steps):
|
|
|
|
if float(i / num_steps) < 0.1:
|
|
|
|
schedule.append(max_weight)
|
|
|
|
elif i % 2 == 0:
|
|
|
|
schedule.append(min_weight)
|
|
|
|
else:
|
|
|
|
schedule.append(max_weight)
|
|
|
|
print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}")
|
|
|
|
return schedule
|
|
|
|
|
|
|
|
|
|
|
|
def torch2np(x):
|
|
|
|
x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8)
|
|
|
|
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
2023-01-27 13:59:02 +00:00
|
|
|
def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
2023-01-14 12:48:28 +00:00
|
|
|
state = dict()
|
|
|
|
if not "model" in state:
|
2023-01-27 13:59:02 +00:00
|
|
|
if version == "Stable unCLIP-L":
|
|
|
|
config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
|
|
|
|
ckpt = "checkpoints/v2-1-stable-unclip-l-ft.ckpt"
|
|
|
|
|
|
|
|
elif version == "Stable unOpenCLIP-H":
|
|
|
|
config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
|
|
|
|
ckpt = "checkpoints/v2-1-stable-unclip-h-ft.ckpt"
|
2023-01-14 12:48:28 +00:00
|
|
|
|
|
|
|
elif version == "Full Karlo":
|
|
|
|
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
|
|
|
st.info("Loading full KARLO..")
|
|
|
|
karlo = T2ISampler.from_pretrained(
|
|
|
|
root_dir="checkpoints/karlo_models",
|
|
|
|
clip_model_path="ViT-L-14.pt",
|
|
|
|
clip_stat_path="ViT-L-14_stats.th",
|
|
|
|
sampling_type="default",
|
|
|
|
)
|
|
|
|
state["karlo_prior"] = karlo
|
|
|
|
state["msg"] = "loaded full Karlo"
|
|
|
|
return state
|
|
|
|
else:
|
|
|
|
raise ValueError(f"version {version} unknown!")
|
|
|
|
|
|
|
|
config = OmegaConf.load(config)
|
|
|
|
model, msg = load_model_from_config(config, ckpt, vae_sd=None)
|
|
|
|
state["msg"] = msg
|
|
|
|
|
|
|
|
if load_karlo_prior:
|
|
|
|
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
|
|
|
st.info("Loading KARLO CLIP prior...")
|
|
|
|
karlo_prior = PriorSampler.from_pretrained(
|
2023-01-27 13:59:02 +00:00
|
|
|
root_dir="checkpoints/karlo_models",
|
2023-01-14 12:48:28 +00:00
|
|
|
clip_model_path="ViT-L-14.pt",
|
|
|
|
clip_stat_path="ViT-L-14_stats.th",
|
|
|
|
sampling_type="default",
|
|
|
|
)
|
|
|
|
state["karlo_prior"] = karlo_prior
|
|
|
|
state["model"] = model
|
|
|
|
state["ckpt"] = ckpt
|
|
|
|
state["config"] = config
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
|
|
|
|
print(f"Loading model from {ckpt}")
|
|
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
msg = None
|
|
|
|
if "global_step" in pl_sd:
|
|
|
|
msg = f"This is global step {pl_sd['global_step']}. "
|
|
|
|
if "model_ema.num_updates" in pl_sd["state_dict"]:
|
|
|
|
msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates."
|
|
|
|
global_step = pl_sd.get("global_step", "?")
|
|
|
|
sd = pl_sd["state_dict"]
|
|
|
|
if vae_sd is not None:
|
|
|
|
for k in sd.keys():
|
|
|
|
if "first_stage" in k:
|
|
|
|
sd[k] = vae_sd[k[len("first_stage_model."):]]
|
|
|
|
|
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
|
|
if len(m) > 0 and verbose:
|
|
|
|
print("missing keys:")
|
|
|
|
print(m)
|
|
|
|
if len(u) > 0 and verbose:
|
|
|
|
print("unexpected keys:")
|
|
|
|
print(u)
|
|
|
|
|
|
|
|
model.cuda()
|
|
|
|
model.eval()
|
|
|
|
print(f"Loaded global step {global_step}")
|
|
|
|
return model, msg
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-01-27 13:59:02 +00:00
|
|
|
st.title("Stable unCLIP")
|
2023-01-14 12:48:28 +00:00
|
|
|
mode = "txt2img"
|
|
|
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
2023-01-29 22:38:31 +00:00
|
|
|
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
2023-01-14 13:05:53 +00:00
|
|
|
state = init(version=version, load_karlo_prior=use_karlo)
|
2023-01-14 12:48:28 +00:00
|
|
|
st.info(state["msg"])
|
2023-01-30 10:13:15 +00:00
|
|
|
prompt = st.text_input("Prompt", "a professional photograph")
|
2023-01-14 12:48:28 +00:00
|
|
|
negative_prompt = st.text_input("Negative Prompt", "")
|
|
|
|
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
|
|
|
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
|
|
|
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
2023-01-30 10:13:15 +00:00
|
|
|
steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
|
2023-01-14 12:48:28 +00:00
|
|
|
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
2023-01-30 10:13:15 +00:00
|
|
|
force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
|
2023-01-14 12:48:28 +00:00
|
|
|
if version != "Full Karlo":
|
|
|
|
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
|
|
|
|
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
|
|
|
|
C = VERSION2SPECS[version]["C"]
|
|
|
|
f = VERSION2SPECS[version]["f"]
|
|
|
|
|
2023-01-14 13:05:53 +00:00
|
|
|
SAVE_PATH = os.path.join(SAVE_PATH, version)
|
2023-01-14 12:48:28 +00:00
|
|
|
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
|
|
|
|
|
|
|
|
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
|
|
|
seed_everything(seed)
|
|
|
|
|
|
|
|
ucg_schedule = None
|
2023-01-29 22:21:30 +00:00
|
|
|
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "DPM"], 0)
|
2023-01-14 12:48:28 +00:00
|
|
|
if version == "Full Karlo":
|
|
|
|
pass
|
|
|
|
else:
|
2023-01-29 22:21:30 +00:00
|
|
|
if sampler == "DPM":
|
2023-01-14 12:48:28 +00:00
|
|
|
sampler = DPMSolverSampler(state["model"])
|
|
|
|
elif sampler == "DDIM":
|
|
|
|
sampler = DDIMSampler(state["model"])
|
|
|
|
else:
|
|
|
|
raise ValueError(f"unknown sampler {sampler}!")
|
|
|
|
|
|
|
|
adm_cond, adm_uc = None, None
|
|
|
|
if use_karlo:
|
|
|
|
# uses the prior
|
|
|
|
karlo_sampler = state["karlo_prior"]
|
2023-01-18 09:24:38 +00:00
|
|
|
noise_level = None
|
|
|
|
if state["model"].noise_augmentor is not None:
|
|
|
|
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
|
|
|
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
2023-01-14 12:48:28 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
karlo_prediction = iter(
|
|
|
|
karlo_sampler(
|
|
|
|
prompt=prompt,
|
|
|
|
bsz=number_cols,
|
|
|
|
progressive_mode="final",
|
|
|
|
)
|
|
|
|
).__next__()
|
2023-01-18 09:24:38 +00:00
|
|
|
adm_cond = karlo_prediction
|
|
|
|
if noise_level is not None:
|
|
|
|
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
|
|
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
|
|
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
2023-01-30 10:13:15 +00:00
|
|
|
adm_uc = torch.zeros_like(adm_cond)
|
2023-01-14 12:48:28 +00:00
|
|
|
|
|
|
|
else:
|
|
|
|
init_img = get_init_img(batch_size=number_cols)
|
|
|
|
with torch.no_grad():
|
|
|
|
adm_cond = state["model"].embedder(init_img)
|
2023-01-18 09:24:38 +00:00
|
|
|
if state["model"].noise_augmentor is not None:
|
|
|
|
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
|
|
|
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
|
|
|
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
|
|
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
|
|
|
# assume this gives embeddings of noise levels
|
|
|
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
2023-01-29 22:21:30 +00:00
|
|
|
adm_uc = torch.zeros_like(adm_cond)
|
2023-01-14 12:48:28 +00:00
|
|
|
|
|
|
|
if st.button("Sample"):
|
|
|
|
print("running prompt:", prompt)
|
|
|
|
st.text("Sampling")
|
|
|
|
t_progress = st.progress(0)
|
|
|
|
result = st.empty()
|
|
|
|
|
|
|
|
def t_callback(t):
|
|
|
|
t_progress.progress(min((t + 1) / steps, 1.))
|
|
|
|
|
|
|
|
if version == "KARLO":
|
|
|
|
outputs = st.empty()
|
|
|
|
karlo_sampler = state["karlo_prior"]
|
|
|
|
all_samples = list()
|
|
|
|
with torch.no_grad():
|
|
|
|
for _ in range(number_rows):
|
|
|
|
karlo_prediction = iter(
|
|
|
|
karlo_sampler(
|
|
|
|
prompt=prompt,
|
|
|
|
bsz=number_cols,
|
|
|
|
progressive_mode="final",
|
|
|
|
)
|
|
|
|
).__next__()
|
|
|
|
all_samples.append(karlo_prediction)
|
|
|
|
grid = torch.stack(all_samples, 0)
|
|
|
|
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
|
|
|
|
outputs.image(grid.cpu().numpy())
|
|
|
|
|
|
|
|
else:
|
|
|
|
samples = sample(
|
|
|
|
state["model"],
|
|
|
|
prompt,
|
|
|
|
n_runs=number_rows,
|
|
|
|
n_samples=number_cols,
|
|
|
|
H=H, W=W, C=C, f=f,
|
|
|
|
scale=scale,
|
|
|
|
ddim_steps=steps,
|
|
|
|
ddim_eta=eta,
|
|
|
|
callback=t_callback,
|
|
|
|
ucg_schedule=ucg_schedule,
|
|
|
|
negative_prompt=negative_prompt,
|
|
|
|
adm_cond=adm_cond, adm_uc=adm_uc,
|
|
|
|
use_full_precision=force_full_precision,
|
|
|
|
only_adm_cond=False
|
|
|
|
)
|
|
|
|
|