mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Added type hints and improved docstrings for DPMSolverSampler class.
This commit is contained in:
parent
cf1d67a6fd
commit
80ca3cd485
1 changed files with 69 additions and 28 deletions
|
@ -1,4 +1,3 @@
|
|||
"""SAMPLING ONLY."""
|
||||
import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
@ -8,46 +7,87 @@ MODEL_TYPES = {
|
|||
"v": "v"
|
||||
}
|
||||
|
||||
|
||||
class DPMSolverSampler(object):
|
||||
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
||||
def __init__(self, model: torch.nn.Module, device: torch.device = torch.device("cuda"), **kwargs) -> None:
|
||||
"""
|
||||
Initialize the DPMSolverSampler.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to use for sampling.
|
||||
device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.device = device
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
def register_buffer(self, name: str, attr: torch.Tensor) -> None:
|
||||
"""
|
||||
Register a buffer in the module.
|
||||
|
||||
Args:
|
||||
name (str): The name of the buffer.
|
||||
attr (torch.Tensor): The tensor to register as a buffer.
|
||||
"""
|
||||
if isinstance(attr, torch.Tensor):
|
||||
if attr.device != self.device:
|
||||
attr = attr.to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
S: int,
|
||||
batch_size: int,
|
||||
shape: tuple,
|
||||
conditioning: dict = None,
|
||||
callback: callable = None,
|
||||
normals_sequence: list = None,
|
||||
img_callback: callable = None,
|
||||
quantize_x0: bool = False,
|
||||
eta: float = 0.,
|
||||
mask: torch.Tensor = None,
|
||||
x0: torch.Tensor = None,
|
||||
temperature: float = 1.,
|
||||
noise_dropout: float = 0.,
|
||||
score_corrector: callable = None,
|
||||
corrector_kwargs: dict = None,
|
||||
verbose: bool = True,
|
||||
x_T: torch.Tensor = None,
|
||||
log_every_t: int = 100,
|
||||
unconditional_guidance_scale: float = 1.,
|
||||
unconditional_conditioning: torch.Tensor = None,
|
||||
**kwargs) -> tuple:
|
||||
"""
|
||||
Perform sampling using the DPM Solver.
|
||||
|
||||
Args:
|
||||
S (int): Number of steps.
|
||||
batch_size (int): Batch size for sampling.
|
||||
shape (tuple): Shape of the samples (C, H, W).
|
||||
conditioning (dict, optional): Conditioning information. Defaults to None.
|
||||
callback (callable, optional): Callback function. Defaults to None.
|
||||
normals_sequence (list, optional): Sequence of normals. Defaults to None.
|
||||
img_callback (callable, optional): Image callback function. Defaults to None.
|
||||
quantize_x0 (bool, optional): Flag for quantizing x0. Defaults to False.
|
||||
eta (float, optional): Eta parameter. Defaults to 0..
|
||||
mask (torch.Tensor, optional): Mask tensor. Defaults to None.
|
||||
x0 (torch.Tensor, optional): Initial x0 tensor. Defaults to None.
|
||||
temperature (float, optional): Temperature parameter. Defaults to 1..
|
||||
noise_dropout (float, optional): Noise dropout parameter. Defaults to 0..
|
||||
score_corrector (callable, optional): Score corrector. Defaults to None.
|
||||
corrector_kwargs (dict, optional): Keyword arguments for the score corrector. Defaults to None.
|
||||
verbose (bool, optional): Verbose flag. Defaults to True.
|
||||
x_T (torch.Tensor, optional): Initial x_T tensor. Defaults to None.
|
||||
log_every_t (int, optional): Log interval. Defaults to 100.
|
||||
unconditional_guidance_scale (float, optional): Guidance scale for unconditional sampling. Defaults to 1..
|
||||
unconditional_conditioning (torch.Tensor, optional): Conditioning tensor for unconditional sampling. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
tuple: Sampled tensor and additional information.
|
||||
"""
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
|
@ -94,3 +134,4 @@ class DPMSolverSampler(object):
|
|||
lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
||||
|
||||
|
|
Loading…
Reference in a new issue