mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
make it work
This commit is contained in:
parent
45287f9ed7
commit
929625ac9a
2 changed files with 5 additions and 7 deletions
|
@ -16,6 +16,7 @@ model:
|
|||
conditioning_key: crossattn-adm
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
use_ema: False
|
||||
|
||||
embedder_config:
|
||||
target: ldm.modules.encoders.modules.ClipImageEmbedder
|
||||
|
|
|
@ -117,8 +117,6 @@ def sample(
|
|||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
if isinstance(model, Txt2ImgDiffusionWithPooledInput):
|
||||
c, uc = c[0], uc[0]
|
||||
|
||||
if adm_cond is not None:
|
||||
if adm_cond.shape[0] == 1:
|
||||
|
@ -272,15 +270,14 @@ if __name__ == "__main__":
|
|||
mode = "txt2img"
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
use_karlo = st.checkbox("Use KARLO prior", False)
|
||||
state = init(version=version, vae_version=vae_version, load_karlo_prior=use_karlo)
|
||||
state = init(version=version, load_karlo_prior=use_karlo)
|
||||
st.info(state["msg"])
|
||||
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
||||
negative_prompt = st.text_input("Negative Prompt", "")
|
||||
scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.)
|
||||
number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10)
|
||||
number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10)
|
||||
default_steps = 25
|
||||
steps = st.sidebar.number_input("steps", value=default_steps, min_value=1, max_value=1000)
|
||||
steps = st.sidebar.number_input("steps", value=40, min_value=1, max_value=1000)
|
||||
eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
|
||||
force_full_precision = st.sidebar.checkbox("Force FP32", False)
|
||||
if version != "Full Karlo":
|
||||
|
@ -289,14 +286,14 @@ if __name__ == "__main__":
|
|||
C = VERSION2SPECS[version]["C"]
|
||||
f = VERSION2SPECS[version]["f"]
|
||||
|
||||
SAVE_PATH = os.path.join(SAVE_PATH, version + "_" + vae_version + "-decoder")
|
||||
SAVE_PATH = os.path.join(SAVE_PATH, version)
|
||||
os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True)
|
||||
|
||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
seed_everything(seed)
|
||||
|
||||
ucg_schedule = None
|
||||
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 2)
|
||||
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 0)
|
||||
if version == "Full Karlo":
|
||||
pass
|
||||
else:
|
||||
|
|
Loading…
Reference in a new issue