mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
support image mixing in streamlit
This commit is contained in:
parent
fe1cf687e9
commit
4e89f578ea
1 changed files with 49 additions and 24 deletions
|
@ -18,7 +18,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
|||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
PROMPTS_ROOT = "scripts/prompts/"
|
||||
|
@ -46,8 +45,8 @@ def instantiate_from_config(config):
|
|||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_interactive_image():
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
||||
def get_interactive_image(key=None):
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
|
@ -55,8 +54,8 @@ def get_interactive_image():
|
|||
return image
|
||||
|
||||
|
||||
def load_img(display=True):
|
||||
image = get_interactive_image()
|
||||
def load_img(display=True, key=None):
|
||||
image = get_interactive_image(key=key)
|
||||
if display:
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
|
@ -69,8 +68,8 @@ def load_img(display=True):
|
|||
return 2. * image - 1.
|
||||
|
||||
|
||||
def get_init_img(batch_size=1):
|
||||
init_image = load_img().cuda()
|
||||
def get_init_img(batch_size=1, key=None):
|
||||
init_image = load_img(key=key).cuda()
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
return init_image
|
||||
|
||||
|
@ -99,7 +98,7 @@ def sample(
|
|||
):
|
||||
batch_size = n_samples
|
||||
precision_scope = autocast if not use_full_precision else nullcontext
|
||||
#decoderscope = autocast if not use_full_precision else nullcontext
|
||||
# decoderscope = autocast if not use_full_precision else nullcontext
|
||||
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
@ -274,8 +273,8 @@ if __name__ == "__main__":
|
|||
st.title("Stable unCLIP")
|
||||
mode = "txt2img"
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
||||
state = init(version=version, load_karlo_prior=use_karlo)
|
||||
use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
||||
state = init(version=version, load_karlo_prior=use_karlo_prior)
|
||||
prompt = st.text_input("Prompt", "a professional photograph")
|
||||
negative_prompt = st.text_input("Negative Prompt", "")
|
||||
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
||||
|
@ -309,7 +308,7 @@ if __name__ == "__main__":
|
|||
raise ValueError(f"unknown sampler {sampler}!")
|
||||
|
||||
adm_cond, adm_uc = None, None
|
||||
if use_karlo:
|
||||
if use_karlo_prior:
|
||||
# uses the prior
|
||||
karlo_sampler = state["karlo_prior"]
|
||||
noise_level = None
|
||||
|
@ -330,19 +329,44 @@ if __name__ == "__main__":
|
|||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||
adm_uc = torch.zeros_like(adm_cond)
|
||||
|
||||
elif version == "Full Karlo":
|
||||
pass
|
||||
else:
|
||||
init_img = get_init_img(batch_size=number_cols)
|
||||
num_inputs = st.number_input("Number of Input Images", 1)
|
||||
|
||||
|
||||
def make_conditionings_from_input(num=1, key=None):
|
||||
init_img = get_init_img(batch_size=number_cols, key=key)
|
||||
with torch.no_grad():
|
||||
adm_cond = state["model"].embedder(init_img)
|
||||
weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.)
|
||||
if state["model"].noise_augmentor is not None:
|
||||
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
||||
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
||||
noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0,
|
||||
max_value=state["model"].noise_augmentor.max_noise_level - 1,
|
||||
value=0, )
|
||||
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
# assume this gives embeddings of noise levels
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||
adm_uc = torch.zeros_like(adm_cond)
|
||||
return adm_cond, adm_uc, weight
|
||||
|
||||
|
||||
adm_inputs = list()
|
||||
weights = list()
|
||||
for n in range(num_inputs):
|
||||
adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n)
|
||||
weights.append(w)
|
||||
adm_inputs.append(adm_cond)
|
||||
adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights)
|
||||
if num_inputs > 1:
|
||||
if st.checkbox("Apply Noise to Embedding Mix", True):
|
||||
noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0,
|
||||
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0, )
|
||||
c_adm, noise_level_emb = state["model"].noise_augmentor(
|
||||
adm_cond[:, :state["model"].noise_augmentor.time_embed.dim],
|
||||
noise_level=repeat(
|
||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||
|
||||
if st.button("Sample"):
|
||||
print("running prompt:", prompt)
|
||||
|
@ -350,10 +374,12 @@ if __name__ == "__main__":
|
|||
t_progress = st.progress(0)
|
||||
result = st.empty()
|
||||
|
||||
|
||||
def t_callback(t):
|
||||
t_progress.progress(min((t + 1) / steps, 1.))
|
||||
|
||||
if version == "KARLO":
|
||||
|
||||
if version == "Full Karlo":
|
||||
outputs = st.empty()
|
||||
karlo_sampler = state["karlo_prior"]
|
||||
all_samples = list()
|
||||
|
@ -388,4 +414,3 @@ if __name__ == "__main__":
|
|||
use_full_precision=force_full_precision,
|
||||
only_adm_cond=False
|
||||
)
|
||||
|
Loading…
Reference in a new issue