Fix encode inside ddim.py

This commit is contained in:
Alex Ergasti 2023-04-05 09:39:29 +02:00
parent cf1d67a6fd
commit 11ceee2da2

View file

@ -266,11 +266,18 @@ class DDIMSampler(object):
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:num_steps]
total_steps = timesteps.shape[0]
iterator = tqdm(timesteps, desc='Encoding image', total=total_steps)
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
for i, step in enumerate(iterator):
t = torch.full((x0.shape[0],), step, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else: