From e5a500dafc3d2a7958300a3a270383f6c5691860 Mon Sep 17 00:00:00 2001 From: MindOfMatter Date: Thu, 25 Jan 2024 15:56:18 -0500 Subject: [PATCH] add lora enabler feature --- modules/async_worker.py | 15 ++++++++++++++- modules/html.py | 16 ++++++++++++++++ modules/meta_parser.py | 10 ++++++---- webui.py | 5 +++-- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index b2af671..14fb349 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -110,6 +110,18 @@ def build_image_wall(async_task): # must use deep copy otherwise gradio is super laggy. Do not use list.append() . async_task.results = async_task.results + [wall] return + + def apply_enable_loras(loras): + # Initialize an empty list to hold the LoRAs with the enable status applied + loras_with_applied_enable = [] + + # Process each LoRA setting based on its enable state + for lora_setting in loras: + lora_model, lora_weight, lora_enable = lora_setting + if lora_enable: # Only add the LoRA setting if it is enabled + loras_with_applied_enable.append([lora_model, lora_weight]) + + return loras_with_applied_enable @torch.no_grad() @torch.inference_mode() @@ -131,7 +143,8 @@ def handler(async_task): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = [[str(args.pop()), float(args.pop())] for _ in range(5)] + loras = [[str(args.pop()), float(args.pop()), bool(args.pop())] for _ in range(5)] + loras = apply_enable_loras(loras) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() diff --git a/modules/html.py b/modules/html.py index 3ec6f2d..2d25f41 100644 --- a/modules/html.py +++ b/modules/html.py @@ -112,6 +112,22 @@ margin-left: -5px !important; } +.lora_enable { + min-width: min(0px, 100%) !important; +} + +.lora_enable label { + height: 100%; +} + +.lora_enable label input { + margin: auto; +} + +.lora_enable label span { + display: none; +} + ''' progress_html = '''
diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 07b42a1..6bbdd38 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -139,10 +139,12 @@ def load_parameter_button_click(raw_prompt_txt, is_generating): try: n, w = loaded_parameter_dict.get(f'LoRA {i}').split(' : ') w = float(w) - results.append(n) - results.append(w) + results.append(n) # Update LoRA model + results.append(w) # Update LoRA weight + results.append(True) # Enable the LoRA setting by default except: - results.append(gr.update()) - results.append(gr.update()) + results.append("None") # Update LoRA model + results.append(1.0) # Update LoRA weight + results.append(True) # Enable the LoRA setting by default return results diff --git a/webui.py b/webui.py index fadd852..a725b1f 100644 --- a/webui.py +++ b/webui.py @@ -313,11 +313,12 @@ def refresh_seed(r, seed_string): for i, (n, v) in enumerate(modules.config.default_loras): with gr.Row(): + lora_enable = gr.Checkbox(label='Enable', value=True, elem_classes='lora_enable', container=True) lora_model = gr.Dropdown(label=f'LoRA {i + 1}', choices=['None'] + modules.config.lora_filenames, value=n) lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v, elem_classes='lora_weight') - lora_ctrls += [lora_model, lora_weight] + lora_ctrls += [lora_model, lora_weight, lora_enable] with gr.Row(): model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button') @@ -467,7 +468,7 @@ def model_refresh_clicked(): results = [] results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)] for i in range(5): - results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] + results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update(), gr.update(interactive=True)] return results model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,