From 4b71f18cfca7ea73d1ebe80a785e500f9c4d03f2 Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Mon, 30 Jan 2023 11:13:15 +0100 Subject: [PATCH] support dpm --- ldm/models/diffusion/dpm_solver/dpm_solver.py | 10 ++++---- ldm/models/diffusion/dpm_solver/sampler.py | 24 ++++++++++++------- scripts/streamlit/stableunclip.py | 12 +++++----- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py index fbe7aac..da8d41f 100644 --- a/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -308,13 +308,13 @@ def model_wrapper( x_in = torch.cat([x] * 2) t_in = torch.cat([t_continuous] * 2) if isinstance(condition, dict): - assert isinstance(unconditional_conditioning, dict) + assert isinstance(unconditional_condition, dict) c_in = dict() - for k in c: - if isinstance(c[k], list): - c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))] else: - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + c_in[k] = torch.cat([unconditional_condition[k], condition[k]]) else: c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8..3c1e021 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -3,7 +3,6 @@ import torch from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver - MODEL_TYPES = { "eps": "noise", "v": "v" @@ -50,12 +49,20 @@ class DPMSolverSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + if isinstance(ctmp, torch.Tensor): + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + if isinstance(conditioning, torch.Tensor): + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") # sampling C, H, W = shape @@ -82,6 +89,7 @@ class DPMSolverSampler(object): ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, + lower_order_final=True) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/scripts/streamlit/stableunclip.py b/scripts/streamlit/stableunclip.py index 892a809..c193434 100644 --- a/scripts/streamlit/stableunclip.py +++ b/scripts/streamlit/stableunclip.py @@ -99,7 +99,8 @@ def sample( ): batch_size = n_samples precision_scope = autocast if not use_full_precision else nullcontext - if use_full_precision: st.warning(f"Sampling {model.__class__.__name__} at full precision.") + #decoderscope = autocast if not use_full_precision else nullcontext + if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.") if isinstance(prompt, str): prompt = [prompt] prompts = batch_size * prompt @@ -146,7 +147,6 @@ def sample( callback=callback, ucg_schedule=ucg_schedule ) - x_samples = model.decode_first_stage(samples_ddim) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) @@ -277,14 +277,14 @@ if __name__ == "__main__": use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False) state = init(version=version, load_karlo_prior=use_karlo) st.info(state["msg"]) - prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse") + prompt = st.text_input("Prompt", "a professional photograph") negative_prompt = st.text_input("Negative Prompt", "") scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10) number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10) - steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000) + steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000) eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) - force_full_precision = st.sidebar.checkbox("Force FP32", False) + force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break. if version != "Full Karlo": H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048) W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048) @@ -330,7 +330,7 @@ if __name__ == "__main__": c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat( torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) adm_cond = torch.cat((c_adm, noise_level_emb), 1) - adm_uc = torch.zeros_like(karlo_prediction) + adm_uc = torch.zeros_like(adm_cond) else: init_img = get_init_img(batch_size=number_cols)