mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 23:55:00 +00:00
272 lines
8.5 KiB
Python
272 lines
8.5 KiB
Python
# ------------------------------------------------------------------------------------
|
|
# Karlo-v1.0.alpha
|
|
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
|
|
|
# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15
|
|
# ------------------------------------------------------------------------------------
|
|
|
|
from typing import Iterator
|
|
|
|
import torch
|
|
import torchvision.transforms.functional as TVF
|
|
from torchvision.transforms import InterpolationMode
|
|
|
|
from .template import BaseSampler, CKPT_PATH
|
|
|
|
|
|
class T2ISampler(BaseSampler):
|
|
"""
|
|
A sampler for text-to-image generation.
|
|
:param root_dir: directory for model checkpoints.
|
|
:param sampling_type: ["default", "fast"]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root_dir: str,
|
|
sampling_type: str = "default",
|
|
):
|
|
super().__init__(root_dir, sampling_type)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
root_dir: str,
|
|
clip_model_path: str,
|
|
clip_stat_path: str,
|
|
sampling_type: str = "default",
|
|
):
|
|
|
|
model = cls(
|
|
root_dir=root_dir,
|
|
sampling_type=sampling_type,
|
|
)
|
|
model.load_clip(clip_model_path)
|
|
model.load_prior(
|
|
f"{CKPT_PATH['prior']}",
|
|
clip_stat_path=clip_stat_path,
|
|
prior_config="configs/karlo/prior_1B_vit_l.yaml"
|
|
)
|
|
model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml")
|
|
model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml")
|
|
return model
|
|
|
|
def preprocess(
|
|
self,
|
|
prompt: str,
|
|
bsz: int,
|
|
):
|
|
"""Setup prompts & cfg scales"""
|
|
prompts_batch = [prompt for _ in range(bsz)]
|
|
|
|
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
|
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
|
|
|
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
|
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
|
|
|
""" Get CLIP text feature """
|
|
clip_model = self._clip
|
|
tokenizer = self._tokenizer
|
|
max_txt_length = self._prior.model.text_ctx
|
|
|
|
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
|
|
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
|
|
if not (cf_token.shape == tok.shape):
|
|
cf_token = cf_token.expand(tok.shape[0], -1)
|
|
cf_mask = cf_mask.expand(tok.shape[0], -1)
|
|
|
|
tok = torch.cat([tok, cf_token], dim=0)
|
|
mask = torch.cat([mask, cf_mask], dim=0)
|
|
|
|
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
|
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
|
|
|
return (
|
|
prompts_batch,
|
|
prior_cf_scales_batch,
|
|
decoder_cf_scales_batch,
|
|
txt_feat,
|
|
txt_feat_seq,
|
|
tok,
|
|
mask,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
prompt: str,
|
|
bsz: int,
|
|
progressive_mode=None,
|
|
) -> Iterator[torch.Tensor]:
|
|
assert progressive_mode in ("loop", "stage", "final")
|
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
|
(
|
|
prompts_batch,
|
|
prior_cf_scales_batch,
|
|
decoder_cf_scales_batch,
|
|
txt_feat,
|
|
txt_feat_seq,
|
|
tok,
|
|
mask,
|
|
) = self.preprocess(
|
|
prompt,
|
|
bsz,
|
|
)
|
|
|
|
""" Transform CLIP text feature into image feature """
|
|
img_feat = self._prior(
|
|
txt_feat,
|
|
txt_feat_seq,
|
|
mask,
|
|
prior_cf_scales_batch,
|
|
timestep_respacing=self._prior_sm,
|
|
)
|
|
|
|
""" Generate 64x64px images """
|
|
images_64_outputs = self._decoder(
|
|
txt_feat,
|
|
txt_feat_seq,
|
|
tok,
|
|
mask,
|
|
img_feat,
|
|
cf_guidance_scales=decoder_cf_scales_batch,
|
|
timestep_respacing=self._decoder_sm,
|
|
)
|
|
|
|
images_64 = None
|
|
for k, out in enumerate(images_64_outputs):
|
|
images_64 = out
|
|
if progressive_mode == "loop":
|
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
|
if progressive_mode == "stage":
|
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
|
|
|
images_64 = torch.clamp(images_64, -1, 1)
|
|
|
|
""" Upsample 64x64 to 256x256 """
|
|
images_256 = TVF.resize(
|
|
images_64,
|
|
[256, 256],
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
antialias=True,
|
|
)
|
|
images_256_outputs = self._sr_64_256(
|
|
images_256, timestep_respacing=self._sr_sm
|
|
)
|
|
|
|
for k, out in enumerate(images_256_outputs):
|
|
images_256 = out
|
|
if progressive_mode == "loop":
|
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
|
if progressive_mode == "stage":
|
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
|
|
|
|
yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
|
|
|
|
|
|
class PriorSampler(BaseSampler):
|
|
"""
|
|
A sampler for text-to-image generation, but only the prior.
|
|
:param root_dir: directory for model checkpoints.
|
|
:param sampling_type: ["default", "fast"]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root_dir: str,
|
|
sampling_type: str = "default",
|
|
):
|
|
super().__init__(root_dir, sampling_type)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
root_dir: str,
|
|
clip_model_path: str,
|
|
clip_stat_path: str,
|
|
sampling_type: str = "default",
|
|
):
|
|
model = cls(
|
|
root_dir=root_dir,
|
|
sampling_type=sampling_type,
|
|
)
|
|
model.load_clip(clip_model_path)
|
|
model.load_prior(
|
|
f"{CKPT_PATH['prior']}",
|
|
clip_stat_path=clip_stat_path,
|
|
prior_config="configs/karlo/prior_1B_vit_l.yaml"
|
|
)
|
|
return model
|
|
|
|
def preprocess(
|
|
self,
|
|
prompt: str,
|
|
bsz: int,
|
|
):
|
|
"""Setup prompts & cfg scales"""
|
|
prompts_batch = [prompt for _ in range(bsz)]
|
|
|
|
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
|
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
|
|
|
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
|
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
|
|
|
""" Get CLIP text feature """
|
|
clip_model = self._clip
|
|
tokenizer = self._tokenizer
|
|
max_txt_length = self._prior.model.text_ctx
|
|
|
|
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
|
|
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
|
|
if not (cf_token.shape == tok.shape):
|
|
cf_token = cf_token.expand(tok.shape[0], -1)
|
|
cf_mask = cf_mask.expand(tok.shape[0], -1)
|
|
|
|
tok = torch.cat([tok, cf_token], dim=0)
|
|
mask = torch.cat([mask, cf_mask], dim=0)
|
|
|
|
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
|
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
|
|
|
return (
|
|
prompts_batch,
|
|
prior_cf_scales_batch,
|
|
decoder_cf_scales_batch,
|
|
txt_feat,
|
|
txt_feat_seq,
|
|
tok,
|
|
mask,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
prompt: str,
|
|
bsz: int,
|
|
progressive_mode=None,
|
|
) -> Iterator[torch.Tensor]:
|
|
assert progressive_mode in ("loop", "stage", "final")
|
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
|
(
|
|
prompts_batch,
|
|
prior_cf_scales_batch,
|
|
decoder_cf_scales_batch,
|
|
txt_feat,
|
|
txt_feat_seq,
|
|
tok,
|
|
mask,
|
|
) = self.preprocess(
|
|
prompt,
|
|
bsz,
|
|
)
|
|
|
|
""" Transform CLIP text feature into image feature """
|
|
img_feat = self._prior(
|
|
txt_feat,
|
|
txt_feat_seq,
|
|
mask,
|
|
prior_cf_scales_batch,
|
|
timestep_respacing=self._prior_sm,
|
|
)
|
|
|
|
yield img_feat
|