mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Update dpm_solver.py
This commit is contained in:
parent
462a9d3298
commit
683c14645e
1 changed files with 16 additions and 8 deletions
|
@ -131,7 +131,7 @@ class NoiseScheduleVP:
|
||||||
self.log_alpha_array.to(t.device),
|
self.log_alpha_array.to(t.device),
|
||||||
).reshape((-1))
|
).reshape((-1))
|
||||||
elif self.schedule == "linear":
|
elif self.schedule == "linear":
|
||||||
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||||
elif self.schedule == "cosine":
|
elif self.schedule == "cosine":
|
||||||
log_alpha_fn = lambda s: torch.log(
|
log_alpha_fn = lambda s: torch.log(
|
||||||
torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
|
torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
|
||||||
|
@ -169,7 +169,7 @@ class NoiseScheduleVP:
|
||||||
* (self.beta_1 - self.beta_0)
|
* (self.beta_1 - self.beta_0)
|
||||||
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
||||||
)
|
)
|
||||||
Delta = self.beta_0 ** 2 + tmp
|
Delta = self.beta_0**2 + tmp
|
||||||
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||||
elif self.schedule == "discrete":
|
elif self.schedule == "discrete":
|
||||||
log_alpha = -0.5 * torch.logaddexp(
|
log_alpha = -0.5 * torch.logaddexp(
|
||||||
|
@ -522,15 +522,21 @@ class DPM_Solver:
|
||||||
if order == 3:
|
if order == 3:
|
||||||
K = steps // 3 + 1
|
K = steps // 3 + 1
|
||||||
if steps % 3 == 0:
|
if steps % 3 == 0:
|
||||||
orders = [3, ] * (
|
orders = [
|
||||||
|
3,
|
||||||
|
] * (
|
||||||
K - 2
|
K - 2
|
||||||
) + [2, 1]
|
) + [2, 1]
|
||||||
elif steps % 3 == 1:
|
elif steps % 3 == 1:
|
||||||
orders = [3, ] * (
|
orders = [
|
||||||
|
3,
|
||||||
|
] * (
|
||||||
K - 1
|
K - 1
|
||||||
) + [1]
|
) + [1]
|
||||||
else:
|
else:
|
||||||
orders = [3, ] * (
|
orders = [
|
||||||
|
3,
|
||||||
|
] * (
|
||||||
K - 1
|
K - 1
|
||||||
) + [2]
|
) + [2]
|
||||||
elif order == 2:
|
elif order == 2:
|
||||||
|
@ -541,7 +547,9 @@ class DPM_Solver:
|
||||||
] * K
|
] * K
|
||||||
else:
|
else:
|
||||||
K = steps // 2 + 1
|
K = steps // 2 + 1
|
||||||
orders = [2, ] * (
|
orders = [
|
||||||
|
2,
|
||||||
|
] * (
|
||||||
K - 1
|
K - 1
|
||||||
) + [1]
|
) + [1]
|
||||||
elif order == 1:
|
elif order == 1:
|
||||||
|
@ -1000,7 +1008,7 @@ class DPM_Solver:
|
||||||
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
||||||
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
|
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
|
||||||
- expand_dims(
|
- expand_dims(
|
||||||
alpha_t * ((torch.exp(-h) - 1.0 + h) / h ** 2 - 0.5), dims
|
alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
|
||||||
)
|
)
|
||||||
* D2
|
* D2
|
||||||
)
|
)
|
||||||
|
@ -1009,7 +1017,7 @@ class DPM_Solver:
|
||||||
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
||||||
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
||||||
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
|
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
|
||||||
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h ** 2 - 0.5), dims)
|
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
|
||||||
* D2
|
* D2
|
||||||
)
|
)
|
||||||
return x_t
|
return x_t
|
||||||
|
|
Loading…
Reference in a new issue