mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
Get rid of "No module 'xformers'" message (and fix variable name)
This commit is contained in:
parent
cf1d67a6fd
commit
9d44460f62
2 changed files with 8 additions and 10 deletions
|
@ -12,9 +12,9 @@ from ldm.modules.diffusionmodules.util import checkpoint
|
|||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
# CrossAttn precision handling
|
||||
import os
|
||||
|
@ -251,7 +251,7 @@ class BasicTransformerBlock(nn.Module):
|
|||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
||||
attn_mode = "softmax-xformers" if XFORMERS_AVAILABLE else "softmax"
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
self.disable_self_attn = disable_self_attn
|
||||
|
@ -338,4 +338,3 @@ class SpatialTransformer(nn.Module):
|
|||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
|
|
@ -11,10 +11,9 @@ from ldm.modules.attention import MemoryEfficientCrossAttention
|
|||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
print("No module 'xformers'. Proceeding without it.")
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
@ -279,7 +278,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||
if XFORMERS_AVAILABLE and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
|
|
Loading…
Reference in a new issue