make it work

This commit is contained in:
Robin Rombach 2023-01-14 14:05:53 +01:00
parent 45287f9ed7
commit 929625ac9a
2 changed files with 5 additions and 7 deletions

View file

@ -16,6 +16,7 @@ model:
conditioning_key: crossattn-adm conditioning_key: crossattn-adm
scale_factor: 0.18215 scale_factor: 0.18215
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
use_ema: False
embedder_config: embedder_config:
target: ldm.modules.encoders.modules.ClipImageEmbedder target: ldm.modules.encoders.modules.ClipImageEmbedder

View file

@ -117,8 +117,6 @@ def sample(
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
if isinstance(model, Txt2ImgDiffusionWithPooledInput):
c, uc = c[0], uc[0]
if adm_cond is not None: if adm_cond is not None:
if adm_cond.shape[0] == 1: if adm_cond.shape[0] == 1:
@ -272,15 +270,14 @@ if __name__ == "__main__":
mode = "txt2img" mode = "txt2img"
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
use_karlo = st.checkbox("Use KARLO prior", False) use_karlo = st.checkbox("Use KARLO prior", False)
state = init(version=version, vae_version=vae_version, load_karlo_prior=use_karlo) state = init(version=version, load_karlo_prior=use_karlo)
st.info(state["msg"]) st.info(state["msg"])
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse") prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
negative_prompt = st.text_input("Negative Prompt", "") negative_prompt = st.text_input("Negative Prompt", "")
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10) number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10) number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
default_steps = 25 steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000)
steps = st.sidebar.number_input("steps", value=default_steps, min_value=1, max_value=1000)
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
force_full_precision = st.sidebar.checkbox("Force FP32", False) force_full_precision = st.sidebar.checkbox("Force FP32", False)
if version != "Full Karlo": if version != "Full Karlo":
@ -289,14 +286,14 @@ if __name__ == "__main__":
C = VERSION2SPECS[version]["C"] C = VERSION2SPECS[version]["C"]
f = VERSION2SPECS[version]["f"] f = VERSION2SPECS[version]["f"]
SAVE_PATH = os.path.join(SAVE_PATH, version + "_" + vae_version + "-decoder") SAVE_PATH = os.path.join(SAVE_PATH, version)
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True) os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed) seed_everything(seed)
ucg_schedule = None ucg_schedule = None
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 2) sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 0)
if version == "Full Karlo": if version == "Full Karlo":
pass pass
else: else: