mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Added type hints and improved docstrings for DPMSolverSampler class.
This commit is contained in:
parent
cf1d67a6fd
commit
80ca3cd485
1 changed files with 69 additions and 28 deletions
|
@ -1,4 +1,3 @@
|
||||||
"""SAMPLING ONLY."""
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||||
|
@ -8,46 +7,87 @@ MODEL_TYPES = {
|
||||||
"v": "v"
|
"v": "v"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DPMSolverSampler(object):
|
class DPMSolverSampler(object):
|
||||||
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
def __init__(self, model: torch.nn.Module, device: torch.device = torch.device("cuda"), **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the DPMSolverSampler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model to use for sampling.
|
||||||
|
device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name: str, attr: torch.Tensor) -> None:
|
||||||
if type(attr) == torch.Tensor:
|
"""
|
||||||
|
Register a buffer in the module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the buffer.
|
||||||
|
attr (torch.Tensor): The tensor to register as a buffer.
|
||||||
|
"""
|
||||||
|
if isinstance(attr, torch.Tensor):
|
||||||
if attr.device != self.device:
|
if attr.device != self.device:
|
||||||
attr = attr.to(self.device)
|
attr = attr.to(self.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self,
|
def sample(self,
|
||||||
S,
|
S: int,
|
||||||
batch_size,
|
batch_size: int,
|
||||||
shape,
|
shape: tuple,
|
||||||
conditioning=None,
|
conditioning: dict = None,
|
||||||
callback=None,
|
callback: callable = None,
|
||||||
normals_sequence=None,
|
normals_sequence: list = None,
|
||||||
img_callback=None,
|
img_callback: callable = None,
|
||||||
quantize_x0=False,
|
quantize_x0: bool = False,
|
||||||
eta=0.,
|
eta: float = 0.,
|
||||||
mask=None,
|
mask: torch.Tensor = None,
|
||||||
x0=None,
|
x0: torch.Tensor = None,
|
||||||
temperature=1.,
|
temperature: float = 1.,
|
||||||
noise_dropout=0.,
|
noise_dropout: float = 0.,
|
||||||
score_corrector=None,
|
score_corrector: callable = None,
|
||||||
corrector_kwargs=None,
|
corrector_kwargs: dict = None,
|
||||||
verbose=True,
|
verbose: bool = True,
|
||||||
x_T=None,
|
x_T: torch.Tensor = None,
|
||||||
log_every_t=100,
|
log_every_t: int = 100,
|
||||||
unconditional_guidance_scale=1.,
|
unconditional_guidance_scale: float = 1.,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning: torch.Tensor = None,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
**kwargs) -> tuple:
|
||||||
**kwargs
|
"""
|
||||||
):
|
Perform sampling using the DPM Solver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
S (int): Number of steps.
|
||||||
|
batch_size (int): Batch size for sampling.
|
||||||
|
shape (tuple): Shape of the samples (C, H, W).
|
||||||
|
conditioning (dict, optional): Conditioning information. Defaults to None.
|
||||||
|
callback (callable, optional): Callback function. Defaults to None.
|
||||||
|
normals_sequence (list, optional): Sequence of normals. Defaults to None.
|
||||||
|
img_callback (callable, optional): Image callback function. Defaults to None.
|
||||||
|
quantize_x0 (bool, optional): Flag for quantizing x0. Defaults to False.
|
||||||
|
eta (float, optional): Eta parameter. Defaults to 0..
|
||||||
|
mask (torch.Tensor, optional): Mask tensor. Defaults to None.
|
||||||
|
x0 (torch.Tensor, optional): Initial x0 tensor. Defaults to None.
|
||||||
|
temperature (float, optional): Temperature parameter. Defaults to 1..
|
||||||
|
noise_dropout (float, optional): Noise dropout parameter. Defaults to 0..
|
||||||
|
score_corrector (callable, optional): Score corrector. Defaults to None.
|
||||||
|
corrector_kwargs (dict, optional): Keyword arguments for the score corrector. Defaults to None.
|
||||||
|
verbose (bool, optional): Verbose flag. Defaults to True.
|
||||||
|
x_T (torch.Tensor, optional): Initial x_T tensor. Defaults to None.
|
||||||
|
log_every_t (int, optional): Log interval. Defaults to 100.
|
||||||
|
unconditional_guidance_scale (float, optional): Guidance scale for unconditional sampling. Defaults to 1..
|
||||||
|
unconditional_conditioning (torch.Tensor, optional): Conditioning tensor for unconditional sampling. Defaults to None.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Sampled tensor and additional information.
|
||||||
|
"""
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
@ -94,3 +134,4 @@ class DPMSolverSampler(object):
|
||||||
lower_order_final=True)
|
lower_order_final=True)
|
||||||
|
|
||||||
return x.to(device), None
|
return x.to(device), None
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue