From 3aba736ee2bf3228a95c823f20dbfce0c3b1ff61 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 15 Aug 2023 09:36:31 -0700 Subject: [PATCH] Refactoring of Stable Diffusion scripts (#17138) Reduce duplicated code in two stable diffusion pipelines (CUDA and TensorRT). Move the common code to models.py --- .../models/stable_diffusion/models.py | 368 ++++++++++++++++ .../onnxruntime_cuda_txt2img.py | 412 +----------------- .../onnxruntime_tensorrt_txt2img.py | 328 +------------- .../models/stable_diffusion/ort_optimizer.py | 84 ++++ .../models/stable_diffusion/ort_utils.py | 118 ++++- 5 files changed, 594 insertions(+), 716 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/models.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py new file mode 100644 index 000000000000..0f7688a3df9f --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py @@ -0,0 +1,368 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Models used in Stable diffusion. +""" +import logging + +import onnx +import onnx_graphsurgeon as gs +import torch +from onnx import shape_inference +from ort_optimizer import OrtStableDiffusionOptimizer +from polygraphy.backend.onnx.loader import fold_constants + +logger = logging.getLogger(__name__) + + +class TrtOptimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self): + self.graph.cleanup().toposort() + + def get_optimized_onnx_graph(self): + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + + def infer_shapes(self): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + + +class BaseModel: + def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77): + self.model = model + self.name = name + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae" + self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type) + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT EP""" + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" + + if self.name != "CLIP": + if static_image_shape: + profile_id += f"_h_{image_height}_w_{image_width}" + else: + profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" + + return profile_id + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT""" + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_image_shape else self.min_image_shape + max_image_height = image_height if static_image_shape else self.max_image_shape + min_image_width = image_width if static_image_shape else self.min_image_shape + max_image_width = image_width if static_image_shape else self.max_image_shape + min_latent_height = latent_height if static_image_shape else self.min_latent_shape + max_latent_height = latent_height if static_image_shape else self.max_latent_shape + min_latent_width = latent_width if static_image_shape else self.min_latent_shape + max_latent_width = latent_width if static_image_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +class CLIP(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super().__init__( + model=model, + name="CLIP", + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + ) + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_image_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + onnx.save(onnx_opt_graph, optimized_onnx_path) + + +class UNet(BaseModel): + def __init__( + self, + model, + device="cuda", + fp16=False, # used by TRT + max_batch_size=16, + embedding_dim=768, + text_maxlen=77, + unet_dim=4, + ): + super().__init__( + model=model, + name="UNet", + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": [1], + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +class VAE(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super().__init__( + model=model, + name="VAE Decoder", + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + ) + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py index 37c74217c6b4..6134fa7bddcf 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py @@ -31,9 +31,8 @@ pip install onnxruntime-gpu """ -import gc +import logging import os -import shutil from typing import List, Optional, Union import torch @@ -44,385 +43,13 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE, logging +from diffusers.utils import DIFFUSERS_CACHE from huggingface_hub import snapshot_download -from ort_utils import OrtCudaSession +from models import CLIP, VAE, UNet +from ort_utils import Engines from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -import onnxruntime as ort -from onnxruntime.transformers.fusion_options import FusionOptions -from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel -from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel -from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel -from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class Engine(OrtCudaSession): - def __init__(self, engine_path, provider, device_id: int = 0, enable_cuda_graph=False): - self.engine_path = engine_path - self.provider = provider - self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph) - - device = torch.device("cuda", device_id) - ort_session = ort.InferenceSession( - self.engine_path, - providers=[ - (provider, self.provider_options), - "CPUExecutionProvider", - ], - ) - - super().__init__(ort_session, device, enable_cuda_graph) - - def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool): - return { - "device_id": device_id, - "arena_extend_strategy": "kSameAsRequested", - "enable_cuda_graph": enable_cuda_graph, - } - - -class OrtStableDiffusionOptimizer: - def __init__(self, model_type: str): - assert model_type in ["vae", "unet", "clip"] - self.model_type = model_type - self.model_type_class_mapping = { - "unet": UnetOnnxModel, - "vae": VaeOnnxModel, - "clip": ClipOnnxModel, - } - - def optimize_by_ort(self, onnx_model): - import tempfile - from pathlib import Path - - import onnx - - # Use this step to see the final graph that executed by Onnx Runtime. - with tempfile.TemporaryDirectory() as tmp_dir: - # Save to a temporary file so that we can load it with Onnx Runtime. - logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") - tmp_model_path = Path(tmp_dir) / "model.onnx" - onnx_model.save_model_to_file(str(tmp_model_path)) - ort_optimized_model_path = tmp_model_path - optimize_by_onnxruntime( - str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) - ) - model = onnx.load(str(ort_optimized_model_path), load_external_data=True) - return self.model_type_class_mapping[self.model_type](model) - - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): - """Optimize onnx model using ONNX Runtime transformers optimizer""" - logger.info(f"Optimize {input_fp32_onnx_path}...") - fusion_options = FusionOptions(self.model_type) - if self.model_type in ["unet"] and not float16: - fusion_options.enable_packed_kv = False - fusion_options.enable_packed_qkv = False - - m = optimize_model( - input_fp32_onnx_path, - model_type=self.model_type, - num_heads=0, # will be deduced from graph - hidden_size=0, # will be deduced from graph - opt_level=0, - optimization_options=fusion_options, - use_gpu=True, - ) - - if self.model_type == "clip": - m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. - - if float16: - logger.info("Convert to float16 ...") - m.convert_float_to_float16( - keep_io_types=False, - op_block_list=["RandomNormalLike"], - ) - - # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 - if float16 or (self.model_type != "unet"): - m = self.optimize_by_ort(m) - - m.get_operator_statistics() - m.get_fused_operator_statistics() - m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) - logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path) - - -class BaseModel: - def __init__(self, model, name, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): - self.model = model - self.name = name - self.device = device - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.embedding_dim = embedding_dim - self.text_maxlen = text_maxlen - - self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae" - self.optimizer = OrtStableDiffusionOptimizer(self.model_type) - - def get_model(self): - return self.model - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, fp16): - self.optimizer.optimize(input_fp32_onnx_path, optimized_onnx_path, fp16) - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape - assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape - return (latent_height, latent_width) - - def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_image_shape else self.min_image_shape - max_image_height = image_height if static_image_shape else self.max_image_shape - min_image_width = image_width if static_image_shape else self.min_image_shape - max_image_width = image_width if static_image_shape else self.max_image_shape - min_latent_height = latent_height if static_image_shape else self.min_latent_shape - max_latent_height = latent_height if static_image_shape else self.max_latent_shape - min_latent_width = latent_width if static_image_shape else self.min_latent_shape - max_latent_width = latent_width if static_image_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - -def get_onnx_path(model_name, onnx_dir): - return os.path.join(onnx_dir, model_name + ".onnx") - - -def get_engine_path(engine_dir, model_name, profile_id): - return os.path.join(engine_dir, model_name + profile_id + ".onnx") - - -def build_engines( - models, - engine_dir, - onnx_dir, - onnx_opset, - force_engine_rebuild: bool = False, - fp16: bool = True, - provider: str = "CUDAExecutionProvider", - device_id: int = 0, - enable_cuda_graph: bool = False, -): - profile_id = "_fp16" if fp16 else "_fp32" - - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - - if not os.path.isdir(engine_dir): - os.makedirs(engine_dir) - - if not os.path.isdir(onnx_dir): - os.makedirs(onnx_dir) - - # Export models to ONNX - for model_name, model_obj in models.items(): - onnx_path = get_onnx_path(model_name, onnx_dir) - onnx_opt_path = get_engine_path(engine_dir, model_name, profile_id) - if os.path.exists(onnx_opt_path): - logger.info("Found cached optimized model: %s", onnx_opt_path) - else: - if os.path.exists(onnx_path): - logger.info("Found cached model: %s", onnx_path) - else: - logger.info("Exporting model: %s", onnx_path) - model = model_obj.get_model().to(model_obj.device) - with torch.inference_mode(): - inputs = model_obj.get_sample_input(1, 512, 512) - torch.onnx.export( - model, - inputs, - onnx_path, - export_params=True, - opset_version=onnx_opset, - do_constant_folding=True, - input_names=model_obj.get_input_names(), - output_names=model_obj.get_output_names(), - dynamic_axes=model_obj.get_dynamic_axes(), - ) - del model - torch.cuda.empty_cache() - gc.collect() - - # Optimize onnx - logger.info("Generating optimized model: %s", onnx_opt_path) - model_obj.optimize(onnx_path, onnx_opt_path, fp16) - - built_engines = {} - for model_name in models: - engine_path = get_engine_path(engine_dir, model_name, profile_id) - engine = Engine(engine_path, provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph) - logger.info("%s options for %s: %s", provider, model_name, engine.provider_options) - built_engines[model_name] = engine - - return built_engines - - -def run_engine(engine, feed_dict): - return engine.infer(feed_dict) - - -class CLIP(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="CLIP", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings", "pooler_output"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - # "pooler_output": (batch_size, self.embedding_dim) - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) - - -class UNet(BaseModel): - def __init__( - self, - model, - device="cuda", - max_batch_size=16, - embedding_dim=768, - text_maxlen=77, - unet_dim=4, - ): - super().__init__( - model=model, - name="UNet", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - text_maxlen=text_maxlen, - ) - self.unet_dim = unet_dim - - def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timestep": [1], - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return ( - torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=torch.float32, device=self.device), - ) - - -class VAE(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="VAE Decoder", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) +logger = logging.getLogger(__name__) class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline): @@ -457,7 +84,6 @@ def __init__( self.unet_in_channels = unet.config.in_channels self.inpaint = False - self.onnx_opset = onnx_opset self.onnx_dir = onnx_dir self.engine_dir = engine_dir self.force_engine_rebuild = force_engine_rebuild @@ -466,9 +92,8 @@ def __init__( self.max_batch_size = 16 self.models = {} # loaded in __load_models() - self.engines = {} # loaded in build_engines() + self.engines = Engines("CUDAExecutionProvider", onnx_opset) - self.provider = "CUDAExecutionProvider" self.fp16 = False def __load_models(self): @@ -484,6 +109,7 @@ def __load_models(self): self.models["unet"] = UNet( self.unet, device=self.torch_device, + fp16=self.fp16, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim, unet_dim=(9 if self.inpaint else 4), @@ -529,18 +155,16 @@ def to( self.torch_device = torch.device(torch_device) # load models + self.fp16 = torch_dtype == torch.float16 self.__load_models() # build engines - self.fp16 = torch_dtype == torch.float16 - self.engines = build_engines( + self.engines.build( self.models, self.engine_dir, self.onnx_dir, - self.onnx_opset, force_engine_rebuild=self.force_engine_rebuild, fp16=self.fp16, - provider=self.provider, device_id=self.torch_device.index or torch.cuda.current_device(), enable_cuda_graph=self.enable_cuda_graph, ) @@ -582,7 +206,9 @@ def __encode_prompt(self, prompt, negative_prompt): ) # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone() + text_embeddings = ( + self.engines.get_engine("clip").infer({"input_ids": text_input_ids})["text_embeddings"].clone() + ) # Tokenize negative prompt uncond_input_ids = ( @@ -597,7 +223,7 @@ def __encode_prompt(self, prompt, negative_prompt): .to(self.torch_device) ) - uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"] + uncond_embeddings = self.engines.get_engine("clip").infer({"input_ids": uncond_input_ids})["text_embeddings"] # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) @@ -618,8 +244,7 @@ def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32) # Predict the noise residual - noise_pred = run_engine( - self.engines["unet"], + noise_pred = self.engines.get_engine("unet").infer( {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, )["latent"] @@ -633,14 +258,16 @@ def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, return latents def __decode_latent(self, latents): - images = run_engine(self.engines["vae"], {"latent": latents})["images"] + images = self.engines.get_engine("vae").infer({"latent": latents})["images"] images = (images / 2 + 0.5).clamp(0, 1) return images.cpu().permute(0, 2, 3, 1).float().numpy() def __allocate_buffers(self, image_height, image_width, batch_size): # Allocate output tensors for I/O bindings for model_name, obj in self.models.items(): - self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width)) + self.engines.get_engine(model_name).allocate_buffers( + obj.get_shape_dict(batch_size, image_height, image_width) + ) @torch.no_grad() def __call__( @@ -736,9 +363,6 @@ def __call__( if __name__ == "__main__": - import torch - from diffusers import DDIMScheduler - model_name_or_path = "runwayml/stable-diffusion-v1-5" scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py index d8abd56d0e65..6f3c215f3631 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py @@ -37,8 +37,6 @@ import shutil from typing import List, Optional, Union -import onnx -import onnx_graphsurgeon as gs import torch from cuda import cudart from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -50,9 +48,8 @@ from diffusers.schedulers import DDIMScheduler from diffusers.utils import DIFFUSERS_CACHE, logging from huggingface_hub import snapshot_download -from onnx import shape_inference +from models import CLIP, VAE, UNet from ort_utils import OrtCudaSession -from polygraphy.backend.onnx.loader import fold_constants from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer import onnxruntime as ort @@ -124,142 +121,6 @@ def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, dev return trt_ep_options -class Optimizer: - def __init__(self, onnx_graph): - self.graph = gs.import_onnx(onnx_graph) - - def cleanup(self): - self.graph.cleanup().toposort() - - def get_optimized_onnx_graph(self): - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self): - onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) - self.graph = gs.import_onnx(onnx_graph) - - def infer_shapes(self): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - raise TypeError("ERROR: model size exceeds supported 2GB limit") - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - - -class BaseModel: - def __init__(self, model, name, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): - self.model = model - self.name = name - self.fp16 = fp16 - self.device = device - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.embedding_dim = embedding_dim - self.text_maxlen = text_maxlen - - def get_model(self): - return self.model - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - _, - _, - _, - _, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - - profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" - - if self.name != "CLIP": - if static_image_shape: - profile_id += f"_h_{image_height}_w_{image_width}" - else: - profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" - - return profile_id - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph) - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.cleanup() - return opt.get_optimized_onnx_graph() - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape - assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape - return (latent_height, latent_width) - - def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_image_shape else self.min_image_shape - max_image_height = image_height if static_image_shape else self.max_image_shape - min_image_width = image_width if static_image_shape else self.min_image_shape - max_image_width = image_width if static_image_shape else self.max_image_shape - min_latent_height = latent_height if static_image_shape else self.min_latent_shape - max_latent_height = latent_height if static_image_shape else self.max_latent_shape - min_latent_width = latent_width if static_image_shape else self.min_latent_shape - max_latent_width = latent_width if static_image_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - def get_onnx_path(model_name, onnx_dir, opt=True): return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") @@ -352,8 +213,7 @@ def build_engines( # Optimize onnx if not os.path.exists(onnx_opt_path): logger.info("Generating optimizing model: %s", onnx_opt_path) - onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) - onnx.save(onnx_opt_graph, onnx_opt_path) + model_obj.optimize_trt(onnx_path, onnx_opt_path) else: logger.info("Found cached optimized model: %s", onnx_opt_path) @@ -403,177 +263,6 @@ def run_engine(engine, feed_dict): return engine.infer(feed_dict) -class CLIP(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, name="CLIP", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim - ) - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings", "pooler_output"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_image_shape - ) - return { - "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph) - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.select_outputs([0], names=["text_embeddings"]) # rename network output - opt.cleanup() - return opt.get_optimized_onnx_graph() - - -class UNet(BaseModel): - def __init__( - self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4 - ): - super().__init__( - model=model, - name="UNet", - fp16=fp16, - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - text_maxlen=text_maxlen, - ) - self.unet_dim = unet_dim - - def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timestep": [1], - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), - ) - - -class VAE(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, name="VAE decoder", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "latent": [ - (min_batch, 4, min_latent_height, min_latent_width), - (batch_size, 4, latent_height, latent_width), - (max_batch, 4, max_latent_height, max_latent_width), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) - - class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime. @@ -644,8 +333,8 @@ def __load_models(self): self.models["unet"] = UNet( self.unet, - fp16=True, device=self.torch_device, + fp16=True, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim, unet_dim=(9 if self.inpaint else 4), @@ -888,23 +577,22 @@ def __call__( if __name__ == "__main__": - import torch - from diffusers import DDIMScheduler + model_name_or_path = "runwayml/stable-diffusion-v1-5" - scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler") + scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", + model_name_or_path, revision="fp16", torch_dtype=torch.float16, scheduler=scheduler, image_height=512, image_width=512, - max_batch_size=1, + max_batch_size=4, ) # re-use cached folder to save ONNX models and TensorRT Engines - pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision="fp16") + pipe.set_cached_folder(model_name_or_path, revision="fp16") pipe = pipe.to("cuda") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py new file mode 100644 index 000000000000..0824c8f07d6e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -0,0 +1,84 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +ONNX Model Optimizer for Stable Diffusion +""" + +import logging +import tempfile +from pathlib import Path + +import onnx + +from onnxruntime.transformers.fusion_options import FusionOptions +from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel +from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel +from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel +from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model + +logger = logging.getLogger(__name__) + + +class OrtStableDiffusionOptimizer: + def __init__(self, model_type: str): + assert model_type in ["vae", "unet", "clip"] + self.model_type = model_type + self.model_type_class_mapping = { + "unet": UnetOnnxModel, + "vae": VaeOnnxModel, + "clip": ClipOnnxModel, + } + + def optimize_by_ort(self, onnx_model): + # Use this step to see the final graph that executed by Onnx Runtime. + with tempfile.TemporaryDirectory() as tmp_dir: + # Save to a temporary file so that we can load it with Onnx Runtime. + logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") + tmp_model_path = Path(tmp_dir) / "model.onnx" + onnx_model.save_model_to_file(str(tmp_model_path)) + ort_optimized_model_path = tmp_model_path + optimize_by_onnxruntime( + str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + ) + model = onnx.load(str(ort_optimized_model_path), load_external_data=True) + return self.model_type_class_mapping[self.model_type](model) + + def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): + """Optimize onnx model using ONNX Runtime transformers optimizer""" + logger.info(f"Optimize {input_fp32_onnx_path}...") + fusion_options = FusionOptions(self.model_type) + if self.model_type in ["unet"] and not float16: + fusion_options.enable_packed_kv = False + fusion_options.enable_packed_qkv = False + + m = optimize_model( + input_fp32_onnx_path, + model_type=self.model_type, + num_heads=0, # will be deduced from graph + hidden_size=0, # will be deduced from graph + opt_level=0, + optimization_options=fusion_options, + use_gpu=True, + ) + + if self.model_type == "clip": + m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. + + if float16: + logger.info("Convert to float16 ...") + m.convert_float_to_float16( + keep_io_types=False, + op_block_list=["RandomNormalLike"], + ) + + # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 + if float16 or (self.model_type != "unet"): + m = self.optimize_by_ort(m) + + m.get_operator_statistics() + m.get_fused_operator_statistics() + m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) + logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py index 5ad43b6e3989..7192e4ad5584 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -3,17 +3,23 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import gc +import logging +import os +import shutil from collections import OrderedDict -from typing import Dict +from typing import Any, Dict import torch import onnxruntime as ort from onnxruntime.transformers.io_binding_helper import TypeHelper +logger = logging.getLogger(__name__) + class OrtCudaSession: - """ONNX Runtime Session for CUDA or TensorRT provider""" + """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider""" def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False): self.ort_session = ort_session @@ -110,3 +116,111 @@ def infer(self, feed_dict): self.ort_session.run_with_iobinding(self.io_binding) return self.output_tensors + + +class Engine(OrtCudaSession): + def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False): + self.engine_path = engine_path + self.provider = provider + self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph) + + device = torch.device("cuda", device_id) + ort_session = ort.InferenceSession( + self.engine_path, + providers=[ + (provider, self.provider_options), + "CPUExecutionProvider", + ], + ) + + super().__init__(ort_session, device, enable_cuda_graph) + + def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]: + return { + "device_id": device_id, + "arena_extend_strategy": "kSameAsRequested", + "enable_cuda_graph": enable_cuda_graph, + } + + +class Engines: + def __init__(self, provider, onnx_opset: int = 14): + self.provider = provider + self.engines = {} + self.onnx_opset = onnx_opset + + @staticmethod + def get_onnx_path(onnx_dir, model_name): + return os.path.join(onnx_dir, model_name + ".onnx") + + @staticmethod + def get_engine_path(engine_dir, model_name, profile_id): + return os.path.join(engine_dir, model_name + profile_id + ".onnx") + + def build( + self, + models, + engine_dir: str, + onnx_dir: str, + force_engine_rebuild: bool = False, + fp16: bool = True, + device_id: int = 0, + enable_cuda_graph: bool = False, + ): + profile_id = "_fp16" if fp16 else "_fp32" + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + onnx_path = Engines.get_onnx_path(onnx_dir, model_name) + onnx_opt_path = Engines.get_engine_path(engine_dir, model_name, profile_id) + if os.path.exists(onnx_opt_path): + logger.info("Found cached optimized model: %s", onnx_opt_path) + else: + if os.path.exists(onnx_path): + logger.info("Found cached model: %s", onnx_path) + else: + logger.info("Exporting model: %s", onnx_path) + model = model_obj.get_model().to(model_obj.device) + with torch.inference_mode(): + inputs = model_obj.get_sample_input(1, 512, 512) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=self.onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + + # Optimize onnx + logger.info("Generating optimized model: %s", onnx_opt_path) + model_obj.optimize_ort(onnx_path, onnx_opt_path, to_fp16=fp16) + + for model_name in models: + engine_path = Engines.get_engine_path(engine_dir, model_name, profile_id) + engine = Engine(engine_path, self.provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph) + logger.info("%s options for %s: %s", self.provider, model_name, engine.provider_options) + self.engines[model_name] = engine + + def get_engine(self, model_name): + return self.engines[model_name]