From 6e92cda76d12ab4051e98e53503e95698871d68e Mon Sep 17 00:00:00 2001 From: Dango233 Date: Wed, 7 Dec 2022 19:56:39 +0800 Subject: [PATCH 1/2] * Force cast to fp32 to avoid atten layer overflow --- ldm/modules/attention.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index d504d93..622c6fc 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -167,9 +167,13 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + # force cast to fp32 to avoid overflowing + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k - + if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max From e1797ae248408ea47561eeb8755737f1e35784f2 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Wed, 7 Dec 2022 21:38:32 +0800 Subject: [PATCH 2/2] Add env var for resume previous behavior --- ldm/modules/attention.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 622c6fc..509cd87 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -16,6 +16,9 @@ try: except: XFORMERS_IS_AVAILBLE = False +# CrossAttn precision handling +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") def exists(val): return val is not None @@ -168,8 +171,11 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # force cast to fp32 to avoid overflowing - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: sim = einsum('b i d, b j d -> b i j', q, k) * self.scale del q, k