mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
support dpm
This commit is contained in:
parent
c81b231008
commit
4b71f18cfc
3 changed files with 27 additions and 19 deletions
|
@ -308,13 +308,13 @@ def model_wrapper(
|
|||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))]
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
|
|
|
@ -3,7 +3,6 @@ import torch
|
|||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
|
@ -50,12 +49,20 @@ class DPMSolverSampler(object):
|
|||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
if isinstance(ctmp, torch.Tensor):
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
if isinstance(conditioning, torch.Tensor):
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
|
@ -82,6 +89,7 @@ class DPMSolverSampler(object):
|
|||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||
lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
||||
return x.to(device), None
|
||||
|
|
|
@ -99,7 +99,8 @@ def sample(
|
|||
):
|
||||
batch_size = n_samples
|
||||
precision_scope = autocast if not use_full_precision else nullcontext
|
||||
if use_full_precision: st.warning(f"Sampling {model.__class__.__name__} at full precision.")
|
||||
#decoderscope = autocast if not use_full_precision else nullcontext
|
||||
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
prompts = batch_size * prompt
|
||||
|
@ -146,7 +147,6 @@ def sample(
|
|||
callback=callback,
|
||||
ucg_schedule=ucg_schedule
|
||||
)
|
||||
|
||||
x_samples = model.decode_first_stage(samples_ddim)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
|
@ -277,14 +277,14 @@ if __name__ == "__main__":
|
|||
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
||||
state = init(version=version, load_karlo_prior=use_karlo)
|
||||
st.info(state["msg"])
|
||||
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
||||
prompt = st.text_input("Prompt", "a professional photograph")
|
||||
negative_prompt = st.text_input("Negative Prompt", "")
|
||||
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
||||
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
||||
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
||||
steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000)
|
||||
steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
|
||||
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
||||
force_full_precision = st.sidebar.checkbox("Force FP32", False)
|
||||
force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
|
||||
if version != "Full Karlo":
|
||||
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
|
||||
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
|
||||
|
@ -330,7 +330,7 @@ if __name__ == "__main__":
|
|||
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||
adm_uc = torch.zeros_like(karlo_prediction)
|
||||
adm_uc = torch.zeros_like(adm_cond)
|
||||
|
||||
else:
|
||||
init_img = get_init_img(batch_size=number_cols)
|
||||
|
|
Loading…
Reference in a new issue