diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 509cd87..6014635 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -12,9 +12,9 @@ from ldm.modules.diffusionmodules.util import checkpoint try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False # CrossAttn precision handling import os @@ -251,7 +251,7 @@ class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): super().__init__() - attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + attn_mode = "softmax-xformers" if XFORMERS_AVAILABLE else "softmax" assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.disable_self_attn = disable_self_attn @@ -338,4 +338,3 @@ class SpatialTransformer(nn.Module): if not self.use_linear: x = self.proj_out(x) return x + x_in - diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index b089eeb..50fc9d0 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -11,10 +11,9 @@ from ldm.modules.attention import MemoryEfficientCrossAttention try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - print("No module 'xformers'. Proceeding without it.") + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False def get_timestep_embedding(timesteps, embedding_dim): @@ -279,7 +278,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' - if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + if XFORMERS_AVAILABLE and attn_type == "vanilla": attn_type = "vanilla-xformers" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla":