diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 27ead0e..66cf745 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -315,7 +315,7 @@ class DDIMSampler(object): @torch.no_grad() def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): + use_original_steps=False, callback=None, img_callback=None): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -333,4 +333,5 @@ class DDIMSampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) if callback: callback(i) + if img_callback: img_callback(x_dec, i) return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba..357d092 100644 --- a/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -876,7 +876,7 @@ class DPM_Solver: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): + solver_type='dpm_solver', callback=None, img_callback=None): """ The adaptive step size solver based on singlestep DPM-Solver. Args: @@ -931,6 +931,8 @@ class DPM_Solver: s = t x_prev = x_lower lambda_s = ns.marginal_lambda(s) + if callback: callback(nfe) + if img_callback: img_callback(x, nfe) h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) nfe += order print('adaptive solver nfe', nfe) @@ -939,6 +941,7 @@ class DPM_Solver: def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', atol=0.0078, rtol=0.05, + callback=None, img_callback=None, ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. @@ -1040,7 +1043,8 @@ class DPM_Solver: if method == 'adaptive': with torch.no_grad(): x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) + solver_type=solver_type, + callback=callback, img_callback=img_callback) elif method == 'multistep': assert steps >= order timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) @@ -1065,6 +1069,8 @@ class DPM_Solver: step_order = order x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) + if callback: callback(step) + if img_callback: img_callback(x, step) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] @@ -1092,8 +1098,12 @@ class DPM_Solver: r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if callback: callback(i) + if img_callback: img_callback(x, i) if denoise_to_zero: x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + if callback: callback(len(orders)-1) + if img_callback: img_callback(x, len(orders)-1) return x diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8..9aa3edb 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -82,6 +82,6 @@ class DPMSolverSampler(object): ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True, callback=callback, img_callback=img_callback) return x.to(device), None \ No newline at end of file