Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick SDXL demo update to 1.16.3 #18496

Merged
merged 4 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 126 additions & 57 deletions onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,31 @@
f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
)

pipeline_info = PipelineInfo(args.version)
pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size)
# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
# This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768.
min_image_size = 512 if args.engine != "ORT_CUDA" else 256
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
pipeline_info = PipelineInfo(args.version, min_image_size=min_image_size, max_image_size=max_image_size)

# Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
# optimize the shape used most frequently. We can let user config it when we develop a UI plugin.
# In this demo, we optimize batch size 1 and image size 512x512 (or 768x768 for SD 2.0/2.1) for dynamic engine.
# This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference.
opt_batch_size = 1 if args.build_dynamic_batch else batch_size
opt_image_height = pipeline_info.default_image_size() if args.build_dynamic_shape else args.height
opt_image_width = pipeline_info.default_image_size() if args.build_dynamic_shape else args.width

pipeline = init_pipeline(
Txt2ImgPipeline,
pipeline_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)

if engine_type == EngineType.TRT:
max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,7 @@
from pipeline_txt2img_xl import Txt2ImgXLPipeline


def run_demo():
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""

args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo")

prompt, negative_prompt = repeat_prompt(args)

# Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf).
image_height = args.height
image_width = args.width

def load_pipelines(args, batch_size):
# Register TensorRT plugins
engine_type = get_engine_type(args.engine)
if engine_type == EngineType.TRT:
Expand All @@ -49,37 +39,83 @@ def run_demo():

max_batch_size = 16
if (engine_type in [EngineType.ORT_TRT, EngineType.TRT]) and (
args.build_dynamic_shape or image_height > 512 or image_width > 512
args.build_dynamic_shape or args.height > 512 or args.width > 512
):
max_batch_size = 4

batch_size = len(prompt)
if batch_size > max_batch_size:
raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")

# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
# This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048

# No VAE decoder in base when it outputs latent instead of image.
base_info = PipelineInfo(args.version, use_vae=False)
base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size)
base_info = PipelineInfo(
args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size
)

refiner_info = PipelineInfo(args.version, is_refiner=True)
refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size)
# Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
# optimize the shape used most frequently. We can let user config it when we develop a UI plugin.
# In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine.
# This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference.
opt_batch_size = 1 if args.build_dynamic_batch else batch_size
opt_image_height = base_info.default_image_size() if args.build_dynamic_shape else args.height
opt_image_width = base_info.default_image_size() if args.build_dynamic_shape else args.width

base = init_pipeline(
Txt2ImgXLPipeline,
base_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)

refiner = None
if not args.disable_refiner:
refiner_info = PipelineInfo(
args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size
)
refiner = init_pipeline(
Img2ImgXLPipeline,
refiner_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)

if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory())
max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
base.backend.activate_engines(shared_device_memory)
refiner.backend.activate_engines(shared_device_memory)
if refiner:
refiner.backend.activate_engines(shared_device_memory)

if engine_type == EngineType.ORT_CUDA:
enable_vae_slicing = args.enable_vae_slicing
if batch_size > 4 and not enable_vae_slicing:
print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.")
enable_vae_slicing = True
if enable_vae_slicing:
refiner.backend.enable_vae_slicing()
(refiner or base).backend.enable_vae_slicing()
return base, refiner


def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False):
image_height = args.height
image_width = args.width
batch_size = len(prompt)
base.load_resources(image_height, image_width, batch_size)
refiner.load_resources(image_height, image_width, batch_size)
if refiner:
refiner.load_resources(image_height, image_width, batch_size)

def run_base_and_refiner(warmup=False):
images, time_base = base.run(
Expand All @@ -91,8 +127,13 @@ def run_base_and_refiner(warmup=False):
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
return_type="latent",
return_type="latent" if refiner else "image",
)
if refiner is None:
return images, time_base

# Use same seed in base and refiner.
seed = base.get_current_seed()

images, time_refiner = refiner.run(
prompt,
Expand All @@ -103,7 +144,7 @@ def run_base_and_refiner(warmup=False):
warmup=warmup,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
seed=seed,
)

return images, time_base + time_refiner
Expand All @@ -112,25 +153,104 @@ def run_base_and_refiner(warmup=False):
# inference once to get cuda graph
_, _ = run_base_and_refiner(warmup=True)

print("[I] Warming up ..")
if args.num_warmup_runs > 0:
print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
_, _ = run_base_and_refiner(warmup=True)

if is_warm_up:
return

print("[I] Running StableDiffusion XL pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
_, latency = run_base_and_refiner(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()

base.teardown()

print("|------------|--------------|")
print("| {:^10} | {:>9.2f} ms |".format("e2e", latency))
print("|------------|--------------|")
refiner.teardown()


def run_demo(args):
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""

prompt, negative_prompt = repeat_prompt(args)
batch_size = len(prompt)
base, refiner = load_pipelines(args, batch_size)
run_pipelines(args, base, refiner, prompt, negative_prompt)
base.teardown()
if refiner:
refiner.teardown()


def run_dynamic_shape_demo(args):
"""Run demo of generating images with different settings with ORT CUDA provider."""
args.engine = "ORT_CUDA"
args.disable_cuda_graph = True
base, refiner = load_pipelines(args, 1)

prompts = [
"starry night over Golden Gate Bridge by van gogh",
"beautiful photograph of Mt. Fuji during cherry blossom",
"little cute gremlin sitting on a bed, cinematic",
"cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
"beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
"blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
]

# batch size, height, width, scheduler, steps, prompt, seed
configs = [
(1, 832, 1216, "UniPC", 8, prompts[0], None),
(1, 1024, 1024, "DDIM", 24, prompts[1], None),
(1, 1216, 832, "UniPC", 16, prompts[2], None),
(1, 1344, 768, "DDIM", 24, prompts[3], None),
(2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712),
(2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906),
]

# Warm up each combination of (batch size, height, width) once before serving.
args.prompt = ["warm up"]
args.num_warmup_runs = 1
for batch_size, height, width, _, _, _, _ in configs:
args.batch_size = batch_size
args.height = height
args.width = width
print(f"\nWarm up batch_size={batch_size}, height={height}, width={width}")
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=True)

# Run pipeline on a list of prompts.
args.num_warmup_runs = 0
for batch_size, height, width, scheduler, steps, example_prompt, seed in configs:
args.prompt = [example_prompt]
args.batch_size = batch_size
args.height = height
args.width = width
args.scheduler = scheduler
args.denoising_steps = steps
args.seed = seed
base.set_scheduler(scheduler)
if refiner:
refiner.set_scheduler(scheduler)
print(
f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}"
)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)

base.teardown()
if refiner:
refiner.teardown()


if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
run_demo()

args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo")
no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0]
if no_prompt:
run_dynamic_shape_demo(args)
else:
run_demo(args)
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def parse_arguments(is_xl: bool, description: str):
help="Root Directory to store torch or ONNX models, built engines and output images etc.",
)

parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.")
parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.")

parser.add_argument(
"--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation."
)
parser.add_argument(
"--repeat-prompt",
"--batch-size",
type=int,
default=1,
choices=[1, 2, 4, 8, 16],
Expand Down Expand Up @@ -145,6 +145,10 @@ def parse_arguments(is_xl: bool, description: str):
parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.")
parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.")

parser.add_argument(
"--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline."
)

group = parser.add_argument_group("Options for ORT_CUDA engine only")
group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")

Expand Down Expand Up @@ -174,9 +178,9 @@ def parse_arguments(is_xl: bool, description: str):
)

# Validate image dimensions
if args.height % 8 != 0 or args.width % 8 != 0:
if args.height % 64 != 0 or args.width % 64 != 0:
raise ValueError(
f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}."
f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}."
)

if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
Expand All @@ -194,7 +198,7 @@ def parse_arguments(is_xl: bool, description: str):
def repeat_prompt(args):
if not isinstance(args.prompt, list):
raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
prompt = args.prompt * args.repeat_prompt
prompt = args.prompt * args.batch_size

if not isinstance(args.negative_prompt, list):
raise ValueError(
Expand All @@ -209,7 +213,9 @@ def repeat_prompt(args):
return prompt, negative_prompt


def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size):
def init_pipeline(
pipeline_class, pipeline_info, engine_type, args, max_batch_size, opt_batch_size, opt_image_height, opt_image_width
):
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type
)
Expand All @@ -234,9 +240,6 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=batch_size,
force_engine_rebuild=args.force_engine_build,
device_id=torch.cuda.current_device(),
)
Expand All @@ -247,14 +250,15 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
framework_model_dir,
onnx_dir,
args.onnx_opset,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=batch_size,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
opt_batch_size=opt_batch_size,
force_engine_rebuild=args.force_engine_build,
static_batch=not args.build_dynamic_batch,
static_image_shape=not args.build_dynamic_shape,
max_workspace_size=0,
device_id=torch.cuda.current_device(),
timing_cache=timing_cache,
)
elif engine_type == EngineType.TRT:
# Load TensorRT engines and pytorch modules
Expand All @@ -263,9 +267,9 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
framework_model_dir,
onnx_dir,
args.onnx_opset,
opt_batch_size=batch_size,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=opt_batch_size,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
force_export=args.force_onnx_export,
force_optimize=args.force_onnx_optimize,
force_build=args.force_engine_build,
Expand Down
Loading
Loading