mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
Support gradio and streamlit
This commit is contained in:
parent
d6933311e7
commit
90d4c71350
13 changed files with 154 additions and 68 deletions
|
@ -1 +0,0 @@
|
|||
# Only Import opt
|
|
@ -6,10 +6,17 @@ from torch.utils.checkpoint import checkpoint
|
|||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
import open_clip
|
||||
from ldm import global_opt as g
|
||||
from ldm.util import default, count_params, autocast
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -61,9 +68,9 @@ def disabled_train(self, mode=True):
|
|||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(self, version="google/t5-v1_1-large", device=g.opt.device, max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
def __init__(self, version="google/t5-v1_1-large", device=get_device(), max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version, device)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
|
@ -98,7 +105,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
"hidden"
|
||||
]
|
||||
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device=g.opt.device, max_length=77,
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device=get_device(), max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
@ -142,7 +149,7 @@ class ClipImageEmbedder(nn.Module):
|
|||
self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
device=get_device(),
|
||||
antialias=True,
|
||||
ucg_rate=0.
|
||||
):
|
||||
|
@ -185,7 +192,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=g.opt.device, max_length=77,
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
|
||||
freeze=True, layer="last"):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
@ -243,7 +250,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
|||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
|
||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||
|
@ -298,7 +305,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
|||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=g.opt.device,
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=get_device(),
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||
|
|
|
@ -11,6 +11,7 @@ import torch
|
|||
import torchvision.transforms.functional as TVF
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from contextlib import nullcontext
|
||||
from .template import BaseSampler, CKPT_PATH
|
||||
|
||||
|
||||
|
@ -31,6 +32,7 @@ class T2ISampler(BaseSampler):
|
|||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
device,
|
||||
root_dir: str,
|
||||
clip_model_path: str,
|
||||
clip_stat_path: str,
|
||||
|
@ -41,7 +43,7 @@ class T2ISampler(BaseSampler):
|
|||
root_dir=root_dir,
|
||||
sampling_type=sampling_type,
|
||||
)
|
||||
model.load_clip(clip_model_path)
|
||||
model.load_clip(clip_model_path, device)
|
||||
model.load_prior(
|
||||
f"{CKPT_PATH['prior']}",
|
||||
clip_stat_path=clip_stat_path,
|
||||
|
@ -60,10 +62,10 @@ class T2ISampler(BaseSampler):
|
|||
prompts_batch = [prompt for _ in range(bsz)]
|
||||
|
||||
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
|
||||
|
||||
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
|
||||
|
||||
""" Get CLIP text feature """
|
||||
clip_model = self._clip
|
||||
|
@ -79,7 +81,7 @@ class T2ISampler(BaseSampler):
|
|||
tok = torch.cat([tok, cf_token], dim=0)
|
||||
mask = torch.cat([mask, cf_mask], dim=0)
|
||||
|
||||
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
||||
tok, mask = tok.to(device=self.device), mask.to(device=self.device)
|
||||
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
||||
|
||||
return (
|
||||
|
@ -99,7 +101,8 @@ class T2ISampler(BaseSampler):
|
|||
progressive_mode=None,
|
||||
) -> Iterator[torch.Tensor]:
|
||||
assert progressive_mode in ("loop", "stage", "final")
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
precision_scope = nullcontext if self.device.type == 'mps' else torch.cuda.amp.autocast
|
||||
with torch.no_grad(), precision_scope(self.device.type):
|
||||
(
|
||||
prompts_batch,
|
||||
prior_cf_scales_batch,
|
||||
|
@ -181,6 +184,7 @@ class PriorSampler(BaseSampler):
|
|||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
device,
|
||||
root_dir: str,
|
||||
clip_model_path: str,
|
||||
clip_stat_path: str,
|
||||
|
@ -190,7 +194,7 @@ class PriorSampler(BaseSampler):
|
|||
root_dir=root_dir,
|
||||
sampling_type=sampling_type,
|
||||
)
|
||||
model.load_clip(clip_model_path)
|
||||
model.load_clip(clip_model_path, device)
|
||||
model.load_prior(
|
||||
f"{CKPT_PATH['prior']}",
|
||||
clip_stat_path=clip_stat_path,
|
||||
|
@ -207,10 +211,10 @@ class PriorSampler(BaseSampler):
|
|||
prompts_batch = [prompt for _ in range(bsz)]
|
||||
|
||||
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
|
||||
|
||||
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
|
||||
|
||||
""" Get CLIP text feature """
|
||||
clip_model = self._clip
|
||||
|
@ -226,7 +230,7 @@ class PriorSampler(BaseSampler):
|
|||
tok = torch.cat([tok, cf_token], dim=0)
|
||||
mask = torch.cat([mask, cf_mask], dim=0)
|
||||
|
||||
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
||||
tok, mask = tok.to(device=self.device), mask.to(device=self.device)
|
||||
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
||||
|
||||
return (
|
||||
|
@ -246,7 +250,8 @@ class PriorSampler(BaseSampler):
|
|||
progressive_mode=None,
|
||||
) -> Iterator[torch.Tensor]:
|
||||
assert progressive_mode in ("loop", "stage", "final")
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
precision_scope = nullcontext if self.device.type == 'mps' else torch.autocast
|
||||
with torch.no_grad(), precision_scope(self.device.type):
|
||||
(
|
||||
prompts_batch,
|
||||
prior_cf_scales_batch,
|
||||
|
|
|
@ -73,15 +73,16 @@ class BaseSampler:
|
|||
|
||||
return line
|
||||
|
||||
def load_clip(self, clip_path: str):
|
||||
def load_clip(self, clip_path: str, device):
|
||||
clip = CustomizedCLIP.load_from_checkpoint(
|
||||
os.path.join(self._root_dir, clip_path)
|
||||
)
|
||||
clip = torch.jit.script(clip)
|
||||
clip.cuda()
|
||||
clip.to(device)
|
||||
clip.eval()
|
||||
|
||||
self._clip = clip
|
||||
self.device = device
|
||||
self._tokenizer = CustomizedTokenizer()
|
||||
|
||||
def load_prior(
|
||||
|
@ -105,7 +106,7 @@ class BaseSampler:
|
|||
os.path.join(self._root_dir, ckpt_path),
|
||||
strict=True,
|
||||
)
|
||||
prior.cuda()
|
||||
prior.to(self.device)
|
||||
prior.eval()
|
||||
logging.info("done.")
|
||||
|
||||
|
@ -121,7 +122,7 @@ class BaseSampler:
|
|||
os.path.join(self._root_dir, ckpt_path),
|
||||
strict=True,
|
||||
)
|
||||
decoder.cuda()
|
||||
decoder.to(self.device)
|
||||
decoder.eval()
|
||||
logging.info("done.")
|
||||
|
||||
|
@ -134,7 +135,7 @@ class BaseSampler:
|
|||
sr = self._SR256_CLASS.load_from_checkpoint(
|
||||
config, os.path.join(self._root_dir, ckpt_path), strict=True
|
||||
)
|
||||
sr.cuda()
|
||||
sr.to(self.device)
|
||||
sr.eval()
|
||||
logging.info("done.")
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
|||
from omegaconf import OmegaConf
|
||||
from einops import repeat, rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from contextlib import nullcontext
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from scripts.txt2img import put_watermark
|
||||
|
@ -16,15 +17,23 @@ from ldm.data.util import AddMiDaS
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def initialize_model(config, ckpt):
|
||||
config = OmegaConf.load(config)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
sampler = DDIMSampler(model, device)
|
||||
return sampler
|
||||
|
||||
|
||||
|
@ -54,8 +63,7 @@ def make_batch_sd(
|
|||
|
||||
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
||||
do_full_sample=False):
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = sampler.model
|
||||
seed_everything(seed)
|
||||
|
||||
|
@ -64,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
|
|||
wm_encoder = WatermarkEncoder()
|
||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||
|
||||
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||
with torch.no_grad(),\
|
||||
torch.autocast("cuda"):
|
||||
precision_scope(device.type):
|
||||
batch = make_batch_sd(
|
||||
image, txt=prompt, device=device, num_samples=num_samples)
|
||||
z = model.get_first_stage_encoding(model.encode_first_stage(
|
||||
|
|
|
@ -6,6 +6,7 @@ import gradio as gr
|
|||
from PIL import Image
|
||||
from omegaconf import OmegaConf
|
||||
from einops import repeat
|
||||
from contextlib import nullcontext
|
||||
from imwatermark import WatermarkEncoder
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -16,6 +17,16 @@ from ldm.util import instantiate_from_config
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
|
||||
def put_watermark(img, wm_encoder=None):
|
||||
if wm_encoder is not None:
|
||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
|
@ -30,10 +41,9 @@ def initialize_model(config, ckpt):
|
|||
|
||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
sampler = DDIMSampler(model, device)
|
||||
|
||||
return sampler
|
||||
|
||||
|
@ -67,8 +77,7 @@ def make_batch_sd(
|
|||
|
||||
|
||||
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = sampler.model
|
||||
|
||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||
|
@ -81,8 +90,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
|
|||
start_code = torch.from_numpy(start_code).to(
|
||||
device=device, dtype=torch.float32)
|
||||
|
||||
with torch.no_grad(), \
|
||||
torch.autocast("cuda"):
|
||||
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||
with torch.no_grad(),\
|
||||
precision_scope(device.type):
|
||||
batch = make_batch_sd(image, mask, txt=prompt,
|
||||
device=device, num_samples=num_samples)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
|||
from omegaconf import OmegaConf
|
||||
from einops import repeat, rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from contextlib import nullcontext
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from scripts.txt2img import put_watermark
|
||||
|
@ -17,15 +18,23 @@ from ldm.util import exists, instantiate_from_config
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def initialize_model(config, ckpt):
|
||||
config = OmegaConf.load(config)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = get_device()
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
sampler = DDIMSampler(model, device)
|
||||
return sampler
|
||||
|
||||
|
||||
|
@ -54,8 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
|
|||
|
||||
|
||||
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = sampler.model
|
||||
seed_everything(seed)
|
||||
prng = np.random.RandomState(seed)
|
||||
|
@ -67,8 +75,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
|||
wm = "SDV2"
|
||||
wm_encoder = WatermarkEncoder()
|
||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||
with torch.no_grad(),\
|
||||
torch.autocast("cuda"):
|
||||
precision_scope(device.type):
|
||||
batch = make_batch_sd(
|
||||
image, txt=prompt, device=device, num_samples=num_samples)
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
|
|
|
@ -16,16 +16,15 @@ from pytorch_lightning import seed_everything
|
|||
from imwatermark import WatermarkEncoder
|
||||
|
||||
|
||||
from ldm import global_opt as g
|
||||
from scripts.txt2img import put_watermark
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
|
||||
def get_device():
|
||||
if(torch.cuda.is_available()):
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif(torch.backends.mps.is_available()):
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
@ -197,7 +196,7 @@ def main():
|
|||
default="autocast"
|
||||
)
|
||||
|
||||
opt = g.opt = parser.parse_args()
|
||||
opt = parser.parse_args()
|
||||
seed_everything(opt.seed)
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
|||
from omegaconf import OmegaConf
|
||||
from einops import repeat, rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from contextlib import nullcontext
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from scripts.txt2img import put_watermark
|
||||
|
@ -16,15 +17,25 @@ from ldm.data.util import AddMiDaS
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def initialize_model(config, ckpt):
|
||||
config = OmegaConf.load(config)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
sampler = DDIMSampler(model, device)
|
||||
return sampler
|
||||
|
||||
|
||||
|
@ -52,7 +63,7 @@ def make_batch_sd(
|
|||
|
||||
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
||||
do_full_sample=False):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = sampler.model
|
||||
seed_everything(seed)
|
||||
|
||||
|
@ -61,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
|
|||
wm_encoder = WatermarkEncoder()
|
||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||
|
||||
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||
with torch.no_grad(),\
|
||||
torch.autocast("cuda"):
|
||||
precision_scope(device.type):
|
||||
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
||||
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
|
|
|
@ -7,6 +7,7 @@ from PIL import Image
|
|||
from omegaconf import OmegaConf
|
||||
from einops import repeat
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
from contextlib import nullcontext
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
@ -16,6 +17,15 @@ from ldm.util import instantiate_from_config
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def put_watermark(img, wm_encoder=None):
|
||||
if wm_encoder is not None:
|
||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
|
@ -31,9 +41,9 @@ def initialize_model(config, ckpt):
|
|||
|
||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
sampler = DDIMSampler(model, device)
|
||||
|
||||
return sampler
|
||||
|
||||
|
@ -67,7 +77,7 @@ def make_batch_sd(
|
|||
|
||||
|
||||
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = get_device()
|
||||
model = sampler.model
|
||||
|
||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||
|
@ -79,8 +89,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
|
|||
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
||||
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
|
||||
|
||||
with torch.no_grad(), \
|
||||
torch.autocast("cuda"):
|
||||
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||
with torch.no_grad(),\
|
||||
precision_scope(device.type):
|
||||
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
|
||||
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
|
|
|
@ -30,6 +30,15 @@ VERSION2SPECS = {
|
|||
}
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
importlib.invalidate_caches()
|
||||
|
@ -69,7 +78,7 @@ def load_img(display=True, key=None):
|
|||
|
||||
|
||||
def get_init_img(batch_size=1, key=None):
|
||||
init_image = load_img(key=key).cuda()
|
||||
init_image = load_img(key=key).to(get_device())
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
return init_image
|
||||
|
||||
|
@ -97,7 +106,10 @@ def sample(
|
|||
only_adm_cond=False
|
||||
):
|
||||
batch_size = n_samples
|
||||
device = torch.device(get_device())
|
||||
precision_scope = autocast if not use_full_precision else nullcontext
|
||||
if device.type == 'mps':
|
||||
precision_scope = nullcontext
|
||||
# decoderscope = autocast if not use_full_precision else nullcontext
|
||||
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
||||
if isinstance(prompt, str):
|
||||
|
@ -106,7 +118,7 @@ def sample(
|
|||
|
||||
outputs = st.empty()
|
||||
|
||||
with precision_scope("cuda"):
|
||||
with precision_scope(device.type):
|
||||
with model.ema_scope():
|
||||
all_samples = list()
|
||||
for n in trange(n_runs, desc="Sampling"):
|
||||
|
@ -208,6 +220,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
|||
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
||||
st.info("Loading full KARLO..")
|
||||
karlo = T2ISampler.from_pretrained(
|
||||
device=get_device(),
|
||||
root_dir="checkpoints/karlo_models",
|
||||
clip_model_path="ViT-L-14.pt",
|
||||
clip_stat_path="ViT-L-14_stats.th",
|
||||
|
@ -227,6 +240,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
|||
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
||||
st.info("Loading KARLO CLIP prior...")
|
||||
karlo_prior = PriorSampler.from_pretrained(
|
||||
device=get_device(),
|
||||
root_dir="checkpoints/karlo_models",
|
||||
clip_model_path="ViT-L-14.pt",
|
||||
clip_stat_path="ViT-L-14_stats.th",
|
||||
|
@ -263,7 +277,7 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
|
|||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.to(get_device())
|
||||
model.eval()
|
||||
print(f"Loaded global step {global_step}")
|
||||
return model, msg
|
||||
|
@ -301,7 +315,7 @@ if __name__ == "__main__":
|
|||
pass
|
||||
else:
|
||||
if sampler == "DPM":
|
||||
sampler = DPMSolverSampler(state["model"])
|
||||
sampler = DPMSolverSampler(state["model"], device=get_device())
|
||||
elif sampler == "DDIM":
|
||||
sampler = DDIMSampler(state["model"])
|
||||
else:
|
||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
|||
from omegaconf import OmegaConf
|
||||
from einops import repeat, rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from contextlib import nullcontext
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from scripts.txt2img import put_watermark
|
||||
|
@ -17,15 +18,24 @@ from ldm.util import exists, instantiate_from_config
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
@st.cache(allow_output_mutation=True)
|
||||
def initialize_model(config, ckpt):
|
||||
config = OmegaConf.load(config)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
sampler = DDIMSampler(model, device)
|
||||
return sampler
|
||||
|
||||
|
||||
|
@ -53,7 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
|
|||
|
||||
|
||||
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = torch.device(get_device())
|
||||
model = sampler.model
|
||||
seed_everything(seed)
|
||||
prng = np.random.RandomState(seed)
|
||||
|
@ -64,8 +74,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
|||
wm = "SDV2"
|
||||
wm_encoder = WatermarkEncoder()
|
||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||
with torch.no_grad(),\
|
||||
torch.autocast("cuda"):
|
||||
precision_scope(device.type):
|
||||
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
c_cat = list()
|
||||
|
|
|
@ -13,7 +13,6 @@ from torch import autocast
|
|||
from contextlib import nullcontext
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from ldm import global_opt as g
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
@ -22,9 +21,9 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
def get_device():
|
||||
if(torch.cuda.is_available()):
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
elif(torch.backends.mps.is_available()):
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
@ -389,5 +388,5 @@ def main(opt):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = g.opt = parse_args()
|
||||
opt = parse_args()
|
||||
main(opt)
|
||||
|
|
Loading…
Reference in a new issue