mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Add callback and img_callback for DDIM decode() and DPM solver
This commit is contained in:
parent
cc77f2300d
commit
4923394f5c
3 changed files with 15 additions and 4 deletions
|
@ -315,7 +315,7 @@ class DDIMSampler(object):
|
|||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, callback=None):
|
||||
use_original_steps=False, callback=None, img_callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
@ -333,4 +333,5 @@ class DDIMSampler(object):
|
|||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(x_dec, i)
|
||||
return x_dec
|
|
@ -876,7 +876,7 @@ class DPM_Solver:
|
|||
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
||||
|
||||
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
||||
solver_type='dpm_solver'):
|
||||
solver_type='dpm_solver', callback=None, img_callback=None):
|
||||
"""
|
||||
The adaptive step size solver based on singlestep DPM-Solver.
|
||||
Args:
|
||||
|
@ -931,6 +931,8 @@ class DPM_Solver:
|
|||
s = t
|
||||
x_prev = x_lower
|
||||
lambda_s = ns.marginal_lambda(s)
|
||||
if callback: callback(nfe)
|
||||
if img_callback: img_callback(x, nfe)
|
||||
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
||||
nfe += order
|
||||
print('adaptive solver nfe', nfe)
|
||||
|
@ -939,6 +941,7 @@ class DPM_Solver:
|
|||
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||
atol=0.0078, rtol=0.05,
|
||||
callback=None, img_callback=None,
|
||||
):
|
||||
"""
|
||||
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
||||
|
@ -1040,7 +1043,8 @@ class DPM_Solver:
|
|||
if method == 'adaptive':
|
||||
with torch.no_grad():
|
||||
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
||||
solver_type=solver_type)
|
||||
solver_type=solver_type,
|
||||
callback=callback, img_callback=img_callback)
|
||||
elif method == 'multistep':
|
||||
assert steps >= order
|
||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
|
@ -1065,6 +1069,8 @@ class DPM_Solver:
|
|||
step_order = order
|
||||
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
|
||||
solver_type=solver_type)
|
||||
if callback: callback(step)
|
||||
if img_callback: img_callback(x, step)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
|
@ -1092,8 +1098,12 @@ class DPM_Solver:
|
|||
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
||||
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
||||
x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(x, i)
|
||||
if denoise_to_zero:
|
||||
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||
if callback: callback(len(orders)-1)
|
||||
if img_callback: img_callback(x, len(orders)-1)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -82,6 +82,6 @@ class DPMSolverSampler(object):
|
|||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True, callback=callback, img_callback=img_callback)
|
||||
|
||||
return x.to(device), None
|
Loading…
Reference in a new issue