mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Merge pull request #89 from Stability-AI/dango.patch.atten_overflow
* Force cast to fp32 to avoid atten layer overflow
This commit is contained in:
commit
8bde0cf64f
1 changed files with 12 additions and 2 deletions
|
@ -16,6 +16,9 @@ try:
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_IS_AVAILBLE = False
|
||||||
|
|
||||||
|
# CrossAttn precision handling
|
||||||
|
import os
|
||||||
|
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
@ -167,9 +170,16 @@ class CrossAttention(nn.Module):
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
# force cast to fp32 to avoid overflowing
|
||||||
|
if _ATTN_PRECISION =="fp32":
|
||||||
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
else:
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
del q, k
|
del q, k
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = rearrange(mask, 'b ... -> b (...)')
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
Loading…
Reference in a new issue