fix missing adm_in_channels and ClipImageEmbedder

This commit is contained in:
Robin Rombach 2023-01-14 17:50:50 +01:00
parent 929625ac9a
commit aad6e38a78
2 changed files with 50 additions and 0 deletions

View file

@ -469,6 +469,7 @@ class UNetModel(nn.Module):
num_attention_blocks=None, num_attention_blocks=None,
disable_middle_self_attn=False, disable_middle_self_attn=False,
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None,
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
@ -536,6 +537,15 @@ class UNetModel(nn.Module):
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else: else:
raise ValueError() raise ValueError()

View file

@ -1,5 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import kornia
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
@ -131,6 +132,45 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
from clip import load as load_clip
class ClipImageEmbedder(nn.Module):
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.
):
super().__init__()
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic', align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli((1.-self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder): class FrozenOpenCLIPEmbedder(AbstractEncoder):
""" """
Uses the OpenCLIP transformer encoder for text Uses the OpenCLIP transformer encoder for text