mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
fix a bug for DDIM inversion (DDIMSampler.encode())
fix a bug of the incorrect timestep in DDIMSampler.encode() for DDIM inversion t is incorrectly set as the index of timesteps rather than the timestep in the original code.
This commit is contained in:
parent
cf1d67a6fd
commit
85e9d294c6
1 changed files with 4 additions and 3 deletions
|
@ -254,7 +254,8 @@ class DDIMSampler(object):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
||||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||||
|
num_reference_steps = timesteps.shape[0]
|
||||||
|
|
||||||
assert t_enc <= num_reference_steps
|
assert t_enc <= num_reference_steps
|
||||||
num_steps = t_enc
|
num_steps = t_enc
|
||||||
|
@ -270,7 +271,7 @@ class DDIMSampler(object):
|
||||||
intermediates = []
|
intermediates = []
|
||||||
inter_steps = []
|
inter_steps = []
|
||||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
|
||||||
if unconditional_guidance_scale == 1.:
|
if unconditional_guidance_scale == 1.:
|
||||||
noise_pred = self.model.apply_model(x_next, t, c)
|
noise_pred = self.model.apply_model(x_next, t, c)
|
||||||
else:
|
else:
|
||||||
|
@ -334,4 +335,4 @@ class DDIMSampler(object):
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
return x_dec
|
return x_dec
|
||||||
|
|
Loading…
Reference in a new issue