mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
Update dpm_solver.py
This commit is contained in:
parent
462a9d3298
commit
683c14645e
1 changed files with 16 additions and 8 deletions
|
@ -131,7 +131,7 @@ class NoiseScheduleVP:
|
|||
self.log_alpha_array.to(t.device),
|
||||
).reshape((-1))
|
||||
elif self.schedule == "linear":
|
||||
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
elif self.schedule == "cosine":
|
||||
log_alpha_fn = lambda s: torch.log(
|
||||
torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
|
||||
|
@ -169,7 +169,7 @@ class NoiseScheduleVP:
|
|||
* (self.beta_1 - self.beta_0)
|
||||
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
||||
)
|
||||
Delta = self.beta_0 ** 2 + tmp
|
||||
Delta = self.beta_0**2 + tmp
|
||||
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||
elif self.schedule == "discrete":
|
||||
log_alpha = -0.5 * torch.logaddexp(
|
||||
|
@ -522,15 +522,21 @@ class DPM_Solver:
|
|||
if order == 3:
|
||||
K = steps // 3 + 1
|
||||
if steps % 3 == 0:
|
||||
orders = [3, ] * (
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
K - 2
|
||||
) + [2, 1]
|
||||
elif steps % 3 == 1:
|
||||
orders = [3, ] * (
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
K - 1
|
||||
) + [1]
|
||||
else:
|
||||
orders = [3, ] * (
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
K - 1
|
||||
) + [2]
|
||||
elif order == 2:
|
||||
|
@ -541,7 +547,9 @@ class DPM_Solver:
|
|||
] * K
|
||||
else:
|
||||
K = steps // 2 + 1
|
||||
orders = [2, ] * (
|
||||
orders = [
|
||||
2,
|
||||
] * (
|
||||
K - 1
|
||||
) + [1]
|
||||
elif order == 1:
|
||||
|
@ -1000,7 +1008,7 @@ class DPM_Solver:
|
|||
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
||||
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
|
||||
- expand_dims(
|
||||
alpha_t * ((torch.exp(-h) - 1.0 + h) / h ** 2 - 0.5), dims
|
||||
alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
|
||||
)
|
||||
* D2
|
||||
)
|
||||
|
@ -1009,7 +1017,7 @@ class DPM_Solver:
|
|||
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
||||
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
||||
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
|
||||
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h ** 2 - 0.5), dims)
|
||||
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
|
||||
* D2
|
||||
)
|
||||
return x_t
|
||||
|
|
Loading…
Reference in a new issue