Fix a bug

This commit is contained in:
qcloud 2023-11-26 17:05:47 +08:00
parent cf1d67a6fd
commit 0b44aad16e

View file

@ -70,7 +70,9 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule # select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps] alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# alphas_prev should start with 1.
alphas_prev = np.append(1., alphas[:-1])
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))