Support gradio and streamlit

This commit is contained in:
Ftps 2023-04-02 21:00:29 +09:00
parent d6933311e7
commit 90d4c71350
13 changed files with 154 additions and 68 deletions

View file

@ -1 +0,0 @@
# Only Import opt

View file

@ -6,10 +6,17 @@ from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
import open_clip import open_clip
from ldm import global_opt as g
from ldm.util import default, count_params, autocast from ldm.util import default, count_params, autocast
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
def __init__(self): def __init__(self):
@ -61,9 +68,9 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder): class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text""" """Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device=g.opt.device, max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl def __init__(self, version="google/t5-v1_1-large", device=get_device(), max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__() super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version) self.tokenizer = T5Tokenizer.from_pretrained(version, device)
self.transformer = T5EncoderModel.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device self.device = device
self.max_length = max_length # TODO: typical value? self.max_length = max_length # TODO: typical value?
@ -98,7 +105,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device=g.opt.device, max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device=get_device(), max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
@ -142,7 +149,7 @@ class ClipImageEmbedder(nn.Module):
self, self,
model, model,
jit=False, jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu', device=get_device(),
antialias=True, antialias=True,
ucg_rate=0. ucg_rate=0.
): ):
@ -185,7 +192,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
"penultimate" "penultimate"
] ]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=g.opt.device, max_length=77, def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
freeze=True, layer="last"): freeze=True, layer="last"):
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
@ -243,7 +250,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
Uses the OpenCLIP vision transformer encoder for images Uses the OpenCLIP vision transformer encoder for images
""" """
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.): freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__() super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
@ -298,7 +305,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
class FrozenCLIPT5Encoder(AbstractEncoder): class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=g.opt.device, def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=get_device(),
clip_max_length=77, t5_max_length=77): clip_max_length=77, t5_max_length=77):
super().__init__() super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)

View file

@ -11,6 +11,7 @@ import torch
import torchvision.transforms.functional as TVF import torchvision.transforms.functional as TVF
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from contextlib import nullcontext
from .template import BaseSampler, CKPT_PATH from .template import BaseSampler, CKPT_PATH
@ -31,6 +32,7 @@ class T2ISampler(BaseSampler):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
device,
root_dir: str, root_dir: str,
clip_model_path: str, clip_model_path: str,
clip_stat_path: str, clip_stat_path: str,
@ -41,7 +43,7 @@ class T2ISampler(BaseSampler):
root_dir=root_dir, root_dir=root_dir,
sampling_type=sampling_type, sampling_type=sampling_type,
) )
model.load_clip(clip_model_path) model.load_clip(clip_model_path, device)
model.load_prior( model.load_prior(
f"{CKPT_PATH['prior']}", f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path, clip_stat_path=clip_stat_path,
@ -60,10 +62,10 @@ class T2ISampler(BaseSampler):
prompts_batch = [prompt for _ in range(bsz)] prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
""" Get CLIP text feature """ """ Get CLIP text feature """
clip_model = self._clip clip_model = self._clip
@ -79,7 +81,7 @@ class T2ISampler(BaseSampler):
tok = torch.cat([tok, cf_token], dim=0) tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0) mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda") tok, mask = tok.to(device=self.device), mask.to(device=self.device)
txt_feat, txt_feat_seq = clip_model.encode_text(tok) txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return ( return (
@ -99,7 +101,8 @@ class T2ISampler(BaseSampler):
progressive_mode=None, progressive_mode=None,
) -> Iterator[torch.Tensor]: ) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final") assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast(): precision_scope = nullcontext if self.device.type == 'mps' else torch.cuda.amp.autocast
with torch.no_grad(), precision_scope(self.device.type):
( (
prompts_batch, prompts_batch,
prior_cf_scales_batch, prior_cf_scales_batch,
@ -181,6 +184,7 @@ class PriorSampler(BaseSampler):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
device,
root_dir: str, root_dir: str,
clip_model_path: str, clip_model_path: str,
clip_stat_path: str, clip_stat_path: str,
@ -190,7 +194,7 @@ class PriorSampler(BaseSampler):
root_dir=root_dir, root_dir=root_dir,
sampling_type=sampling_type, sampling_type=sampling_type,
) )
model.load_clip(clip_model_path) model.load_clip(clip_model_path, device)
model.load_prior( model.load_prior(
f"{CKPT_PATH['prior']}", f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path, clip_stat_path=clip_stat_path,
@ -207,10 +211,10 @@ class PriorSampler(BaseSampler):
prompts_batch = [prompt for _ in range(bsz)] prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
""" Get CLIP text feature """ """ Get CLIP text feature """
clip_model = self._clip clip_model = self._clip
@ -226,7 +230,7 @@ class PriorSampler(BaseSampler):
tok = torch.cat([tok, cf_token], dim=0) tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0) mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda") tok, mask = tok.to(device=self.device), mask.to(device=self.device)
txt_feat, txt_feat_seq = clip_model.encode_text(tok) txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return ( return (
@ -246,7 +250,8 @@ class PriorSampler(BaseSampler):
progressive_mode=None, progressive_mode=None,
) -> Iterator[torch.Tensor]: ) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final") assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast(): precision_scope = nullcontext if self.device.type == 'mps' else torch.autocast
with torch.no_grad(), precision_scope(self.device.type):
( (
prompts_batch, prompts_batch,
prior_cf_scales_batch, prior_cf_scales_batch,

View file

@ -73,15 +73,16 @@ class BaseSampler:
return line return line
def load_clip(self, clip_path: str): def load_clip(self, clip_path: str, device):
clip = CustomizedCLIP.load_from_checkpoint( clip = CustomizedCLIP.load_from_checkpoint(
os.path.join(self._root_dir, clip_path) os.path.join(self._root_dir, clip_path)
) )
clip = torch.jit.script(clip) clip = torch.jit.script(clip)
clip.cuda() clip.to(device)
clip.eval() clip.eval()
self._clip = clip self._clip = clip
self.device = device
self._tokenizer = CustomizedTokenizer() self._tokenizer = CustomizedTokenizer()
def load_prior( def load_prior(
@ -105,7 +106,7 @@ class BaseSampler:
os.path.join(self._root_dir, ckpt_path), os.path.join(self._root_dir, ckpt_path),
strict=True, strict=True,
) )
prior.cuda() prior.to(self.device)
prior.eval() prior.eval()
logging.info("done.") logging.info("done.")
@ -121,7 +122,7 @@ class BaseSampler:
os.path.join(self._root_dir, ckpt_path), os.path.join(self._root_dir, ckpt_path),
strict=True, strict=True,
) )
decoder.cuda() decoder.to(self.device)
decoder.eval() decoder.eval()
logging.info("done.") logging.info("done.")
@ -134,7 +135,7 @@ class BaseSampler:
sr = self._SR256_CLASS.load_from_checkpoint( sr = self._SR256_CLASS.load_from_checkpoint(
config, os.path.join(self._root_dir, ckpt_path), strict=True config, os.path.join(self._root_dir, ckpt_path), strict=True
) )
sr.cuda() sr.to(self.device)
sr.eval() sr.eval()
logging.info("done.") logging.info("done.")

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from einops import repeat, rearrange from einops import repeat, rearrange
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark from scripts.txt2img import put_watermark
@ -16,15 +17,23 @@ from ldm.data.util import AddMiDaS
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def initialize_model(config, ckpt): def initialize_model(config, ckpt):
config = OmegaConf.load(config) config = OmegaConf.load(config)
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device( device = torch.device(get_device())
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device) model = model.to(device)
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device)
return sampler return sampler
@ -54,8 +63,7 @@ def make_batch_sd(
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None, def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
do_full_sample=False): do_full_sample=False):
device = torch.device( device = torch.device(get_device())
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = sampler.model model = sampler.model
seed_everything(seed) seed_everything(seed)
@ -64,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\ with torch.no_grad(),\
torch.autocast("cuda"): precision_scope(device.type):
batch = make_batch_sd( batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples) image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage( z = model.get_first_stage_encoding(model.encode_first_stage(

View file

@ -6,6 +6,7 @@ import gradio as gr
from PIL import Image from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from einops import repeat from einops import repeat
from contextlib import nullcontext
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from pathlib import Path from pathlib import Path
@ -16,6 +17,16 @@ from ldm.util import instantiate_from_config
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def put_watermark(img, wm_encoder=None): def put_watermark(img, wm_encoder=None):
if wm_encoder is not None: if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@ -30,10 +41,9 @@ def initialize_model(config, ckpt):
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device( device = torch.device(get_device())
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device) model = model.to(device)
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device)
return sampler return sampler
@ -67,8 +77,7 @@ def make_batch_sd(
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512): def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
device = torch.device( device = torch.device(get_device())
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = sampler.model model = sampler.model
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
@ -81,8 +90,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = torch.from_numpy(start_code).to( start_code = torch.from_numpy(start_code).to(
device=device, dtype=torch.float32) device=device, dtype=torch.float32)
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\ with torch.no_grad(),\
torch.autocast("cuda"): precision_scope(device.type):
batch = make_batch_sd(image, mask, txt=prompt, batch = make_batch_sd(image, mask, txt=prompt,
device=device, num_samples=num_samples) device=device, num_samples=num_samples)

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from einops import repeat, rearrange from einops import repeat, rearrange
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark from scripts.txt2img import put_watermark
@ -17,15 +18,23 @@ from ldm.util import exists, instantiate_from_config
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def initialize_model(config, ckpt): def initialize_model(config, ckpt):
config = OmegaConf.load(config) config = OmegaConf.load(config)
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device( device = get_device()
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device) model = model.to(device)
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device)
return sampler return sampler
@ -54,8 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None): def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
device = torch.device( device = torch.device(get_device())
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = sampler.model model = sampler.model
seed_everything(seed) seed_everything(seed)
prng = np.random.RandomState(seed) prng = np.random.RandomState(seed)
@ -67,8 +75,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2" wm = "SDV2"
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\ with torch.no_grad(),\
torch.autocast("cuda"): precision_scope(device.type):
batch = make_batch_sd( batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples) image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])

View file

@ -16,16 +16,15 @@ from pytorch_lightning import seed_everything
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from ldm import global_opt as g
from scripts.txt2img import put_watermark from scripts.txt2img import put_watermark
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
def get_device(): def get_device():
if(torch.cuda.is_available()): if torch.cuda.is_available():
return 'cuda' return 'cuda'
elif(torch.backends.mps.is_available()): elif torch.backends.mps.is_available():
return 'mps' return 'mps'
else: else:
return 'cpu' return 'cpu'
@ -197,7 +196,7 @@ def main():
default="autocast" default="autocast"
) )
opt = g.opt = parser.parse_args() opt = parser.parse_args()
seed_everything(opt.seed) seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}") config = OmegaConf.load(f"{opt.config}")

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from einops import repeat, rearrange from einops import repeat, rearrange
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark from scripts.txt2img import put_watermark
@ -16,15 +17,25 @@ from ldm.data.util import AddMiDaS
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
@st.cache(allow_output_mutation=True) @st.cache(allow_output_mutation=True)
def initialize_model(config, ckpt): def initialize_model(config, ckpt):
config = OmegaConf.load(config) config = OmegaConf.load(config)
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device(get_device())
model = model.to(device) model = model.to(device)
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device)
return sampler return sampler
@ -52,7 +63,7 @@ def make_batch_sd(
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None, def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
do_full_sample=False): do_full_sample=False):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device(get_device())
model = sampler.model model = sampler.model
seed_everything(seed) seed_everything(seed)
@ -61,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\ with torch.no_grad(),\
torch.autocast("cuda"): precision_scope(device.type):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])

View file

@ -7,6 +7,7 @@ from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from einops import repeat from einops import repeat
from streamlit_drawable_canvas import st_canvas from streamlit_drawable_canvas import st_canvas
from contextlib import nullcontext
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
@ -16,6 +17,15 @@ from ldm.util import instantiate_from_config
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def put_watermark(img, wm_encoder=None): def put_watermark(img, wm_encoder=None):
if wm_encoder is not None: if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@ -31,9 +41,9 @@ def initialize_model(config, ckpt):
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device(get_device())
model = model.to(device) model = model.to(device)
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device)
return sampler return sampler
@ -67,7 +77,7 @@ def make_batch_sd(
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.): def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = get_device()
model = sampler.model model = sampler.model
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
@ -79,8 +89,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = prng.randn(num_samples, 4, h // 8, w // 8) start_code = prng.randn(num_samples, 4, h // 8, w // 8)
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\ with torch.no_grad(),\
torch.autocast("cuda"): precision_scope(device.type):
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples) batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])

View file

@ -30,6 +30,15 @@ VERSION2SPECS = {
} }
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def get_obj_from_str(string, reload=False): def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1) module, cls = string.rsplit(".", 1)
importlib.invalidate_caches() importlib.invalidate_caches()
@ -69,7 +78,7 @@ def load_img(display=True, key=None):
def get_init_img(batch_size=1, key=None): def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda() init_image = load_img(key=key).to(get_device())
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
return init_image return init_image
@ -97,7 +106,10 @@ def sample(
only_adm_cond=False only_adm_cond=False
): ):
batch_size = n_samples batch_size = n_samples
device = torch.device(get_device())
precision_scope = autocast if not use_full_precision else nullcontext precision_scope = autocast if not use_full_precision else nullcontext
if device.type == 'mps':
precision_scope = nullcontext
# decoderscope = autocast if not use_full_precision else nullcontext # decoderscope = autocast if not use_full_precision else nullcontext
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.") if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
if isinstance(prompt, str): if isinstance(prompt, str):
@ -106,7 +118,7 @@ def sample(
outputs = st.empty() outputs = st.empty()
with precision_scope("cuda"): with precision_scope(device.type):
with model.ema_scope(): with model.ema_scope():
all_samples = list() all_samples = list()
for n in trange(n_runs, desc="Sampling"): for n in trange(n_runs, desc="Sampling"):
@ -208,6 +220,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
from ldm.modules.karlo.kakao.sampler import T2ISampler from ldm.modules.karlo.kakao.sampler import T2ISampler
st.info("Loading full KARLO..") st.info("Loading full KARLO..")
karlo = T2ISampler.from_pretrained( karlo = T2ISampler.from_pretrained(
device=get_device(),
root_dir="checkpoints/karlo_models", root_dir="checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt", clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th", clip_stat_path="ViT-L-14_stats.th",
@ -227,6 +240,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
from ldm.modules.karlo.kakao.sampler import PriorSampler from ldm.modules.karlo.kakao.sampler import PriorSampler
st.info("Loading KARLO CLIP prior...") st.info("Loading KARLO CLIP prior...")
karlo_prior = PriorSampler.from_pretrained( karlo_prior = PriorSampler.from_pretrained(
device=get_device(),
root_dir="checkpoints/karlo_models", root_dir="checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt", clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th", clip_stat_path="ViT-L-14_stats.th",
@ -263,7 +277,7 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
print("unexpected keys:") print("unexpected keys:")
print(u) print(u)
model.cuda() model.to(get_device())
model.eval() model.eval()
print(f"Loaded global step {global_step}") print(f"Loaded global step {global_step}")
return model, msg return model, msg
@ -301,7 +315,7 @@ if __name__ == "__main__":
pass pass
else: else:
if sampler == "DPM": if sampler == "DPM":
sampler = DPMSolverSampler(state["model"]) sampler = DPMSolverSampler(state["model"], device=get_device())
elif sampler == "DDIM": elif sampler == "DDIM":
sampler = DDIMSampler(state["model"]) sampler = DDIMSampler(state["model"])
else: else:

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf from omegaconf import OmegaConf
from einops import repeat, rearrange from einops import repeat, rearrange
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark from scripts.txt2img import put_watermark
@ -17,15 +18,24 @@ from ldm.util import exists, instantiate_from_config
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
@st.cache(allow_output_mutation=True) @st.cache(allow_output_mutation=True)
def initialize_model(config, ckpt): def initialize_model(config, ckpt):
config = OmegaConf.load(config) config = OmegaConf.load(config)
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device(get_device())
model = model.to(device) model = model.to(device)
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device)
return sampler return sampler
@ -53,7 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None): def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device(get_device())
model = sampler.model model = sampler.model
seed_everything(seed) seed_everything(seed)
prng = np.random.RandomState(seed) prng = np.random.RandomState(seed)
@ -64,8 +74,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2" wm = "SDV2"
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\ with torch.no_grad(),\
torch.autocast("cuda"): precision_scope(device.type):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])
c_cat = list() c_cat = list()

View file

@ -13,7 +13,6 @@ from torch import autocast
from contextlib import nullcontext from contextlib import nullcontext
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from ldm import global_opt as g
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
@ -22,9 +21,9 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def get_device(): def get_device():
if(torch.cuda.is_available()): if torch.cuda.is_available():
return 'cuda' return 'cuda'
elif(torch.backends.mps.is_available()): elif torch.backends.mps.is_available():
return 'mps' return 'mps'
else: else:
return 'cpu' return 'cpu'
@ -389,5 +388,5 @@ def main(opt):
if __name__ == "__main__": if __name__ == "__main__":
opt = g.opt = parse_args() opt = parse_args()
main(opt) main(opt)