Get rid of "No module 'xformers'" message (and fix variable name)

This commit is contained in:
Aarni Koskela 2023-06-15 13:03:08 +03:00
parent cf1d67a6fd
commit 9d44460f62
2 changed files with 8 additions and 10 deletions

View file

@ -12,9 +12,9 @@ from ldm.modules.diffusionmodules.util import checkpoint
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True XFORMERS_AVAILABLE = True
except: except ImportError:
XFORMERS_IS_AVAILBLE = False XFORMERS_AVAILABLE = False
# CrossAttn precision handling # CrossAttn precision handling
import os import os
@ -251,7 +251,7 @@ class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False): disable_self_attn=False):
super().__init__() super().__init__()
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" attn_mode = "softmax-xformers" if XFORMERS_AVAILABLE else "softmax"
assert attn_mode in self.ATTENTION_MODES assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode] attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
@ -338,4 +338,3 @@ class SpatialTransformer(nn.Module):
if not self.use_linear: if not self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in

View file

@ -11,10 +11,9 @@ from ldm.modules.attention import MemoryEfficientCrossAttention
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True XFORMERS_AVAILABLE = True
except: except ImportError:
XFORMERS_IS_AVAILBLE = False XFORMERS_AVAILABLE = False
print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
@ -279,7 +278,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": if XFORMERS_AVAILABLE and attn_type == "vanilla":
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla": if attn_type == "vanilla":