make it work in sampling script

This commit is contained in:
Robin Rombach 2023-01-18 10:24:38 +01:00
parent 8ec7903da2
commit 639b3f3f01

View file

@ -314,6 +314,10 @@ if __name__ == "__main__":
if use_karlo: if use_karlo:
# uses the prior # uses the prior
karlo_sampler = state["karlo_prior"] karlo_sampler = state["karlo_prior"]
noise_level = None
if state["model"].noise_augmentor is not None:
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
with torch.no_grad(): with torch.no_grad():
karlo_prediction = iter( karlo_prediction = iter(
karlo_sampler( karlo_sampler(
@ -322,14 +326,25 @@ if __name__ == "__main__":
progressive_mode="final", progressive_mode="final",
) )
).__next__() ).__next__()
adm_cond = karlo_prediction adm_cond = karlo_prediction
adm_uc = torch.zeros_like(karlo_prediction) if noise_level is not None:
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
adm_uc = torch.zeros_like(karlo_prediction)
else: else:
init_img = get_init_img(batch_size=number_cols) init_img = get_init_img(batch_size=number_cols)
with torch.no_grad(): with torch.no_grad():
adm_cond = state["model"].embedder(init_img) adm_cond = state["model"].embedder(init_img)
adm_uc = torch.zeros_like(adm_cond) adm_uc = torch.zeros_like(adm_cond)
if state["model"].noise_augmentor is not None:
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
# assume this gives embeddings of noise levels
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
if st.button("Sample"): if st.button("Sample"):
print("running prompt:", prompt) print("running prompt:", prompt)