fix missing adm_in_channels and ClipImageEmbedder

This commit is contained in:
Robin Rombach 2023-01-14 17:50:50 +01:00
parent 929625ac9a
commit aad6e38a78
2 changed files with 50 additions and 0 deletions

View file

@ -469,6 +469,7 @@ class UNetModel(nn.Module):
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
):
super().__init__()
if use_spatial_transformer:
@ -536,6 +537,15 @@ class UNetModel(nn.Module):
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()

View file

@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import kornia
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
@ -131,6 +132,45 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self(text)
from clip import load as load_clip
class ClipImageEmbedder(nn.Module):
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.
):
super().__init__()
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic', align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli((1.-self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text