mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Get rid of "No module 'xformers'" message (and fix variable name)
This commit is contained in:
parent
cf1d67a6fd
commit
9d44460f62
2 changed files with 8 additions and 10 deletions
|
@ -12,9 +12,9 @@ from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
XFORMERS_IS_AVAILBLE = True
|
XFORMERS_AVAILABLE = True
|
||||||
except:
|
except ImportError:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
# CrossAttn precision handling
|
# CrossAttn precision handling
|
||||||
import os
|
import os
|
||||||
|
@ -251,7 +251,7 @@ class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False):
|
disable_self_attn=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
attn_mode = "softmax-xformers" if XFORMERS_AVAILABLE else "softmax"
|
||||||
assert attn_mode in self.ATTENTION_MODES
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
|
@ -338,4 +338,3 @@ class SpatialTransformer(nn.Module):
|
||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x + x_in
|
return x + x_in
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,9 @@ from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
XFORMERS_IS_AVAILBLE = True
|
XFORMERS_AVAILABLE = True
|
||||||
except:
|
except ImportError:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_AVAILABLE = False
|
||||||
print("No module 'xformers'. Proceeding without it.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
|
@ -279,7 +278,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
if XFORMERS_AVAILABLE and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
|
|
Loading…
Reference in a new issue