From 4e89f578ea14b1dce4442f51c82b488af0cdc07a Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Thu, 23 Feb 2023 12:45:18 +0100 Subject: [PATCH] support image mixing in streamlit --- scripts/streamlit/stableunclip.py | 73 +++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/scripts/streamlit/stableunclip.py b/scripts/streamlit/stableunclip.py index 6dd4bb7..f52622a 100644 --- a/scripts/streamlit/stableunclip.py +++ b/scripts/streamlit/stableunclip.py @@ -18,7 +18,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler - torch.set_grad_enabled(False) PROMPTS_ROOT = "scripts/prompts/" @@ -46,8 +45,8 @@ def instantiate_from_config(config): return get_obj_from_str(config["target"])(**config.get("params", dict())) -def get_interactive_image(): - image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) +def get_interactive_image(key=None): + image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) if not image.mode == "RGB": @@ -55,8 +54,8 @@ def get_interactive_image(): return image -def load_img(display=True): - image = get_interactive_image() +def load_img(display=True, key=None): + image = get_interactive_image(key=key) if display: st.image(image) w, h = image.size @@ -69,8 +68,8 @@ def load_img(display=True): return 2. * image - 1. -def get_init_img(batch_size=1): - init_image = load_img().cuda() +def get_init_img(batch_size=1, key=None): + init_image = load_img(key=key).cuda() init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) return init_image @@ -99,7 +98,7 @@ def sample( ): batch_size = n_samples precision_scope = autocast if not use_full_precision else nullcontext - #decoderscope = autocast if not use_full_precision else nullcontext + # decoderscope = autocast if not use_full_precision else nullcontext if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.") if isinstance(prompt, str): prompt = [prompt] @@ -274,8 +273,8 @@ if __name__ == "__main__": st.title("Stable unCLIP") mode = "txt2img" version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) - use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False) - state = init(version=version, load_karlo_prior=use_karlo) + use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False) + state = init(version=version, load_karlo_prior=use_karlo_prior) prompt = st.text_input("Prompt", "a professional photograph") negative_prompt = st.text_input("Negative Prompt", "") scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) @@ -309,7 +308,7 @@ if __name__ == "__main__": raise ValueError(f"unknown sampler {sampler}!") adm_cond, adm_uc = None, None - if use_karlo: + if use_karlo_prior: # uses the prior karlo_sampler = state["karlo_prior"] noise_level = None @@ -330,19 +329,44 @@ if __name__ == "__main__": torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) adm_cond = torch.cat((c_adm, noise_level_emb), 1) adm_uc = torch.zeros_like(adm_cond) - + elif version == "Full Karlo": + pass else: - init_img = get_init_img(batch_size=number_cols) - with torch.no_grad(): - adm_cond = state["model"].embedder(init_img) - if state["model"].noise_augmentor is not None: - noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0, - max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0) - c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat( - torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) - # assume this gives embeddings of noise levels + num_inputs = st.number_input("Number of Input Images", 1) + + + def make_conditionings_from_input(num=1, key=None): + init_img = get_init_img(batch_size=number_cols, key=key) + with torch.no_grad(): + adm_cond = state["model"].embedder(init_img) + weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.) + if state["model"].noise_augmentor is not None: + noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0, + max_value=state["model"].noise_augmentor.max_noise_level - 1, + value=0, ) + c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat( + torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) + adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight + adm_uc = torch.zeros_like(adm_cond) + return adm_cond, adm_uc, weight + + + adm_inputs = list() + weights = list() + for n in range(num_inputs): + adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n) + weights.append(w) + adm_inputs.append(adm_cond) + adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights) + if num_inputs > 1: + if st.checkbox("Apply Noise to Embedding Mix", True): + noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0, + max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0, ) + c_adm, noise_level_emb = state["model"].noise_augmentor( + adm_cond[:, :state["model"].noise_augmentor.time_embed.dim], + noise_level=repeat( + torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) adm_cond = torch.cat((c_adm, noise_level_emb), 1) - adm_uc = torch.zeros_like(adm_cond) if st.button("Sample"): print("running prompt:", prompt) @@ -350,10 +374,12 @@ if __name__ == "__main__": t_progress = st.progress(0) result = st.empty() + def t_callback(t): t_progress.progress(min((t + 1) / steps, 1.)) - if version == "KARLO": + + if version == "Full Karlo": outputs = st.empty() karlo_sampler = state["karlo_prior"] all_samples = list() @@ -388,4 +414,3 @@ if __name__ == "__main__": use_full_precision=force_full_precision, only_adm_cond=False ) - \ No newline at end of file