From 85e9d294c61559910bfe69566e31aced80d16212 Mon Sep 17 00:00:00 2001 From: Shuai Yang <596836482@qq.com> Date: Thu, 30 Mar 2023 10:09:54 +0800 Subject: [PATCH] fix a bug for DDIM inversion (DDIMSampler.encode()) fix a bug of the incorrect timestep in DDIMSampler.encode() for DDIM inversion t is incorrectly set as the index of timesteps rather than the timestep in the original code. --- ldm/models/diffusion/ddim.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index c6cfd57..7c5eded 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -254,7 +254,8 @@ class DDIMSampler(object): @torch.no_grad() def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): - num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + num_reference_steps = timesteps.shape[0] assert t_enc <= num_reference_steps num_steps = t_enc @@ -270,7 +271,7 @@ class DDIMSampler(object): intermediates = [] inter_steps = [] for i in tqdm(range(num_steps), desc='Encoding Image'): - t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long) if unconditional_guidance_scale == 1.: noise_pred = self.model.apply_model(x_next, t, c) else: @@ -334,4 +335,4 @@ class DDIMSampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) if callback: callback(i) - return x_dec \ No newline at end of file + return x_dec