fix a bug for DDIM inversion (DDIMSampler.encode())

fix a bug of the incorrect timestep in DDIMSampler.encode() for DDIM inversion

t is incorrectly set as the index of timesteps rather than the timestep in the original code.
This commit is contained in:
Shuai Yang 2023-03-30 10:09:54 +08:00 committed by GitHub
parent cf1d67a6fd
commit 85e9d294c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -254,7 +254,8 @@ class DDIMSampler(object):
@torch.no_grad() @torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
num_reference_steps = timesteps.shape[0]
assert t_enc <= num_reference_steps assert t_enc <= num_reference_steps
num_steps = t_enc num_steps = t_enc
@ -270,7 +271,7 @@ class DDIMSampler(object):
intermediates = [] intermediates = []
inter_steps = [] inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'): for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.: if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c) noise_pred = self.model.apply_model(x_next, t, c)
else: else: