Add callback and img_callback for DDIM decode() and DPM solver

This commit is contained in:
cmdr2 2022-12-16 16:12:58 +05:30
parent cc77f2300d
commit 4923394f5c
3 changed files with 15 additions and 4 deletions

View file

@ -315,7 +315,7 @@ class DDIMSampler(object):
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
use_original_steps=False, callback=None, img_callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
@ -333,4 +333,5 @@ class DDIMSampler(object):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
if img_callback: img_callback(x_dec, i)
return x_dec

View file

@ -876,7 +876,7 @@ class DPM_Solver:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
solver_type='dpm_solver'):
solver_type='dpm_solver', callback=None, img_callback=None):
"""
The adaptive step size solver based on singlestep DPM-Solver.
Args:
@ -931,6 +931,8 @@ class DPM_Solver:
s = t
x_prev = x_lower
lambda_s = ns.marginal_lambda(s)
if callback: callback(nfe)
if img_callback: img_callback(x, nfe)
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
nfe += order
print('adaptive solver nfe', nfe)
@ -939,6 +941,7 @@ class DPM_Solver:
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05,
callback=None, img_callback=None,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
@ -1040,7 +1043,8 @@ class DPM_Solver:
if method == 'adaptive':
with torch.no_grad():
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
solver_type=solver_type)
solver_type=solver_type,
callback=callback, img_callback=img_callback)
elif method == 'multistep':
assert steps >= order
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
@ -1065,6 +1069,8 @@ class DPM_Solver:
step_order = order
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
solver_type=solver_type)
if callback: callback(step)
if img_callback: img_callback(x, step)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
@ -1092,8 +1098,12 @@ class DPM_Solver:
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
if callback: callback(i)
if img_callback: img_callback(x, i)
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
if callback: callback(len(orders)-1)
if img_callback: img_callback(x, len(orders)-1)
return x

View file

@ -82,6 +82,6 @@ class DPMSolverSampler(object):
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True, callback=callback, img_callback=img_callback)
return x.to(device), None