Merge pull request #89 from Stability-AI/dango.patch.atten_overflow

* Force cast to fp32 to avoid atten layer overflow
This commit is contained in:
Robin Rombach 2022-12-07 14:54:35 +01:00 committed by GitHub
commit 8bde0cf64f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -16,6 +16,9 @@ try:
except: except:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILBLE = False
# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
def exists(val): def exists(val):
return val is not None return val is not None
@ -167,9 +170,16 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k del q, k
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)') mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max