From 639b3f3f01d2814aeed3d7bd62f5a38ff7a51509 Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Wed, 18 Jan 2023 10:24:38 +0100 Subject: [PATCH] make it work in sampling script --- scripts/streamlit/stablekarlo.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/scripts/streamlit/stablekarlo.py b/scripts/streamlit/stablekarlo.py index 905e62f..5121bd2 100644 --- a/scripts/streamlit/stablekarlo.py +++ b/scripts/streamlit/stablekarlo.py @@ -314,6 +314,10 @@ if __name__ == "__main__": if use_karlo: # uses the prior karlo_sampler = state["karlo_prior"] + noise_level = None + if state["model"].noise_augmentor is not None: + noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0, + max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0) with torch.no_grad(): karlo_prediction = iter( karlo_sampler( @@ -322,14 +326,25 @@ if __name__ == "__main__": progressive_mode="final", ) ).__next__() - adm_cond = karlo_prediction - adm_uc = torch.zeros_like(karlo_prediction) + adm_cond = karlo_prediction + if noise_level is not None: + c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat( + torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) + adm_cond = torch.cat((c_adm, noise_level_emb), 1) + adm_uc = torch.zeros_like(karlo_prediction) else: init_img = get_init_img(batch_size=number_cols) with torch.no_grad(): adm_cond = state["model"].embedder(init_img) adm_uc = torch.zeros_like(adm_cond) + if state["model"].noise_augmentor is not None: + noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0, + max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0) + c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat( + torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) + # assume this gives embeddings of noise levels + adm_cond = torch.cat((c_adm, noise_level_emb), 1) if st.button("Sample"): print("running prompt:", prompt)