mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
Merge pull request #147 from aalbersk/intel_cpu_optimizations
[Txt2Img] CPU support + TorchScript and Intel® Extension for PyTorch* optimizations
This commit is contained in:
commit
fc1488421a
11 changed files with 450 additions and 21 deletions
35
README.md
35
README.md
|
@ -137,6 +137,41 @@ Note: The inference config for all model versions is designed to be used with EM
|
||||||
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
|
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
|
||||||
non-EMA to EMA weights.
|
non-EMA to EMA weights.
|
||||||
|
|
||||||
|
#### Enable Intel® Extension for PyTorch* optimizations in Text-to-Image script
|
||||||
|
|
||||||
|
If you're planning on running Text-to-Image on Intel® CPU, try to sample an image with TorchScript and Intel® Extension for PyTorch* optimizations. Intel® Extension for PyTorch* extends PyTorch by enabling up-to-date features optimizations for an extra performance boost on Intel® hardware. It can optimize memory layout of the operators to Channel Last memory format, which is generally beneficial for Intel CPUs, take advantage of the most advanced instruction set available on a machine, optimize operators and many more.
|
||||||
|
|
||||||
|
**Prerequisites**
|
||||||
|
|
||||||
|
Before running the script, make sure you have all needed libraries installed. (the optimization was checked on `Ubuntu 20.04`). Install [jemalloc](https://github.com/jemalloc/jemalloc), [numactl](https://linux.die.net/man/8/numactl), Intel® OpenMP and Intel® Extension for PyTorch*.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
apt-get install numactl libjemalloc-dev
|
||||||
|
pip install intel-openmp
|
||||||
|
pip install intel_extension_for_pytorch -f https://software.intel.com/ipex-whl-stable
|
||||||
|
```
|
||||||
|
|
||||||
|
To sample from the _SD2.1-v_ model with TorchScript+IPEX optimizations, run the following. Remember to specify desired number of instances you want to run the program on ([more](https://github.com/intel/intel-extension-for-pytorch/blob/master/intel_extension_for_pytorch/cpu/launch.py#L48)).
|
||||||
|
|
||||||
|
```
|
||||||
|
MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance <number of an instance> --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt <path/to/768model.ckpt/> --config configs/stable-diffusion/intel/v2-inference-v-fp32.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex
|
||||||
|
```
|
||||||
|
|
||||||
|
To sample from the base model with IPEX optimizations, use
|
||||||
|
|
||||||
|
```
|
||||||
|
MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance <number of an instance> --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt <path/to/model.ckpt/> --config configs/stable-diffusion/intel/v2-inference-fp32.yaml --n_samples 1 --n_iter 4 --precision full --device cpu --torchscript --ipex
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're using a CPU that supports `bfloat16`, consider sample from the model with bfloat16 enabled for a performance boost, like so
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# SD2.1-v
|
||||||
|
MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance <number of an instance> --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt <path/to/768model.ckpt/> --config configs/stable-diffusion/intel/v2-inference-v-bf16.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex --bf16
|
||||||
|
# SD2.1-base
|
||||||
|
MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance <number of an instance> --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt <path/to/model.ckpt/> --config configs/stable-diffusion/intel/v2-inference-bf16.yaml --precision full --device cpu --torchscript --ipex --bf16
|
||||||
|
```
|
||||||
|
|
||||||
### Image Modification with Stable Diffusion
|
### Image Modification with Stable Diffusion
|
||||||
|
|
||||||
![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png)
|
![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png)
|
||||||
|
|
71
configs/stable-diffusion/intel/v2-inference-bf16.yaml
Normal file
71
configs/stable-diffusion/intel/v2-inference-bf16.yaml
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (C) 2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: False
|
||||||
|
use_fp16: False
|
||||||
|
use_bf16: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
70
configs/stable-diffusion/intel/v2-inference-fp32.yaml
Normal file
70
configs/stable-diffusion/intel/v2-inference-fp32.yaml
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright (C) 2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: False
|
||||||
|
use_fp16: False
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
72
configs/stable-diffusion/intel/v2-inference-v-bf16.yaml
Normal file
72
configs/stable-diffusion/intel/v2-inference-v-bf16.yaml
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (C) 2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
parameterization: "v"
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: False
|
||||||
|
use_fp16: False
|
||||||
|
use_bf16: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
71
configs/stable-diffusion/intel/v2-inference-v-fp32.yaml
Normal file
71
configs/stable-diffusion/intel/v2-inference-v-fp32.yaml
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (C) 2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
parameterization: "v"
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: False
|
||||||
|
use_fp16: False
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
|
@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
def __init__(self, model, schedule="linear", **kwargs):
|
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
if attr.device != self.device:
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(self.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
|
|
@ -1326,7 +1326,13 @@ class DiffusionWrapper(pl.LightningModule):
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
else:
|
else:
|
||||||
cc = c_crossattn
|
cc = c_crossattn
|
||||||
out = self.diffusion_model(x, t, context=cc)
|
if hasattr(self, "scripted_diffusion_model"):
|
||||||
|
# TorchScript changes names of the arguments
|
||||||
|
# with argument cc defined as context=cc scripted model will produce
|
||||||
|
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||||
|
out = self.scripted_diffusion_model(x, t, cc)
|
||||||
|
else:
|
||||||
|
out = self.diffusion_model(x, t, context=cc)
|
||||||
elif self.conditioning_key == 'hybrid':
|
elif self.conditioning_key == 'hybrid':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
|
|
|
@ -11,16 +11,17 @@ MODEL_TYPES = {
|
||||||
|
|
||||||
|
|
||||||
class DPMSolverSampler(object):
|
class DPMSolverSampler(object):
|
||||||
def __init__(self, model, **kwargs):
|
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.device = device
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
if attr.device != self.device:
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(self.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
class PLMSSampler(object):
|
class PLMSSampler(object):
|
||||||
def __init__(self, model, schedule="linear", **kwargs):
|
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
if attr.device != self.device:
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(self.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
|
|
@ -454,6 +454,7 @@ class UNetModel(nn.Module):
|
||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
use_bf16=False,
|
||||||
num_heads=-1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
|
@ -518,6 +519,7 @@ class UNetModel(nn.Module):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = th.float16 if use_fp16 else th.float32
|
self.dtype = th.float16 if use_fp16 else th.float32
|
||||||
|
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
|
|
@ -25,7 +25,7 @@ def chunk(it, size):
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
|
@ -40,7 +40,13 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print("unexpected keys:")
|
print("unexpected keys:")
|
||||||
print(u)
|
print(u)
|
||||||
|
|
||||||
model.cuda()
|
if device == torch.device("cuda"):
|
||||||
|
model.cuda()
|
||||||
|
elif device == torch.device("cpu"):
|
||||||
|
model.cpu()
|
||||||
|
model.cond_stage_model.device = "cpu"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Incorrect device name. Received: {device}")
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -171,6 +177,28 @@ def parse_args():
|
||||||
default=1,
|
default=1,
|
||||||
help="repeat each prompt in file this often",
|
help="repeat each prompt in file this often",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
help="Device on which Stable Diffusion will be run",
|
||||||
|
choices=["cpu", "cuda"],
|
||||||
|
default="cpu"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--torchscript",
|
||||||
|
action='store_true',
|
||||||
|
help="Use TorchScript",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ipex",
|
||||||
|
action='store_true',
|
||||||
|
help="Use Intel® Extension for PyTorch*",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bf16",
|
||||||
|
action='store_true',
|
||||||
|
help="Use bfloat16",
|
||||||
|
)
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
@ -187,17 +215,15 @@ def main(opt):
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}", device)
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
if opt.plms:
|
if opt.plms:
|
||||||
sampler = PLMSSampler(model)
|
sampler = PLMSSampler(model, device=device)
|
||||||
elif opt.dpm:
|
elif opt.dpm:
|
||||||
sampler = DPMSolverSampler(model)
|
sampler = DPMSolverSampler(model, device=device)
|
||||||
else:
|
else:
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model, device=device)
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
outpath = opt.outdir
|
outpath = opt.outdir
|
||||||
|
@ -231,9 +257,82 @@ def main(opt):
|
||||||
if opt.fixed_code:
|
if opt.fixed_code:
|
||||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
if opt.torchscript or opt.ipex:
|
||||||
|
transformer = model.cond_stage_model.model
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
decoder = model.first_stage_model.decoder
|
||||||
|
additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext()
|
||||||
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
|
|
||||||
|
if opt.bf16 and not opt.torchscript and not opt.ipex:
|
||||||
|
raise ValueError('Bfloat16 is supported only for torchscript+ipex')
|
||||||
|
if opt.bf16 and unet.dtype != torch.bfloat16:
|
||||||
|
raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " +
|
||||||
|
"you'd like to use bfloat16 with CPU.")
|
||||||
|
if unet.dtype == torch.float16 and device == torch.device("cpu"):
|
||||||
|
raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.")
|
||||||
|
|
||||||
|
if opt.ipex:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
bf16_dtype = torch.bfloat16 if opt.bf16 else None
|
||||||
|
transformer = transformer.to(memory_format=torch.channels_last)
|
||||||
|
transformer = ipex.optimize(transformer, level="O1", inplace=True)
|
||||||
|
|
||||||
|
unet = unet.to(memory_format=torch.channels_last)
|
||||||
|
unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
|
||||||
|
|
||||||
|
decoder = decoder.to(memory_format=torch.channels_last)
|
||||||
|
decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype)
|
||||||
|
|
||||||
|
if opt.torchscript:
|
||||||
|
with torch.no_grad(), additional_context:
|
||||||
|
# get UNET scripted
|
||||||
|
if unet.use_checkpoint:
|
||||||
|
raise ValueError("Gradient checkpoint won't work with tracing. " +
|
||||||
|
"Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.")
|
||||||
|
|
||||||
|
img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32)
|
||||||
|
t_in = torch.ones(2, dtype=torch.int64)
|
||||||
|
context = torch.ones(2, 77, 1024, dtype=torch.float32)
|
||||||
|
scripted_unet = torch.jit.trace(unet, (img_in, t_in, context))
|
||||||
|
scripted_unet = torch.jit.optimize_for_inference(scripted_unet)
|
||||||
|
print(type(scripted_unet))
|
||||||
|
model.model.scripted_diffusion_model = scripted_unet
|
||||||
|
|
||||||
|
# get Decoder for first stage model scripted
|
||||||
|
samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32)
|
||||||
|
scripted_decoder = torch.jit.trace(decoder, (samples_ddim))
|
||||||
|
scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder)
|
||||||
|
print(type(scripted_decoder))
|
||||||
|
model.first_stage_model.decoder = scripted_decoder
|
||||||
|
|
||||||
|
prompts = data[0]
|
||||||
|
print("Running a forward pass to initialize optimizations")
|
||||||
|
uc = None
|
||||||
|
if opt.scale != 1.0:
|
||||||
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
|
||||||
|
with torch.no_grad(), additional_context:
|
||||||
|
for _ in range(3):
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
samples_ddim, _ = sampler.sample(S=5,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt.ddim_eta,
|
||||||
|
x_T=start_code)
|
||||||
|
print("Running a forward pass for decoder")
|
||||||
|
for _ in range(3):
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
|
||||||
|
precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext
|
||||||
with torch.no_grad(), \
|
with torch.no_grad(), \
|
||||||
precision_scope("cuda"), \
|
precision_scope(opt.device), \
|
||||||
model.ema_scope():
|
model.ema_scope():
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
|
Loading…
Reference in a new issue