mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
adding deepspeed for inference optimization
This commit is contained in:
parent
47b6b607fd
commit
5820f54cb0
3 changed files with 26 additions and 0 deletions
|
@ -13,4 +13,5 @@ transformers==4.19.2
|
||||||
webdataset==0.2.5
|
webdataset==0.2.5
|
||||||
open-clip-torch==2.7.0
|
open-clip-torch==2.7.0
|
||||||
gradio==3.11
|
gradio==3.11
|
||||||
|
deepspeed==0.7.5
|
||||||
-e .
|
-e .
|
||||||
|
|
|
@ -20,6 +20,8 @@ from scripts.txt2img import put_watermark
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
def chunk(it, size):
|
||||||
it = iter(it)
|
it = iter(it)
|
||||||
|
@ -43,6 +45,16 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
ds_engine = deepspeed.init_inference(model,
|
||||||
|
mp_size=2,
|
||||||
|
dtype=torch.half,
|
||||||
|
checkpoint=None,
|
||||||
|
replace_method='auto',
|
||||||
|
replace_with_kernel_inject=True)
|
||||||
|
|
||||||
|
model = ds_engine.module
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
@ -39,9 +40,21 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
if len(u) > 0 and verbose:
|
if len(u) > 0 and verbose:
|
||||||
print("unexpected keys:")
|
print("unexpected keys:")
|
||||||
print(u)
|
print(u)
|
||||||
|
|
||||||
|
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# Initialize the DeepSpeed-Inference engine
|
||||||
|
ds_engine = deepspeed.init_inference(model,
|
||||||
|
mp_size=2,
|
||||||
|
dtype=torch.half,
|
||||||
|
checkpoint=None,
|
||||||
|
replace_method='auto',
|
||||||
|
replace_with_kernel_inject=True)
|
||||||
|
|
||||||
|
model = ds_engine.module
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue