mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
fix missing adm_in_channels and ClipImageEmbedder
This commit is contained in:
parent
929625ac9a
commit
aad6e38a78
2 changed files with 50 additions and 0 deletions
|
@ -469,6 +469,7 @@ class UNetModel(nn.Module):
|
||||||
num_attention_blocks=None,
|
num_attention_blocks=None,
|
||||||
disable_middle_self_attn=False,
|
disable_middle_self_attn=False,
|
||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
|
adm_in_channels=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
|
@ -536,6 +537,15 @@ class UNetModel(nn.Module):
|
||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
print("setting up linear c_adm embedding layer")
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
|
elif self.num_classes == "sequential":
|
||||||
|
assert adm_in_channels is not None
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
linear(adm_in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import kornia
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||||
|
@ -131,6 +132,45 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
from clip import load as load_clip
|
||||||
|
class ClipImageEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
jit=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
antialias=True,
|
||||||
|
ucg_rate=0.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
self.ucg_rate = ucg_rate
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
interpolation='bicubic', align_corners=True,
|
||||||
|
antialias=self.antialias)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# re-normalize according to clip
|
||||||
|
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, no_dropout=False):
|
||||||
|
# x is assumed to be in range [-1,1]
|
||||||
|
out = self.model.encode_image(self.preprocess(x))
|
||||||
|
out = out.to(x.dtype)
|
||||||
|
if self.ucg_rate > 0. and not no_dropout:
|
||||||
|
out = torch.bernoulli((1.-self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||||
"""
|
"""
|
||||||
Uses the OpenCLIP transformer encoder for text
|
Uses the OpenCLIP transformer encoder for text
|
||||||
|
|
Loading…
Reference in a new issue