add cpu support & add intel ipex optimizations

This commit is contained in:
aalbersk 2022-12-20 15:18:33 +00:00
parent d55bcd4d31
commit 7ad54c5ee9
10 changed files with 403 additions and 21 deletions

View file

@ -0,0 +1,68 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: False
use_fp16: False
use_bf16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View file

@ -0,0 +1,67 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: False
use_fp16: False
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View file

@ -0,0 +1,69 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: False
use_fp16: False
use_bf16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View file

@ -0,0 +1,68 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: False
use_fp16: False
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View file

@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class DDIMSampler(object): class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View file

@ -1326,6 +1326,12 @@ class DiffusionWrapper(pl.LightningModule):
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
else: else:
cc = c_crossattn cc = c_crossattn
if hasattr(self, "scripted_diffusion_model"):
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc)
else:
out = self.diffusion_model(x, t, context=cc) out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)

View file

@ -11,16 +11,17 @@ MODEL_TYPES = {
class DPMSolverSampler(object): class DPMSolverSampler(object):
def __init__(self, model, **kwargs): def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
@torch.no_grad() @torch.no_grad()

View file

@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object): class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View file

@ -454,6 +454,7 @@ class UNetModel(nn.Module):
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
@ -518,6 +519,7 @@ class UNetModel(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample

View file

@ -25,7 +25,7 @@ def chunk(it, size):
return iter(lambda: tuple(islice(it, size)), ()) return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
@ -40,7 +40,13 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:") print("unexpected keys:")
print(u) print(u)
if device == torch.device("cuda"):
model.cuda() model.cuda()
elif device == torch.device("cpu"):
model.cpu()
model.cond_stage_model.device = "cpu"
else:
raise ValueError(f"Incorrect device name. Received: {device}")
model.eval() model.eval()
return model return model
@ -171,6 +177,28 @@ def parse_args():
default=1, default=1,
help="repeat each prompt in file this often", help="repeat each prompt in file this often",
) )
parser.add_argument(
"--device",
type=str,
help="Device on which Stable Diffusion will be run",
choices=["cpu", "cuda"],
default="cpu"
)
parser.add_argument(
"--torchscript",
action='store_true',
help="Use TorchScript",
)
parser.add_argument(
"--ipex",
action='store_true',
help="Use Intel® Extension for PyTorch*",
)
parser.add_argument(
"--bf16",
action='store_true',
help="Use bfloat16",
)
opt = parser.parse_args() opt = parser.parse_args()
return opt return opt
@ -187,17 +215,15 @@ def main(opt):
seed_everything(opt.seed) seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}") config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}") device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
model = load_model_from_config(config, f"{opt.ckpt}", device)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
if opt.plms: if opt.plms:
sampler = PLMSSampler(model) sampler = PLMSSampler(model, device=device)
elif opt.dpm: elif opt.dpm:
sampler = DPMSolverSampler(model) sampler = DPMSolverSampler(model, device=device)
else: else:
sampler = DDIMSampler(model) sampler = DDIMSampler(model, device=device)
os.makedirs(opt.outdir, exist_ok=True) os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir outpath = opt.outdir
@ -231,9 +257,82 @@ def main(opt):
if opt.fixed_code: if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision == "autocast" else nullcontext if opt.torchscript or opt.ipex:
transformer = model.cond_stage_model.model
unet = model.model.diffusion_model
decoder = model.first_stage_model.decoder
additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext()
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if opt.bf16 and not opt.torchscript and not opt.ipex:
raise ValueError('Bfloat16 is supported only for torchscript+ipex')
if opt.bf16 and unet.dtype != torch.bfloat16:
raise ValueError("Use configs/stable-diffusion/ipex/ configs with bf16 enabled if " +
"you'd like to use bfloat16 with CPU.")
if unet.dtype == torch.float16 and device == torch.device("cpu"):
raise ValueError("Use configs/stable-diffusion/ipex/ configs for your model if you'd like to run it on CPU.")
if opt.ipex:
import intel_extension_for_pytorch as ipex
bf16_dtype = torch.bfloat16 if opt.bf16 else None
transformer = transformer.to(memory_format=torch.channels_last)
transformer = ipex.optimize(transformer, level="O1", inplace=True)
unet = unet.to(memory_format=torch.channels_last)
unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
decoder = decoder.to(memory_format=torch.channels_last)
decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
if opt.torchscript:
with torch.no_grad(), additional_context:
# get UNET scripted
if unet.use_checkpoint:
raise ValueError("Gradient checkpoint won't work with tracing. " +
"Use configs/stable-diffusion/ipex/ configs for your model or disable checkpoint in your config.")
img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
t_in = torch.ones(2, dtype=torch.int64)
context = torch.ones(2, 77, 1024, dtype=torch.float32)
scripted_unet = torch.jit.trace(unet, (img_in, t_in, context))
scripted_unet = torch.jit.optimize_for_inference(scripted_unet)
print(type(scripted_unet))
model.model.scripted_diffusion_model = scripted_unet
# get Decoder for first stage model scripted
samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32)
scripted_decoder = torch.jit.trace(decoder, (samples_ddim))
scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder)
print(type(scripted_decoder))
model.first_stage_model.decoder = scripted_decoder
prompts = data[0]
print("Running a forward pass to initialize optimizations")
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
with torch.no_grad(), additional_context:
for _ in range(3):
c = model.get_learned_conditioning(prompts)
samples_ddim, _ = sampler.sample(S=5,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
print("Running a forward pass for decoder")
for _ in range(3):
x_samples_ddim = model.decode_first_stage(samples_ddim)
precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
with torch.no_grad(), \ with torch.no_grad(), \
precision_scope("cuda"), \ precision_scope(opt.device), \
model.ema_scope(): model.ema_scope():
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):