mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
make it work
This commit is contained in:
parent
45287f9ed7
commit
929625ac9a
2 changed files with 5 additions and 7 deletions
|
@ -16,6 +16,7 @@ model:
|
||||||
conditioning_key: crossattn-adm
|
conditioning_key: crossattn-adm
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
embedder_config:
|
embedder_config:
|
||||||
target: ldm.modules.encoders.modules.ClipImageEmbedder
|
target: ldm.modules.encoders.modules.ClipImageEmbedder
|
||||||
|
|
|
@ -117,8 +117,6 @@ def sample(
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
if isinstance(model, Txt2ImgDiffusionWithPooledInput):
|
|
||||||
c, uc = c[0], uc[0]
|
|
||||||
|
|
||||||
if adm_cond is not None:
|
if adm_cond is not None:
|
||||||
if adm_cond.shape[0] == 1:
|
if adm_cond.shape[0] == 1:
|
||||||
|
@ -272,15 +270,14 @@ if __name__ == "__main__":
|
||||||
mode = "txt2img"
|
mode = "txt2img"
|
||||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||||
use_karlo = st.checkbox("Use KARLO prior", False)
|
use_karlo = st.checkbox("Use KARLO prior", False)
|
||||||
state = init(version=version, vae_version=vae_version, load_karlo_prior=use_karlo)
|
state = init(version=version, load_karlo_prior=use_karlo)
|
||||||
st.info(state["msg"])
|
st.info(state["msg"])
|
||||||
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
||||||
negative_prompt = st.text_input("Negative Prompt", "")
|
negative_prompt = st.text_input("Negative Prompt", "")
|
||||||
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
||||||
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
||||||
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
||||||
default_steps = 25
|
steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000)
|
||||||
steps = st.sidebar.number_input("steps", value=default_steps, min_value=1, max_value=1000)
|
|
||||||
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
||||||
force_full_precision = st.sidebar.checkbox("Force FP32", False)
|
force_full_precision = st.sidebar.checkbox("Force FP32", False)
|
||||||
if version != "Full Karlo":
|
if version != "Full Karlo":
|
||||||
|
@ -289,14 +286,14 @@ if __name__ == "__main__":
|
||||||
C = VERSION2SPECS[version]["C"]
|
C = VERSION2SPECS[version]["C"]
|
||||||
f = VERSION2SPECS[version]["f"]
|
f = VERSION2SPECS[version]["f"]
|
||||||
|
|
||||||
SAVE_PATH = os.path.join(SAVE_PATH, version + "_" + vae_version + "-decoder")
|
SAVE_PATH = os.path.join(SAVE_PATH, version)
|
||||||
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
|
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
|
||||||
|
|
||||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
ucg_schedule = None
|
ucg_schedule = None
|
||||||
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 2)
|
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 0)
|
||||||
if version == "Full Karlo":
|
if version == "Full Karlo":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue