mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
support image mixing in streamlit
This commit is contained in:
parent
fe1cf687e9
commit
4e89f578ea
1 changed files with 49 additions and 24 deletions
|
@ -18,7 +18,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||||
|
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
PROMPTS_ROOT = "scripts/prompts/"
|
PROMPTS_ROOT = "scripts/prompts/"
|
||||||
|
@ -46,8 +45,8 @@ def instantiate_from_config(config):
|
||||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||||
|
|
||||||
|
|
||||||
def get_interactive_image():
|
def get_interactive_image(key=None):
|
||||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
|
@ -55,8 +54,8 @@ def get_interactive_image():
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def load_img(display=True):
|
def load_img(display=True, key=None):
|
||||||
image = get_interactive_image()
|
image = get_interactive_image(key=key)
|
||||||
if display:
|
if display:
|
||||||
st.image(image)
|
st.image(image)
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
|
@ -69,8 +68,8 @@ def load_img(display=True):
|
||||||
return 2. * image - 1.
|
return 2. * image - 1.
|
||||||
|
|
||||||
|
|
||||||
def get_init_img(batch_size=1):
|
def get_init_img(batch_size=1, key=None):
|
||||||
init_image = load_img().cuda()
|
init_image = load_img(key=key).cuda()
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
return init_image
|
return init_image
|
||||||
|
|
||||||
|
@ -99,7 +98,7 @@ def sample(
|
||||||
):
|
):
|
||||||
batch_size = n_samples
|
batch_size = n_samples
|
||||||
precision_scope = autocast if not use_full_precision else nullcontext
|
precision_scope = autocast if not use_full_precision else nullcontext
|
||||||
#decoderscope = autocast if not use_full_precision else nullcontext
|
# decoderscope = autocast if not use_full_precision else nullcontext
|
||||||
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
|
@ -274,8 +273,8 @@ if __name__ == "__main__":
|
||||||
st.title("Stable unCLIP")
|
st.title("Stable unCLIP")
|
||||||
mode = "txt2img"
|
mode = "txt2img"
|
||||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||||
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
use_karlo_prior = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
||||||
state = init(version=version, load_karlo_prior=use_karlo)
|
state = init(version=version, load_karlo_prior=use_karlo_prior)
|
||||||
prompt = st.text_input("Prompt", "a professional photograph")
|
prompt = st.text_input("Prompt", "a professional photograph")
|
||||||
negative_prompt = st.text_input("Negative Prompt", "")
|
negative_prompt = st.text_input("Negative Prompt", "")
|
||||||
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
||||||
|
@ -309,7 +308,7 @@ if __name__ == "__main__":
|
||||||
raise ValueError(f"unknown sampler {sampler}!")
|
raise ValueError(f"unknown sampler {sampler}!")
|
||||||
|
|
||||||
adm_cond, adm_uc = None, None
|
adm_cond, adm_uc = None, None
|
||||||
if use_karlo:
|
if use_karlo_prior:
|
||||||
# uses the prior
|
# uses the prior
|
||||||
karlo_sampler = state["karlo_prior"]
|
karlo_sampler = state["karlo_prior"]
|
||||||
noise_level = None
|
noise_level = None
|
||||||
|
@ -330,19 +329,44 @@ if __name__ == "__main__":
|
||||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
adm_uc = torch.zeros_like(adm_cond)
|
adm_uc = torch.zeros_like(adm_cond)
|
||||||
|
elif version == "Full Karlo":
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
init_img = get_init_img(batch_size=number_cols)
|
num_inputs = st.number_input("Number of Input Images", 1)
|
||||||
|
|
||||||
|
|
||||||
|
def make_conditionings_from_input(num=1, key=None):
|
||||||
|
init_img = get_init_img(batch_size=number_cols, key=key)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
adm_cond = state["model"].embedder(init_img)
|
adm_cond = state["model"].embedder(init_img)
|
||||||
|
weight = st.slider(f"Weight for Input {num}", min_value=-10., max_value=10., value=1.)
|
||||||
if state["model"].noise_augmentor is not None:
|
if state["model"].noise_augmentor is not None:
|
||||||
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
noise_level = st.number_input(f"Noise Augmentation for CLIP embedding of input #{num}", min_value=0,
|
||||||
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
max_value=state["model"].noise_augmentor.max_noise_level - 1,
|
||||||
|
value=0, )
|
||||||
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||||
# assume this gives embeddings of noise levels
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
|
||||||
adm_uc = torch.zeros_like(adm_cond)
|
adm_uc = torch.zeros_like(adm_cond)
|
||||||
|
return adm_cond, adm_uc, weight
|
||||||
|
|
||||||
|
|
||||||
|
adm_inputs = list()
|
||||||
|
weights = list()
|
||||||
|
for n in range(num_inputs):
|
||||||
|
adm_cond, adm_uc, w = make_conditionings_from_input(num=n + 1, key=n)
|
||||||
|
weights.append(w)
|
||||||
|
adm_inputs.append(adm_cond)
|
||||||
|
adm_cond = torch.stack(adm_inputs).sum(0) / sum(weights)
|
||||||
|
if num_inputs > 1:
|
||||||
|
if st.checkbox("Apply Noise to Embedding Mix", True):
|
||||||
|
noise_level = st.number_input(f"Noise Augmentation for averaged CLIP embeddings", min_value=0,
|
||||||
|
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0, )
|
||||||
|
c_adm, noise_level_emb = state["model"].noise_augmentor(
|
||||||
|
adm_cond[:, :state["model"].noise_augmentor.time_embed.dim],
|
||||||
|
noise_level=repeat(
|
||||||
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||||
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
|
||||||
if st.button("Sample"):
|
if st.button("Sample"):
|
||||||
print("running prompt:", prompt)
|
print("running prompt:", prompt)
|
||||||
|
@ -350,10 +374,12 @@ if __name__ == "__main__":
|
||||||
t_progress = st.progress(0)
|
t_progress = st.progress(0)
|
||||||
result = st.empty()
|
result = st.empty()
|
||||||
|
|
||||||
|
|
||||||
def t_callback(t):
|
def t_callback(t):
|
||||||
t_progress.progress(min((t + 1) / steps, 1.))
|
t_progress.progress(min((t + 1) / steps, 1.))
|
||||||
|
|
||||||
if version == "KARLO":
|
|
||||||
|
if version == "Full Karlo":
|
||||||
outputs = st.empty()
|
outputs = st.empty()
|
||||||
karlo_sampler = state["karlo_prior"]
|
karlo_sampler = state["karlo_prior"]
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
|
@ -388,4 +414,3 @@ if __name__ == "__main__":
|
||||||
use_full_precision=force_full_precision,
|
use_full_precision=force_full_precision,
|
||||||
only_adm_cond=False
|
only_adm_cond=False
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue