Update dpm_solver.py

This commit is contained in:
Andres Caicedo 2023-05-08 14:43:25 +02:00
parent 462a9d3298
commit 683c14645e
No known key found for this signature in database
GPG key ID: 6E797C4F5A327624

View file

@ -131,7 +131,7 @@ class NoiseScheduleVP:
self.log_alpha_array.to(t.device),
).reshape((-1))
elif self.schedule == "linear":
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == "cosine":
log_alpha_fn = lambda s: torch.log(
torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
@ -169,7 +169,7 @@ class NoiseScheduleVP:
* (self.beta_1 - self.beta_0)
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
)
Delta = self.beta_0 ** 2 + tmp
Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp(
@ -522,15 +522,21 @@ class DPM_Solver:
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3, ] * (
orders = [
3,
] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [3, ] * (
orders = [
3,
] * (
K - 1
) + [1]
else:
orders = [3, ] * (
orders = [
3,
] * (
K - 1
) + [2]
elif order == 2:
@ -541,7 +547,9 @@ class DPM_Solver:
] * K
else:
K = steps // 2 + 1
orders = [2, ] * (
orders = [
2,
] * (
K - 1
) + [1]
elif order == 1:
@ -1000,7 +1008,7 @@ class DPM_Solver:
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
- expand_dims(
alpha_t * ((torch.exp(-h) - 1.0 + h) / h ** 2 - 0.5), dims
alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
)
* D2
)
@ -1009,7 +1017,7 @@ class DPM_Solver:
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h ** 2 - 0.5), dims)
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
* D2
)
return x_t