stable unclip finetune

This commit is contained in:
Robin Rombach 2023-01-14 13:48:28 +01:00
parent d55bcd4d31
commit 45287f9ed7
26 changed files with 4361 additions and 0 deletions

View file

@ -137,6 +137,42 @@ Note: The inference config for all model versions is designed to be used with EM
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
non-EMA to EMA weights. non-EMA to EMA weights.
### Stable Diffusion Meets Karlo
![upscaling-x4](assets/stable-samples/stable-unclip/panda.jpg)
_++++++ NOTE: preliminary checkpoint for internal testing ++++++_
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125) (also known as DALL·E 2).
We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1.
More precisely, we finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings.
This means that the model can be used to produce image variations in the style of unCLIP, but can also be combined with the
embedding prior of KARLO and directly decodes to 768x768 pixel resolution.
To run the model, first download the KARLO checkpoints
```shell
mkdir -p checkpoints/karlo_models
cd checkpoints/karlo_models
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
cd ../../
```
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
The, run
```
streamlit run scripts/streamlit/stablekarlo.py
```
The script optionally supports sampling from the full Karlo model. To do so, you need to download the 64x64 decoder and 64->256 upscaler
via
```shell
cd checkpoints/karlo_models
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt
cd ../../
```
### Image Modification with Stable Diffusion ### Image Modification with Stable Diffusion
![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png) ![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 171 KiB

View file

@ -0,0 +1,37 @@
model:
type: t2i-decoder
diffusion_sampler: uniform
hparams:
image_size: 64
num_channels: 320
num_res_blocks: 3
channel_mult: ''
attention_resolutions: 32,16,8
num_heads: -1
num_head_channels: 64
num_heads_upsample: -1
use_scale_shift_norm: true
dropout: 0.1
clip_dim: 768
clip_emb_mult: 4
text_ctx: 77
xf_width: 1536
xf_layers: 0
xf_heads: 0
xf_final_ln: false
resblock_updown: true
learn_sigma: true
text_drop: 0.3
clip_emb_type: image
clip_emb_drop: 0.1
use_plm: true
diffusion:
steps: 1000
learn_sigma: true
sigma_small: false
noise_schedule: squaredcos_cap_v2
use_kl: false
predict_xstart: false
rescale_learned_sigmas: true
timestep_respacing: ''

View file

@ -0,0 +1,27 @@
model:
type: improved_sr_64_256
diffusion_sampler: uniform
hparams:
channels: 320
depth: 3
channels_multiple:
- 1
- 2
- 3
- 4
dropout: 0.0
diffusion:
steps: 1000
learn_sigma: false
sigma_small: true
noise_schedule: squaredcos_cap_v2
use_kl: false
predict_xstart: false
rescale_learned_sigmas: true
timestep_respacing: '7'
sampling:
timestep_respacing: '7' # fix
clip_denoise: true

View file

@ -0,0 +1,21 @@
model:
type: prior
diffusion_sampler: uniform
hparams:
text_ctx: 77
xf_width: 2048
xf_layers: 20
xf_heads: 32
xf_final_ln: true
text_drop: 0.2
clip_dim: 768
diffusion:
steps: 1000
learn_sigma: false
sigma_small: true
noise_schedule: squaredcos_cap_v2
use_kl: false
predict_xstart: true
rescale_learned_sigmas: false
timestep_respacing: ''

View file

@ -0,0 +1,74 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
params:
embedding_dropout: 0.25
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 96
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn-adm
scale_factor: 0.18215
monitor: val/loss_simple_ema
embedder_config:
target: ldm.modules.encoders.modules.ClipImageEmbedder
params:
model: "ViT-L/14"
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
num_classes: "sequential"
adm_in_channels: 768
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: "softmax-xformers"
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View file

@ -1793,3 +1793,58 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
log = super().log_images(*args, **kwargs) log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
return log return log
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, freeze_embedder=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embed_key = embedding_key
self.embedding_dropout = embedding_dropout
self._init_embedder(embedder_config, freeze_embedder)
def _init_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config)
if freeze:
self.embedder = embedder.eval()
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
z, c = outputs[0], outputs[1]
img = batch[self.embed_key][:bs]
img = rearrange(img, 'b h w c -> b c h w')
c_adm = self.embedder(img)
if self.training:
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
device=c_adm.device)[:, None]) * c_adm
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
noutputs = [z, all_conds]
noutputs.extend(outputs[2:])
return noutputs
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, **kwargs):
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
return_original_cond=True)
log["inputs"] = x
log["reconstruction"] = xrec
assert self.model.conditioning_key is not None
assert self.cond_stage_key in ["caption", "txt"]
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
with ema_scope(f"Sampling"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_, )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log

View file

View file

@ -0,0 +1,512 @@
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import functional as F
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class UnCLIPPipeline(DiffusionPipeline):
"""
Pipeline for text-to-image generation using unCLIP
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
text_encoder ([`CLIPTextModelWithProjection`]):
Frozen text-encoder.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
prior ([`PriorTransformer`]):
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
text_proj ([`UnCLIPTextProjModel`]):
Utility class to prepare and combine the embeddings before they are passed to the decoder.
decoder ([`UNet2DConditionModel`]):
The decoder to invert the image embedding into an image.
super_res_first ([`UNet2DModel`]):
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
super_res_last ([`UNet2DModel`]):
Super resolution unet. Used in the last step of the super resolution diffusion process.
prior_scheduler ([`UnCLIPScheduler`]):
Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
decoder_scheduler ([`UnCLIPScheduler`]):
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
super_res_scheduler ([`UnCLIPScheduler`]):
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
"""
prior: PriorTransformer
decoder: UNet2DConditionModel
text_proj: UnCLIPTextProjModel
text_encoder: CLIPTextModelWithProjection
tokenizer: CLIPTokenizer
super_res_first: UNet2DModel
super_res_last: UNet2DModel
prior_scheduler: UnCLIPScheduler
decoder_scheduler: UnCLIPScheduler
super_res_scheduler: UnCLIPScheduler
def __init__(
self,
prior: PriorTransformer,
decoder: UNet2DConditionModel,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_proj: UnCLIPTextProjModel,
super_res_first: UNet2DModel,
super_res_last: UNet2DModel,
prior_scheduler: UnCLIPScheduler,
decoder_scheduler: UnCLIPScheduler,
super_res_scheduler: UnCLIPScheduler,
):
super().__init__()
self.register_modules(
prior=prior,
decoder=decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_proj=text_proj,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
):
if text_model_output is None:
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
# done duplicates
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
models = [
self.decoder,
self.text_proj,
self.text_encoder,
self.super_res_first,
self.super_res_last,
]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
return self.device
for module in self.decoder.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
prior_num_inference_steps: int = 25,
decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None,
prior_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
prior_guidance_scale: float = 4.0,
decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation. This can only be left undefined if
`text_model_output` and `text_attention_mask` is passed.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prior_num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
quality image at the expense of slower inference.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
Pre-generated noisy latents to be used as inputs for the prior.
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder.
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
Pre-generated noisy latents to be used as inputs for the decoder.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
text_model_output (`CLIPTextModelOutput`, *optional*):
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
can be passed for tasks like text embedding interpolations. Make sure to also pass
`text_attention_mask` in this case. `prompt` can the be left to `None`.
text_attention_mask (`torch.Tensor`, *optional*):
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
masks are necessary when passing `text_model_output`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
"""
if prompt is not None:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
else:
batch_size = text_model_output[0].shape[0]
device = self._execution_device
batch_size = batch_size * num_images_per_prompt
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
# prior
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
prior_timesteps_tensor = self.prior_scheduler.timesteps
embedding_dim = self.prior.config.embedding_dim
prior_latents = self.prepare_latents(
(batch_size, embedding_dim),
text_embeddings.dtype,
device,
generator,
prior_latents,
self.prior_scheduler,
)
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
predicted_image_embedding = self.prior(
latent_model_input,
timestep=t,
proj_embedding=text_embeddings,
encoder_hidden_states=text_encoder_hidden_states,
attention_mask=text_mask,
).predicted_image_embedding
if do_classifier_free_guidance:
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
predicted_image_embedding_text - predicted_image_embedding_uncond
)
if i + 1 == prior_timesteps_tensor.shape[0]:
prev_timestep = None
else:
prev_timestep = prior_timesteps_tensor[i + 1]
prior_latents = self.prior_scheduler.step(
predicted_image_embedding,
timestep=t,
sample=prior_latents,
generator=generator,
prev_timestep=prev_timestep,
).prev_sample
prior_latents = self.prior.post_process_latents(prior_latents)
image_embeddings = prior_latents
# done prior
# decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
device,
generator,
decoder_latents,
self.decoder_scheduler,
)
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
noise_pred = self.decoder(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=text_encoder_hidden_states,
class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask,
).sample
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if i + 1 == decoder_timesteps_tensor.shape[0]:
prev_timestep = None
else:
prev_timestep = decoder_timesteps_tensor[i + 1]
# compute the previous noisy sample x_t -> x_t-1
decoder_latents = self.decoder_scheduler.step(
noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
decoder_latents = decoder_latents.clamp(-1, 1)
image_small = decoder_latents
# done decoder
# super res
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
device,
generator,
super_res_latents,
self.super_res_scheduler,
)
interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True
image_upscaled = F.interpolate(
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
)
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
# no classifier free guidance
if i == super_res_timesteps_tensor.shape[0] - 1:
unet = self.super_res_last
else:
unet = self.super_res_first
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
noise_pred = unet(
sample=latent_model_input,
timestep=t,
).sample
if i + 1 == super_res_timesteps_tensor.shape[0]:
prev_timestep = None
else:
prev_timestep = super_res_timesteps_tensor[i + 1]
# compute the previous noisy sample x_t -> x_t-1
super_res_latents = self.super_res_scheduler.step(
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
image = super_res_latents
# done super res
# post processing
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View file

View file

@ -0,0 +1,182 @@
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------
# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
# ------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from clip.model import CLIP, convert_weights
from clip.simple_tokenizer import SimpleTokenizer, default_bpe
"""===== Monkey-Patching original CLIP for JIT compile ====="""
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = F.layer_norm(
x.type(torch.float32),
self.normalized_shape,
self.weight,
self.bias,
self.eps,
)
return ret.type(orig_type)
clip.model.LayerNorm = LayerNorm
delattr(clip.model.CLIP, "forward")
"""===== End of Monkey-Patching ====="""
class CustomizedCLIP(CLIP):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.jit.export
def encode_image(self, image):
return self.visual(image)
@torch.jit.export
def encode_text(self, text):
# re-define this function to return unpooled text features
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
x_seq = x
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x_out, x_seq
@torch.jit.ignore
def forward(self, image, text):
super().forward(image, text)
@classmethod
def load_from_checkpoint(cls, ckpt_path: str):
state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round(
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"visual.layer{b}")
)
)
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round(
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
)
vision_patch_size = None
assert (
output_width**2 + 1
== state_dict["visual.attnpool.positional_embedding"].shape[0]
)
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith("transformer.resblocks")
)
)
model = cls(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
model.eval()
model.float()
return model
class CustomizedTokenizer(SimpleTokenizer):
def __init__(self):
super().__init__(bpe_path=default_bpe())
self.sot_token = self.encoder["<|startoftext|>"]
self.eot_token = self.encoder["<|endoftext|>"]
def padded_tokens_and_mask(self, texts, text_ctx):
assert isinstance(texts, list) and all(
isinstance(elem, str) for elem in texts
), "texts should be a list of strings"
all_tokens = [
[self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
]
mask = [
[True] * min(text_ctx, len(tokens))
+ [False] * max(text_ctx - len(tokens), 0)
for tokens in all_tokens
]
mask = torch.tensor(mask, dtype=torch.bool)
result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > text_ctx:
tokens = tokens[:text_ctx]
tokens[-1] = self.eot_token
result[i, : len(tokens)] = torch.tensor(tokens)
return result, mask

View file

@ -0,0 +1,193 @@
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import copy
import torch
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
from ldm.modules.karlo.kakao.modules.unet import PLMImUNet
class Text2ImProgressiveModel(torch.nn.Module):
"""
A decoder that generates 64x64px images based on the text prompt.
:param config: yaml config to define the decoder.
:param tokenizer: tokenizer used in clip.
"""
def __init__(
self,
config,
tokenizer,
):
super().__init__()
self._conf = config
self._model_conf = config.model.hparams
self._diffusion_kwargs = dict(
steps=config.diffusion.steps,
learn_sigma=config.diffusion.learn_sigma,
sigma_small=config.diffusion.sigma_small,
noise_schedule=config.diffusion.noise_schedule,
use_kl=config.diffusion.use_kl,
predict_xstart=config.diffusion.predict_xstart,
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
timestep_respacing=config.diffusion.timestep_respacing,
)
self._tokenizer = tokenizer
self.model = self.create_plm_dec_model()
cf_token, cf_mask = self.set_cf_text_tensor()
self.register_buffer("cf_token", cf_token, persistent=False)
self.register_buffer("cf_mask", cf_mask, persistent=False)
@classmethod
def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model = cls(config, tokenizer)
model.load_state_dict(ckpt, strict=strict)
return model
def create_plm_dec_model(self):
image_size = self._model_conf.image_size
if self._model_conf.channel_mult == "":
if image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 128:
channel_mult = (1, 1, 2, 3, 4)
elif image_size == 64:
channel_mult = (1, 2, 3, 4)
else:
raise ValueError(f"unsupported image size: {image_size}")
else:
channel_mult = tuple(
int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
)
assert 2 ** (len(channel_mult) + 2) == image_size
attention_ds = []
for res in self._model_conf.attention_resolutions.split(","):
attention_ds.append(image_size // int(res))
return PLMImUNet(
text_ctx=self._model_conf.text_ctx,
xf_width=self._model_conf.xf_width,
in_channels=3,
model_channels=self._model_conf.num_channels,
out_channels=6 if self._model_conf.learn_sigma else 3,
num_res_blocks=self._model_conf.num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=self._model_conf.dropout,
channel_mult=channel_mult,
num_heads=self._model_conf.num_heads,
num_head_channels=self._model_conf.num_head_channels,
num_heads_upsample=self._model_conf.num_heads_upsample,
use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
resblock_updown=self._model_conf.resblock_updown,
clip_dim=self._model_conf.clip_dim,
clip_emb_mult=self._model_conf.clip_emb_mult,
clip_emb_type=self._model_conf.clip_emb_type,
clip_emb_drop=self._model_conf.clip_emb_drop,
)
def set_cf_text_tensor(self):
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
def get_sample_fn(self, timestep_respacing):
use_ddim = timestep_respacing.startswith(("ddim", "fast"))
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
sample_fn = (
diffusion.ddim_sample_loop_progressive
if use_ddim
else diffusion.p_sample_loop_progressive
)
return sample_fn
def forward(
self,
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat=None,
cf_guidance_scales=None,
timestep_respacing=None,
):
# cfg should be enabled in inference
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
assert img_feat is not None
bsz = txt_feat.shape[0]
img_sz = self._model_conf.image_size
def guided_model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
cond_eps - uncond_eps
)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
cf_feat = self.model.cf_param.unsqueeze(0)
cf_feat = cf_feat.expand(bsz // 2, -1)
feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)
cond = {
"y": feat,
"txt_feat": txt_feat,
"txt_feat_seq": txt_feat_seq,
"mask": mask,
}
sample_fn = self.get_sample_fn(timestep_respacing)
sample_outputs = sample_fn(
guided_model_fn,
(bsz, 3, img_sz, img_sz),
noise=None,
device=txt_feat.device,
clip_denoised=True,
model_kwargs=cond,
)
for out in sample_outputs:
sample = out["sample"]
yield sample if cf_guidance_scales is None else sample[
: sample.shape[0] // 2
]
class Text2ImModel(Text2ImProgressiveModel):
def forward(
self,
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat=None,
cf_guidance_scales=None,
timestep_respacing=None,
):
last_out = None
for out in super().forward(
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat,
cf_guidance_scales,
timestep_respacing,
):
last_out = out
return last_out

View file

@ -0,0 +1,138 @@
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import copy
import torch
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
from ldm.modules.karlo.kakao.modules.xf import PriorTransformer
class PriorDiffusionModel(torch.nn.Module):
"""
A prior that generates clip image feature based on the text prompt.
:param config: yaml config to define the decoder.
:param tokenizer: tokenizer used in clip.
:param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
:param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
"""
def __init__(self, config, tokenizer, clip_mean, clip_std):
super().__init__()
self._conf = config
self._model_conf = config.model.hparams
self._diffusion_kwargs = dict(
steps=config.diffusion.steps,
learn_sigma=config.diffusion.learn_sigma,
sigma_small=config.diffusion.sigma_small,
noise_schedule=config.diffusion.noise_schedule,
use_kl=config.diffusion.use_kl,
predict_xstart=config.diffusion.predict_xstart,
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
timestep_respacing=config.diffusion.timestep_respacing,
)
self._tokenizer = tokenizer
self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
self.register_buffer("clip_std", clip_std[None, :], persistent=False)
causal_mask = self.get_causal_mask()
self.register_buffer("causal_mask", causal_mask, persistent=False)
self.model = PriorTransformer(
text_ctx=self._model_conf.text_ctx,
xf_width=self._model_conf.xf_width,
xf_layers=self._model_conf.xf_layers,
xf_heads=self._model_conf.xf_heads,
xf_final_ln=self._model_conf.xf_final_ln,
clip_dim=self._model_conf.clip_dim,
)
cf_token, cf_mask = self.set_cf_text_tensor()
self.register_buffer("cf_token", cf_token, persistent=False)
self.register_buffer("cf_mask", cf_mask, persistent=False)
@classmethod
def load_from_checkpoint(
cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
):
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model = cls(config, tokenizer, clip_mean, clip_std)
model.load_state_dict(ckpt, strict=strict)
return model
def set_cf_text_tensor(self):
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
def get_sample_fn(self, timestep_respacing):
use_ddim = timestep_respacing.startswith(("ddim", "fast"))
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
return sample_fn
def get_causal_mask(self):
seq_len = self._model_conf.text_ctx + 4
mask = torch.empty(seq_len, seq_len)
mask.fill_(float("-inf"))
mask.triu_(1)
mask = mask[None, ...]
return mask
def forward(
self,
txt_feat,
txt_feat_seq,
mask,
cf_guidance_scales=None,
timestep_respacing=None,
denoised_fn=True,
):
# cfg should be enabled in inference
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
bsz_ = txt_feat.shape[0]
bsz = bsz_ // 2
def guided_model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = (
model_out[:, : int(x_t.shape[1])],
model_out[:, int(x_t.shape[1]) :],
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
cond_eps - uncond_eps
)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
cond = {
"text_emb": txt_feat,
"text_enc": txt_feat_seq,
"mask": mask,
"causal_mask": self.causal_mask,
}
sample_fn = self.get_sample_fn(timestep_respacing)
sample = sample_fn(
guided_model_fn,
(bsz_, self.model.clip_dim),
noise=None,
device=txt_feat.device,
clip_denoised=False,
denoised_fn=lambda x: torch.clamp(x, -10, 10),
model_kwargs=cond,
)
sample = (sample * self.clip_std) + self.clip_mean
return sample[:bsz]

View file

@ -0,0 +1,10 @@
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive
class SupRes256to1kProgressive(SupRes64to256Progressive):
pass # no difference currently

View file

@ -0,0 +1,88 @@
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import copy
import torch
from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module):
"""
ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
"""
def __init__(self, config):
super().__init__()
self._config = config
self._diffusion_kwargs = dict(
steps=config.diffusion.steps,
learn_sigma=config.diffusion.learn_sigma,
sigma_small=config.diffusion.sigma_small,
noise_schedule=config.diffusion.noise_schedule,
use_kl=config.diffusion.use_kl,
predict_xstart=config.diffusion.predict_xstart,
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
)
self.model_first_steps = SuperResUNetModel(
in_channels=3, # auto-changed to 6 inside the model
model_channels=config.model.hparams.channels,
out_channels=3,
num_res_blocks=config.model.hparams.depth,
attention_resolutions=(), # no attention
dropout=config.model.hparams.dropout,
channel_mult=config.model.hparams.channels_multiple,
resblock_updown=True,
use_middle_attention=False,
)
self.model_last_step = SuperResUNetModel(
in_channels=3, # auto-changed to 6 inside the model
model_channels=config.model.hparams.channels,
out_channels=3,
num_res_blocks=config.model.hparams.depth,
attention_resolutions=(), # no attention
dropout=config.model.hparams.dropout,
channel_mult=config.model.hparams.channels_multiple,
resblock_updown=True,
use_middle_attention=False,
)
@classmethod
def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True):
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model = cls(config)
model.load_state_dict(ckpt, strict=strict)
return model
def get_sample_fn(self, timestep_respacing):
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
return diffusion.p_sample_loop_progressive_for_improved_sr
def forward(self, low_res, timestep_respacing="7", **kwargs):
assert (
timestep_respacing == "7"
), "different respacing method may work, but no guaranteed"
sample_fn = self.get_sample_fn(timestep_respacing)
sample_outputs = sample_fn(
self.model_first_steps,
self.model_last_step,
shape=low_res.shape,
clip_denoised=True,
model_kwargs=dict(low_res=low_res),
**kwargs,
)
for x in sample_outputs:
sample = x["sample"]
yield sample

View file

@ -0,0 +1,49 @@
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
from .diffusion import gaussian_diffusion as gd
from .diffusion.respace import (
SpacedDiffusion,
space_timesteps,
)
def create_gaussian_diffusion(
steps,
learn_sigma,
sigma_small,
noise_schedule,
use_kl,
predict_xstart,
rescale_learned_sigmas,
timestep_respacing,
):
betas = gd.get_named_beta_schedule(noise_schedule, steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
)

View file

@ -0,0 +1,828 @@
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import enum
import math
import numpy as np
import torch as th
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(
beta_start, beta_end, warmup_time, dtype=np.float64
)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = (
enum.auto()
) # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
class GaussianDiffusion(th.nn.Module):
"""
Utilities for training and sampling diffusion models.
Original ported from this codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
"""
def __init__(
self,
*,
betas,
model_mean_type,
model_var_type,
loss_type,
):
super(GaussianDiffusion, self).__init__()
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0)
assert alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod)
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
posterior_log_variance_clipped = np.log(
np.append(posterior_variance[1], posterior_variance[1:])
)
posterior_mean_coef1 = (
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
)
self.register_buffer("betas", th.from_numpy(betas), persistent=False)
self.register_buffer(
"alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False
)
self.register_buffer(
"alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False
)
self.register_buffer(
"alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False
)
self.register_buffer(
"sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
th.from_numpy(sqrt_one_minus_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"log_one_minus_alphas_cumprod",
th.from_numpy(log_one_minus_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"sqrt_recip_alphas_cumprod",
th.from_numpy(sqrt_recip_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
th.from_numpy(sqrt_recipm1_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"posterior_variance", th.from_numpy(posterior_variance), persistent=False
)
self.register_buffer(
"posterior_log_variance_clipped",
th.from_numpy(posterior_log_variance_clipped),
persistent=False,
)
self.register_buffer(
"posterior_mean_coef1",
th.from_numpy(posterior_mean_coef1),
persistent=False,
)
self.register_buffer(
"posterior_mean_coef2",
th.from_numpy(posterior_mean_coef2),
persistent=False,
)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
**ignore_kwargs,
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = None
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
th.cat([self.posterior_variance[1][None], self.betas[1:]]),
th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else:
raise NotImplementedError(self.model_mean_type)
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
new_mean = (
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(
x_start=out["pred_xstart"], x_t=x, t=t
)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for idx, i in enumerate(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def p_sample_loop_progressive_for_improved_sr(
self,
model,
model_aux,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Modified version of p_sample_loop_progressive for sampling from the improved sr model
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for idx, i in enumerate(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model_aux if len(indices) - 1 == idx else model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out["sample"]
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = arr.to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)

View file

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
elif section_counts == "fast27":
steps = space_timesteps(num_timesteps, "10,10,3,2,2")
# Help reduce DDIM artifacts from noisiest timesteps.
steps.remove(num_timesteps - 1)
steps.add(num_timesteps - 3)
return steps
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
timestep_map = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
timestep_map.append(i)
kwargs["betas"] = th.tensor(new_betas).numpy()
super().__init__(**kwargs)
self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False)
def p_mean_variance(self, model, *args, **kwargs):
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
def wrapped(x, ts, **kwargs):
ts_cpu = ts.detach().to("cpu")
return model(
x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
)
return wrapped

View file

@ -0,0 +1,114 @@
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import math
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period)
* th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
/ half
)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))

View file

@ -0,0 +1,68 @@
# ------------------------------------------------------------------------------------
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
from abc import abstractmethod
import torch as th
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(th.nn.Module):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / th.sum(w)
indices = p.multinomial(batch_size, replacement=True)
weights = 1 / (len(p) * p[indices])
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
super(UniformSampler, self).__init__()
self.diffusion = diffusion
self.register_buffer(
"_weights", th.ones([diffusion.num_timesteps]), persistent=False
)
def weights(self):
return self._weights

View file

@ -0,0 +1,792 @@
# ------------------------------------------------------------------------------------
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import math
from abc import abstractmethod
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .nn import (
avg_pool_nd,
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
)
from .xf import LayerNorm
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, encoder_out=None, mask=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out, mask=mask)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(
self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0
),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class ResBlockNoTimeEmbedding(nn.Module):
"""
A residual block without time embedding
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
**kwargs,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=1.0),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb=None):
"""
Apply the block to a Tensor, NOT conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
assert emb is None
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
encoder_channels=None,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x, encoder_out=None, mask=None):
b, c, *spatial = x.shape
qkv = self.qkv(self.norm(x).view(b, c, -1))
if encoder_out is not None:
encoder_out = self.encoder_kv(encoder_out)
h = self.attention(qkv, encoder_out, mask=mask)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, encoder_kv=None, mask=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = th.cat([ek, k], dim=-1)
v = th.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum("bct,bcs->bts", q * scale, k * scale)
if mask is not None:
mask = F.pad(mask, (0, length), value=0.0)
mask = (
mask.unsqueeze(1)
.expand(-1, self.n_heads, -1)
.reshape(bs * self.n_heads, 1, -1)
)
weight = weight + mask
weight = th.softmax(weight, dim=-1)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param clip_dim: dimension of clip feature.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param encoder_channels: use to make the dimension of query and kv same in AttentionBlock.
:param use_time_embedding: use time embedding for condition.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
clip_dim=None,
use_checkpoint=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
use_middle_attention=True,
resblock_updown=False,
encoder_channels=None,
use_time_embedding=True,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.clip_dim = clip_dim
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_middle_attention = use_middle_attention
self.use_time_embedding = use_time_embedding
if self.use_time_embedding:
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.clip_dim is not None:
self.clip_emb = nn.Linear(clip_dim, time_embed_dim)
else:
time_embed_dim = None
CustomResidualBlock = (
ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding
)
ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
*(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
),
)
if self.use_middle_attention
else tuple(), # add AttentionBlock or not
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
CustomResidualBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch, swish=1.0),
nn.Identity(),
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
)
def forward(self, x, timesteps, y=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.clip_dim is not None
), "must specify y if and only if the model is clip-rep-conditional"
hs = []
if self.use_time_embedding:
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.clip_dim is not None:
emb = emb + self.clip_emb(y)
else:
emb = None
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb)
return self.out(h)
class SuperResUNetModel(UNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
Assumes that the shape of low-resolution and the input should be the same.
"""
def __init__(self, *args, **kwargs):
if "in_channels" in kwargs:
kwargs = dict(kwargs)
kwargs["in_channels"] = kwargs["in_channels"] * 2
else:
# Curse you, Python. Or really, just curse positional arguments :|.
args = list(args)
args[1] = args[1] * 2
super().__init__(*args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
assert new_height == low_res.shape[2] and new_width == low_res.shape[3]
x = th.cat([x, low_res], dim=1)
return super().forward(x, timesteps, **kwargs)
class PLMImUNet(UNetModel):
"""
A UNetModel that conditions on text with a pretrained text encoder in CLIP.
:param text_ctx: number of text tokens to expect.
:param xf_width: width of the transformer.
:param clip_emb_mult: #extra tokens by projecting clip text feature.
:param clip_emb_type: type of condition (here, we fix clip image feature).
:param clip_emb_drop: dropout rato of clip image feature for cfg.
"""
def __init__(
self,
text_ctx,
xf_width,
*args,
clip_emb_mult=None,
clip_emb_type="image",
clip_emb_drop=0.0,
**kwargs,
):
self.text_ctx = text_ctx
self.xf_width = xf_width
self.clip_emb_mult = clip_emb_mult
self.clip_emb_type = clip_emb_type
self.clip_emb_drop = clip_emb_drop
if not xf_width:
super().__init__(*args, **kwargs, encoder_channels=None)
else:
super().__init__(*args, **kwargs, encoder_channels=xf_width)
# Project text encoded feat seq from pre-trained text encoder in CLIP
self.text_seq_proj = nn.Sequential(
nn.Linear(self.clip_dim, xf_width),
LayerNorm(xf_width),
)
# Project CLIP text feat
self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4)
assert clip_emb_mult is not None
assert clip_emb_type == "image"
assert self.clip_dim is not None, "CLIP representation dim should be specified"
self.clip_tok_proj = nn.Linear(
self.clip_dim, self.xf_width * self.clip_emb_mult
)
if self.clip_emb_drop > 0:
self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32))
def proc_clip_emb_drop(self, feat):
if self.clip_emb_drop > 0:
bsz, feat_dim = feat.shape
assert (
feat_dim == self.clip_dim
), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}"
drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop
feat = th.where(
drop_idx[..., None], self.cf_param[None].type_as(feat), feat
)
return feat
def forward(
self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None
):
bsz = x.shape[0]
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = emb + self.clip_emb(y)
xf_out = self.text_seq_proj(txt_feat_seq)
xf_out = xf_out.permute(0, 2, 1)
emb = emb + self.text_feat_proj(txt_feat)
xf_out = th.cat(
[
self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
xf_out,
],
dim=2,
)
mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
mask = th.where(mask, 0.0, float("-inf"))
h = x
for module in self.input_blocks:
h = module(h, emb, xf_out, mask=mask)
hs.append(h)
h = self.middle_block(h, emb, xf_out, mask=mask)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, xf_out, mask=mask)
h = self.out(h)
return h

View file

@ -0,0 +1,231 @@
# ------------------------------------------------------------------------------------
# Adapted from the repos below:
# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
# (b) CLIP ViT (https://github.com/openai/CLIP/)
# ------------------------------------------------------------------------------------
import math
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .nn import timestep_embedding
def convert_module_to_f16(param):
"""
Convert primitive modules to float16.
"""
if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
param.weight.data = param.weight.data.half()
if param.bias is not None:
param.bias.data = param.bias.data.half()
class LayerNorm(nn.LayerNorm):
"""
Implementation that supports fp16 inputs but fp32 gains/biases.
"""
def forward(self, x: th.Tensor):
return super().forward(x.float()).to(x.dtype)
class MultiheadAttention(nn.Module):
def __init__(self, n_ctx, width, heads):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadAttention(heads, n_ctx)
def forward(self, x, mask=None):
x = self.c_qkv(x)
x = self.attention(x, mask=mask)
x = self.c_proj(x)
return x
class MLP(nn.Module):
def __init__(self, width):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4)
self.c_proj = nn.Linear(width * 4, width)
self.gelu = nn.GELU()
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class QKVMultiheadAttention(nn.Module):
def __init__(self, n_heads: int, n_ctx: int):
super().__init__()
self.n_heads = n_heads
self.n_ctx = n_ctx
def forward(self, qkv, mask=None):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.n_heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
q, k, v = th.split(qkv, attn_ch, dim=-1)
weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale)
wdtype = weight.dtype
if mask is not None:
weight = weight + mask[:, None, ...]
weight = th.softmax(weight, dim=-1).type(wdtype)
return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
heads: int,
):
super().__init__()
self.attn = MultiheadAttention(
n_ctx,
width,
heads,
)
self.ln_1 = LayerNorm(width)
self.mlp = MLP(width)
self.ln_2 = LayerNorm(width)
def forward(self, x, mask=None):
x = x + self.attn(self.ln_1(x), mask=mask)
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
layers: int,
heads: int,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
n_ctx,
width,
heads,
)
for _ in range(layers)
]
)
def forward(self, x, mask=None):
for block in self.resblocks:
x = block(x, mask=mask)
return x
class PriorTransformer(nn.Module):
"""
A Causal Transformer that conditions on CLIP text embedding, text.
:param text_ctx: number of text tokens to expect.
:param xf_width: width of the transformer.
:param xf_layers: depth of the transformer.
:param xf_heads: heads in the transformer.
:param xf_final_ln: use a LayerNorm after the output layer.
:param clip_dim: dimension of clip feature.
"""
def __init__(
self,
text_ctx,
xf_width,
xf_layers,
xf_heads,
xf_final_ln,
clip_dim,
):
super().__init__()
self.text_ctx = text_ctx
self.xf_width = xf_width
self.xf_layers = xf_layers
self.xf_heads = xf_heads
self.clip_dim = clip_dim
self.ext_len = 4
self.time_embed = nn.Sequential(
nn.Linear(xf_width, xf_width),
nn.SiLU(),
nn.Linear(xf_width, xf_width),
)
self.text_enc_proj = nn.Linear(clip_dim, xf_width)
self.text_emb_proj = nn.Linear(clip_dim, xf_width)
self.clip_img_proj = nn.Linear(clip_dim, xf_width)
self.out_proj = nn.Linear(xf_width, clip_dim)
self.transformer = Transformer(
text_ctx + self.ext_len,
xf_width,
xf_layers,
xf_heads,
)
if xf_final_ln:
self.final_ln = LayerNorm(xf_width)
else:
self.final_ln = None
self.positional_embedding = nn.Parameter(
th.empty(1, text_ctx + self.ext_len, xf_width)
)
self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
nn.init.normal_(self.prd_emb, std=0.01)
nn.init.normal_(self.positional_embedding, std=0.01)
def forward(
self,
x,
timesteps,
text_emb=None,
text_enc=None,
mask=None,
causal_mask=None,
):
bsz = x.shape[0]
mask = F.pad(mask, (0, self.ext_len), value=True)
t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width))
text_enc = self.text_enc_proj(text_enc)
text_emb = self.text_emb_proj(text_emb)
x = self.clip_img_proj(x)
input_seq = [
text_enc,
text_emb[:, None, :],
t_emb[:, None, :],
x[:, None, :],
self.prd_emb.to(x.dtype).expand(bsz, -1, -1),
]
input = th.cat(input_seq, dim=1)
input = input + self.positional_embedding.to(input.dtype)
mask = th.where(mask, 0.0, float("-inf"))
mask = (mask[:, None, :] + causal_mask).to(input.dtype)
out = self.transformer(input, mask=mask)
if self.final_ln is not None:
out = self.final_ln(out)
out = self.out_proj(out[:, -1])
return out

View file

@ -0,0 +1,272 @@
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15
# ------------------------------------------------------------------------------------
from typing import Iterator
import torch
import torchvision.transforms.functional as TVF
from torchvision.transforms import InterpolationMode
from .template import BaseSampler, CKPT_PATH
class T2ISampler(BaseSampler):
"""
A sampler for text-to-image generation.
:param root_dir: directory for model checkpoints.
:param sampling_type: ["default", "fast"]
"""
def __init__(
self,
root_dir: str,
sampling_type: str = "default",
):
super().__init__(root_dir, sampling_type)
@classmethod
def from_pretrained(
cls,
root_dir: str,
clip_model_path: str,
clip_stat_path: str,
sampling_type: str = "default",
):
model = cls(
root_dir=root_dir,
sampling_type=sampling_type,
)
model.load_clip(clip_model_path)
model.load_prior(
f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path,
prior_config="configs/karlo/prior_1B_vit_l.yaml"
)
model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml")
model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml")
return model
def preprocess(
self,
prompt: str,
bsz: int,
):
"""Setup prompts & cfg scales"""
prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
""" Get CLIP text feature """
clip_model = self._clip
tokenizer = self._tokenizer
max_txt_length = self._prior.model.text_ctx
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
if not (cf_token.shape == tok.shape):
cf_token = cf_token.expand(tok.shape[0], -1)
cf_mask = cf_mask.expand(tok.shape[0], -1)
tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return (
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
)
def __call__(
self,
prompt: str,
bsz: int,
progressive_mode=None,
) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast():
(
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
) = self.preprocess(
prompt,
bsz,
)
""" Transform CLIP text feature into image feature """
img_feat = self._prior(
txt_feat,
txt_feat_seq,
mask,
prior_cf_scales_batch,
timestep_respacing=self._prior_sm,
)
""" Generate 64x64px images """
images_64_outputs = self._decoder(
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat,
cf_guidance_scales=decoder_cf_scales_batch,
timestep_respacing=self._decoder_sm,
)
images_64 = None
for k, out in enumerate(images_64_outputs):
images_64 = out
if progressive_mode == "loop":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
if progressive_mode == "stage":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
images_64 = torch.clamp(images_64, -1, 1)
""" Upsample 64x64 to 256x256 """
images_256 = TVF.resize(
images_64,
[256, 256],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
)
images_256_outputs = self._sr_64_256(
images_256, timestep_respacing=self._sr_sm
)
for k, out in enumerate(images_256_outputs):
images_256 = out
if progressive_mode == "loop":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
if progressive_mode == "stage":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
class PriorSampler(BaseSampler):
"""
A sampler for text-to-image generation, but only the prior.
:param root_dir: directory for model checkpoints.
:param sampling_type: ["default", "fast"]
"""
def __init__(
self,
root_dir: str,
sampling_type: str = "default",
):
super().__init__(root_dir, sampling_type)
@classmethod
def from_pretrained(
cls,
root_dir: str,
clip_model_path: str,
clip_stat_path: str,
sampling_type: str = "default",
):
model = cls(
root_dir=root_dir,
sampling_type=sampling_type,
)
model.load_clip(clip_model_path)
model.load_prior(
f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path,
prior_config="configs/karlo/prior_1B_vit_l.yaml"
)
return model
def preprocess(
self,
prompt: str,
bsz: int,
):
"""Setup prompts & cfg scales"""
prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
""" Get CLIP text feature """
clip_model = self._clip
tokenizer = self._tokenizer
max_txt_length = self._prior.model.text_ctx
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
if not (cf_token.shape == tok.shape):
cf_token = cf_token.expand(tok.shape[0], -1)
cf_mask = cf_mask.expand(tok.shape[0], -1)
tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return (
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
)
def __call__(
self,
prompt: str,
bsz: int,
progressive_mode=None,
) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast():
(
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
) = self.preprocess(
prompt,
bsz,
)
""" Transform CLIP text feature into image feature """
img_feat = self._prior(
txt_feat,
txt_feat_seq,
mask,
prior_cf_scales_batch,
timestep_respacing=self._prior_sm,
)
yield img_feat

View file

@ -0,0 +1,141 @@
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import os
import logging
import torch
from omegaconf import OmegaConf
from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
SAMPLING_CONF = {
"default": {
"prior_sm": "25",
"prior_n_samples": 1,
"prior_cf_scale": 4.0,
"decoder_sm": "50",
"decoder_cf_scale": 8.0,
"sr_sm": "7",
},
"fast": {
"prior_sm": "25",
"prior_n_samples": 1,
"prior_cf_scale": 4.0,
"decoder_sm": "25",
"decoder_cf_scale": 8.0,
"sr_sm": "7",
},
}
CKPT_PATH = {
"prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
"decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
"sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
}
class BaseSampler:
_PRIOR_CLASS = PriorDiffusionModel
_DECODER_CLASS = Text2ImProgressiveModel
_SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
def __init__(
self,
root_dir: str,
sampling_type: str = "fast",
):
self._root_dir = root_dir
sampling_type = SAMPLING_CONF[sampling_type]
self._prior_sm = sampling_type["prior_sm"]
self._prior_n_samples = sampling_type["prior_n_samples"]
self._prior_cf_scale = sampling_type["prior_cf_scale"]
assert self._prior_n_samples == 1
self._decoder_sm = sampling_type["decoder_sm"]
self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
self._sr_sm = sampling_type["sr_sm"]
def __repr__(self):
line = ""
line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
line += f"SR(64->256), sampling method: {self._sr_sm}"
return line
def load_clip(self, clip_path: str):
clip = CustomizedCLIP.load_from_checkpoint(
os.path.join(self._root_dir, clip_path)
)
clip = torch.jit.script(clip)
clip.cuda()
clip.eval()
self._clip = clip
self._tokenizer = CustomizedTokenizer()
def load_prior(
self,
ckpt_path: str,
clip_stat_path: str,
prior_config: str = "configs/prior_1B_vit_l.yaml"
):
logging.info(f"Loading prior: {ckpt_path}")
config = OmegaConf.load(prior_config)
clip_mean, clip_std = torch.load(
os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
)
prior = self._PRIOR_CLASS.load_from_checkpoint(
config,
self._tokenizer,
clip_mean,
clip_std,
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
prior.cuda()
prior.eval()
logging.info("done.")
self._prior = prior
def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
logging.info(f"Loading decoder: {ckpt_path}")
config = OmegaConf.load(decoder_config)
decoder = self._DECODER_CLASS.load_from_checkpoint(
config,
self._tokenizer,
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
decoder.cuda()
decoder.eval()
logging.info("done.")
self._decoder = decoder
def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
logging.info(f"Loading SR(64->256): {ckpt_path}")
config = OmegaConf.load(sr_config)
sr = self._SR256_CLASS.load_from_checkpoint(
config, os.path.join(self._root_dir, ckpt_path), strict=True
)
sr.cuda()
sr.eval()
logging.info("done.")
self._sr_64_256 = sr

View file

@ -0,0 +1,381 @@
import importlib
import streamlit as st
import torch
import cv2
import numpy as np
import PIL
from omegaconf import OmegaConf
from PIL import Image
from tqdm import trange
import io, os
from torch import autocast
from einops import rearrange, repeat
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False)
PROMPTS_ROOT = "scripts/prompts/"
SAVE_PATH = "outputs/demo/stable-karlo/"
VERSION2SPECS = {
"Stable Karlo": {"H": 768, "W": 768, "C": 4, "f": 8},
"Full Karlo": {}
}
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_interactive_image():
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def load_img(display=True):
image = get_interactive_image()
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
w, h = map(lambda x: x - x % 64, (w, h))
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2. * image - 1.
def get_init_img(batch_size=1):
init_image = load_img().cuda()
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
return init_image
def sample(
model,
prompt,
n_runs=3,
n_samples=2,
H=512,
W=512,
C=4,
f=8,
scale=10.0,
ddim_steps=50,
ddim_eta=0.0,
callback=None,
skip_single_save=False,
save_grid=True,
ucg_schedule=None,
negative_prompt="",
adm_cond=None,
adm_uc=None,
use_full_precision=False,
only_adm_cond=False
):
batch_size = n_samples
precision_scope = autocast if not use_full_precision else nullcontext
if use_full_precision: st.warning(f"Sampling {model.__class__.__name__} at full precision.")
if isinstance(prompt, str):
prompt = [prompt]
prompts = batch_size * prompt
outputs = st.empty()
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(n_runs, desc="Sampling"):
shape = [C, H // f, W // f]
if not only_adm_cond:
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
if isinstance(model, Txt2ImgDiffusionWithPooledInput):
c, uc = c[0], uc[0]
if adm_cond is not None:
if adm_cond.shape[0] == 1:
adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size)
if adm_uc is None:
st.warning("Not guiding via c_adm")
adm_uc = adm_cond
else:
if adm_uc.shape[0] == 1:
adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size)
if not only_adm_cond:
c = {"c_crossattn": [c], "c_adm": adm_cond}
uc = {"c_crossattn": [uc], "c_adm": adm_uc}
else:
c = adm_cond
uc = adm_uc
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=None,
callback=callback,
ucg_schedule=ucg_schedule
)
x_samples = model.decode_first_stage(samples_ddim)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not skip_single_save:
base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples")))
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png"))
base_count += 1
all_samples.append(x_samples)
# get grid of all samples
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
outputs.image(grid.cpu().numpy())
# additionally, save grid
grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8))
if save_grid:
grid_count = len(os.listdir(SAVE_PATH)) - 1
grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png'))
return x_samples
def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.):
schedule = list()
for i in range(num_steps):
if float(i / num_steps) < 0.1:
schedule.append(max_weight)
elif i % 2 == 0:
schedule.append(min_weight)
else:
schedule.append(max_weight)
print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}")
return schedule
def torch2np(x):
x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8)
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
return x
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def init(version="Stable Karlo", load_karlo_prior=False):
state = dict()
if not "model" in state:
if version == "Stable Karlo":
config = "configs/stable-diffusion/v2-1-stable-karlo-inference.yaml"
ckpt = "checkpoints/v2-1-stable-unclip-ft.ckpt"
elif version == "Full Karlo":
from ldm.modules.karlo.kakao.sampler import T2ISampler
st.info("Loading full KARLO..")
karlo = T2ISampler.from_pretrained(
root_dir="checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th",
sampling_type="default",
)
state["karlo_prior"] = karlo
state["msg"] = "loaded full Karlo"
return state
else:
raise ValueError(f"version {version} unknown!")
config = OmegaConf.load(config)
model, msg = load_model_from_config(config, ckpt, vae_sd=None)
state["msg"] = msg
if load_karlo_prior:
from ldm.modules.karlo.kakao.sampler import PriorSampler
st.info("Loading KARLO CLIP prior...")
karlo_prior = PriorSampler.from_pretrained(
root_dir="/fsx/robin/checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th",
sampling_type="default",
)
state["karlo_prior"] = karlo_prior
state["model"] = model
state["ckpt"] = ckpt
state["config"] = config
return state
def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
msg = None
if "global_step" in pl_sd:
msg = f"This is global step {pl_sd['global_step']}. "
if "model_ema.num_updates" in pl_sd["state_dict"]:
msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates."
global_step = pl_sd.get("global_step", "?")
sd = pl_sd["state_dict"]
if vae_sd is not None:
for k in sd.keys():
if "first_stage" in k:
sd[k] = vae_sd[k[len("first_stage_model."):]]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
print(f"Loaded global step {global_step}")
return model, msg
if __name__ == "__main__":
st.title("Stable Karlo")
mode = "txt2img"
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
use_karlo = st.checkbox("Use KARLO prior", False)
state = init(version=version, vae_version=vae_version, load_karlo_prior=use_karlo)
st.info(state["msg"])
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
negative_prompt = st.text_input("Negative Prompt", "")
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
default_steps = 25
steps = st.sidebar.number_input("steps", value=default_steps, min_value=1, max_value=1000)
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
force_full_precision = st.sidebar.checkbox("Force FP32", False)
if version != "Full Karlo":
H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048)
W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048)
C = VERSION2SPECS[version]["C"]
f = VERSION2SPECS[version]["f"]
SAVE_PATH = os.path.join(SAVE_PATH, version + "_" + vae_version + "-decoder")
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed)
ucg_schedule = None
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 2)
if version == "Full Karlo":
pass
else:
if sampler == "PLMS":
st.warning("NOTE: Some models (such as v-pred) currently only support DDIM/DPM sampling here")
sampler = PLMSSampler(state["model"])
elif sampler == "DPM":
st.warning("NOTE: Using DPM sampler with default sampling parameters (DPM-2)")
sampler = DPMSolverSampler(state["model"])
elif sampler == "DDIM":
sampler = DDIMSampler(state["model"])
if st.checkbox("Try oscillating guidance?", False):
ucg_schedule = make_oscillating_guidance_schedule(num_steps=steps, max_weight=scale, min_weight=1.)
else:
raise ValueError(f"unknown sampler {sampler}!")
adm_cond, adm_uc = None, None
if use_karlo:
# uses the prior
karlo_sampler = state["karlo_prior"]
with torch.no_grad():
karlo_prediction = iter(
karlo_sampler(
prompt=prompt,
bsz=number_cols,
progressive_mode="final",
)
).__next__()
adm_cond = karlo_prediction
adm_uc = torch.zeros_like(karlo_prediction)
else:
init_img = get_init_img(batch_size=number_cols)
with torch.no_grad():
adm_cond = state["model"].embedder(init_img)
adm_uc = torch.zeros_like(adm_cond)
if st.button("Sample"):
print("running prompt:", prompt)
st.text("Sampling")
t_progress = st.progress(0)
result = st.empty()
def t_callback(t):
t_progress.progress(min((t + 1) / steps, 1.))
if version == "KARLO":
outputs = st.empty()
karlo_sampler = state["karlo_prior"]
all_samples = list()
with torch.no_grad():
for _ in range(number_rows):
karlo_prediction = iter(
karlo_sampler(
prompt=prompt,
bsz=number_cols,
progressive_mode="final",
)
).__next__()
all_samples.append(karlo_prediction)
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n h) (b w) c')
outputs.image(grid.cpu().numpy())
else:
samples = sample(
state["model"],
prompt,
n_runs=number_rows,
n_samples=number_cols,
H=H, W=W, C=C, f=f,
scale=scale,
ddim_steps=steps,
ddim_eta=eta,
callback=t_callback,
ucg_schedule=ucg_schedule,
negative_prompt=negative_prompt,
adm_cond=adm_cond, adm_uc=adm_uc,
use_full_precision=force_full_precision,
only_adm_cond=False
)