support dpm

This commit is contained in:
Robin Rombach 2023-01-30 11:13:15 +01:00
parent c81b231008
commit 4b71f18cfc
3 changed files with 27 additions and 19 deletions

View file

@ -308,13 +308,13 @@ def model_wrapper(
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2) t_in = torch.cat([t_continuous] * 2)
if isinstance(condition, dict): if isinstance(condition, dict):
assert isinstance(unconditional_conditioning, dict) assert isinstance(unconditional_condition, dict)
c_in = dict() c_in = dict()
for k in c: for k in condition:
if isinstance(c[k], list): if isinstance(condition[k], list):
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
else: else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
else: else:
c_in = torch.cat([unconditional_condition, condition]) c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)

View file

@ -3,7 +3,6 @@ import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = { MODEL_TYPES = {
"eps": "noise", "eps": "noise",
"v": "v" "v": "v"
@ -50,10 +49,18 @@ class DPMSolverSampler(object):
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else: else:
if isinstance(conditioning, torch.Tensor):
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
@ -82,6 +89,7 @@ class DPMSolverSampler(object):
) )
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None return x.to(device), None

View file

@ -99,7 +99,8 @@ def sample(
): ):
batch_size = n_samples batch_size = n_samples
precision_scope = autocast if not use_full_precision else nullcontext precision_scope = autocast if not use_full_precision else nullcontext
if use_full_precision: st.warning(f"Sampling {model.__class__.__name__} at full precision.") #decoderscope = autocast if not use_full_precision else nullcontext
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
prompts = batch_size * prompt prompts = batch_size * prompt
@ -146,7 +147,6 @@ def sample(
callback=callback, callback=callback,
ucg_schedule=ucg_schedule ucg_schedule=ucg_schedule
) )
x_samples = model.decode_first_stage(samples_ddim) x_samples = model.decode_first_stage(samples_ddim)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -277,14 +277,14 @@ if __name__ == "__main__":
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False) use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
state = init(version=version, load_karlo_prior=use_karlo) state = init(version=version, load_karlo_prior=use_karlo)
st.info(state["msg"]) st.info(state["msg"])
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse") prompt = st.text_input("Prompt", "a professional photograph")
negative_prompt = st.text_input("Negative Prompt", "") negative_prompt = st.text_input("Negative Prompt", "")
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10) number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10) number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000) steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
force_full_precision = st.sidebar.checkbox("Force FP32", False) force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
if version != "Full Karlo": if version != "Full Karlo":
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048) H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048) W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
@ -330,7 +330,7 @@ if __name__ == "__main__":
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat( c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1) adm_cond = torch.cat((c_adm, noise_level_emb), 1)
adm_uc = torch.zeros_like(karlo_prediction) adm_uc = torch.zeros_like(adm_cond)
else: else:
init_img = get_init_img(batch_size=number_cols) init_img = get_init_img(batch_size=number_cols)