* Use Cutlass ops when possible to +15% speed

This commit is contained in:
Dango233 2022-12-07 18:01:50 +08:00
parent d7440ac160
commit f07c5ec5dc
2 changed files with 2 additions and 2 deletions

View file

@ -201,7 +201,7 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = xformers.ops.MemoryEfficientAttentionCutlassOp if hasattr(xformers.ops, "MemoryEfficientAttentionCutlassOp") else None
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
q = self.to_q(x) q = self.to_q(x)

View file

@ -234,7 +234,7 @@ class MemoryEfficientAttnBlock(nn.Module):
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = xformers.ops.MemoryEfficientAttentionCutlassOp if hasattr(xformers.ops, "MemoryEfficientAttentionCutlassOp") else None
def forward(self, x): def forward(self, x):
h_ = x h_ = x