support image mixing in streamlit

This commit is contained in:
Robin Rombach 2023-02-23 12:45:18 +01:00
parent fe1cf687e9
commit 4e89f578ea

View file

@ -18,7 +18,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False)
PROMPTS_ROOT = "scripts/prompts/"
@ -46,8 +45,8 @@ def instantiate_from_config(config):
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_interactive_image():
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
def get_interactive_image(key=None):
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
@ -55,8 +54,8 @@ def get_interactive_image():
return image
def load_img(display=True):
image = get_interactive_image()
def load_img(display=True, key=None):
image = get_interactive_image(key=key)
if display:
st.image(image)
w, h = image.size
@ -69,8 +68,8 @@ def load_img(display=True):
return 2. * image - 1.
def get_init_img(batch_size=1):
init_image = load_img().cuda()
def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
return init_image
@ -99,7 +98,7 @@ def sample(
):
batch_size = n_samples
precision_scope = autocast if not use_full_precision else nullcontext
#decoderscope = autocast if not use_full_precision else nullcontext
# decoderscope = autocast if not use_full_precision else nullcontext
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
if isinstance(prompt, str):
prompt = [prompt]
@ -274,8 +273,8 @@ if __name__ == "__main__":
st.title("Stable unCLIP")
mode = "txt2img"
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
state = init(version=version, load_karlo_prior=use_karlo)
use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
state = init(version=version, load_karlo_prior=use_karlo_prior)
prompt = st.text_input("Prompt", "a professional photograph")
negative_prompt = st.text_input("Negative Prompt", "")
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
@ -309,7 +308,7 @@ if __name__ == "__main__":
raise ValueError(f"unknown sampler {sampler}!")
adm_cond, adm_uc = None, None
if use_karlo:
if use_karlo_prior:
# uses the prior
karlo_sampler = state["karlo_prior"]
noise_level = None
@ -330,19 +329,44 @@ if __name__ == "__main__":
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
adm_uc = torch.zeros_like(adm_cond)
elif version == "Full Karlo":
pass
else:
init_img = get_init_img(batch_size=number_cols)
with torch.no_grad():
adm_cond = state["model"].embedder(init_img)
if state["model"].noise_augmentor is not None:
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
# assume this gives embeddings of noise levels
num_inputs = st.number_input("Number of Input Images", 1)
def make_conditionings_from_input(num=1, key=None):
init_img = get_init_img(batch_size=number_cols, key=key)
with torch.no_grad():
adm_cond = state["model"].embedder(init_img)
weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.)
if state["model"].noise_augmentor is not None:
noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1,
value=0, )
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight
adm_uc = torch.zeros_like(adm_cond)
return adm_cond, adm_uc, weight
adm_inputs = list()
weights = list()
for n in range(num_inputs):
adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n)
weights.append(w)
adm_inputs.append(adm_cond)
adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights)
if num_inputs > 1:
if st.checkbox("Apply Noise to Embedding Mix", True):
noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0, )
c_adm, noise_level_emb = state["model"].noise_augmentor(
adm_cond[:, :state["model"].noise_augmentor.time_embed.dim],
noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
adm_uc = torch.zeros_like(adm_cond)
if st.button("Sample"):
print("running prompt:", prompt)
@ -350,10 +374,12 @@ if __name__ == "__main__":
t_progress = st.progress(0)
result = st.empty()
def t_callback(t):
t_progress.progress(min((t + 1) / steps, 1.))
if version == "KARLO":
if version == "Full Karlo":
outputs = st.empty()
karlo_sampler = state["karlo_prior"]
all_samples = list()
@ -388,4 +414,3 @@ if __name__ == "__main__":
use_full_precision=force_full_precision,
only_adm_cond=False
)