mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Fix encode inside ddim.py
This commit is contained in:
parent
cf1d67a6fd
commit
11ceee2da2
1 changed files with 9 additions and 2 deletions
|
@ -266,11 +266,18 @@ class DDIMSampler(object):
|
||||||
alphas_next = self.ddim_alphas[:num_steps]
|
alphas_next = self.ddim_alphas[:num_steps]
|
||||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||||
|
|
||||||
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||||
|
timesteps = timesteps[:num_steps]
|
||||||
|
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
iterator = tqdm(timesteps, desc='Encoding image', total=total_steps)
|
||||||
|
|
||||||
x_next = x0
|
x_next = x0
|
||||||
intermediates = []
|
intermediates = []
|
||||||
inter_steps = []
|
inter_steps = []
|
||||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
for i, step in enumerate(iterator):
|
||||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
|
||||||
|
t = torch.full((x0.shape[0],), step, device=self.model.device, dtype=torch.long)
|
||||||
if unconditional_guidance_scale == 1.:
|
if unconditional_guidance_scale == 1.:
|
||||||
noise_pred = self.model.apply_model(x_next, t, c)
|
noise_pred = self.model.apply_model(x_next, t, c)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue