use pytorch built-in SiLU function to save GPU memory usage

This commit is contained in:
Haoxiang Li 2023-10-04 18:04:12 -07:00
parent cf1d67a6fd
commit 51c813001b

View file

@ -5,6 +5,7 @@ import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any
import torch.nn.functional as F
from ldm.modules.attention import MemoryEfficientCrossAttention
@ -40,7 +41,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
return F.silu(x)
def Normalize(in_channels, num_groups=32):