mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
support dpm
This commit is contained in:
parent
c81b231008
commit
4b71f18cfc
3 changed files with 27 additions and 19 deletions
|
@ -308,13 +308,13 @@ def model_wrapper(
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
if isinstance(condition, dict):
|
if isinstance(condition, dict):
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
assert isinstance(unconditional_condition, dict)
|
||||||
c_in = dict()
|
c_in = dict()
|
||||||
for k in c:
|
for k in condition:
|
||||||
if isinstance(c[k], list):
|
if isinstance(condition[k], list):
|
||||||
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))]
|
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
|
||||||
else:
|
else:
|
||||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
|
||||||
else:
|
else:
|
||||||
c_in = torch.cat([unconditional_condition, condition])
|
c_in = torch.cat([unconditional_condition, condition])
|
||||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||||
|
|
|
@ -3,7 +3,6 @@ import torch
|
||||||
|
|
||||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||||
|
|
||||||
|
|
||||||
MODEL_TYPES = {
|
MODEL_TYPES = {
|
||||||
"eps": "noise",
|
"eps": "noise",
|
||||||
"v": "v"
|
"v": "v"
|
||||||
|
@ -50,10 +49,18 @@ class DPMSolverSampler(object):
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
|
if isinstance(ctmp, torch.Tensor):
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
elif isinstance(conditioning, list):
|
||||||
|
for ctmp in conditioning:
|
||||||
|
if ctmp.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
else:
|
else:
|
||||||
|
if isinstance(conditioning, torch.Tensor):
|
||||||
if conditioning.shape[0] != batch_size:
|
if conditioning.shape[0] != batch_size:
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
@ -82,6 +89,7 @@ class DPMSolverSampler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||||
|
lower_order_final=True)
|
||||||
|
|
||||||
return x.to(device), None
|
return x.to(device), None
|
|
@ -99,7 +99,8 @@ def sample(
|
||||||
):
|
):
|
||||||
batch_size = n_samples
|
batch_size = n_samples
|
||||||
precision_scope = autocast if not use_full_precision else nullcontext
|
precision_scope = autocast if not use_full_precision else nullcontext
|
||||||
if use_full_precision: st.warning(f"Sampling {model.__class__.__name__} at full precision.")
|
#decoderscope = autocast if not use_full_precision else nullcontext
|
||||||
|
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
prompts = batch_size * prompt
|
prompts = batch_size * prompt
|
||||||
|
@ -146,7 +147,6 @@ def sample(
|
||||||
callback=callback,
|
callback=callback,
|
||||||
ucg_schedule=ucg_schedule
|
ucg_schedule=ucg_schedule
|
||||||
)
|
)
|
||||||
|
|
||||||
x_samples = model.decode_first_stage(samples_ddim)
|
x_samples = model.decode_first_stage(samples_ddim)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
@ -277,14 +277,14 @@ if __name__ == "__main__":
|
||||||
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
||||||
state = init(version=version, load_karlo_prior=use_karlo)
|
state = init(version=version, load_karlo_prior=use_karlo)
|
||||||
st.info(state["msg"])
|
st.info(state["msg"])
|
||||||
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
prompt = st.text_input("Prompt", "a professional photograph")
|
||||||
negative_prompt = st.text_input("Negative Prompt", "")
|
negative_prompt = st.text_input("Negative Prompt", "")
|
||||||
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
||||||
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
||||||
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
||||||
steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000)
|
steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
|
||||||
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
||||||
force_full_precision = st.sidebar.checkbox("Force FP32", False)
|
force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
|
||||||
if version != "Full Karlo":
|
if version != "Full Karlo":
|
||||||
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
|
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
|
||||||
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
|
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
|
||||||
|
@ -330,7 +330,7 @@ if __name__ == "__main__":
|
||||||
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
adm_uc = torch.zeros_like(karlo_prediction)
|
adm_uc = torch.zeros_like(adm_cond)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
init_img = get_init_img(batch_size=number_cols)
|
init_img = get_init_img(batch_size=number_cols)
|
||||||
|
|
Loading…
Reference in a new issue