mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
make it work in sampling script
This commit is contained in:
parent
8ec7903da2
commit
639b3f3f01
1 changed files with 17 additions and 2 deletions
|
@ -314,6 +314,10 @@ if __name__ == "__main__":
|
||||||
if use_karlo:
|
if use_karlo:
|
||||||
# uses the prior
|
# uses the prior
|
||||||
karlo_sampler = state["karlo_prior"]
|
karlo_sampler = state["karlo_prior"]
|
||||||
|
noise_level = None
|
||||||
|
if state["model"].noise_augmentor is not None:
|
||||||
|
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
||||||
|
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
karlo_prediction = iter(
|
karlo_prediction = iter(
|
||||||
karlo_sampler(
|
karlo_sampler(
|
||||||
|
@ -323,6 +327,10 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
).__next__()
|
).__next__()
|
||||||
adm_cond = karlo_prediction
|
adm_cond = karlo_prediction
|
||||||
|
if noise_level is not None:
|
||||||
|
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||||
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||||
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
adm_uc = torch.zeros_like(karlo_prediction)
|
adm_uc = torch.zeros_like(karlo_prediction)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -330,6 +338,13 @@ if __name__ == "__main__":
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
adm_cond = state["model"].embedder(init_img)
|
adm_cond = state["model"].embedder(init_img)
|
||||||
adm_uc = torch.zeros_like(adm_cond)
|
adm_uc = torch.zeros_like(adm_cond)
|
||||||
|
if state["model"].noise_augmentor is not None:
|
||||||
|
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
||||||
|
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
||||||
|
c_adm, noise_level_emb = state["model"].noise_augmentor(adm_cond, noise_level=repeat(
|
||||||
|
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||||
|
# assume this gives embeddings of noise levels
|
||||||
|
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
|
||||||
if st.button("Sample"):
|
if st.button("Sample"):
|
||||||
print("running prompt:", prompt)
|
print("running prompt:", prompt)
|
||||||
|
|
Loading…
Reference in a new issue