mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Support gradio and streamlit
This commit is contained in:
parent
d6933311e7
commit
90d4c71350
13 changed files with 154 additions and 68 deletions
|
@ -1 +0,0 @@
|
||||||
# Only Import opt
|
|
|
@ -6,10 +6,17 @@ from torch.utils.checkpoint import checkpoint
|
||||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
import open_clip
|
import open_clip
|
||||||
from ldm import global_opt as g
|
|
||||||
from ldm.util import default, count_params, autocast
|
from ldm.util import default, count_params, autocast
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
class AbstractEncoder(nn.Module):
|
class AbstractEncoder(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -61,9 +68,9 @@ def disabled_train(self, mode=True):
|
||||||
class FrozenT5Embedder(AbstractEncoder):
|
class FrozenT5Embedder(AbstractEncoder):
|
||||||
"""Uses the T5 transformer encoder for text"""
|
"""Uses the T5 transformer encoder for text"""
|
||||||
|
|
||||||
def __init__(self, version="google/t5-v1_1-large", device=g.opt.device, max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
def __init__(self, version="google/t5-v1_1-large", device=get_device(), max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
self.tokenizer = T5Tokenizer.from_pretrained(version, device)
|
||||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length # TODO: typical value?
|
self.max_length = max_length # TODO: typical value?
|
||||||
|
@ -98,7 +105,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device=g.opt.device, max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device=get_device(), max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
@ -142,7 +149,7 @@ class ClipImageEmbedder(nn.Module):
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
jit=False,
|
jit=False,
|
||||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
device=get_device(),
|
||||||
antialias=True,
|
antialias=True,
|
||||||
ucg_rate=0.
|
ucg_rate=0.
|
||||||
):
|
):
|
||||||
|
@ -185,7 +192,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||||
"penultimate"
|
"penultimate"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=g.opt.device, max_length=77,
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
|
||||||
freeze=True, layer="last"):
|
freeze=True, layer="last"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
@ -243,7 +250,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||||
Uses the OpenCLIP vision transformer encoder for images
|
Uses the OpenCLIP vision transformer encoder for images
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
|
||||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||||
|
@ -298,7 +305,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=g.opt.device,
|
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=get_device(),
|
||||||
clip_max_length=77, t5_max_length=77):
|
clip_max_length=77, t5_max_length=77):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||||
|
|
|
@ -11,6 +11,7 @@ import torch
|
||||||
import torchvision.transforms.functional as TVF
|
import torchvision.transforms.functional as TVF
|
||||||
from torchvision.transforms import InterpolationMode
|
from torchvision.transforms import InterpolationMode
|
||||||
|
|
||||||
|
from contextlib import nullcontext
|
||||||
from .template import BaseSampler, CKPT_PATH
|
from .template import BaseSampler, CKPT_PATH
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +32,7 @@ class T2ISampler(BaseSampler):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
device,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
clip_model_path: str,
|
clip_model_path: str,
|
||||||
clip_stat_path: str,
|
clip_stat_path: str,
|
||||||
|
@ -41,7 +43,7 @@ class T2ISampler(BaseSampler):
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
sampling_type=sampling_type,
|
sampling_type=sampling_type,
|
||||||
)
|
)
|
||||||
model.load_clip(clip_model_path)
|
model.load_clip(clip_model_path, device)
|
||||||
model.load_prior(
|
model.load_prior(
|
||||||
f"{CKPT_PATH['prior']}",
|
f"{CKPT_PATH['prior']}",
|
||||||
clip_stat_path=clip_stat_path,
|
clip_stat_path=clip_stat_path,
|
||||||
|
@ -60,10 +62,10 @@ class T2ISampler(BaseSampler):
|
||||||
prompts_batch = [prompt for _ in range(bsz)]
|
prompts_batch = [prompt for _ in range(bsz)]
|
||||||
|
|
||||||
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
||||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
|
||||||
|
|
||||||
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
||||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
|
||||||
|
|
||||||
""" Get CLIP text feature """
|
""" Get CLIP text feature """
|
||||||
clip_model = self._clip
|
clip_model = self._clip
|
||||||
|
@ -79,7 +81,7 @@ class T2ISampler(BaseSampler):
|
||||||
tok = torch.cat([tok, cf_token], dim=0)
|
tok = torch.cat([tok, cf_token], dim=0)
|
||||||
mask = torch.cat([mask, cf_mask], dim=0)
|
mask = torch.cat([mask, cf_mask], dim=0)
|
||||||
|
|
||||||
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
tok, mask = tok.to(device=self.device), mask.to(device=self.device)
|
||||||
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -99,7 +101,8 @@ class T2ISampler(BaseSampler):
|
||||||
progressive_mode=None,
|
progressive_mode=None,
|
||||||
) -> Iterator[torch.Tensor]:
|
) -> Iterator[torch.Tensor]:
|
||||||
assert progressive_mode in ("loop", "stage", "final")
|
assert progressive_mode in ("loop", "stage", "final")
|
||||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
precision_scope = nullcontext if self.device.type == 'mps' else torch.cuda.amp.autocast
|
||||||
|
with torch.no_grad(), precision_scope(self.device.type):
|
||||||
(
|
(
|
||||||
prompts_batch,
|
prompts_batch,
|
||||||
prior_cf_scales_batch,
|
prior_cf_scales_batch,
|
||||||
|
@ -181,6 +184,7 @@ class PriorSampler(BaseSampler):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
device,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
clip_model_path: str,
|
clip_model_path: str,
|
||||||
clip_stat_path: str,
|
clip_stat_path: str,
|
||||||
|
@ -190,7 +194,7 @@ class PriorSampler(BaseSampler):
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
sampling_type=sampling_type,
|
sampling_type=sampling_type,
|
||||||
)
|
)
|
||||||
model.load_clip(clip_model_path)
|
model.load_clip(clip_model_path, device)
|
||||||
model.load_prior(
|
model.load_prior(
|
||||||
f"{CKPT_PATH['prior']}",
|
f"{CKPT_PATH['prior']}",
|
||||||
clip_stat_path=clip_stat_path,
|
clip_stat_path=clip_stat_path,
|
||||||
|
@ -207,10 +211,10 @@ class PriorSampler(BaseSampler):
|
||||||
prompts_batch = [prompt for _ in range(bsz)]
|
prompts_batch = [prompt for _ in range(bsz)]
|
||||||
|
|
||||||
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
|
||||||
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
|
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
|
||||||
|
|
||||||
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
|
||||||
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
|
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
|
||||||
|
|
||||||
""" Get CLIP text feature """
|
""" Get CLIP text feature """
|
||||||
clip_model = self._clip
|
clip_model = self._clip
|
||||||
|
@ -226,7 +230,7 @@ class PriorSampler(BaseSampler):
|
||||||
tok = torch.cat([tok, cf_token], dim=0)
|
tok = torch.cat([tok, cf_token], dim=0)
|
||||||
mask = torch.cat([mask, cf_mask], dim=0)
|
mask = torch.cat([mask, cf_mask], dim=0)
|
||||||
|
|
||||||
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
|
tok, mask = tok.to(device=self.device), mask.to(device=self.device)
|
||||||
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -246,7 +250,8 @@ class PriorSampler(BaseSampler):
|
||||||
progressive_mode=None,
|
progressive_mode=None,
|
||||||
) -> Iterator[torch.Tensor]:
|
) -> Iterator[torch.Tensor]:
|
||||||
assert progressive_mode in ("loop", "stage", "final")
|
assert progressive_mode in ("loop", "stage", "final")
|
||||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
precision_scope = nullcontext if self.device.type == 'mps' else torch.autocast
|
||||||
|
with torch.no_grad(), precision_scope(self.device.type):
|
||||||
(
|
(
|
||||||
prompts_batch,
|
prompts_batch,
|
||||||
prior_cf_scales_batch,
|
prior_cf_scales_batch,
|
||||||
|
|
|
@ -73,15 +73,16 @@ class BaseSampler:
|
||||||
|
|
||||||
return line
|
return line
|
||||||
|
|
||||||
def load_clip(self, clip_path: str):
|
def load_clip(self, clip_path: str, device):
|
||||||
clip = CustomizedCLIP.load_from_checkpoint(
|
clip = CustomizedCLIP.load_from_checkpoint(
|
||||||
os.path.join(self._root_dir, clip_path)
|
os.path.join(self._root_dir, clip_path)
|
||||||
)
|
)
|
||||||
clip = torch.jit.script(clip)
|
clip = torch.jit.script(clip)
|
||||||
clip.cuda()
|
clip.to(device)
|
||||||
clip.eval()
|
clip.eval()
|
||||||
|
|
||||||
self._clip = clip
|
self._clip = clip
|
||||||
|
self.device = device
|
||||||
self._tokenizer = CustomizedTokenizer()
|
self._tokenizer = CustomizedTokenizer()
|
||||||
|
|
||||||
def load_prior(
|
def load_prior(
|
||||||
|
@ -105,7 +106,7 @@ class BaseSampler:
|
||||||
os.path.join(self._root_dir, ckpt_path),
|
os.path.join(self._root_dir, ckpt_path),
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
prior.cuda()
|
prior.to(self.device)
|
||||||
prior.eval()
|
prior.eval()
|
||||||
logging.info("done.")
|
logging.info("done.")
|
||||||
|
|
||||||
|
@ -121,7 +122,7 @@ class BaseSampler:
|
||||||
os.path.join(self._root_dir, ckpt_path),
|
os.path.join(self._root_dir, ckpt_path),
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
decoder.cuda()
|
decoder.to(self.device)
|
||||||
decoder.eval()
|
decoder.eval()
|
||||||
logging.info("done.")
|
logging.info("done.")
|
||||||
|
|
||||||
|
@ -134,7 +135,7 @@ class BaseSampler:
|
||||||
sr = self._SR256_CLASS.load_from_checkpoint(
|
sr = self._SR256_CLASS.load_from_checkpoint(
|
||||||
config, os.path.join(self._root_dir, ckpt_path), strict=True
|
config, os.path.join(self._root_dir, ckpt_path), strict=True
|
||||||
)
|
)
|
||||||
sr.cuda()
|
sr.to(self.device)
|
||||||
sr.eval()
|
sr.eval()
|
||||||
logging.info("done.")
|
logging.info("done.")
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
from contextlib import nullcontext
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
from scripts.txt2img import put_watermark
|
from scripts.txt2img import put_watermark
|
||||||
|
@ -16,15 +17,23 @@ from ldm.data.util import AddMiDaS
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
def initialize_model(config, ckpt):
|
def initialize_model(config, ckpt):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(get_device())
|
||||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,8 +63,7 @@ def make_batch_sd(
|
||||||
|
|
||||||
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
||||||
do_full_sample=False):
|
do_full_sample=False):
|
||||||
device = torch.device(
|
device = torch.device(get_device())
|
||||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = sampler.model
|
model = sampler.model
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
|
@ -64,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
|
||||||
|
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||||
with torch.no_grad(),\
|
with torch.no_grad(),\
|
||||||
torch.autocast("cuda"):
|
precision_scope(device.type):
|
||||||
batch = make_batch_sd(
|
batch = make_batch_sd(
|
||||||
image, txt=prompt, device=device, num_samples=num_samples)
|
image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
z = model.get_first_stage_encoding(model.encode_first_stage(
|
z = model.get_first_stage_encoding(model.encode_first_stage(
|
||||||
|
|
|
@ -6,6 +6,7 @@ import gradio as gr
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
from contextlib import nullcontext
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -16,6 +17,16 @@ from ldm.util import instantiate_from_config
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def put_watermark(img, wm_encoder=None):
|
def put_watermark(img, wm_encoder=None):
|
||||||
if wm_encoder is not None:
|
if wm_encoder is not None:
|
||||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||||
|
@ -30,10 +41,9 @@ def initialize_model(config, ckpt):
|
||||||
|
|
||||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(get_device())
|
||||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device)
|
||||||
|
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
@ -67,8 +77,7 @@ def make_batch_sd(
|
||||||
|
|
||||||
|
|
||||||
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
|
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
|
||||||
device = torch.device(
|
device = torch.device(get_device())
|
||||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = sampler.model
|
model = sampler.model
|
||||||
|
|
||||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||||
|
@ -81,8 +90,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
|
||||||
start_code = torch.from_numpy(start_code).to(
|
start_code = torch.from_numpy(start_code).to(
|
||||||
device=device, dtype=torch.float32)
|
device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||||
with torch.no_grad(),\
|
with torch.no_grad(),\
|
||||||
torch.autocast("cuda"):
|
precision_scope(device.type):
|
||||||
batch = make_batch_sd(image, mask, txt=prompt,
|
batch = make_batch_sd(image, mask, txt=prompt,
|
||||||
device=device, num_samples=num_samples)
|
device=device, num_samples=num_samples)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
from contextlib import nullcontext
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
from scripts.txt2img import put_watermark
|
from scripts.txt2img import put_watermark
|
||||||
|
@ -17,15 +18,23 @@ from ldm.util import exists, instantiate_from_config
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
def initialize_model(config, ckpt):
|
def initialize_model(config, ckpt):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||||
|
|
||||||
device = torch.device(
|
device = get_device()
|
||||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,8 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
|
||||||
|
|
||||||
|
|
||||||
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
||||||
device = torch.device(
|
device = torch.device(get_device())
|
||||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = sampler.model
|
model = sampler.model
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
prng = np.random.RandomState(seed)
|
prng = np.random.RandomState(seed)
|
||||||
|
@ -67,8 +75,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
||||||
wm = "SDV2"
|
wm = "SDV2"
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||||
with torch.no_grad(),\
|
with torch.no_grad(),\
|
||||||
torch.autocast("cuda"):
|
precision_scope(device.type):
|
||||||
batch = make_batch_sd(
|
batch = make_batch_sd(
|
||||||
image, txt=prompt, device=device, num_samples=num_samples)
|
image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
|
|
@ -16,16 +16,15 @@ from pytorch_lightning import seed_everything
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
|
|
||||||
from ldm import global_opt as g
|
|
||||||
from scripts.txt2img import put_watermark
|
from scripts.txt2img import put_watermark
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
|
|
||||||
def get_device():
|
def get_device():
|
||||||
if(torch.cuda.is_available()):
|
if torch.cuda.is_available():
|
||||||
return 'cuda'
|
return 'cuda'
|
||||||
elif(torch.backends.mps.is_available()):
|
elif torch.backends.mps.is_available():
|
||||||
return 'mps'
|
return 'mps'
|
||||||
else:
|
else:
|
||||||
return 'cpu'
|
return 'cpu'
|
||||||
|
@ -197,7 +196,7 @@ def main():
|
||||||
default="autocast"
|
default="autocast"
|
||||||
)
|
)
|
||||||
|
|
||||||
opt = g.opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
from contextlib import nullcontext
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
from scripts.txt2img import put_watermark
|
from scripts.txt2img import put_watermark
|
||||||
|
@ -16,15 +17,25 @@ from ldm.data.util import AddMiDaS
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
@st.cache(allow_output_mutation=True)
|
@st.cache(allow_output_mutation=True)
|
||||||
def initialize_model(config, ckpt):
|
def initialize_model(config, ckpt):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device(get_device())
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +63,7 @@ def make_batch_sd(
|
||||||
|
|
||||||
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
||||||
do_full_sample=False):
|
do_full_sample=False):
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device(get_device())
|
||||||
model = sampler.model
|
model = sampler.model
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
|
@ -61,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
|
||||||
|
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||||
with torch.no_grad(),\
|
with torch.no_grad(),\
|
||||||
torch.autocast("cuda"):
|
precision_scope(device.type):
|
||||||
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
|
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
|
|
@ -7,6 +7,7 @@ from PIL import Image
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from streamlit_drawable_canvas import st_canvas
|
from streamlit_drawable_canvas import st_canvas
|
||||||
|
from contextlib import nullcontext
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
@ -16,6 +17,15 @@ from ldm.util import instantiate_from_config
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
def put_watermark(img, wm_encoder=None):
|
def put_watermark(img, wm_encoder=None):
|
||||||
if wm_encoder is not None:
|
if wm_encoder is not None:
|
||||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||||
|
@ -31,9 +41,9 @@ def initialize_model(config, ckpt):
|
||||||
|
|
||||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device(get_device())
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device)
|
||||||
|
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
@ -67,7 +77,7 @@ def make_batch_sd(
|
||||||
|
|
||||||
|
|
||||||
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
|
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = get_device()
|
||||||
model = sampler.model
|
model = sampler.model
|
||||||
|
|
||||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||||
|
@ -79,8 +89,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
|
||||||
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
||||||
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
|
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||||
with torch.no_grad(),\
|
with torch.no_grad(),\
|
||||||
torch.autocast("cuda"):
|
precision_scope(device.type):
|
||||||
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
|
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
|
||||||
|
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
|
|
@ -30,6 +30,15 @@ VERSION2SPECS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
def get_obj_from_str(string, reload=False):
|
def get_obj_from_str(string, reload=False):
|
||||||
module, cls = string.rsplit(".", 1)
|
module, cls = string.rsplit(".", 1)
|
||||||
importlib.invalidate_caches()
|
importlib.invalidate_caches()
|
||||||
|
@ -69,7 +78,7 @@ def load_img(display=True, key=None):
|
||||||
|
|
||||||
|
|
||||||
def get_init_img(batch_size=1, key=None):
|
def get_init_img(batch_size=1, key=None):
|
||||||
init_image = load_img(key=key).cuda()
|
init_image = load_img(key=key).to(get_device())
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
return init_image
|
return init_image
|
||||||
|
|
||||||
|
@ -97,7 +106,10 @@ def sample(
|
||||||
only_adm_cond=False
|
only_adm_cond=False
|
||||||
):
|
):
|
||||||
batch_size = n_samples
|
batch_size = n_samples
|
||||||
|
device = torch.device(get_device())
|
||||||
precision_scope = autocast if not use_full_precision else nullcontext
|
precision_scope = autocast if not use_full_precision else nullcontext
|
||||||
|
if device.type == 'mps':
|
||||||
|
precision_scope = nullcontext
|
||||||
# decoderscope = autocast if not use_full_precision else nullcontext
|
# decoderscope = autocast if not use_full_precision else nullcontext
|
||||||
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
|
@ -106,7 +118,7 @@ def sample(
|
||||||
|
|
||||||
outputs = st.empty()
|
outputs = st.empty()
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope(device.type):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(n_runs, desc="Sampling"):
|
for n in trange(n_runs, desc="Sampling"):
|
||||||
|
@ -208,6 +220,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
||||||
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
||||||
st.info("Loading full KARLO..")
|
st.info("Loading full KARLO..")
|
||||||
karlo = T2ISampler.from_pretrained(
|
karlo = T2ISampler.from_pretrained(
|
||||||
|
device=get_device(),
|
||||||
root_dir="checkpoints/karlo_models",
|
root_dir="checkpoints/karlo_models",
|
||||||
clip_model_path="ViT-L-14.pt",
|
clip_model_path="ViT-L-14.pt",
|
||||||
clip_stat_path="ViT-L-14_stats.th",
|
clip_stat_path="ViT-L-14_stats.th",
|
||||||
|
@ -227,6 +240,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
||||||
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
||||||
st.info("Loading KARLO CLIP prior...")
|
st.info("Loading KARLO CLIP prior...")
|
||||||
karlo_prior = PriorSampler.from_pretrained(
|
karlo_prior = PriorSampler.from_pretrained(
|
||||||
|
device=get_device(),
|
||||||
root_dir="checkpoints/karlo_models",
|
root_dir="checkpoints/karlo_models",
|
||||||
clip_model_path="ViT-L-14.pt",
|
clip_model_path="ViT-L-14.pt",
|
||||||
clip_stat_path="ViT-L-14_stats.th",
|
clip_stat_path="ViT-L-14_stats.th",
|
||||||
|
@ -263,7 +277,7 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
|
||||||
print("unexpected keys:")
|
print("unexpected keys:")
|
||||||
print(u)
|
print(u)
|
||||||
|
|
||||||
model.cuda()
|
model.to(get_device())
|
||||||
model.eval()
|
model.eval()
|
||||||
print(f"Loaded global step {global_step}")
|
print(f"Loaded global step {global_step}")
|
||||||
return model, msg
|
return model, msg
|
||||||
|
@ -301,7 +315,7 @@ if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if sampler == "DPM":
|
if sampler == "DPM":
|
||||||
sampler = DPMSolverSampler(state["model"])
|
sampler = DPMSolverSampler(state["model"], device=get_device())
|
||||||
elif sampler == "DDIM":
|
elif sampler == "DDIM":
|
||||||
sampler = DDIMSampler(state["model"])
|
sampler = DDIMSampler(state["model"])
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -6,6 +6,7 @@ from PIL import Image
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
from contextlib import nullcontext
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
from scripts.txt2img import put_watermark
|
from scripts.txt2img import put_watermark
|
||||||
|
@ -17,15 +18,24 @@ from ldm.util import exists, instantiate_from_config
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return 'mps'
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
@st.cache(allow_output_mutation=True)
|
@st.cache(allow_output_mutation=True)
|
||||||
def initialize_model(config, ckpt):
|
def initialize_model(config, ckpt):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device(get_device())
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
|
||||||
|
|
||||||
|
|
||||||
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device(get_device())
|
||||||
model = sampler.model
|
model = sampler.model
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
prng = np.random.RandomState(seed)
|
prng = np.random.RandomState(seed)
|
||||||
|
@ -64,8 +74,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
||||||
wm = "SDV2"
|
wm = "SDV2"
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
|
||||||
with torch.no_grad(),\
|
with torch.no_grad(),\
|
||||||
torch.autocast("cuda"):
|
precision_scope(device.type):
|
||||||
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
c_cat = list()
|
c_cat = list()
|
||||||
|
|
|
@ -13,7 +13,6 @@ from torch import autocast
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
from ldm import global_opt as g
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
@ -22,9 +21,9 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
def get_device():
|
def get_device():
|
||||||
if(torch.cuda.is_available()):
|
if torch.cuda.is_available():
|
||||||
return 'cuda'
|
return 'cuda'
|
||||||
elif(torch.backends.mps.is_available()):
|
elif torch.backends.mps.is_available():
|
||||||
return 'mps'
|
return 'mps'
|
||||||
else:
|
else:
|
||||||
return 'cpu'
|
return 'cpu'
|
||||||
|
@ -389,5 +388,5 @@ def main(opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
opt = g.opt = parse_args()
|
opt = parse_args()
|
||||||
main(opt)
|
main(opt)
|
||||||
|
|
Loading…
Reference in a new issue