mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 23:55:00 +00:00
Add callback and img_callback for DDIM decode() and DPM solver
This commit is contained in:
parent
cc77f2300d
commit
4923394f5c
3 changed files with 15 additions and 4 deletions
|
@ -315,7 +315,7 @@ class DDIMSampler(object):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||||
use_original_steps=False, callback=None):
|
use_original_steps=False, callback=None, img_callback=None):
|
||||||
|
|
||||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||||
timesteps = timesteps[:t_start]
|
timesteps = timesteps[:t_start]
|
||||||
|
@ -333,4 +333,5 @@ class DDIMSampler(object):
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(x_dec, i)
|
||||||
return x_dec
|
return x_dec
|
|
@ -876,7 +876,7 @@ class DPM_Solver:
|
||||||
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
||||||
|
|
||||||
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
||||||
solver_type='dpm_solver'):
|
solver_type='dpm_solver', callback=None, img_callback=None):
|
||||||
"""
|
"""
|
||||||
The adaptive step size solver based on singlestep DPM-Solver.
|
The adaptive step size solver based on singlestep DPM-Solver.
|
||||||
Args:
|
Args:
|
||||||
|
@ -931,6 +931,8 @@ class DPM_Solver:
|
||||||
s = t
|
s = t
|
||||||
x_prev = x_lower
|
x_prev = x_lower
|
||||||
lambda_s = ns.marginal_lambda(s)
|
lambda_s = ns.marginal_lambda(s)
|
||||||
|
if callback: callback(nfe)
|
||||||
|
if img_callback: img_callback(x, nfe)
|
||||||
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
||||||
nfe += order
|
nfe += order
|
||||||
print('adaptive solver nfe', nfe)
|
print('adaptive solver nfe', nfe)
|
||||||
|
@ -939,6 +941,7 @@ class DPM_Solver:
|
||||||
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||||
atol=0.0078, rtol=0.05,
|
atol=0.0078, rtol=0.05,
|
||||||
|
callback=None, img_callback=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
||||||
|
@ -1040,7 +1043,8 @@ class DPM_Solver:
|
||||||
if method == 'adaptive':
|
if method == 'adaptive':
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
||||||
solver_type=solver_type)
|
solver_type=solver_type,
|
||||||
|
callback=callback, img_callback=img_callback)
|
||||||
elif method == 'multistep':
|
elif method == 'multistep':
|
||||||
assert steps >= order
|
assert steps >= order
|
||||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||||
|
@ -1065,6 +1069,8 @@ class DPM_Solver:
|
||||||
step_order = order
|
step_order = order
|
||||||
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
|
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
|
||||||
solver_type=solver_type)
|
solver_type=solver_type)
|
||||||
|
if callback: callback(step)
|
||||||
|
if img_callback: img_callback(x, step)
|
||||||
for i in range(order - 1):
|
for i in range(order - 1):
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
@ -1092,8 +1098,12 @@ class DPM_Solver:
|
||||||
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
||||||
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
||||||
x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
|
x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(x, i)
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||||
|
if callback: callback(len(orders)-1)
|
||||||
|
if img_callback: img_callback(x, len(orders)-1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -82,6 +82,6 @@ class DPMSolverSampler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True, callback=callback, img_callback=img_callback)
|
||||||
|
|
||||||
return x.to(device), None
|
return x.to(device), None
|
Loading…
Reference in a new issue