Skip to content

Commit

Permalink
feat: add button to enable LoRAs (#2210)
Browse files Browse the repository at this point in the history
* Initial commit

* Update README.md

* sync with original main Fooocus repo

* update with my gitignore setup

* add max lora config feature

* Revert "add max lora config feature"

This reverts commit cfe7463.

* add lora enabler feature

* Update README.md

* Update .gitignore

* update

* merge

* revert changes

* revert

* feat: change width of LoRA columns

* refactor: rename lora_enable to lora_enabled, optimize code

---------

Co-authored-by: Manuel Schmid <manuel.schmid@odt.net>
  • Loading branch information
MindOfMatter and mashb1t committed Feb 25, 2024
1 parent eebd775 commit 468d704
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
10 changes: 9 additions & 1 deletion modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def build_image_wall(async_task):
# must use deep copy otherwise gradio is super laggy. Do not use list.append() .
async_task.results = async_task.results + [wall]
return

def apply_enabled_loras(loras):
enabled_loras = []
for lora_enabled, lora_model, lora_weight in loras:
if lora_enabled:
enabled_loras.append([lora_model, lora_weight])

return enabled_loras

@torch.no_grad()
@torch.inference_mode()
Expand All @@ -137,7 +145,7 @@ def handler(async_task):
base_model_name = args.pop()
refiner_model_name = args.pop()
refiner_switch = args.pop()
loras = [[str(args.pop()), float(args.pop())] for _ in range(5)]
loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(5)])
input_image_checkbox = args.pop()
current_tab = args.pop()
uov_method = args.pop()
Expand Down
24 changes: 24 additions & 0 deletions modules/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,30 @@
margin-left: -5px !important;
}
.lora_enable {
flex-grow: 1 !important;
}
.lora_enable label {
height: 100%;
}
.lora_enable label input {
margin: auto;
}
.lora_enable label span {
display: none;
}
.lora_model {
flex-grow: 5 !important;
}
.lora_weight {
flex-grow: 5 !important;
}
'''
progress_html = '''
<div class="loader-container">
Expand Down
6 changes: 4 additions & 2 deletions modules/meta_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,12 @@ def load_parameter_button_click(raw_prompt_txt, is_generating):
try:
n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ')
w = float(w)
results.append(True)
results.append(n)
results.append(w)
except:
results.append(gr.update())
results.append(gr.update())
results.append(True)
results.append("None")
results.append(1.0)

return results
9 changes: 6 additions & 3 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,14 @@ def update_history_link():

for i, (n, v) in enumerate(modules.config.default_loras):
with gr.Row():
lora_enabled = gr.Checkbox(label='Enable', value=True,
elem_classes=['lora_enable', 'min_check'])
lora_model = gr.Dropdown(label=f'LoRA {i + 1}',
choices=['None'] + modules.config.lora_filenames, value=n)
choices=['None'] + modules.config.lora_filenames, value=n,
elem_classes='lora_model')
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v,
elem_classes='lora_weight')
lora_ctrls += [lora_model, lora_weight]
lora_ctrls += [lora_enabled, lora_model, lora_weight]

with gr.Row():
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
Expand Down Expand Up @@ -471,7 +474,7 @@ def model_refresh_clicked():
results = []
results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
for i in range(5):
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update(), gr.update(interactive=True)]
return results

model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
Expand Down

0 comments on commit 468d704

Please sign in to comment.