use pytorch built-in SiLU function to save GPU memory usage

This commit is contained in:
Haoxiang Li 2023-10-04 18:04:12 -07:00
parent cf1d67a6fd
commit 51c813001b

View file

@ -5,6 +5,7 @@ import torch.nn as nn
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
import torch.nn.functional as F
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
@ -40,7 +41,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return x*torch.sigmoid(x) return F.silu(x)
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):