import importlib import streamlit as st import torch import cv2 import numpy as np import PIL from omegaconf import OmegaConf from PIL import Image from tqdm import trange import io, os from torch import autocast from einops import rearrange, repeat from torchvision.utils import make_grid from pytorch_lightning import seed_everything from contextlib import nullcontext from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler torch.set_grad_enabled(False) PROMPTS_ROOT = "scripts/prompts/" SAVE_PATH = "outputs/demo/stable-karlo/" VERSION2SPECS = { "Stable Karlo": {"H": 768, "W": 768, "C": 4, "f": 8}, "Full Karlo": {} } def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_interactive_image(): image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image def load_img(display=True): image = get_interactive_image() if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") w, h = map(lambda x: x - x % 64, (w, h)) image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2. * image - 1. def get_init_img(batch_size=1): init_image = load_img().cuda() init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) return init_image def sample( model, prompt, n_runs=3, n_samples=2, H=512, W=512, C=4, f=8, scale=10.0, ddim_steps=50, ddim_eta=0.0, callback=None, skip_single_save=False, save_grid=True, ucg_schedule=None, negative_prompt="", adm_cond=None, adm_uc=None, use_full_precision=False, only_adm_cond=False ): batch_size = n_samples precision_scope = autocast if not use_full_precision else nullcontext if use_full_precision: st.warning(f"Sampling {model.__class__.__name__} at full precision.") if isinstance(prompt, str): prompt = [prompt] prompts = batch_size * prompt outputs = st.empty() with precision_scope("cuda"): with model.ema_scope(): all_samples = list() for n in trange(n_runs, desc="Sampling"): shape = [C, H // f, W // f] if not only_adm_cond: uc = None if scale != 1.0: uc = model.get_learned_conditioning(batch_size * [negative_prompt]) if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) if adm_cond is not None: if adm_cond.shape[0] == 1: adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size) if adm_uc is None: st.warning("Not guiding via c_adm") adm_uc = adm_cond else: if adm_uc.shape[0] == 1: adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size) if not only_adm_cond: c = {"c_crossattn": [c], "c_adm": adm_cond} uc = {"c_crossattn": [uc], "c_adm": adm_uc} else: c = adm_cond uc = adm_uc samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=batch_size, shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None, callback=callback, ucg_schedule=ucg_schedule ) x_samples = model.decode_first_stage(samples_ddim) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) if not skip_single_save: base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples"))) for x_sample in x_samples: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') Image.fromarray(x_sample.astype(np.uint8)).save( os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png")) base_count += 1 all_samples.append(x_samples) # get grid of all samples grid = torch.stack(all_samples, 0) grid = rearrange(grid, 'n b c h w -> (n h) (b w) c') outputs.image(grid.cpu().numpy()) # additionally, save grid grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8)) if save_grid: grid_count = len(os.listdir(SAVE_PATH)) - 1 grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png')) return x_samples def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.): schedule = list() for i in range(num_steps): if float(i / num_steps) < 0.1: schedule.append(max_weight) elif i % 2 == 0: schedule.append(min_weight) else: schedule.append(max_weight) print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}") return schedule def torch2np(x): x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8) x = x.permute(0, 2, 3, 1).detach().cpu().numpy() return x @st.cache(allow_output_mutation=True, suppress_st_warning=True) def init(version="Stable Karlo", load_karlo_prior=False): state = dict() if not "model" in state: if version == "Stable Karlo": config = "configs/stable-diffusion/v2-1-stable-karlo-inference.yaml" ckpt = "checkpoints/v2-1-stable-unclip-ft.ckpt" elif version == "Full Karlo": from ldm.modules.karlo.kakao.sampler import T2ISampler st.info("Loading full KARLO..") karlo = T2ISampler.from_pretrained( root_dir="checkpoints/karlo_models", clip_model_path="ViT-L-14.pt", clip_stat_path="ViT-L-14_stats.th", sampling_type="default", ) state["karlo_prior"] = karlo state["msg"] = "loaded full Karlo" return state else: raise ValueError(f"version {version} unknown!") config = OmegaConf.load(config) model, msg = load_model_from_config(config, ckpt, vae_sd=None) state["msg"] = msg if load_karlo_prior: from ldm.modules.karlo.kakao.sampler import PriorSampler st.info("Loading KARLO CLIP prior...") karlo_prior = PriorSampler.from_pretrained( root_dir="/fsx/robin/checkpoints/karlo_models", clip_model_path="ViT-L-14.pt", clip_stat_path="ViT-L-14_stats.th", sampling_type="default", ) state["karlo_prior"] = karlo_prior state["model"] = model state["ckpt"] = ckpt state["config"] = config return state def load_model_from_config(config, ckpt, verbose=False, vae_sd=None): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") msg = None if "global_step" in pl_sd: msg = f"This is global step {pl_sd['global_step']}. " if "model_ema.num_updates" in pl_sd["state_dict"]: msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates." global_step = pl_sd.get("global_step", "?") sd = pl_sd["state_dict"] if vae_sd is not None: for k in sd.keys(): if "first_stage" in k: sd[k] = vae_sd[k[len("first_stage_model."):]] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.cuda() model.eval() print(f"Loaded global step {global_step}") return model, msg if __name__ == "__main__": st.title("Stable Karlo") mode = "txt2img" version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) use_karlo = st.checkbox("Use KARLO prior", False) state = init(version=version, load_karlo_prior=use_karlo) st.info(state["msg"]) prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse") negative_prompt = st.text_input("Negative Prompt", "") scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10) number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10) steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000) eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) force_full_precision = st.sidebar.checkbox("Force FP32", False) if version != "Full Karlo": H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048) W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048) C = VERSION2SPECS[version]["C"] f = VERSION2SPECS[version]["f"] SAVE_PATH = os.path.join(SAVE_PATH, version) os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True) seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed_everything(seed) ucg_schedule = None sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 0) if version == "Full Karlo": pass else: if sampler == "PLMS": st.warning("NOTE: Some models (such as v-pred) currently only support DDIM/DPM sampling here") sampler = PLMSSampler(state["model"]) elif sampler == "DPM": st.warning("NOTE: Using DPM sampler with default sampling parameters (DPM-2)") sampler = DPMSolverSampler(state["model"]) elif sampler == "DDIM": sampler = DDIMSampler(state["model"]) if st.checkbox("Try oscillating guidance?", False): ucg_schedule = make_oscillating_guidance_schedule(num_steps=steps, max_weight=scale, min_weight=1.) else: raise ValueError(f"unknown sampler {sampler}!") adm_cond, adm_uc = None, None if use_karlo: # uses the prior karlo_sampler = state["karlo_prior"] with torch.no_grad(): karlo_prediction = iter( karlo_sampler( prompt=prompt, bsz=number_cols, progressive_mode="final", ) ).__next__() adm_cond = karlo_prediction adm_uc = torch.zeros_like(karlo_prediction) else: init_img = get_init_img(batch_size=number_cols) with torch.no_grad(): adm_cond = state["model"].embedder(init_img) adm_uc = torch.zeros_like(adm_cond) if st.button("Sample"): print("running prompt:", prompt) st.text("Sampling") t_progress = st.progress(0) result = st.empty() def t_callback(t): t_progress.progress(min((t + 1) / steps, 1.)) if version == "KARLO": outputs = st.empty() karlo_sampler = state["karlo_prior"] all_samples = list() with torch.no_grad(): for _ in range(number_rows): karlo_prediction = iter( karlo_sampler( prompt=prompt, bsz=number_cols, progressive_mode="final", ) ).__next__() all_samples.append(karlo_prediction) grid = torch.stack(all_samples, 0) grid = rearrange(grid, 'n b c h w -> (n h) (b w) c') outputs.image(grid.cpu().numpy()) else: samples = sample( state["model"], prompt, n_runs=number_rows, n_samples=number_cols, H=H, W=W, C=C, f=f, scale=scale, ddim_steps=steps, ddim_eta=eta, callback=t_callback, ucg_schedule=ucg_schedule, negative_prompt=negative_prompt, adm_cond=adm_cond, adm_uc=adm_uc, use_full_precision=force_full_precision, only_adm_cond=False )