diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py index 79390c0..c7c2e09 100644 --- a/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -131,7 +131,7 @@ class NoiseScheduleVP: self.log_alpha_array.to(t.device), ).reshape((-1)) elif self.schedule == "linear": - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 elif self.schedule == "cosine": log_alpha_fn = lambda s: torch.log( torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0) @@ -169,7 +169,7 @@ class NoiseScheduleVP: * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) ) - Delta = self.beta_0 ** 2 + tmp + Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) elif self.schedule == "discrete": log_alpha = -0.5 * torch.logaddexp( @@ -522,15 +522,21 @@ class DPM_Solver: if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3, ] * ( + orders = [ + 3, + ] * ( K - 2 ) + [2, 1] elif steps % 3 == 1: - orders = [3, ] * ( + orders = [ + 3, + ] * ( K - 1 ) + [1] else: - orders = [3, ] * ( + orders = [ + 3, + ] * ( K - 1 ) + [2] elif order == 2: @@ -541,7 +547,9 @@ class DPM_Solver: ] * K else: K = steps // 2 + 1 - orders = [2, ] * ( + orders = [ + 2, + ] * ( K - 1 ) + [1] elif order == 1: @@ -1000,7 +1008,7 @@ class DPM_Solver: - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1 - expand_dims( - alpha_t * ((torch.exp(-h) - 1.0 + h) / h ** 2 - 0.5), dims + alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims ) * D2 ) @@ -1009,7 +1017,7 @@ class DPM_Solver: expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h ** 2 - 0.5), dims) + - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims) * D2 ) return x_t