support image mixing in streamlit

This commit is contained in:
Robin Rombach 2023-02-23 12:45:18 +01:00
parent fe1cf687e9
commit 4e89f578ea

View file

@ -18,7 +18,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
PROMPTS_ROOT = "scripts/prompts/" PROMPTS_ROOT = "scripts/prompts/"
@ -46,8 +45,8 @@ def instantiate_from_config(config):
return get_obj_from_str(config["target"])(**config.get("params", dict())) return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_interactive_image(): def get_interactive_image(key=None):
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None: if image is not None:
image = Image.open(image) image = Image.open(image)
if not image.mode == "RGB": if not image.mode == "RGB":
@ -55,8 +54,8 @@ def get_interactive_image():
return image return image
def load_img(display=True): def load_img(display=True, key=None):
image = get_interactive_image() image = get_interactive_image(key=key)
if display: if display:
st.image(image) st.image(image)
w, h = image.size w, h = image.size
@ -69,8 +68,8 @@ def load_img(display=True):
return 2. * image - 1. return 2. * image - 1.
def get_init_img(batch_size=1): def get_init_img(batch_size=1, key=None):
init_image = load_img().cuda() init_image = load_img(key=key).cuda()
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
return init_image return init_image
@ -274,8 +273,8 @@ if __name__ == "__main__":
st.title("Stable unCLIP") st.title("Stable unCLIP")
mode = "txt2img" mode = "txt2img"
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False) use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
state = init(version=version, load_karlo_prior=use_karlo) state = init(version=version, load_karlo_prior=use_karlo_prior)
prompt = st.text_input("Prompt", "a professional photograph") prompt = st.text_input("Prompt", "a professional photograph")
negative_prompt = st.text_input("Negative Prompt", "") negative_prompt = st.text_input("Negative Prompt", "")
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
@ -309,7 +308,7 @@ if __name__ == "__main__":
raise ValueError(f"unknown sampler {sampler}!") raise ValueError(f"unknown sampler {sampler}!")
adm_cond, adm_uc = None, None adm_cond, adm_uc = None, None
if use_karlo: if use_karlo_prior:
# uses the prior # uses the prior
karlo_sampler = state["karlo_prior"] karlo_sampler = state["karlo_prior"]
noise_level = None noise_level = None
@ -330,19 +329,44 @@ if __name__ == "__main__":
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1) adm_cond = torch.cat((c_adm, noise_level_emb), 1)
adm_uc = torch.zeros_like(adm_cond) adm_uc = torch.zeros_like(adm_cond)
elif version == "Full Karlo":
pass
else: else:
init_img = get_init_img(batch_size=number_cols) num_inputs = st.number_input("Number of Input Images", 1)
def make_conditionings_from_input(num=1, key=None):
init_img = get_init_img(batch_size=number_cols, key=key)
with torch.no_grad(): with torch.no_grad():
adm_cond = state["model"].embedder(init_img) adm_cond = state["model"].embedder(init_img)
weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.)
if state["model"].noise_augmentor is not None: if state["model"].noise_augmentor is not None:
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0, noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0) max_value=state["model"].noise_augmentor.max_noise_level - 1,
value=0, )
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat( c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
# assume this gives embeddings of noise levels adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
adm_uc = torch.zeros_like(adm_cond) adm_uc = torch.zeros_like(adm_cond)
return adm_cond, adm_uc, weight
adm_inputs = list()
weights = list()
for n in range(num_inputs):
adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n)
weights.append(w)
adm_inputs.append(adm_cond)
adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights)
if num_inputs > 1:
if st.checkbox("Apply Noise to Embedding Mix", True):
noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0, )
c_adm, noise_level_emb = state["model"].noise_augmentor(
adm_cond[:, :state["model"].noise_augmentor.time_embed.dim],
noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
if st.button("Sample"): if st.button("Sample"):
print("running prompt:", prompt) print("running prompt:", prompt)
@ -350,10 +374,12 @@ if __name__ == "__main__":
t_progress = st.progress(0) t_progress = st.progress(0)
result = st.empty() result = st.empty()
def t_callback(t): def t_callback(t):
t_progress.progress(min((t + 1) / steps, 1.)) t_progress.progress(min((t + 1) / steps, 1.))
if version == "KARLO":
if version == "Full Karlo":
outputs = st.empty() outputs = st.empty()
karlo_sampler = state["karlo_prior"] karlo_sampler = state["karlo_prior"]
all_samples = list() all_samples = list()
@ -388,4 +414,3 @@ if __name__ == "__main__":
use_full_precision=force_full_precision, use_full_precision=force_full_precision,
only_adm_cond=False only_adm_cond=False
) )