mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
Merge pull request #215 from Stability-AI/unclip_prerelease
Add unCLIP finetunes
This commit is contained in:
commit
3396b0ff88
35 changed files with 4684 additions and 24 deletions
|
@ -8,12 +8,14 @@ new checkpoints. The following list provides an overview of all currently availa
|
|||
|
||||
## News
|
||||
|
||||
|
||||
**March 24, 2023**
|
||||
|
||||
*Stable UnCLIP 2.1*
|
||||
|
||||
- New stable diffusion finetune (_Stable unCLIP 2.1_, [HuggingFace](https://huggingface.co/stabilityai/)) at 768x768 resolution, based on SD2.1-768. This model allows for image variations and mixing operations as described in [*Hierarchical Text-Conditional Image Generation with CLIP Latents*](https://arxiv.org/abs/2204.06125), and, thanks to its modularity, can be combined with other models such as [KARLO](https://github.com/kakaobrain/karlo). Comes in two variants: [*Stable unCLIP-L*](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/blob/main/sd21-unclip-l.ckpt) and [*Stable unCLIP-H*](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/blob/main/sd21-unclip-h.ckpt), which are conditioned on CLIP ViT-L and ViT-H image embeddings, respectively. Instructions are available [here](doc/UNCLIP.MD).
|
||||
|
||||
|
||||
**December 7, 2022**
|
||||
|
||||
*Version 2.1*
|
||||
|
|
1
checkpoints/checkpoints.txt
Normal file
1
checkpoints/checkpoints.txt
Normal file
|
@ -0,0 +1 @@
|
|||
Put unCLIP checkpoints here.
|
37
configs/karlo/decoder_900M_vit_l.yaml
Normal file
37
configs/karlo/decoder_900M_vit_l.yaml
Normal file
|
@ -0,0 +1,37 @@
|
|||
model:
|
||||
type: t2i-decoder
|
||||
diffusion_sampler: uniform
|
||||
hparams:
|
||||
image_size: 64
|
||||
num_channels: 320
|
||||
num_res_blocks: 3
|
||||
channel_mult: ''
|
||||
attention_resolutions: 32,16,8
|
||||
num_heads: -1
|
||||
num_head_channels: 64
|
||||
num_heads_upsample: -1
|
||||
use_scale_shift_norm: true
|
||||
dropout: 0.1
|
||||
clip_dim: 768
|
||||
clip_emb_mult: 4
|
||||
text_ctx: 77
|
||||
xf_width: 1536
|
||||
xf_layers: 0
|
||||
xf_heads: 0
|
||||
xf_final_ln: false
|
||||
resblock_updown: true
|
||||
learn_sigma: true
|
||||
text_drop: 0.3
|
||||
clip_emb_type: image
|
||||
clip_emb_drop: 0.1
|
||||
use_plm: true
|
||||
|
||||
diffusion:
|
||||
steps: 1000
|
||||
learn_sigma: true
|
||||
sigma_small: false
|
||||
noise_schedule: squaredcos_cap_v2
|
||||
use_kl: false
|
||||
predict_xstart: false
|
||||
rescale_learned_sigmas: true
|
||||
timestep_respacing: ''
|
27
configs/karlo/improved_sr_64_256_1.4B.yaml
Normal file
27
configs/karlo/improved_sr_64_256_1.4B.yaml
Normal file
|
@ -0,0 +1,27 @@
|
|||
model:
|
||||
type: improved_sr_64_256
|
||||
diffusion_sampler: uniform
|
||||
hparams:
|
||||
channels: 320
|
||||
depth: 3
|
||||
channels_multiple:
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
dropout: 0.0
|
||||
|
||||
diffusion:
|
||||
steps: 1000
|
||||
learn_sigma: false
|
||||
sigma_small: true
|
||||
noise_schedule: squaredcos_cap_v2
|
||||
use_kl: false
|
||||
predict_xstart: false
|
||||
rescale_learned_sigmas: true
|
||||
timestep_respacing: '7'
|
||||
|
||||
|
||||
sampling:
|
||||
timestep_respacing: '7' # fix
|
||||
clip_denoise: true
|
21
configs/karlo/prior_1B_vit_l.yaml
Normal file
21
configs/karlo/prior_1B_vit_l.yaml
Normal file
|
@ -0,0 +1,21 @@
|
|||
model:
|
||||
type: prior
|
||||
diffusion_sampler: uniform
|
||||
hparams:
|
||||
text_ctx: 77
|
||||
xf_width: 2048
|
||||
xf_layers: 20
|
||||
xf_heads: 32
|
||||
xf_final_ln: true
|
||||
text_drop: 0.2
|
||||
clip_dim: 768
|
||||
|
||||
diffusion:
|
||||
steps: 1000
|
||||
learn_sigma: false
|
||||
sigma_small: true
|
||||
noise_schedule: squaredcos_cap_v2
|
||||
use_kl: false
|
||||
predict_xstart: true
|
||||
rescale_learned_sigmas: false
|
||||
timestep_respacing: ''
|
80
configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml
Normal file
80
configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml
Normal file
|
@ -0,0 +1,80 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
|
||||
params:
|
||||
embedding_dropout: 0.25
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 96
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn-adm
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
use_ema: False
|
||||
|
||||
embedder_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
|
||||
noise_aug_config:
|
||||
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||
params:
|
||||
timestep_dim: 1024
|
||||
noise_schedule_config:
|
||||
timesteps: 1000
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
num_classes: "sequential"
|
||||
adm_in_channels: 2048
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
83
configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml
Normal file
83
configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml
Normal file
|
@ -0,0 +1,83 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
|
||||
params:
|
||||
embedding_dropout: 0.25
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 96
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn-adm
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
use_ema: False
|
||||
|
||||
embedder_config:
|
||||
target: ldm.modules.encoders.modules.ClipImageEmbedder
|
||||
params:
|
||||
model: "ViT-L/14"
|
||||
|
||||
noise_aug_config:
|
||||
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||
params:
|
||||
clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th"
|
||||
timestep_dim: 768
|
||||
noise_schedule_config:
|
||||
timesteps: 1000
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
num_classes: "sequential"
|
||||
adm_in_channels: 1536
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
|
@ -1799,3 +1799,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|||
log = super().log_images(*args, **kwargs)
|
||||
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
||||
return log
|
||||
|
||||
|
||||
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||
def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5,
|
||||
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.embed_key = embedding_key
|
||||
self.embedding_dropout = embedding_dropout
|
||||
self._init_embedder(embedder_config, freeze_embedder)
|
||||
self._init_noise_aug(noise_aug_config)
|
||||
|
||||
def _init_embedder(self, config, freeze=True):
|
||||
embedder = instantiate_from_config(config)
|
||||
if freeze:
|
||||
self.embedder = embedder.eval()
|
||||
self.embedder.train = disabled_train
|
||||
for param in self.embedder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _init_noise_aug(self, config):
|
||||
if config is not None:
|
||||
# use the KARLO schedule for noise augmentation on CLIP image embeddings
|
||||
noise_augmentor = instantiate_from_config(config)
|
||||
assert isinstance(noise_augmentor, nn.Module)
|
||||
noise_augmentor = noise_augmentor.eval()
|
||||
noise_augmentor.train = disabled_train
|
||||
self.noise_augmentor = noise_augmentor
|
||||
else:
|
||||
self.noise_augmentor = None
|
||||
|
||||
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
||||
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
||||
z, c = outputs[0], outputs[1]
|
||||
img = batch[self.embed_key][:bs]
|
||||
img = rearrange(img, 'b h w c -> b c h w')
|
||||
c_adm = self.embedder(img)
|
||||
if self.noise_augmentor is not None:
|
||||
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
|
||||
# assume this gives embeddings of noise levels
|
||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||
if self.training:
|
||||
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
||||
device=c_adm.device)[:, None]) * c_adm
|
||||
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
|
||||
noutputs = [z, all_conds]
|
||||
noutputs.extend(outputs[2:])
|
||||
return noutputs
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, **kwargs):
|
||||
log = dict()
|
||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
|
||||
return_original_cond=True)
|
||||
log["inputs"] = x
|
||||
log["reconstruction"] = xrec
|
||||
assert self.model.conditioning_key is not None
|
||||
assert self.cond_stage_key in ["caption", "txt"]
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
||||
log["conditioning"] = xc
|
||||
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
|
||||
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
|
||||
|
||||
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
|
||||
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
|
||||
with ema_scope(f"Sampling"):
|
||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
|
||||
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc_, )
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||
return log
|
||||
|
|
|
@ -307,6 +307,15 @@ def model_wrapper(
|
|||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = dict()
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
|
|
@ -3,7 +3,6 @@ import torch
|
|||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
|
@ -51,10 +50,18 @@ class DPMSolverSampler(object):
|
|||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
if isinstance(ctmp, torch.Tensor):
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if isinstance(conditioning, torch.Tensor):
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
|
@ -83,6 +90,7 @@ class DPMSolverSampler(object):
|
|||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||
lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
|
@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
|
|||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t):
|
||||
return timestep_embedding(t, self.dim)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
@ -470,6 +479,7 @@ class UNetModel(nn.Module):
|
|||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
|
@ -538,6 +548,15 @@ class UNetModel(nn.Module):
|
|||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
|
|
@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
|||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
||||
# return early
|
||||
return betas_for_alpha_bar(
|
||||
n_timestep,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
|
@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
|
|||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import kornia
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
import open_clip
|
||||
from ldm.util import default, count_params
|
||||
from ldm.util import default, count_params, autocast
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
|
|||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
@ -57,7 +58,9 @@ def disabled_train(self, mode=True):
|
|||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
|
||||
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
|
@ -68,7 +71,7 @@ class FrozenT5Embedder(AbstractEncoder):
|
|||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
|
@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
|
@ -131,15 +135,54 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
return self(text)
|
||||
|
||||
|
||||
class ClipImageEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
antialias=True,
|
||||
ucg_rate=0.
|
||||
):
|
||||
super().__init__()
|
||||
from clip import load as load_clip
|
||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic', align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# re-normalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x, no_dropout=False):
|
||||
# x is assumed to be in range [-1,1]
|
||||
out = self.model.encode_image(self.preprocess(x))
|
||||
out = out.to(x.dtype)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||
return out
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
#"pooled",
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="last"):
|
||||
super().__init__()
|
||||
|
@ -179,7 +222,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
|
@ -193,14 +236,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||
pretrained=version, )
|
||||
del model.transformer
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic', align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@autocast
|
||||
def forward(self, image, no_dropout=False):
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, img):
|
||||
img = self.preprocess(img)
|
||||
x = self.model.visual(img)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
@ -211,3 +313,38 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
|||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
|
||||
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if clip_stats_path is None:
|
||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||
else:
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||
self.time_embed = Timestep(timestep_dim)
|
||||
|
||||
def scale(self, x):
|
||||
# re-normalize to centered mean and unit variance
|
||||
x = (x - self.data_mean) * 1. / self.data_std
|
||||
return x
|
||||
|
||||
def unscale(self, x):
|
||||
# back to original data stats
|
||||
x = (x * self.data_std) + self.data_mean
|
||||
return x
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
x = self.scale(x)
|
||||
z = self.q_sample(x, noise_level)
|
||||
z = self.unscale(z)
|
||||
noise_level = self.time_embed(noise_level)
|
||||
return z, noise_level
|
||||
|
|
0
ldm/modules/karlo/__init__.py
Normal file
0
ldm/modules/karlo/__init__.py
Normal file
512
ldm/modules/karlo/diffusers_pipeline.py
Normal file
512
ldm/modules/karlo/diffusers_pipeline.py
Normal file
|
@ -0,0 +1,512 @@
|
|||
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||
|
||||
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class UnCLIPPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using unCLIP
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
Args:
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
text_proj ([`UnCLIPTextProjModel`]):
|
||||
Utility class to prepare and combine the embeddings before they are passed to the decoder.
|
||||
decoder ([`UNet2DConditionModel`]):
|
||||
The decoder to invert the image embedding into an image.
|
||||
super_res_first ([`UNet2DModel`]):
|
||||
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
|
||||
super_res_last ([`UNet2DModel`]):
|
||||
Super resolution unet. Used in the last step of the super resolution diffusion process.
|
||||
prior_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
|
||||
decoder_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
|
||||
super_res_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
|
||||
"""
|
||||
|
||||
prior: PriorTransformer
|
||||
decoder: UNet2DConditionModel
|
||||
text_proj: UnCLIPTextProjModel
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
tokenizer: CLIPTokenizer
|
||||
super_res_first: UNet2DModel
|
||||
super_res_last: UNet2DModel
|
||||
|
||||
prior_scheduler: UnCLIPScheduler
|
||||
decoder_scheduler: UnCLIPScheduler
|
||||
super_res_scheduler: UnCLIPScheduler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
decoder: UNet2DConditionModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_proj: UnCLIPTextProjModel,
|
||||
super_res_first: UNet2DModel,
|
||||
super_res_last: UNet2DModel,
|
||||
prior_scheduler: UnCLIPScheduler,
|
||||
decoder_scheduler: UnCLIPScheduler,
|
||||
super_res_scheduler: UnCLIPScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
decoder=decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_proj=text_proj,
|
||||
super_res_first=super_res_first,
|
||||
super_res_last=super_res_last,
|
||||
prior_scheduler=prior_scheduler,
|
||||
decoder_scheduler=decoder_scheduler,
|
||||
super_res_scheduler=super_res_scheduler,
|
||||
)
|
||||
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
|
||||
text_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if text_model_output is None:
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
text_embeddings = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
||||
else:
|
||||
batch_size = text_model_output[0].shape[0]
|
||||
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
|
||||
text_mask = text_attention_mask
|
||||
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""] * batch_size
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
|
||||
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
||||
|
||||
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# done duplicates
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return text_embeddings, text_encoder_hidden_states, text_mask
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
|
||||
models = [
|
||||
self.decoder,
|
||||
self.text_proj,
|
||||
self.text_encoder,
|
||||
self.super_res_first,
|
||||
self.super_res_last,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.decoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prior_num_inference_steps: int = 25,
|
||||
decoder_num_inference_steps: int = 25,
|
||||
super_res_num_inference_steps: int = 7,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
prior_latents: Optional[torch.FloatTensor] = None,
|
||||
decoder_latents: Optional[torch.FloatTensor] = None,
|
||||
super_res_latents: Optional[torch.FloatTensor] = None,
|
||||
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
|
||||
text_attention_mask: Optional[torch.Tensor] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
decoder_guidance_scale: float = 8.0,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation. This can only be left undefined if
|
||||
`text_model_output` and `text_attention_mask` is passed.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prior_num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
|
||||
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
|
||||
quality image at the expense of slower inference.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the prior.
|
||||
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the decoder.
|
||||
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the decoder.
|
||||
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
text_model_output (`CLIPTextModelOutput`, *optional*):
|
||||
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
|
||||
can be passed for tasks like text embedding interpolations. Make sure to also pass
|
||||
`text_attention_mask` in this case. `prompt` can the be left to `None`.
|
||||
text_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
|
||||
masks are necessary when passing `text_model_output`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
"""
|
||||
if prompt is not None:
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
else:
|
||||
batch_size = text_model_output[0].shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
|
||||
|
||||
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
|
||||
)
|
||||
|
||||
# prior
|
||||
|
||||
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
|
||||
prior_timesteps_tensor = self.prior_scheduler.timesteps
|
||||
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
prior_latents = self.prepare_latents(
|
||||
(batch_size, embedding_dim),
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
prior_latents,
|
||||
self.prior_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
|
||||
|
||||
predicted_image_embedding = self.prior(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=text_embeddings,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
attention_mask=text_mask,
|
||||
).predicted_image_embedding
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
|
||||
predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
|
||||
predicted_image_embedding_text - predicted_image_embedding_uncond
|
||||
)
|
||||
|
||||
if i + 1 == prior_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = prior_timesteps_tensor[i + 1]
|
||||
|
||||
prior_latents = self.prior_scheduler.step(
|
||||
predicted_image_embedding,
|
||||
timestep=t,
|
||||
sample=prior_latents,
|
||||
generator=generator,
|
||||
prev_timestep=prev_timestep,
|
||||
).prev_sample
|
||||
|
||||
prior_latents = self.prior.post_process_latents(prior_latents)
|
||||
|
||||
image_embeddings = prior_latents
|
||||
|
||||
# done prior
|
||||
|
||||
# decoder
|
||||
|
||||
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
|
||||
image_embeddings=image_embeddings,
|
||||
text_embeddings=text_embeddings,
|
||||
text_encoder_hidden_states=text_encoder_hidden_states,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
|
||||
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
self.decoder_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
|
||||
|
||||
noise_pred = self.decoder(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
class_labels=additive_clip_time_embeddings,
|
||||
attention_mask=decoder_text_mask,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if i + 1 == decoder_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = decoder_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
decoder_latents = self.decoder_scheduler.step(
|
||||
noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
|
||||
).prev_sample
|
||||
|
||||
decoder_latents = decoder_latents.clamp(-1, 1)
|
||||
|
||||
image_small = decoder_latents
|
||||
|
||||
# done decoder
|
||||
|
||||
# super res
|
||||
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
device,
|
||||
generator,
|
||||
super_res_latents,
|
||||
self.super_res_scheduler,
|
||||
)
|
||||
|
||||
interpolate_antialias = {}
|
||||
if "antialias" in inspect.signature(F.interpolate).parameters:
|
||||
interpolate_antialias["antialias"] = True
|
||||
|
||||
image_upscaled = F.interpolate(
|
||||
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
|
||||
# no classifier free guidance
|
||||
|
||||
if i == super_res_timesteps_tensor.shape[0] - 1:
|
||||
unet = self.super_res_last
|
||||
else:
|
||||
unet = self.super_res_first
|
||||
|
||||
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
|
||||
|
||||
noise_pred = unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
).sample
|
||||
|
||||
if i + 1 == super_res_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = super_res_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
super_res_latents = self.super_res_scheduler.step(
|
||||
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
|
||||
).prev_sample
|
||||
|
||||
image = super_res_latents
|
||||
# done super res
|
||||
|
||||
# post processing
|
||||
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
0
ldm/modules/karlo/kakao/__init__.py
Normal file
0
ldm/modules/karlo/kakao/__init__.py
Normal file
0
ldm/modules/karlo/kakao/models/__init__.py
Normal file
0
ldm/modules/karlo/kakao/models/__init__.py
Normal file
182
ldm/modules/karlo/kakao/models/clip.py
Normal file
182
ldm/modules/karlo/kakao/models/clip.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Karlo-v1.0.alpha
|
||||
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------------
|
||||
# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import clip
|
||||
|
||||
from clip.model import CLIP, convert_weights
|
||||
from clip.simple_tokenizer import SimpleTokenizer, default_bpe
|
||||
|
||||
|
||||
"""===== Monkey-Patching original CLIP for JIT compile ====="""
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = F.layer_norm(
|
||||
x.type(torch.float32),
|
||||
self.normalized_shape,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps,
|
||||
)
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
clip.model.LayerNorm = LayerNorm
|
||||
delattr(clip.model.CLIP, "forward")
|
||||
|
||||
"""===== End of Monkey-Patching ====="""
|
||||
|
||||
|
||||
class CustomizedCLIP(CLIP):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@torch.jit.export
|
||||
def encode_image(self, image):
|
||||
return self.visual(image)
|
||||
|
||||
@torch.jit.export
|
||||
def encode_text(self, text):
|
||||
# re-define this function to return unpooled text features
|
||||
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
x_seq = x
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x_out, x_seq
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, image, text):
|
||||
super().forward(image, text)
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(cls, ckpt_path: str):
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
|
||||
|
||||
vit = "visual.proj" in state_dict
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len(
|
||||
[
|
||||
k
|
||||
for k in state_dict.keys()
|
||||
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
|
||||
]
|
||||
)
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round(
|
||||
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
|
||||
)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [
|
||||
len(
|
||||
set(
|
||||
k.split(".")[2]
|
||||
for k in state_dict
|
||||
if k.startswith(f"visual.layer{b}")
|
||||
)
|
||||
)
|
||||
for b in [1, 2, 3, 4]
|
||||
]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round(
|
||||
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
|
||||
)
|
||||
vision_patch_size = None
|
||||
assert (
|
||||
output_width**2 + 1
|
||||
== state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
)
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split(".")[2]
|
||||
for k in state_dict
|
||||
if k.startswith("transformer.resblocks")
|
||||
)
|
||||
)
|
||||
|
||||
model = cls(
|
||||
embed_dim,
|
||||
image_resolution,
|
||||
vision_layers,
|
||||
vision_width,
|
||||
vision_patch_size,
|
||||
context_length,
|
||||
vocab_size,
|
||||
transformer_width,
|
||||
transformer_heads,
|
||||
transformer_layers,
|
||||
)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
convert_weights(model)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.float()
|
||||
return model
|
||||
|
||||
|
||||
class CustomizedTokenizer(SimpleTokenizer):
|
||||
def __init__(self):
|
||||
super().__init__(bpe_path=default_bpe())
|
||||
|
||||
self.sot_token = self.encoder["<|startoftext|>"]
|
||||
self.eot_token = self.encoder["<|endoftext|>"]
|
||||
|
||||
def padded_tokens_and_mask(self, texts, text_ctx):
|
||||
assert isinstance(texts, list) and all(
|
||||
isinstance(elem, str) for elem in texts
|
||||
), "texts should be a list of strings"
|
||||
|
||||
all_tokens = [
|
||||
[self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
|
||||
]
|
||||
|
||||
mask = [
|
||||
[True] * min(text_ctx, len(tokens))
|
||||
+ [False] * max(text_ctx - len(tokens), 0)
|
||||
for tokens in all_tokens
|
||||
]
|
||||
mask = torch.tensor(mask, dtype=torch.bool)
|
||||
result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > text_ctx:
|
||||
tokens = tokens[:text_ctx]
|
||||
tokens[-1] = self.eot_token
|
||||
result[i, : len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result, mask
|
193
ldm/modules/karlo/kakao/models/decoder_model.py
Normal file
193
ldm/modules/karlo/kakao/models/decoder_model.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Karlo-v1.0.alpha
|
||||
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
|
||||
from ldm.modules.karlo.kakao.modules.unet import PLMImUNet
|
||||
|
||||
|
||||
class Text2ImProgressiveModel(torch.nn.Module):
|
||||
"""
|
||||
A decoder that generates 64x64px images based on the text prompt.
|
||||
|
||||
:param config: yaml config to define the decoder.
|
||||
:param tokenizer: tokenizer used in clip.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
tokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._conf = config
|
||||
self._model_conf = config.model.hparams
|
||||
self._diffusion_kwargs = dict(
|
||||
steps=config.diffusion.steps,
|
||||
learn_sigma=config.diffusion.learn_sigma,
|
||||
sigma_small=config.diffusion.sigma_small,
|
||||
noise_schedule=config.diffusion.noise_schedule,
|
||||
use_kl=config.diffusion.use_kl,
|
||||
predict_xstart=config.diffusion.predict_xstart,
|
||||
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
|
||||
timestep_respacing=config.diffusion.timestep_respacing,
|
||||
)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self.model = self.create_plm_dec_model()
|
||||
|
||||
cf_token, cf_mask = self.set_cf_text_tensor()
|
||||
self.register_buffer("cf_token", cf_token, persistent=False)
|
||||
self.register_buffer("cf_mask", cf_mask, persistent=False)
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer)
|
||||
model.load_state_dict(ckpt, strict=strict)
|
||||
return model
|
||||
|
||||
def create_plm_dec_model(self):
|
||||
image_size = self._model_conf.image_size
|
||||
if self._model_conf.channel_mult == "":
|
||||
if image_size == 256:
|
||||
channel_mult = (1, 1, 2, 2, 4, 4)
|
||||
elif image_size == 128:
|
||||
channel_mult = (1, 1, 2, 3, 4)
|
||||
elif image_size == 64:
|
||||
channel_mult = (1, 2, 3, 4)
|
||||
else:
|
||||
raise ValueError(f"unsupported image size: {image_size}")
|
||||
else:
|
||||
channel_mult = tuple(
|
||||
int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
|
||||
)
|
||||
assert 2 ** (len(channel_mult) + 2) == image_size
|
||||
|
||||
attention_ds = []
|
||||
for res in self._model_conf.attention_resolutions.split(","):
|
||||
attention_ds.append(image_size // int(res))
|
||||
|
||||
return PLMImUNet(
|
||||
text_ctx=self._model_conf.text_ctx,
|
||||
xf_width=self._model_conf.xf_width,
|
||||
in_channels=3,
|
||||
model_channels=self._model_conf.num_channels,
|
||||
out_channels=6 if self._model_conf.learn_sigma else 3,
|
||||
num_res_blocks=self._model_conf.num_res_blocks,
|
||||
attention_resolutions=tuple(attention_ds),
|
||||
dropout=self._model_conf.dropout,
|
||||
channel_mult=channel_mult,
|
||||
num_heads=self._model_conf.num_heads,
|
||||
num_head_channels=self._model_conf.num_head_channels,
|
||||
num_heads_upsample=self._model_conf.num_heads_upsample,
|
||||
use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
|
||||
resblock_updown=self._model_conf.resblock_updown,
|
||||
clip_dim=self._model_conf.clip_dim,
|
||||
clip_emb_mult=self._model_conf.clip_emb_mult,
|
||||
clip_emb_type=self._model_conf.clip_emb_type,
|
||||
clip_emb_drop=self._model_conf.clip_emb_drop,
|
||||
)
|
||||
|
||||
def set_cf_text_tensor(self):
|
||||
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
|
||||
|
||||
def get_sample_fn(self, timestep_respacing):
|
||||
use_ddim = timestep_respacing.startswith(("ddim", "fast"))
|
||||
|
||||
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
|
||||
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
|
||||
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
|
||||
sample_fn = (
|
||||
diffusion.ddim_sample_loop_progressive
|
||||
if use_ddim
|
||||
else diffusion.p_sample_loop_progressive
|
||||
)
|
||||
|
||||
return sample_fn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
img_feat=None,
|
||||
cf_guidance_scales=None,
|
||||
timestep_respacing=None,
|
||||
):
|
||||
# cfg should be enabled in inference
|
||||
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
|
||||
assert img_feat is not None
|
||||
|
||||
bsz = txt_feat.shape[0]
|
||||
img_sz = self._model_conf.image_size
|
||||
|
||||
def guided_model_fn(x_t, ts, **kwargs):
|
||||
half = x_t[: len(x_t) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.model(combined, ts, **kwargs)
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
|
||||
cond_eps - uncond_eps
|
||||
)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
cf_feat = self.model.cf_param.unsqueeze(0)
|
||||
cf_feat = cf_feat.expand(bsz // 2, -1)
|
||||
feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)
|
||||
|
||||
cond = {
|
||||
"y": feat,
|
||||
"txt_feat": txt_feat,
|
||||
"txt_feat_seq": txt_feat_seq,
|
||||
"mask": mask,
|
||||
}
|
||||
sample_fn = self.get_sample_fn(timestep_respacing)
|
||||
sample_outputs = sample_fn(
|
||||
guided_model_fn,
|
||||
(bsz, 3, img_sz, img_sz),
|
||||
noise=None,
|
||||
device=txt_feat.device,
|
||||
clip_denoised=True,
|
||||
model_kwargs=cond,
|
||||
)
|
||||
|
||||
for out in sample_outputs:
|
||||
sample = out["sample"]
|
||||
yield sample if cf_guidance_scales is None else sample[
|
||||
: sample.shape[0] // 2
|
||||
]
|
||||
|
||||
|
||||
class Text2ImModel(Text2ImProgressiveModel):
|
||||
def forward(
|
||||
self,
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
img_feat=None,
|
||||
cf_guidance_scales=None,
|
||||
timestep_respacing=None,
|
||||
):
|
||||
last_out = None
|
||||
for out in super().forward(
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
img_feat,
|
||||
cf_guidance_scales,
|
||||
timestep_respacing,
|
||||
):
|
||||
last_out = out
|
||||
return last_out
|
138
ldm/modules/karlo/kakao/models/prior_model.py
Normal file
138
ldm/modules/karlo/kakao/models/prior_model.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Karlo-v1.0.alpha
|
||||
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
|
||||
from ldm.modules.karlo.kakao.modules.xf import PriorTransformer
|
||||
|
||||
|
||||
class PriorDiffusionModel(torch.nn.Module):
|
||||
"""
|
||||
A prior that generates clip image feature based on the text prompt.
|
||||
|
||||
:param config: yaml config to define the decoder.
|
||||
:param tokenizer: tokenizer used in clip.
|
||||
:param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
|
||||
:param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
|
||||
"""
|
||||
|
||||
def __init__(self, config, tokenizer, clip_mean, clip_std):
|
||||
super().__init__()
|
||||
|
||||
self._conf = config
|
||||
self._model_conf = config.model.hparams
|
||||
self._diffusion_kwargs = dict(
|
||||
steps=config.diffusion.steps,
|
||||
learn_sigma=config.diffusion.learn_sigma,
|
||||
sigma_small=config.diffusion.sigma_small,
|
||||
noise_schedule=config.diffusion.noise_schedule,
|
||||
use_kl=config.diffusion.use_kl,
|
||||
predict_xstart=config.diffusion.predict_xstart,
|
||||
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
|
||||
timestep_respacing=config.diffusion.timestep_respacing,
|
||||
)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
|
||||
self.register_buffer("clip_std", clip_std[None, :], persistent=False)
|
||||
|
||||
causal_mask = self.get_causal_mask()
|
||||
self.register_buffer("causal_mask", causal_mask, persistent=False)
|
||||
|
||||
self.model = PriorTransformer(
|
||||
text_ctx=self._model_conf.text_ctx,
|
||||
xf_width=self._model_conf.xf_width,
|
||||
xf_layers=self._model_conf.xf_layers,
|
||||
xf_heads=self._model_conf.xf_heads,
|
||||
xf_final_ln=self._model_conf.xf_final_ln,
|
||||
clip_dim=self._model_conf.clip_dim,
|
||||
)
|
||||
|
||||
cf_token, cf_mask = self.set_cf_text_tensor()
|
||||
self.register_buffer("cf_token", cf_token, persistent=False)
|
||||
self.register_buffer("cf_mask", cf_mask, persistent=False)
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(
|
||||
cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
|
||||
):
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, clip_mean, clip_std)
|
||||
model.load_state_dict(ckpt, strict=strict)
|
||||
return model
|
||||
|
||||
def set_cf_text_tensor(self):
|
||||
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
|
||||
|
||||
def get_sample_fn(self, timestep_respacing):
|
||||
use_ddim = timestep_respacing.startswith(("ddim", "fast"))
|
||||
|
||||
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
|
||||
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
|
||||
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
|
||||
sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
|
||||
|
||||
return sample_fn
|
||||
|
||||
def get_causal_mask(self):
|
||||
seq_len = self._model_conf.text_ctx + 4
|
||||
mask = torch.empty(seq_len, seq_len)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
mask = mask[None, ...]
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
mask,
|
||||
cf_guidance_scales=None,
|
||||
timestep_respacing=None,
|
||||
denoised_fn=True,
|
||||
):
|
||||
# cfg should be enabled in inference
|
||||
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
|
||||
|
||||
bsz_ = txt_feat.shape[0]
|
||||
bsz = bsz_ // 2
|
||||
|
||||
def guided_model_fn(x_t, ts, **kwargs):
|
||||
half = x_t[: len(x_t) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.model(combined, ts, **kwargs)
|
||||
eps, rest = (
|
||||
model_out[:, : int(x_t.shape[1])],
|
||||
model_out[:, int(x_t.shape[1]) :],
|
||||
)
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
|
||||
cond_eps - uncond_eps
|
||||
)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
cond = {
|
||||
"text_emb": txt_feat,
|
||||
"text_enc": txt_feat_seq,
|
||||
"mask": mask,
|
||||
"causal_mask": self.causal_mask,
|
||||
}
|
||||
sample_fn = self.get_sample_fn(timestep_respacing)
|
||||
sample = sample_fn(
|
||||
guided_model_fn,
|
||||
(bsz_, self.model.clip_dim),
|
||||
noise=None,
|
||||
device=txt_feat.device,
|
||||
clip_denoised=False,
|
||||
denoised_fn=lambda x: torch.clamp(x, -10, 10),
|
||||
model_kwargs=cond,
|
||||
)
|
||||
sample = (sample * self.clip_std) + self.clip_mean
|
||||
|
||||
return sample[:bsz]
|
10
ldm/modules/karlo/kakao/models/sr_256_1k.py
Normal file
10
ldm/modules/karlo/kakao/models/sr_256_1k.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Karlo-v1.0.alpha
|
||||
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive
|
||||
|
||||
|
||||
class SupRes256to1kProgressive(SupRes64to256Progressive):
|
||||
pass # no difference currently
|
88
ldm/modules/karlo/kakao/models/sr_64_256.py
Normal file
88
ldm/modules/karlo/kakao/models/sr_64_256.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Karlo-v1.0.alpha
|
||||
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel
|
||||
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
|
||||
|
||||
|
||||
class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module):
|
||||
"""
|
||||
ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
|
||||
In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
|
||||
In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
|
||||
This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self._config = config
|
||||
self._diffusion_kwargs = dict(
|
||||
steps=config.diffusion.steps,
|
||||
learn_sigma=config.diffusion.learn_sigma,
|
||||
sigma_small=config.diffusion.sigma_small,
|
||||
noise_schedule=config.diffusion.noise_schedule,
|
||||
use_kl=config.diffusion.use_kl,
|
||||
predict_xstart=config.diffusion.predict_xstart,
|
||||
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
|
||||
)
|
||||
|
||||
self.model_first_steps = SuperResUNetModel(
|
||||
in_channels=3, # auto-changed to 6 inside the model
|
||||
model_channels=config.model.hparams.channels,
|
||||
out_channels=3,
|
||||
num_res_blocks=config.model.hparams.depth,
|
||||
attention_resolutions=(), # no attention
|
||||
dropout=config.model.hparams.dropout,
|
||||
channel_mult=config.model.hparams.channels_multiple,
|
||||
resblock_updown=True,
|
||||
use_middle_attention=False,
|
||||
)
|
||||
self.model_last_step = SuperResUNetModel(
|
||||
in_channels=3, # auto-changed to 6 inside the model
|
||||
model_channels=config.model.hparams.channels,
|
||||
out_channels=3,
|
||||
num_res_blocks=config.model.hparams.depth,
|
||||
attention_resolutions=(), # no attention
|
||||
dropout=config.model.hparams.dropout,
|
||||
channel_mult=config.model.hparams.channels_multiple,
|
||||
resblock_updown=True,
|
||||
use_middle_attention=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True):
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
|
||||
model = cls(config)
|
||||
model.load_state_dict(ckpt, strict=strict)
|
||||
return model
|
||||
|
||||
def get_sample_fn(self, timestep_respacing):
|
||||
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
|
||||
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
|
||||
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
|
||||
return diffusion.p_sample_loop_progressive_for_improved_sr
|
||||
|
||||
def forward(self, low_res, timestep_respacing="7", **kwargs):
|
||||
assert (
|
||||
timestep_respacing == "7"
|
||||
), "different respacing method may work, but no guaranteed"
|
||||
|
||||
sample_fn = self.get_sample_fn(timestep_respacing)
|
||||
sample_outputs = sample_fn(
|
||||
self.model_first_steps,
|
||||
self.model_last_step,
|
||||
shape=low_res.shape,
|
||||
clip_denoised=True,
|
||||
model_kwargs=dict(low_res=low_res),
|
||||
**kwargs,
|
||||
)
|
||||
for x in sample_outputs:
|
||||
sample = x["sample"]
|
||||
yield sample
|
49
ldm/modules/karlo/kakao/modules/__init__.py
Normal file
49
ldm/modules/karlo/kakao/modules/__init__.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
from .diffusion import gaussian_diffusion as gd
|
||||
from .diffusion.respace import (
|
||||
SpacedDiffusion,
|
||||
space_timesteps,
|
||||
)
|
||||
|
||||
|
||||
def create_gaussian_diffusion(
|
||||
steps,
|
||||
learn_sigma,
|
||||
sigma_small,
|
||||
noise_schedule,
|
||||
use_kl,
|
||||
predict_xstart,
|
||||
rescale_learned_sigmas,
|
||||
timestep_respacing,
|
||||
):
|
||||
betas = gd.get_named_beta_schedule(noise_schedule, steps)
|
||||
if use_kl:
|
||||
loss_type = gd.LossType.RESCALED_KL
|
||||
elif rescale_learned_sigmas:
|
||||
loss_type = gd.LossType.RESCALED_MSE
|
||||
else:
|
||||
loss_type = gd.LossType.MSE
|
||||
if not timestep_respacing:
|
||||
timestep_respacing = [steps]
|
||||
|
||||
return SpacedDiffusion(
|
||||
use_timesteps=space_timesteps(steps, timestep_respacing),
|
||||
betas=betas,
|
||||
model_mean_type=(
|
||||
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
||||
),
|
||||
model_var_type=(
|
||||
(
|
||||
gd.ModelVarType.FIXED_LARGE
|
||||
if not sigma_small
|
||||
else gd.ModelVarType.FIXED_SMALL
|
||||
)
|
||||
if not learn_sigma
|
||||
else gd.ModelVarType.LEARNED_RANGE
|
||||
),
|
||||
loss_type=loss_type,
|
||||
)
|
828
ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py
Normal file
828
ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py
Normal file
|
@ -0,0 +1,828 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import enum
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
|
||||
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
||||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
||||
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
||||
betas[:warmup_time] = np.linspace(
|
||||
beta_start, beta_end, warmup_time, dtype=np.float64
|
||||
)
|
||||
return betas
|
||||
|
||||
|
||||
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
||||
"""
|
||||
This is the deprecated API for creating beta schedules.
|
||||
See get_named_beta_schedule() for the new library of schedules.
|
||||
"""
|
||||
if beta_schedule == "quad":
|
||||
betas = (
|
||||
np.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_diffusion_timesteps,
|
||||
dtype=np.float64,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "linear":
|
||||
betas = np.linspace(
|
||||
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
||||
)
|
||||
elif beta_schedule == "warmup10":
|
||||
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
||||
elif beta_schedule == "warmup50":
|
||||
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
||||
elif beta_schedule == "const":
|
||||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
||||
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
||||
betas = 1.0 / np.linspace(
|
||||
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(beta_schedule)
|
||||
assert betas.shape == (num_diffusion_timesteps,)
|
||||
return betas
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
The beta schedule library consists of beta schedules which remain similar
|
||||
in the limit of num_diffusion_timesteps.
|
||||
Beta schedules may be added, but should not be removed or changed once
|
||||
they are committed to maintain backwards compatibility.
|
||||
"""
|
||||
if schedule_name == "linear":
|
||||
# Linear schedule from Ho et al, extended to work for any number of
|
||||
# diffusion steps.
|
||||
scale = 1000 / num_diffusion_timesteps
|
||||
return get_beta_schedule(
|
||||
"linear",
|
||||
beta_start=scale * 0.0001,
|
||||
beta_end=scale * 0.02,
|
||||
num_diffusion_timesteps=num_diffusion_timesteps,
|
||||
)
|
||||
elif schedule_name == "squaredcos_cap_v2":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
class ModelMeanType(enum.Enum):
|
||||
"""
|
||||
Which type of output the model predicts.
|
||||
"""
|
||||
|
||||
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
||||
START_X = enum.auto() # the model predicts x_0
|
||||
EPSILON = enum.auto() # the model predicts epsilon
|
||||
|
||||
|
||||
class ModelVarType(enum.Enum):
|
||||
"""
|
||||
What is used as the model's output variance.
|
||||
The LEARNED_RANGE option has been added to allow the model to predict
|
||||
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
||||
"""
|
||||
|
||||
LEARNED = enum.auto()
|
||||
FIXED_SMALL = enum.auto()
|
||||
FIXED_LARGE = enum.auto()
|
||||
LEARNED_RANGE = enum.auto()
|
||||
|
||||
|
||||
class LossType(enum.Enum):
|
||||
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
||||
RESCALED_MSE = (
|
||||
enum.auto()
|
||||
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
||||
KL = enum.auto() # use the variational lower-bound
|
||||
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
||||
|
||||
def is_vb(self):
|
||||
return self == LossType.KL or self == LossType.RESCALED_KL
|
||||
|
||||
|
||||
class GaussianDiffusion(th.nn.Module):
|
||||
"""
|
||||
Utilities for training and sampling diffusion models.
|
||||
Original ported from this codebase:
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
||||
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
||||
starting at T and going to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
betas,
|
||||
model_mean_type,
|
||||
model_var_type,
|
||||
loss_type,
|
||||
):
|
||||
super(GaussianDiffusion, self).__init__()
|
||||
self.model_mean_type = model_mean_type
|
||||
self.model_var_type = model_var_type
|
||||
self.loss_type = loss_type
|
||||
|
||||
# Use float64 for accuracy.
|
||||
betas = np.array(betas, dtype=np.float64)
|
||||
assert len(betas.shape) == 1, "betas must be 1-D"
|
||||
assert (betas > 0).all() and (betas <= 1).all()
|
||||
|
||||
self.num_timesteps = int(betas.shape[0])
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||
alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0)
|
||||
assert alphas_cumprod_prev.shape == (self.num_timesteps,)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
|
||||
log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod)
|
||||
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
|
||||
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = (
|
||||
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
)
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
posterior_log_variance_clipped = np.log(
|
||||
np.append(posterior_variance[1], posterior_variance[1:])
|
||||
)
|
||||
posterior_mean_coef1 = (
|
||||
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
)
|
||||
posterior_mean_coef2 = (
|
||||
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
)
|
||||
|
||||
self.register_buffer("betas", th.from_numpy(betas), persistent=False)
|
||||
self.register_buffer(
|
||||
"alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
th.from_numpy(sqrt_one_minus_alphas_cumprod),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod",
|
||||
th.from_numpy(log_one_minus_alphas_cumprod),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod",
|
||||
th.from_numpy(sqrt_recip_alphas_cumprod),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
th.from_numpy(sqrt_recipm1_alphas_cumprod),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"posterior_variance", th.from_numpy(posterior_variance), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"posterior_log_variance_clipped",
|
||||
th.from_numpy(posterior_log_variance_clipped),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"posterior_mean_coef1",
|
||||
th.from_numpy(posterior_mean_coef1),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"posterior_mean_coef2",
|
||||
th.from_numpy(posterior_mean_coef2),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
"""
|
||||
Get the distribution q(x_t | x_0).
|
||||
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
||||
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
||||
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
||||
"""
|
||||
mean = (
|
||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
)
|
||||
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = _extract_into_tensor(
|
||||
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
||||
)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
"""
|
||||
Diffuse the data for a given number of diffusion steps.
|
||||
In other words, sample from q(x_t | x_0).
|
||||
:param x_start: the initial data batch.
|
||||
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
||||
:param noise: if specified, the split-out normal noise.
|
||||
:return: A noisy version of x_start.
|
||||
"""
|
||||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
assert noise.shape == x_start.shape
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior:
|
||||
q(x_{t-1} | x_t, x_0)
|
||||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(
|
||||
self.posterior_log_variance_clipped, t, x_t.shape
|
||||
)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
== posterior_log_variance_clipped.shape[0]
|
||||
== x_start.shape[0]
|
||||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(
|
||||
self,
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
model_kwargs=None,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
:param model: the model, which takes a signal and a batch of timesteps
|
||||
as input.
|
||||
:param x: the [N x C x ...] tensor at time t.
|
||||
:param t: a 1-D Tensor of timesteps.
|
||||
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
||||
:param denoised_fn: if not None, a function which applies to the
|
||||
x_start prediction before it is used to sample. Applies before
|
||||
clip_denoised.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict with the following keys:
|
||||
- 'mean': the model mean output.
|
||||
- 'variance': the model variance output.
|
||||
- 'log_variance': the log of 'variance'.
|
||||
- 'pred_xstart': the prediction for x_0.
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
|
||||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
model_output = model(x, t, **model_kwargs)
|
||||
if isinstance(model_output, tuple):
|
||||
model_output, extra = model_output
|
||||
else:
|
||||
extra = None
|
||||
|
||||
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = th.split(model_output, C, dim=1)
|
||||
if self.model_var_type == ModelVarType.LEARNED:
|
||||
model_log_variance = model_var_values
|
||||
model_variance = th.exp(model_log_variance)
|
||||
else:
|
||||
min_log = _extract_into_tensor(
|
||||
self.posterior_log_variance_clipped, t, x.shape
|
||||
)
|
||||
max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
model_variance = th.exp(model_log_variance)
|
||||
else:
|
||||
model_variance, model_log_variance = {
|
||||
# for fixedlarge, we set the initial (log-)variance like so
|
||||
# to get a better decoder log likelihood.
|
||||
ModelVarType.FIXED_LARGE: (
|
||||
th.cat([self.posterior_variance[1][None], self.betas[1:]]),
|
||||
th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])),
|
||||
),
|
||||
ModelVarType.FIXED_SMALL: (
|
||||
self.posterior_variance,
|
||||
self.posterior_log_variance_clipped,
|
||||
),
|
||||
}[self.model_var_type]
|
||||
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
||||
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
||||
|
||||
def process_xstart(x):
|
||||
if denoised_fn is not None:
|
||||
x = denoised_fn(x)
|
||||
if clip_denoised:
|
||||
return x.clamp(-1, 1)
|
||||
return x
|
||||
|
||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||
pred_xstart = process_xstart(
|
||||
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
||||
)
|
||||
model_mean = model_output
|
||||
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
||||
if self.model_mean_type == ModelMeanType.START_X:
|
||||
pred_xstart = process_xstart(model_output)
|
||||
else:
|
||||
pred_xstart = process_xstart(
|
||||
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
||||
)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(
|
||||
x_start=pred_xstart, x_t=x, t=t
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(self.model_mean_type)
|
||||
|
||||
assert (
|
||||
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
)
|
||||
return {
|
||||
"mean": model_mean,
|
||||
"variance": model_variance,
|
||||
"log_variance": model_log_variance,
|
||||
"pred_xstart": pred_xstart,
|
||||
}
|
||||
|
||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- pred_xstart
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
"""
|
||||
Compute the mean for the previous step, given a function cond_fn that
|
||||
computes the gradient of a conditional log probability with respect to
|
||||
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
||||
condition on y.
|
||||
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
||||
"""
|
||||
gradient = cond_fn(x, t, **model_kwargs)
|
||||
new_mean = (
|
||||
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
||||
)
|
||||
return new_mean
|
||||
|
||||
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
"""
|
||||
Compute what the p_mean_variance output would have been, should the
|
||||
model's score function be conditioned by cond_fn.
|
||||
See condition_mean() for details on cond_fn.
|
||||
Unlike condition_mean(), this instead uses the conditioning strategy
|
||||
from Song et al (2020).
|
||||
"""
|
||||
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
||||
|
||||
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
||||
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
||||
|
||||
out = p_mean_var.copy()
|
||||
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
||||
out["mean"], _, _ = self.q_posterior_mean_variance(
|
||||
x_start=out["pred_xstart"], x_t=x, t=t
|
||||
)
|
||||
return out
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
):
|
||||
"""
|
||||
Sample x_{t-1} from the model at the given timestep.
|
||||
:param model: the model to sample from.
|
||||
:param x: the current tensor at x_{t-1}.
|
||||
:param t: the value of t, starting at 0 for the first diffusion step.
|
||||
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
||||
:param denoised_fn: if not None, a function which applies to the
|
||||
x_start prediction before it is used to sample.
|
||||
:param cond_fn: if not None, this is a gradient function that acts
|
||||
similarly to the model.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict containing the following keys:
|
||||
- 'sample': a random sample from the model.
|
||||
- 'pred_xstart': a prediction of x_0.
|
||||
"""
|
||||
out = self.p_mean_variance(
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
noise = th.randn_like(x)
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
if cond_fn is not None:
|
||||
out["mean"] = self.condition_mean(
|
||||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||
)
|
||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def p_sample_loop(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model.
|
||||
:param model: the model module.
|
||||
:param shape: the shape of the samples, (N, C, H, W).
|
||||
:param noise: if specified, the noise from the encoder to sample.
|
||||
Should be of the same shape as `shape`.
|
||||
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
||||
:param denoised_fn: if not None, a function which applies to the
|
||||
x_start prediction before it is used to sample.
|
||||
:param cond_fn: if not None, this is a gradient function that acts
|
||||
similarly to the model.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:param device: if specified, the device to create the samples on.
|
||||
If not specified, use a model parameter's device.
|
||||
:param progress: if True, show a tqdm progress bar.
|
||||
:return: a non-differentiable batch of samples.
|
||||
"""
|
||||
final = None
|
||||
for sample in self.p_sample_loop_progressive(
|
||||
model,
|
||||
shape,
|
||||
noise=noise,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
device=device,
|
||||
progress=progress,
|
||||
):
|
||||
final = sample
|
||||
return final["sample"]
|
||||
|
||||
def p_sample_loop_progressive(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model and yield intermediate samples from
|
||||
each timestep of diffusion.
|
||||
Arguments are the same as p_sample_loop().
|
||||
Returns a generator over dicts, where each dict is the return value of
|
||||
p_sample().
|
||||
"""
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
assert isinstance(shape, (tuple, list))
|
||||
if noise is not None:
|
||||
img = noise
|
||||
else:
|
||||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
if progress:
|
||||
# Lazy import so that we don't depend on tqdm.
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
indices = tqdm(indices)
|
||||
|
||||
for idx, i in enumerate(indices):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def p_sample_loop_progressive_for_improved_sr(
|
||||
self,
|
||||
model,
|
||||
model_aux,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
):
|
||||
"""
|
||||
Modified version of p_sample_loop_progressive for sampling from the improved sr model
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
assert isinstance(shape, (tuple, list))
|
||||
if noise is not None:
|
||||
img = noise
|
||||
else:
|
||||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
if progress:
|
||||
# Lazy import so that we don't depend on tqdm.
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
indices = tqdm(indices)
|
||||
|
||||
for idx, i in enumerate(indices):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
model_aux if len(indices) - 1 == idx else model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def ddim_sample(
|
||||
self,
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Sample x_{t-1} from the model using DDIM.
|
||||
Same usage as p_sample().
|
||||
"""
|
||||
out = self.p_mean_variance(
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
if cond_fn is not None:
|
||||
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
|
||||
# Usually our model outputs epsilon, but we re-derive it
|
||||
# in case we used x_start or x_prev prediction.
|
||||
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
||||
|
||||
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
||||
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
||||
sigma = (
|
||||
eta
|
||||
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
||||
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
)
|
||||
# Equation 12.
|
||||
noise = th.randn_like(x)
|
||||
mean_pred = (
|
||||
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
||||
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
||||
)
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
sample = mean_pred + nonzero_mask * sigma * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def ddim_reverse_sample(
|
||||
self,
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Sample x_{t+1} from the model using DDIM reverse ODE.
|
||||
"""
|
||||
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
||||
out = self.p_mean_variance(
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
if cond_fn is not None:
|
||||
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
# Usually our model outputs epsilon, but we re-derive it
|
||||
# in case we used x_start or x_prev prediction.
|
||||
eps = (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
||||
- out["pred_xstart"]
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
||||
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
||||
|
||||
# Equation 12. reversed
|
||||
mean_pred = (
|
||||
out["pred_xstart"] * th.sqrt(alpha_bar_next)
|
||||
+ th.sqrt(1 - alpha_bar_next) * eps
|
||||
)
|
||||
|
||||
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def ddim_sample_loop(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model using DDIM.
|
||||
Same usage as p_sample_loop().
|
||||
"""
|
||||
final = None
|
||||
for sample in self.ddim_sample_loop_progressive(
|
||||
model,
|
||||
shape,
|
||||
noise=noise,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
device=device,
|
||||
progress=progress,
|
||||
eta=eta,
|
||||
):
|
||||
final = sample
|
||||
return final["sample"]
|
||||
|
||||
def ddim_sample_loop_progressive(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Use DDIM to sample from the model and yield intermediate samples from
|
||||
each timestep of DDIM.
|
||||
Same usage as p_sample_loop_progressive().
|
||||
"""
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
assert isinstance(shape, (tuple, list))
|
||||
if noise is not None:
|
||||
img = noise
|
||||
else:
|
||||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
if progress:
|
||||
# Lazy import so that we don't depend on tqdm.
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
indices = tqdm(indices)
|
||||
|
||||
for i in indices:
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.ddim_sample(
|
||||
model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
eta=eta,
|
||||
)
|
||||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
res = arr.to(device=timesteps.device)[timesteps].float()
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
112
ldm/modules/karlo/kakao/modules/diffusion/respace.py
Normal file
112
ldm/modules/karlo/kakao/modules/diffusion/respace.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
import torch as th
|
||||
|
||||
from .gaussian_diffusion import GaussianDiffusion
|
||||
|
||||
|
||||
def space_timesteps(num_timesteps, section_counts):
|
||||
"""
|
||||
Create a list of timesteps to use from an original diffusion process,
|
||||
given the number of timesteps we want to take from equally-sized portions
|
||||
of the original process.
|
||||
|
||||
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||
|
||||
:param num_timesteps: the number of diffusion steps in the original
|
||||
process to divide up.
|
||||
:param section_counts: either a list of numbers, or a string containing
|
||||
comma-separated numbers, indicating the step count
|
||||
per section. As a special case, use "ddimN" where N
|
||||
is a number of steps to use the striding from the
|
||||
DDIM paper.
|
||||
:return: a set of diffusion steps from the original process to use.
|
||||
"""
|
||||
if isinstance(section_counts, str):
|
||||
if section_counts.startswith("ddim"):
|
||||
desired_count = int(section_counts[len("ddim") :])
|
||||
for i in range(1, num_timesteps):
|
||||
if len(range(0, num_timesteps, i)) == desired_count:
|
||||
return set(range(0, num_timesteps, i))
|
||||
raise ValueError(
|
||||
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||
)
|
||||
elif section_counts == "fast27":
|
||||
steps = space_timesteps(num_timesteps, "10,10,3,2,2")
|
||||
# Help reduce DDIM artifacts from noisiest timesteps.
|
||||
steps.remove(num_timesteps - 1)
|
||||
steps.add(num_timesteps - 3)
|
||||
return steps
|
||||
section_counts = [int(x) for x in section_counts.split(",")]
|
||||
size_per = num_timesteps // len(section_counts)
|
||||
extra = num_timesteps % len(section_counts)
|
||||
start_idx = 0
|
||||
all_steps = []
|
||||
for i, section_count in enumerate(section_counts):
|
||||
size = size_per + (1 if i < extra else 0)
|
||||
if size < section_count:
|
||||
raise ValueError(
|
||||
f"cannot divide section of {size} steps into {section_count}"
|
||||
)
|
||||
if section_count <= 1:
|
||||
frac_stride = 1
|
||||
else:
|
||||
frac_stride = (size - 1) / (section_count - 1)
|
||||
cur_idx = 0.0
|
||||
taken_steps = []
|
||||
for _ in range(section_count):
|
||||
taken_steps.append(start_idx + round(cur_idx))
|
||||
cur_idx += frac_stride
|
||||
all_steps += taken_steps
|
||||
start_idx += size
|
||||
return set(all_steps)
|
||||
|
||||
|
||||
class SpacedDiffusion(GaussianDiffusion):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
||||
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||
original diffusion process to retain.
|
||||
:param kwargs: the kwargs to create the base diffusion process.
|
||||
"""
|
||||
|
||||
def __init__(self, use_timesteps, **kwargs):
|
||||
self.use_timesteps = set(use_timesteps)
|
||||
self.original_num_steps = len(kwargs["betas"])
|
||||
|
||||
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
||||
last_alpha_cumprod = 1.0
|
||||
new_betas = []
|
||||
timestep_map = []
|
||||
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||
if i in self.use_timesteps:
|
||||
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||
last_alpha_cumprod = alpha_cumprod
|
||||
timestep_map.append(i)
|
||||
kwargs["betas"] = th.tensor(new_betas).numpy()
|
||||
super().__init__(**kwargs)
|
||||
self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False)
|
||||
|
||||
def p_mean_variance(self, model, *args, **kwargs):
|
||||
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
||||
|
||||
def condition_score(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
||||
|
||||
def _wrap_model(self, model):
|
||||
def wrapped(x, ts, **kwargs):
|
||||
ts_cpu = ts.detach().to("cpu")
|
||||
return model(
|
||||
x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
|
||||
)
|
||||
|
||||
return wrapped
|
114
ldm/modules/karlo/kakao/modules/nn.py
Normal file
114
ldm/modules/karlo/kakao/modules/nn.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
|
||||
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
|
||||
self.swish = swish
|
||||
|
||||
def forward(self, x):
|
||||
y = super().forward(x.float()).to(x.dtype)
|
||||
if self.swish == 1.0:
|
||||
y = F.silu(y)
|
||||
elif self.swish:
|
||||
y = y * F.sigmoid(y * float(self.swish))
|
||||
return y
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def normalization(channels, swish=0.0):
|
||||
"""
|
||||
Make a standard normalization layer, with an optional swish activation.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(
|
||||
-math.log(max_period)
|
||||
* th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
|
||||
/ half
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
68
ldm/modules/karlo/kakao/modules/resample.py
Normal file
68
ldm/modules/karlo/kakao/modules/resample.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch as th
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
"""
|
||||
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||
|
||||
:param name: the name of the sampler.
|
||||
:param diffusion: the diffusion object to sample for.
|
||||
"""
|
||||
if name == "uniform":
|
||||
return UniformSampler(diffusion)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||
|
||||
|
||||
class ScheduleSampler(th.nn.Module):
|
||||
"""
|
||||
A distribution over timesteps in the diffusion process, intended to reduce
|
||||
variance of the objective.
|
||||
|
||||
By default, samplers perform unbiased importance sampling, in which the
|
||||
objective's mean is unchanged.
|
||||
However, subclasses may override sample() to change how the resampled
|
||||
terms are reweighted, allowing for actual changes in the objective.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def weights(self):
|
||||
"""
|
||||
Get a numpy array of weights, one per diffusion step.
|
||||
|
||||
The weights needn't be normalized, but must be positive.
|
||||
"""
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
"""
|
||||
Importance-sample timesteps for a batch.
|
||||
|
||||
:param batch_size: the number of timesteps.
|
||||
:param device: the torch device to save to.
|
||||
:return: a tuple (timesteps, weights):
|
||||
- timesteps: a tensor of timestep indices.
|
||||
- weights: a tensor of weights to scale the resulting losses.
|
||||
"""
|
||||
w = self.weights()
|
||||
p = w / th.sum(w)
|
||||
indices = p.multinomial(batch_size, replacement=True)
|
||||
weights = 1 / (len(p) * p[indices])
|
||||
return indices, weights
|
||||
|
||||
|
||||
class UniformSampler(ScheduleSampler):
|
||||
def __init__(self, diffusion):
|
||||
super(UniformSampler, self).__init__()
|
||||
self.diffusion = diffusion
|
||||
self.register_buffer(
|
||||
"_weights", th.ones([diffusion.num_timesteps]), persistent=False
|
||||
)
|
||||
|
||||
def weights(self):
|
||||
return self._weights
|
792
ldm/modules/karlo/kakao/modules/unet.py
Normal file
792
ldm/modules/karlo/kakao/modules/unet.py
Normal file
|
@ -0,0 +1,792 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .nn import (
|
||||
avg_pool_nd,
|
||||
conv_nd,
|
||||
linear,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
zero_module,
|
||||
)
|
||||
from .xf import LayerNorm
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, encoder_out=None, mask=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, AttentionBlock):
|
||||
x = layer(x, encoder_out, mask=mask)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
nn.Identity(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(
|
||||
self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0
|
||||
),
|
||||
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class ResBlockNoTimeEmbedding(nn.Module):
|
||||
"""
|
||||
A residual block without time embedding
|
||||
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
nn.Identity(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels, swish=1.0),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb=None):
|
||||
"""
|
||||
Apply the block to a Tensor, NOT conditioned on a timestep embedding.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert emb is None
|
||||
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
encoder_channels=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels, swish=0.0)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
if encoder_channels is not None:
|
||||
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x, encoder_out=None, mask=None):
|
||||
b, c, *spatial = x.shape
|
||||
qkv = self.qkv(self.norm(x).view(b, c, -1))
|
||||
if encoder_out is not None:
|
||||
encoder_out = self.encoder_kv(encoder_out)
|
||||
h = self.attention(qkv, encoder_out, mask=mask)
|
||||
else:
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return x + h.reshape(b, c, *spatial)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, encoder_kv=None, mask=None):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
if encoder_kv is not None:
|
||||
assert encoder_kv.shape[1] == self.n_heads * ch * 2
|
||||
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
|
||||
k = th.cat([ek, k], dim=-1)
|
||||
v = th.cat([ev, v], dim=-1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum("bct,bcs->bts", q * scale, k * scale)
|
||||
if mask is not None:
|
||||
mask = F.pad(mask, (0, length), value=0.0)
|
||||
mask = (
|
||||
mask.unsqueeze(1)
|
||||
.expand(-1, self.n_heads, -1)
|
||||
.reshape(bs * self.n_heads, 1, -1)
|
||||
)
|
||||
weight = weight + mask
|
||||
weight = th.softmax(weight, dim=-1)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param clip_dim: dimension of clip feature.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param encoder_channels: use to make the dimension of query and kv same in AttentionBlock.
|
||||
:param use_time_embedding: use time embedding for condition.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
clip_dim=None,
|
||||
use_checkpoint=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
use_middle_attention=True,
|
||||
resblock_updown=False,
|
||||
encoder_channels=None,
|
||||
use_time_embedding=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.clip_dim = clip_dim
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.use_middle_attention = use_middle_attention
|
||||
self.use_time_embedding = use_time_embedding
|
||||
|
||||
if self.use_time_embedding:
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.clip_dim is not None:
|
||||
self.clip_emb = nn.Linear(clip_dim, time_embed_dim)
|
||||
else:
|
||||
time_embed_dim = None
|
||||
|
||||
CustomResidualBlock = (
|
||||
ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding
|
||||
)
|
||||
ch = input_ch = int(channel_mult[0] * model_channels)
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
||||
)
|
||||
self._feature_size = ch
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
CustomResidualBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=int(mult * model_channels),
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = int(mult * model_channels)
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=encoder_channels,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
CustomResidualBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
CustomResidualBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
*(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=encoder_channels,
|
||||
),
|
||||
)
|
||||
if self.use_middle_attention
|
||||
else tuple(), # add AttentionBlock or not
|
||||
CustomResidualBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
CustomResidualBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=int(model_channels * mult),
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = int(model_channels * mult)
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=encoder_channels,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
CustomResidualBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch, swish=1.0),
|
||||
nn.Identity(),
|
||||
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, y=None):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.clip_dim is not None
|
||||
), "must specify y if and only if the model is clip-rep-conditional"
|
||||
|
||||
hs = []
|
||||
if self.use_time_embedding:
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
if self.clip_dim is not None:
|
||||
emb = emb + self.clip_emb(y)
|
||||
else:
|
||||
emb = None
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class SuperResUNetModel(UNetModel):
|
||||
"""
|
||||
A UNetModel that performs super-resolution.
|
||||
|
||||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
Assumes that the shape of low-resolution and the input should be the same.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "in_channels" in kwargs:
|
||||
kwargs = dict(kwargs)
|
||||
kwargs["in_channels"] = kwargs["in_channels"] * 2
|
||||
else:
|
||||
# Curse you, Python. Or really, just curse positional arguments :|.
|
||||
args = list(args)
|
||||
args[1] = args[1] * 2
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, timesteps, low_res=None, **kwargs):
|
||||
_, _, new_height, new_width = x.shape
|
||||
assert new_height == low_res.shape[2] and new_width == low_res.shape[3]
|
||||
|
||||
x = th.cat([x, low_res], dim=1)
|
||||
return super().forward(x, timesteps, **kwargs)
|
||||
|
||||
|
||||
class PLMImUNet(UNetModel):
|
||||
"""
|
||||
A UNetModel that conditions on text with a pretrained text encoder in CLIP.
|
||||
|
||||
:param text_ctx: number of text tokens to expect.
|
||||
:param xf_width: width of the transformer.
|
||||
:param clip_emb_mult: #extra tokens by projecting clip text feature.
|
||||
:param clip_emb_type: type of condition (here, we fix clip image feature).
|
||||
:param clip_emb_drop: dropout rato of clip image feature for cfg.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_ctx,
|
||||
xf_width,
|
||||
*args,
|
||||
clip_emb_mult=None,
|
||||
clip_emb_type="image",
|
||||
clip_emb_drop=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.text_ctx = text_ctx
|
||||
self.xf_width = xf_width
|
||||
self.clip_emb_mult = clip_emb_mult
|
||||
self.clip_emb_type = clip_emb_type
|
||||
self.clip_emb_drop = clip_emb_drop
|
||||
|
||||
if not xf_width:
|
||||
super().__init__(*args, **kwargs, encoder_channels=None)
|
||||
else:
|
||||
super().__init__(*args, **kwargs, encoder_channels=xf_width)
|
||||
|
||||
# Project text encoded feat seq from pre-trained text encoder in CLIP
|
||||
self.text_seq_proj = nn.Sequential(
|
||||
nn.Linear(self.clip_dim, xf_width),
|
||||
LayerNorm(xf_width),
|
||||
)
|
||||
# Project CLIP text feat
|
||||
self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4)
|
||||
|
||||
assert clip_emb_mult is not None
|
||||
assert clip_emb_type == "image"
|
||||
assert self.clip_dim is not None, "CLIP representation dim should be specified"
|
||||
|
||||
self.clip_tok_proj = nn.Linear(
|
||||
self.clip_dim, self.xf_width * self.clip_emb_mult
|
||||
)
|
||||
if self.clip_emb_drop > 0:
|
||||
self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32))
|
||||
|
||||
def proc_clip_emb_drop(self, feat):
|
||||
if self.clip_emb_drop > 0:
|
||||
bsz, feat_dim = feat.shape
|
||||
assert (
|
||||
feat_dim == self.clip_dim
|
||||
), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}"
|
||||
drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop
|
||||
feat = th.where(
|
||||
drop_idx[..., None], self.cf_param[None].type_as(feat), feat
|
||||
)
|
||||
return feat
|
||||
|
||||
def forward(
|
||||
self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None
|
||||
):
|
||||
bsz = x.shape[0]
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = emb + self.clip_emb(y)
|
||||
|
||||
xf_out = self.text_seq_proj(txt_feat_seq)
|
||||
xf_out = xf_out.permute(0, 2, 1)
|
||||
emb = emb + self.text_feat_proj(txt_feat)
|
||||
xf_out = th.cat(
|
||||
[
|
||||
self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
|
||||
xf_out,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
|
||||
mask = th.where(mask, 0.0, float("-inf"))
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, xf_out, mask=mask)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, xf_out, mask=mask)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, xf_out, mask=mask)
|
||||
h = self.out(h)
|
||||
|
||||
return h
|
231
ldm/modules/karlo/kakao/modules/xf.py
Normal file
231
ldm/modules/karlo/kakao/modules/xf.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Adapted from the repos below:
|
||||
# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
|
||||
# (b) CLIP ViT (https://github.com/openai/CLIP/)
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .nn import timestep_embedding
|
||||
|
||||
|
||||
def convert_module_to_f16(param):
|
||||
"""
|
||||
Convert primitive modules to float16.
|
||||
"""
|
||||
if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
param.weight.data = param.weight.data.half()
|
||||
if param.bias is not None:
|
||||
param.bias.data = param.bias.data.half()
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""
|
||||
Implementation that supports fp16 inputs but fp32 gains/biases.
|
||||
"""
|
||||
|
||||
def forward(self, x: th.Tensor):
|
||||
return super().forward(x.float()).to(x.dtype)
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(self, n_ctx, width, heads):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = nn.Linear(width, width * 3)
|
||||
self.c_proj = nn.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(heads, n_ctx)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = self.c_qkv(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, width):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.c_fc = nn.Linear(width, width * 4)
|
||||
self.c_proj = nn.Linear(width * 4, width)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.c_proj(self.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class QKVMultiheadAttention(nn.Module):
|
||||
def __init__(self, n_heads: int, n_ctx: int):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.n_ctx = n_ctx
|
||||
|
||||
def forward(self, qkv, mask=None):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
attn_ch = width // self.n_heads // 3
|
||||
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
||||
qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
|
||||
q, k, v = th.split(qkv, attn_ch, dim=-1)
|
||||
weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale)
|
||||
wdtype = weight.dtype
|
||||
if mask is not None:
|
||||
weight = weight + mask[:, None, ...]
|
||||
weight = th.softmax(weight, dim=-1).type(wdtype)
|
||||
return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
n_ctx,
|
||||
width,
|
||||
heads,
|
||||
)
|
||||
self.ln_1 = LayerNorm(width)
|
||||
self.mlp = MLP(width)
|
||||
self.ln_2 = LayerNorm(width)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.attn(self.ln_1(x), mask=mask)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
n_ctx,
|
||||
width,
|
||||
heads,
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
for block in self.resblocks:
|
||||
x = block(x, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
class PriorTransformer(nn.Module):
|
||||
"""
|
||||
A Causal Transformer that conditions on CLIP text embedding, text.
|
||||
|
||||
:param text_ctx: number of text tokens to expect.
|
||||
:param xf_width: width of the transformer.
|
||||
:param xf_layers: depth of the transformer.
|
||||
:param xf_heads: heads in the transformer.
|
||||
:param xf_final_ln: use a LayerNorm after the output layer.
|
||||
:param clip_dim: dimension of clip feature.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_ctx,
|
||||
xf_width,
|
||||
xf_layers,
|
||||
xf_heads,
|
||||
xf_final_ln,
|
||||
clip_dim,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.text_ctx = text_ctx
|
||||
self.xf_width = xf_width
|
||||
self.xf_layers = xf_layers
|
||||
self.xf_heads = xf_heads
|
||||
self.clip_dim = clip_dim
|
||||
self.ext_len = 4
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(xf_width, xf_width),
|
||||
nn.SiLU(),
|
||||
nn.Linear(xf_width, xf_width),
|
||||
)
|
||||
self.text_enc_proj = nn.Linear(clip_dim, xf_width)
|
||||
self.text_emb_proj = nn.Linear(clip_dim, xf_width)
|
||||
self.clip_img_proj = nn.Linear(clip_dim, xf_width)
|
||||
self.out_proj = nn.Linear(xf_width, clip_dim)
|
||||
self.transformer = Transformer(
|
||||
text_ctx + self.ext_len,
|
||||
xf_width,
|
||||
xf_layers,
|
||||
xf_heads,
|
||||
)
|
||||
if xf_final_ln:
|
||||
self.final_ln = LayerNorm(xf_width)
|
||||
else:
|
||||
self.final_ln = None
|
||||
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.empty(1, text_ctx + self.ext_len, xf_width)
|
||||
)
|
||||
self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
|
||||
|
||||
nn.init.normal_(self.prd_emb, std=0.01)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
text_emb=None,
|
||||
text_enc=None,
|
||||
mask=None,
|
||||
causal_mask=None,
|
||||
):
|
||||
bsz = x.shape[0]
|
||||
mask = F.pad(mask, (0, self.ext_len), value=True)
|
||||
|
||||
t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width))
|
||||
text_enc = self.text_enc_proj(text_enc)
|
||||
text_emb = self.text_emb_proj(text_emb)
|
||||
x = self.clip_img_proj(x)
|
||||
|
||||
input_seq = [
|
||||
text_enc,
|
||||
text_emb[:, None, :],
|
||||
t_emb[:, None, :],
|
||||
x[:, None, :],
|
||||
self.prd_emb.to(x.dtype).expand(bsz, -1, -1),
|
||||
]
|
||||
input = th.cat(input_seq, dim=1)
|
||||
input = input + self.positional_embedding.to(input.dtype)
|
||||
|
||||
mask = th.where(mask, 0.0, float("-inf"))
|
||||
mask = (mask[:, None, :] + causal_mask).to(input.dtype)
|
||||
|
||||
out = self.transformer(input, mask=mask)
|
||||
if self.final_ln is not None:
|
||||
out = self.final_ln(out)
|
||||
|
||||
out = self.out_proj(out[:, -1])
|
||||
|
||||
return out
|
272
ldm/modules/karlo/kakao/sampler.py
Normal file
272
ldm/modules/karlo/kakao/sampler.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Karlo-v1.0.alpha
|
||||
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
||||
|
||||
# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as TVF
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from .template import BaseSampler, CKPT_PATH
|
||||
|
||||
|
||||
class T2ISampler(BaseSampler):
|
||||
"""
|
||||
A sampler for text-to-image generation.
|
||||
:param root_dir: directory for model checkpoints.
|
||||
:param sampling_type: ["default", "fast"]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
sampling_type: str = "default",
|
||||
):
|
||||
super().__init__(root_dir, sampling_type)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
root_dir: str,
|
||||
clip_model_path: str,
|
||||
clip_stat_path: str,
|
||||
sampling_type: str = "default",
|
||||
):
|
||||
|
||||
model = cls(
|
||||
root_dir=root_dir,
|
||||
sampling_type=sampling_type,
|
||||
)
|
||||
model.load_clip(clip_model_path)
|
||||
model.load_prior(
|
||||
f"{CKPT_PATH['prior']}",
|
||||
clip_stat_path=clip_stat_path,
|
||||
prior_config="configs/karlo/prior_1B_vit_l.yaml"
|
||||
)
|
||||
model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml")
|
||||
model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml")
|
||||
return model
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: str,
|
||||
bsz: int,
|
||||
):
|
||||
"""Setup prompts & cfg scales"""
|
||||
prompts_batch = [prompt for _ in range(bsz)]
|
||||
|
||||
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
||||
|
||||
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
||||
|
||||
""" Get CLIP text feature """
|
||||
clip_model = self._clip
|
||||
tokenizer = self._tokenizer
|
||||
max_txt_length = self._prior.model.text_ctx
|
||||
|
||||
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
|
||||
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
|
||||
if not (cf_token.shape == tok.shape):
|
||||
cf_token = cf_token.expand(tok.shape[0], -1)
|
||||
cf_mask = cf_mask.expand(tok.shape[0], -1)
|
||||
|
||||
tok = torch.cat([tok, cf_token], dim=0)
|
||||
mask = torch.cat([mask, cf_mask], dim=0)
|
||||
|
||||
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
||||
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
||||
|
||||
return (
|
||||
prompts_batch,
|
||||
prior_cf_scales_batch,
|
||||
decoder_cf_scales_batch,
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
bsz: int,
|
||||
progressive_mode=None,
|
||||
) -> Iterator[torch.Tensor]:
|
||||
assert progressive_mode in ("loop", "stage", "final")
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
(
|
||||
prompts_batch,
|
||||
prior_cf_scales_batch,
|
||||
decoder_cf_scales_batch,
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
) = self.preprocess(
|
||||
prompt,
|
||||
bsz,
|
||||
)
|
||||
|
||||
""" Transform CLIP text feature into image feature """
|
||||
img_feat = self._prior(
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
mask,
|
||||
prior_cf_scales_batch,
|
||||
timestep_respacing=self._prior_sm,
|
||||
)
|
||||
|
||||
""" Generate 64x64px images """
|
||||
images_64_outputs = self._decoder(
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
img_feat,
|
||||
cf_guidance_scales=decoder_cf_scales_batch,
|
||||
timestep_respacing=self._decoder_sm,
|
||||
)
|
||||
|
||||
images_64 = None
|
||||
for k, out in enumerate(images_64_outputs):
|
||||
images_64 = out
|
||||
if progressive_mode == "loop":
|
||||
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
||||
if progressive_mode == "stage":
|
||||
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
||||
|
||||
images_64 = torch.clamp(images_64, -1, 1)
|
||||
|
||||
""" Upsample 64x64 to 256x256 """
|
||||
images_256 = TVF.resize(
|
||||
images_64,
|
||||
[256, 256],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=True,
|
||||
)
|
||||
images_256_outputs = self._sr_64_256(
|
||||
images_256, timestep_respacing=self._sr_sm
|
||||
)
|
||||
|
||||
for k, out in enumerate(images_256_outputs):
|
||||
images_256 = out
|
||||
if progressive_mode == "loop":
|
||||
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
||||
if progressive_mode == "stage":
|
||||
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
||||
|
||||
yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
|
||||
|
||||
|
||||
class PriorSampler(BaseSampler):
|
||||
"""
|
||||
A sampler for text-to-image generation, but only the prior.
|
||||
:param root_dir: directory for model checkpoints.
|
||||
:param sampling_type: ["default", "fast"]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
sampling_type: str = "default",
|
||||
):
|
||||
super().__init__(root_dir, sampling_type)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
root_dir: str,
|
||||
clip_model_path: str,
|
||||
clip_stat_path: str,
|
||||
sampling_type: str = "default",
|
||||
):
|
||||
model = cls(
|
||||
root_dir=root_dir,
|
||||
sampling_type=sampling_type,
|
||||
)
|
||||
model.load_clip(clip_model_path)
|
||||
model.load_prior(
|
||||
f"{CKPT_PATH['prior']}",
|
||||
clip_stat_path=clip_stat_path,
|
||||
prior_config="configs/karlo/prior_1B_vit_l.yaml"
|
||||
)
|
||||
return model
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: str,
|
||||
bsz: int,
|
||||
):
|
||||
"""Setup prompts & cfg scales"""
|
||||
prompts_batch = [prompt for _ in range(bsz)]
|
||||
|
||||
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
||||
|
||||
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
||||
|
||||
""" Get CLIP text feature """
|
||||
clip_model = self._clip
|
||||
tokenizer = self._tokenizer
|
||||
max_txt_length = self._prior.model.text_ctx
|
||||
|
||||
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
|
||||
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
|
||||
if not (cf_token.shape == tok.shape):
|
||||
cf_token = cf_token.expand(tok.shape[0], -1)
|
||||
cf_mask = cf_mask.expand(tok.shape[0], -1)
|
||||
|
||||
tok = torch.cat([tok, cf_token], dim=0)
|
||||
mask = torch.cat([mask, cf_mask], dim=0)
|
||||
|
||||
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
||||
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
||||
|
||||
return (
|
||||
prompts_batch,
|
||||
prior_cf_scales_batch,
|
||||
decoder_cf_scales_batch,
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
bsz: int,
|
||||
progressive_mode=None,
|
||||
) -> Iterator[torch.Tensor]:
|
||||
assert progressive_mode in ("loop", "stage", "final")
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
(
|
||||
prompts_batch,
|
||||
prior_cf_scales_batch,
|
||||
decoder_cf_scales_batch,
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
tok,
|
||||
mask,
|
||||
) = self.preprocess(
|
||||
prompt,
|
||||
bsz,
|
||||
)
|
||||
|
||||
""" Transform CLIP text feature into image feature """
|
||||
img_feat = self._prior(
|
||||
txt_feat,
|
||||
txt_feat_seq,
|
||||
mask,
|
||||
prior_cf_scales_batch,
|
||||
timestep_respacing=self._prior_sm,
|
||||
)
|
||||
|
||||
yield img_feat
|
141
ldm/modules/karlo/kakao/template.py
Normal file
141
ldm/modules/karlo/kakao/template.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
# ------------------------------------------------------------------------------------
|
||||
# Karlo-v1.0.alpha
|
||||
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
|
||||
from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
|
||||
from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
|
||||
from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
|
||||
|
||||
|
||||
SAMPLING_CONF = {
|
||||
"default": {
|
||||
"prior_sm": "25",
|
||||
"prior_n_samples": 1,
|
||||
"prior_cf_scale": 4.0,
|
||||
"decoder_sm": "50",
|
||||
"decoder_cf_scale": 8.0,
|
||||
"sr_sm": "7",
|
||||
},
|
||||
"fast": {
|
||||
"prior_sm": "25",
|
||||
"prior_n_samples": 1,
|
||||
"prior_cf_scale": 4.0,
|
||||
"decoder_sm": "25",
|
||||
"decoder_cf_scale": 8.0,
|
||||
"sr_sm": "7",
|
||||
},
|
||||
}
|
||||
|
||||
CKPT_PATH = {
|
||||
"prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
|
||||
"decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
|
||||
"sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
|
||||
}
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
_PRIOR_CLASS = PriorDiffusionModel
|
||||
_DECODER_CLASS = Text2ImProgressiveModel
|
||||
_SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
sampling_type: str = "fast",
|
||||
):
|
||||
self._root_dir = root_dir
|
||||
|
||||
sampling_type = SAMPLING_CONF[sampling_type]
|
||||
self._prior_sm = sampling_type["prior_sm"]
|
||||
self._prior_n_samples = sampling_type["prior_n_samples"]
|
||||
self._prior_cf_scale = sampling_type["prior_cf_scale"]
|
||||
|
||||
assert self._prior_n_samples == 1
|
||||
|
||||
self._decoder_sm = sampling_type["decoder_sm"]
|
||||
self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
|
||||
|
||||
self._sr_sm = sampling_type["sr_sm"]
|
||||
|
||||
def __repr__(self):
|
||||
line = ""
|
||||
line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
|
||||
line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
|
||||
line += f"SR(64->256), sampling method: {self._sr_sm}"
|
||||
|
||||
return line
|
||||
|
||||
def load_clip(self, clip_path: str):
|
||||
clip = CustomizedCLIP.load_from_checkpoint(
|
||||
os.path.join(self._root_dir, clip_path)
|
||||
)
|
||||
clip = torch.jit.script(clip)
|
||||
clip.cuda()
|
||||
clip.eval()
|
||||
|
||||
self._clip = clip
|
||||
self._tokenizer = CustomizedTokenizer()
|
||||
|
||||
def load_prior(
|
||||
self,
|
||||
ckpt_path: str,
|
||||
clip_stat_path: str,
|
||||
prior_config: str = "configs/prior_1B_vit_l.yaml"
|
||||
):
|
||||
logging.info(f"Loading prior: {ckpt_path}")
|
||||
|
||||
config = OmegaConf.load(prior_config)
|
||||
clip_mean, clip_std = torch.load(
|
||||
os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
|
||||
)
|
||||
|
||||
prior = self._PRIOR_CLASS.load_from_checkpoint(
|
||||
config,
|
||||
self._tokenizer,
|
||||
clip_mean,
|
||||
clip_std,
|
||||
os.path.join(self._root_dir, ckpt_path),
|
||||
strict=True,
|
||||
)
|
||||
prior.cuda()
|
||||
prior.eval()
|
||||
logging.info("done.")
|
||||
|
||||
self._prior = prior
|
||||
|
||||
def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
|
||||
logging.info(f"Loading decoder: {ckpt_path}")
|
||||
|
||||
config = OmegaConf.load(decoder_config)
|
||||
decoder = self._DECODER_CLASS.load_from_checkpoint(
|
||||
config,
|
||||
self._tokenizer,
|
||||
os.path.join(self._root_dir, ckpt_path),
|
||||
strict=True,
|
||||
)
|
||||
decoder.cuda()
|
||||
decoder.eval()
|
||||
logging.info("done.")
|
||||
|
||||
self._decoder = decoder
|
||||
|
||||
def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
|
||||
logging.info(f"Loading SR(64->256): {ckpt_path}")
|
||||
|
||||
config = OmegaConf.load(sr_config)
|
||||
sr = self._SR256_CLASS.load_from_checkpoint(
|
||||
config, os.path.join(self._root_dir, ckpt_path), strict=True
|
||||
)
|
||||
sr.cuda()
|
||||
sr.eval()
|
||||
logging.info("done.")
|
||||
|
||||
self._sr_64_256 = sr
|
10
ldm/util.py
10
ldm/util.py
|
@ -8,6 +8,16 @@ from inspect import isfunction
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def autocast(f):
|
||||
def do_autocast(*args, **kwargs):
|
||||
with torch.cuda.amp.autocast(enabled=True,
|
||||
dtype=torch.get_autocast_gpu_dtype(),
|
||||
cache_enabled=torch.is_autocast_cache_enabled()):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return do_autocast
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
|
|
|
@ -97,7 +97,7 @@ We currently provide the following checkpoints, for various versions:
|
|||
- `512-base-ema.ckpt`: Fine-tuned on `512-base-ema.ckpt` 2.0 with 220k extra steps taken, with `punsafe=0.98` on the same dataset.
|
||||
- `768-v-ema.ckpt`: Resumed from `768-v-ema.ckpt` 2.0 with an additional 55k steps on the same dataset (`punsafe=0.1`), and then fine-tuned for another 155k extra steps with `punsafe=0.98`.
|
||||
|
||||
SD-unCLIP 2.1 is a finetuned version of Stable Diffusion 2.1, modified to accept (noisy) CLIP image embedding in addition to the text prompt, and can be used to create image variations ([Examples](../assets/stable-samples/stable-unclip/unclip-variations_noise.png)) or can be chained with text-to-image CLIP priors. The amount of noise added to the image embedding can be specified via the `noise_level` (0 means no noise, 1000 full noise).
|
||||
**SD-unCLIP 2.1** is a finetuned version of Stable Diffusion 2.1, modified to accept (noisy) CLIP image embedding in addition to the text prompt, and can be used to create image variations ([Examples](../assets/stable-samples/stable-unclip/unclip-variations_noise.png)) or can be chained with text-to-image CLIP priors. The amount of noise added to the image embedding can be specified via the `noise_level` (0 means no noise, 1000 full noise).
|
||||
|
||||
### Version 2.0
|
||||
|
||||
|
|
416
scripts/streamlit/stableunclip.py
Normal file
416
scripts/streamlit/stableunclip.py
Normal file
|
@ -0,0 +1,416 @@
|
|||
import importlib
|
||||
import streamlit as st
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
import io, os
|
||||
from torch import autocast
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from pytorch_lightning import seed_everything
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
PROMPTS_ROOT = "scripts/prompts/"
|
||||
SAVE_PATH = "outputs/demo/stable-unclip/"
|
||||
|
||||
VERSION2SPECS = {
|
||||
"Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
|
||||
"Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
|
||||
"Full Karlo": {}
|
||||
}
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
importlib.invalidate_caches()
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_interactive_image(key=None):
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def load_img(display=True, key=None):
|
||||
image = get_interactive_image(key=key)
|
||||
if display:
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
w, h = map(lambda x: x - x % 64, (w, h))
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2. * image - 1.
|
||||
|
||||
|
||||
def get_init_img(batch_size=1, key=None):
|
||||
init_image = load_img(key=key).cuda()
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
return init_image
|
||||
|
||||
|
||||
def sample(
|
||||
model,
|
||||
prompt,
|
||||
n_runs=3,
|
||||
n_samples=2,
|
||||
H=512,
|
||||
W=512,
|
||||
C=4,
|
||||
f=8,
|
||||
scale=10.0,
|
||||
ddim_steps=50,
|
||||
ddim_eta=0.0,
|
||||
callback=None,
|
||||
skip_single_save=False,
|
||||
save_grid=True,
|
||||
ucg_schedule=None,
|
||||
negative_prompt="",
|
||||
adm_cond=None,
|
||||
adm_uc=None,
|
||||
use_full_precision=False,
|
||||
only_adm_cond=False
|
||||
):
|
||||
batch_size = n_samples
|
||||
precision_scope = autocast if not use_full_precision else nullcontext
|
||||
# decoderscope = autocast if not use_full_precision else nullcontext
|
||||
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
prompts = batch_size * prompt
|
||||
|
||||
outputs = st.empty()
|
||||
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
all_samples = list()
|
||||
for n in trange(n_runs, desc="Sampling"):
|
||||
shape = [C, H // f, W // f]
|
||||
if not only_adm_cond:
|
||||
uc = None
|
||||
if scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [negative_prompt])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
|
||||
if adm_cond is not None:
|
||||
if adm_cond.shape[0] == 1:
|
||||
adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size)
|
||||
if adm_uc is None:
|
||||
st.warning("Not guiding via c_adm")
|
||||
adm_uc = adm_cond
|
||||
else:
|
||||
if adm_uc.shape[0] == 1:
|
||||
adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size)
|
||||
if not only_adm_cond:
|
||||
c = {"c_crossattn": [c], "c_adm": adm_cond}
|
||||
uc = {"c_crossattn": [uc], "c_adm": adm_uc}
|
||||
else:
|
||||
c = adm_cond
|
||||
uc = adm_uc
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=batch_size,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
x_T=None,
|
||||
callback=callback,
|
||||
ucg_schedule=ucg_schedule
|
||||
)
|
||||
x_samples = model.decode_first_stage(samples_ddim)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if not skip_single_save:
|
||||
base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples")))
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png"))
|
||||
base_count += 1
|
||||
|
||||
all_samples.append(x_samples)
|
||||
|
||||
# get grid of all samples
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
|
||||
outputs.image(grid.cpu().numpy())
|
||||
|
||||
# additionally, save grid
|
||||
grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8))
|
||||
if save_grid:
|
||||
grid_count = len(os.listdir(SAVE_PATH)) - 1
|
||||
grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png'))
|
||||
|
||||
return x_samples
|
||||
|
||||
|
||||
def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.):
|
||||
schedule = list()
|
||||
for i in range(num_steps):
|
||||
if float(i / num_steps) < 0.1:
|
||||
schedule.append(max_weight)
|
||||
elif i % 2 == 0:
|
||||
schedule.append(min_weight)
|
||||
else:
|
||||
schedule.append(max_weight)
|
||||
print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}")
|
||||
return schedule
|
||||
|
||||
|
||||
def torch2np(x):
|
||||
x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8)
|
||||
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
||||
def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
||||
state = dict()
|
||||
if not "model" in state:
|
||||
if version == "Stable unCLIP-L":
|
||||
config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
|
||||
ckpt = "checkpoints/sd21-unclip-l.ckpt"
|
||||
|
||||
elif version == "Stable unOpenCLIP-H":
|
||||
config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
|
||||
ckpt = "checkpoints/sd21-unclip-h.ckpt"
|
||||
|
||||
elif version == "Full Karlo":
|
||||
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
||||
st.info("Loading full KARLO..")
|
||||
karlo = T2ISampler.from_pretrained(
|
||||
root_dir="checkpoints/karlo_models",
|
||||
clip_model_path="ViT-L-14.pt",
|
||||
clip_stat_path="ViT-L-14_stats.th",
|
||||
sampling_type="default",
|
||||
)
|
||||
state["karlo_prior"] = karlo
|
||||
state["msg"] = "loaded full Karlo"
|
||||
return state
|
||||
else:
|
||||
raise ValueError(f"version {version} unknown!")
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
model, msg = load_model_from_config(config, ckpt, vae_sd=None)
|
||||
state["msg"] = msg
|
||||
|
||||
if load_karlo_prior:
|
||||
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
||||
st.info("Loading KARLO CLIP prior...")
|
||||
karlo_prior = PriorSampler.from_pretrained(
|
||||
root_dir="checkpoints/karlo_models",
|
||||
clip_model_path="ViT-L-14.pt",
|
||||
clip_stat_path="ViT-L-14_stats.th",
|
||||
sampling_type="default",
|
||||
)
|
||||
state["karlo_prior"] = karlo_prior
|
||||
state["model"] = model
|
||||
state["ckpt"] = ckpt
|
||||
state["config"] = config
|
||||
return state
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
msg = None
|
||||
if "global_step" in pl_sd:
|
||||
msg = f"This is global step {pl_sd['global_step']}. "
|
||||
if "model_ema.num_updates" in pl_sd["state_dict"]:
|
||||
msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates."
|
||||
global_step = pl_sd.get("global_step", "?")
|
||||
sd = pl_sd["state_dict"]
|
||||
if vae_sd is not None:
|
||||
for k in sd.keys():
|
||||
if "first_stage" in k:
|
||||
sd[k] = vae_sd[k[len("first_stage_model."):]]
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
print(f"Loaded global step {global_step}")
|
||||
return model, msg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable unCLIP")
|
||||
mode = "txt2img"
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
||||
state = init(version=version, load_karlo_prior=use_karlo_prior)
|
||||
prompt = st.text_input("Prompt", "a professional photograph")
|
||||
negative_prompt = st.text_input("Negative Prompt", "")
|
||||
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
||||
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
||||
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
||||
steps = st.sidebar.number_input("steps", value=20, min_value=1, max_value=1000)
|
||||
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
||||
force_full_precision = st.sidebar.checkbox("Force FP32", False) # TODO: check if/where things break.
|
||||
if version != "Full Karlo":
|
||||
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
|
||||
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
|
||||
C = VERSION2SPECS[version]["C"]
|
||||
f = VERSION2SPECS[version]["f"]
|
||||
|
||||
SAVE_PATH = os.path.join(SAVE_PATH, version)
|
||||
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
|
||||
|
||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
seed_everything(seed)
|
||||
|
||||
ucg_schedule = None
|
||||
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "DPM"], 0)
|
||||
if version == "Full Karlo":
|
||||
pass
|
||||
else:
|
||||
if sampler == "DPM":
|
||||
sampler = DPMSolverSampler(state["model"])
|
||||
elif sampler == "DDIM":
|
||||
sampler = DDIMSampler(state["model"])
|
||||
else:
|
||||
raise ValueError(f"unknown sampler {sampler}!")
|
||||
|
||||
adm_cond, adm_uc = None, None
|
||||
if use_karlo_prior:
|
||||
# uses the prior
|
||||
karlo_sampler = state["karlo_prior"]
|
||||
noise_level = None
|
||||
if state["model"].noise_augmentor is not None:
|
||||
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
||||
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
||||
with torch.no_grad():
|
||||
karlo_prediction = iter(
|
||||
karlo_sampler(
|
||||
prompt=prompt,
|
||||
bsz=number_cols,
|
||||
progressive_mode="final",
|
||||
)
|
||||
).__next__()
|
||||
adm_cond = karlo_prediction
|
||||
if noise_level is not None:
|
||||
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||
adm_uc = torch.zeros_like(adm_cond)
|
||||
elif version == "Full Karlo":
|
||||
pass
|
||||
else:
|
||||
num_inputs = st.number_input("Number of Input Images", 1)
|
||||
|
||||
|
||||
def make_conditionings_from_input(num=1, key=None):
|
||||
init_img = get_init_img(batch_size=number_cols, key=key)
|
||||
with torch.no_grad():
|
||||
adm_cond = state["model"].embedder(init_img)
|
||||
weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.)
|
||||
if state["model"].noise_augmentor is not None:
|
||||
noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0,
|
||||
max_value=state["model"].noise_augmentor.max_noise_level - 1,
|
||||
value=0, )
|
||||
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||
adm_uc = torch.zeros_like(adm_cond)
|
||||
return adm_cond, adm_uc, weight
|
||||
|
||||
|
||||
adm_inputs = list()
|
||||
weights = list()
|
||||
for n in range(num_inputs):
|
||||
adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n)
|
||||
weights.append(w)
|
||||
adm_inputs.append(adm_cond)
|
||||
adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights)
|
||||
if num_inputs > 1:
|
||||
if st.checkbox("Apply Noise to Embedding Mix", True):
|
||||
noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0,
|
||||
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=50, )
|
||||
c_adm, noise_level_emb = state["model"].noise_augmentor(
|
||||
adm_cond[:, :state["model"].noise_augmentor.time_embed.dim],
|
||||
noise_level=repeat(
|
||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||
|
||||
if st.button("Sample"):
|
||||
print("running prompt:", prompt)
|
||||
st.text("Sampling")
|
||||
t_progress = st.progress(0)
|
||||
result = st.empty()
|
||||
|
||||
|
||||
def t_callback(t):
|
||||
t_progress.progress(min((t + 1) / steps, 1.))
|
||||
|
||||
|
||||
if version == "Full Karlo":
|
||||
outputs = st.empty()
|
||||
karlo_sampler = state["karlo_prior"]
|
||||
all_samples = list()
|
||||
with torch.no_grad():
|
||||
for _ in range(number_rows):
|
||||
karlo_prediction = iter(
|
||||
karlo_sampler(
|
||||
prompt=prompt,
|
||||
bsz=number_cols,
|
||||
progressive_mode="final",
|
||||
)
|
||||
).__next__()
|
||||
all_samples.append(karlo_prediction)
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
|
||||
outputs.image(grid.cpu().numpy())
|
||||
|
||||
else:
|
||||
samples = sample(
|
||||
state["model"],
|
||||
prompt,
|
||||
n_runs=number_rows,
|
||||
n_samples=number_cols,
|
||||
H=H, W=W, C=C, f=f,
|
||||
scale=scale,
|
||||
ddim_steps=steps,
|
||||
ddim_eta=eta,
|
||||
callback=t_callback,
|
||||
ucg_schedule=ucg_schedule,
|
||||
negative_prompt=negative_prompt,
|
||||
adm_cond=adm_cond, adm_uc=adm_uc,
|
||||
use_full_precision=force_full_precision,
|
||||
only_adm_cond=False
|
||||
)
|
Loading…
Reference in a new issue