From 6b0f10fda51049fc9d60ef6e025e3277f37dc325 Mon Sep 17 00:00:00 2001 From: Richa Gadgil Date: Fri, 9 Aug 2024 15:28:58 -0700 Subject: [PATCH] more batch sizes for SD2.1 (#3300) --- .../python_stable_diffusion_21/README.md | 6 +- .../python_stable_diffusion_21/gradio_app.py | 4 +- .../python_stable_diffusion_21/txt2img.py | 61 ++++++++++++------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/examples/diffusion/python_stable_diffusion_21/README.md b/examples/diffusion/python_stable_diffusion_21/README.md index dc977b12837..74ff657a4f6 100644 --- a/examples/diffusion/python_stable_diffusion_21/README.md +++ b/examples/diffusion/python_stable_diffusion_21/README.md @@ -37,12 +37,12 @@ optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onn ``` *Note: `models/sd21-onnx` will be used in the scripts.* -Run the text-to-image script with the following example prompt and seed: +Run the text-to-image script with the following example prompt and seed (optionally, you can change the batch size / number of images generated for that prompt) ```bash -python txt2img.py --prompt "a photograph of an astronaut riding a horse" --seed 13 --output astro_horse.jpg +python txt2img.py --prompt "a photograph of an astronaut riding a horse" --seed 13 --output astro_horse.jpg --batch 1 ``` -*Note: The first run will compile the models and cache them to make subsequent runs faster.* +*Note: The first run will compile the models and cache them to make subsequent runs faster. New batch sizes will result in the models re-compiling.* The result should look like this: diff --git a/examples/diffusion/python_stable_diffusion_21/gradio_app.py b/examples/diffusion/python_stable_diffusion_21/gradio_app.py index d9cc2c3f2f2..aeb35f0e9ac 100644 --- a/examples/diffusion/python_stable_diffusion_21/gradio_app.py +++ b/examples/diffusion/python_stable_diffusion_21/gradio_app.py @@ -30,7 +30,7 @@ def main(): args = get_args() # Note: This will load the models, which can take several minutes sd = StableDiffusionMGX(args.onnx_model_path, args.compiled_model_path, - args.fp16, args.force_compile, + args.fp16, args.batch, args.force_compile, args.exhaustive_tune) sd.warmup(5) @@ -51,7 +51,7 @@ def gr_wrapper(prompt, negative_prompt, steps, seed, scale): gr.Slider( 1, 20, step=0.1, value=args.scale, label="Guidance scale"), ], - "image", + gr.Gallery(), ) demo.launch() diff --git a/examples/diffusion/python_stable_diffusion_21/txt2img.py b/examples/diffusion/python_stable_diffusion_21/txt2img.py index a312d2fe63c..d9b200a404a 100644 --- a/examples/diffusion/python_stable_diffusion_21/txt2img.py +++ b/examples/diffusion/python_stable_diffusion_21/txt2img.py @@ -108,6 +108,12 @@ def get_args(): help="Number of steps", ) + parser.add_argument("-b", + "--batch", + type=int, + default=1, + help="Batch count or number of images to produce") + parser.add_argument( "-p", "--prompt", @@ -198,7 +204,7 @@ def allocate_torch_tensors(model): class StableDiffusionMGX(): - def __init__(self, onnx_model_path, compiled_model_path, fp16, + def __init__(self, onnx_model_path, compiled_model_path, fp16, batch, force_compile, exhaustive_tune): model_id = "stabilityai/stable-diffusion-2-1" print(f"Using {model_id}") @@ -215,17 +221,20 @@ def __init__(self, onnx_model_path, compiled_model_path, fp16, elif "all" in fp16: fp16 = ["vae", "clip", "unet"] + self.batch = batch + print("Load models...") self.models = { "vae": StableDiffusionMGX.load_mgx_model( - "vae_decoder", {"latent_sample": [1, 4, 64, 64]}, + "vae_decoder", {"latent_sample": [self.batch, 4, 64, 64]}, onnx_model_path, compiled_model_path=compiled_model_path, use_fp16="vae" in fp16, force_compile=force_compile, exhaustive_tune=exhaustive_tune, - offload_copy=False), + offload_copy=False, + batch=self.batch), "clip": StableDiffusionMGX.load_mgx_model( "text_encoder", {"input_ids": [2, 77]}, @@ -238,8 +247,8 @@ def __init__(self, onnx_model_path, compiled_model_path, fp16, "unet": StableDiffusionMGX.load_mgx_model( "unet", { - "sample": [2, 4, 64, 64], - "encoder_hidden_states": [2, 77, 1024], + "sample": [2 * self.batch, 4, 64, 64], + "encoder_hidden_states": [2 * self.batch, 77, 1024], "timestep": [1], }, onnx_model_path, @@ -247,7 +256,8 @@ def __init__(self, onnx_model_path, compiled_model_path, fp16, use_fp16="unet" in fp16, force_compile=force_compile, exhaustive_tune=exhaustive_tune, - offload_copy=False) + offload_copy=False, + batch=self.batch) } self.tensors = { @@ -317,7 +327,7 @@ def run(self, prompt, negative_prompt, steps, seed, scale): f"Creating random input data ({1}x{4}x{64}x{64}) (latents) with seed={seed}..." ) latents = torch.randn( - (1, 4, 64, 64), + (self.batch, 4, 64, 64), generator=torch.manual_seed(seed)).to(device="cuda") print("Apply initial noise sigma\n") @@ -369,12 +379,13 @@ def load_mgx_model(name, use_fp16=False, force_compile=False, exhaustive_tune=False, - offload_copy=True): + offload_copy=True, + batch=1): print(f"Loading {name} model...") if compiled_model_path is None: compiled_model_path = onnx_model_path onnx_file = f"{onnx_model_path}/{name}/model.onnx" - mxr_file = f"{compiled_model_path}/{name}/model_{'fp16' if use_fp16 else 'fp32'}_{'gpu' if not offload_copy else 'oc'}.mxr" + mxr_file = f"{compiled_model_path}/{name}/model_{'fp16' if use_fp16 else 'fp32'}_b{batch}_{'gpu' if not offload_copy else 'oc'}.mxr" if not force_compile and os.path.isfile(mxr_file): print(f"Found mxr, loading it from {mxr_file}") model = mgx.load(mxr_file, format="msgpack") @@ -410,14 +421,16 @@ def get_embeddings(self, prompt_tokens): copy_tensor_sync(self.tensors["clip"]["input_ids"], prompt_tokens.input_ids.to(torch.int32)) run_model_sync(self.models["clip"], self.model_args["clip"]) - return self.tensors["clip"][get_output_name(0)] + text_embeds = self.tensors["clip"][get_output_name(0)] + return torch.cat( + [torch.cat([i] * self.batch) for i in text_embeds.split(1)]) @staticmethod def convert_to_rgb_image(image): image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") - return Image.fromarray(images[0]) + return [Image.fromarray(images[i]) for i in range(images.shape[0])] @staticmethod def save_image(pil_image, filename="output.png"): @@ -458,14 +471,17 @@ def warmup(self, num_runs): self.profile_start("warmup") copy_tensor_sync(self.tensors["clip"]["input_ids"], torch.ones((2, 77)).to(torch.int32)) - copy_tensor_sync(self.tensors["unet"]["sample"], - torch.randn((2, 4, 64, 64)).to(torch.float32)) - copy_tensor_sync(self.tensors["unet"]["encoder_hidden_states"], - torch.randn((2, 77, 1024)).to(torch.float32)) + copy_tensor_sync( + self.tensors["unet"]["sample"], + torch.randn((2 * self.batch, 4, 64, 64)).to(torch.float32)) + copy_tensor_sync( + self.tensors["unet"]["encoder_hidden_states"], + torch.randn((2 * self.batch, 77, 1024)).to(torch.float32)) copy_tensor_sync(self.tensors["unet"]["timestep"], torch.atleast_1d(torch.randn(1).to(torch.int64))) - copy_tensor_sync(self.tensors["vae"]["latent_sample"], - torch.randn((1, 4, 64, 64)).to(torch.float32)) + copy_tensor_sync( + self.tensors["vae"]["latent_sample"], + torch.randn((self.batch, 4, 64, 64)).to(torch.float32)) for _ in range(num_runs): run_model_sync(self.models["clip"], self.model_args["clip"]) @@ -478,7 +494,7 @@ def warmup(self, num_runs): args = get_args() sd = StableDiffusionMGX(args.onnx_model_path, args.compiled_model_path, - args.fp16, args.force_compile, + args.fp16, args.batch, args.force_compile, args.exhaustive_tune) print("Warmup") sd.warmup(5) @@ -492,7 +508,8 @@ def warmup(self, num_runs): sd.cleanup() print("Convert result to rgb image...") - image = StableDiffusionMGX.convert_to_rgb_image(result) - filename = args.output if args.output else f"output_s{args.seed}_t{args.steps}.png" - StableDiffusionMGX.save_image(image, filename) - print(f"Image saved to {filename}") + images = StableDiffusionMGX.convert_to_rgb_image(result) + for i, image in enumerate(images): + filename = f"{args.batch}_{args.output}" if args.output else f"output_s{args.seed}_t{args.steps}_{i}.png" + StableDiffusionMGX.save_image(image, filename) + print(f"Image saved to {filename}")