From aad6e38a781f70eeba828e4fdb64b22a31a66884 Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Sat, 14 Jan 2023 17:50:50 +0100 Subject: [PATCH] fix missing adm_in_channels and ClipImageEmbedder --- ldm/modules/diffusionmodules/openaimodel.py | 10 ++++++ ldm/modules/encoders/modules.py | 40 +++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index 7df6b5a..157d3b2 100644 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -469,6 +469,7 @@ class UNetModel(nn.Module): num_attention_blocks=None, disable_middle_self_attn=False, use_linear_in_transformer=False, + adm_in_channels=None, ): super().__init__() if use_spatial_transformer: @@ -536,6 +537,15 @@ class UNetModel(nn.Module): elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) else: raise ValueError() diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 4edd549..520bb27 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import kornia from torch.utils.checkpoint import checkpoint from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel @@ -131,6 +132,45 @@ class FrozenCLIPEmbedder(AbstractEncoder): return self(text) +from clip import load as load_clip +class ClipImageEmbedder(nn.Module): + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=True, + ucg_rate=0. + ): + super().__init__() + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # re-normalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0. and not no_dropout: + out = torch.bernoulli((1.-self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out + return out + + + class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text