Update dpm_solver.py

This commit is contained in:
Andres Caicedo 2023-05-08 14:43:25 +02:00
parent 462a9d3298
commit 683c14645e
No known key found for this signature in database
GPG key ID: 6E797C4F5A327624

View file

@ -131,7 +131,7 @@ class NoiseScheduleVP:
self.log_alpha_array.to(t.device), self.log_alpha_array.to(t.device),
).reshape((-1)) ).reshape((-1))
elif self.schedule == "linear": elif self.schedule == "linear":
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == "cosine": elif self.schedule == "cosine":
log_alpha_fn = lambda s: torch.log( log_alpha_fn = lambda s: torch.log(
torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0) torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
@ -169,7 +169,7 @@ class NoiseScheduleVP:
* (self.beta_1 - self.beta_0) * (self.beta_1 - self.beta_0)
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
) )
Delta = self.beta_0 ** 2 + tmp Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete": elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp( log_alpha = -0.5 * torch.logaddexp(
@ -522,15 +522,21 @@ class DPM_Solver:
if order == 3: if order == 3:
K = steps // 3 + 1 K = steps // 3 + 1
if steps % 3 == 0: if steps % 3 == 0:
orders = [3, ] * ( orders = [
3,
] * (
K - 2 K - 2
) + [2, 1] ) + [2, 1]
elif steps % 3 == 1: elif steps % 3 == 1:
orders = [3, ] * ( orders = [
3,
] * (
K - 1 K - 1
) + [1] ) + [1]
else: else:
orders = [3, ] * ( orders = [
3,
] * (
K - 1 K - 1
) + [2] ) + [2]
elif order == 2: elif order == 2:
@ -541,7 +547,9 @@ class DPM_Solver:
] * K ] * K
else: else:
K = steps // 2 + 1 K = steps // 2 + 1
orders = [2, ] * ( orders = [
2,
] * (
K - 1 K - 1
) + [1] ) + [1]
elif order == 1: elif order == 1:
@ -1000,7 +1008,7 @@ class DPM_Solver:
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1 + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
- expand_dims( - expand_dims(
alpha_t * ((torch.exp(-h) - 1.0 + h) / h ** 2 - 0.5), dims alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
) )
* D2 * D2
) )
@ -1009,7 +1017,7 @@ class DPM_Solver:
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1 - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h ** 2 - 0.5), dims) - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
* D2 * D2
) )
return x_t return x_t