From 929625ac9a5986b981da92adf3a7f68615686774 Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Sat, 14 Jan 2023 14:05:53 +0100 Subject: [PATCH] make it work --- .../stable-diffusion/v2-1-stable-karlo-inference.yaml | 1 + scripts/streamlit/stablekarlo.py | 11 ++++------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml b/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml index da867b4..5aeb176 100644 --- a/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml +++ b/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml @@ -16,6 +16,7 @@ model: conditioning_key: crossattn-adm scale_factor: 0.18215 monitor: val/loss_simple_ema + use_ema: False embedder_config: target: ldm.modules.encoders.modules.ClipImageEmbedder diff --git a/scripts/streamlit/stablekarlo.py b/scripts/streamlit/stablekarlo.py index e57500c..905e62f 100644 --- a/scripts/streamlit/stablekarlo.py +++ b/scripts/streamlit/stablekarlo.py @@ -117,8 +117,6 @@ def sample( if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) - if isinstance(model, Txt2ImgDiffusionWithPooledInput): - c, uc = c[0], uc[0] if adm_cond is not None: if adm_cond.shape[0] == 1: @@ -272,15 +270,14 @@ if __name__ == "__main__": mode = "txt2img" version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) use_karlo = st.checkbox("Use KARLO prior", False) - state = init(version=version, vae_version=vae_version, load_karlo_prior=use_karlo) + state = init(version=version, load_karlo_prior=use_karlo) st.info(state["msg"]) prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse") negative_prompt = st.text_input("Negative Prompt", "") scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10) number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10) - default_steps = 25 - steps = st.sidebar.number_input("steps", value=default_steps, min_value=1, max_value=1000) + steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000) eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) force_full_precision = st.sidebar.checkbox("Force FP32", False) if version != "Full Karlo": @@ -289,14 +286,14 @@ if __name__ == "__main__": C = VERSION2SPECS[version]["C"] f = VERSION2SPECS[version]["f"] - SAVE_PATH = os.path.join(SAVE_PATH, version + "_" + vae_version + "-decoder") + SAVE_PATH = os.path.join(SAVE_PATH, version) os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True) seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed_everything(seed) ucg_schedule = None - sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 2) + sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 0) if version == "Full Karlo": pass else: