mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2025-01-02 21:02:15 +00:00
Add env var for resume previous behavior
This commit is contained in:
parent
6e92cda76d
commit
e1797ae248
1 changed files with 8 additions and 2 deletions
|
@ -16,6 +16,9 @@ try:
|
|||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
# CrossAttn precision handling
|
||||
import os
|
||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
@ -168,8 +171,11 @@ class CrossAttention(nn.Module):
|
|||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
del q, k
|
||||
|
|
Loading…
Reference in a new issue