diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index d937e3f4213e0..1ec1ca3ba0c83 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -1,76 +1,94 @@ # Stable Diffusion GPU Optimization -## Overview - -[Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement) is a text-to-image latent diffusion model for image generation. Explanation of the Stable Diffusion can be found in [Stable Diffusion with Diffusers](https://huggingface.co/blog/stable_diffusion). - -## Optimizations for Stable Diffusion - ONNX Runtime uses the following optimizations to speed up Stable Diffusion in CUDA: * [Flash Attention](https://arxiv.org/abs/2205.14135) for float16 precision. Flash Attention uses tiling to reduce number of GPU memory reads/writes, and improves performance with less memory for long sequence length. The kernel requires GPUs of Compute Capability >= 7.5 (like T4, A100, and RTX 2060~4090). * [Memory Efficient Attention](https://arxiv.org/abs/2112.05682v2) for float32 precision or older GPUs (like V100). We used the fused multi-head attention kernel in CUTLASS, and the kernel was contributed by xFormers. * Channel-last (NHWC) convolution. For NVidia GPU with Tensor Cores support, NHWC tensor layout is recommended for convolution. See [Tensor Layouts In Memory: NCHW vs NHWC](https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout). -* GroupNorm kernel for NHWC tensor layout. +* GroupNorm for NHWC tensor layout, and SkipGroupNorm fusion which fuses GroupNorm with Add bias and residual inputs * SkipLayerNormalization which fuses LayerNormalization with Add bias and residual inputs. * BiasSplitGelu is a fusion of Add bias with SplitGelu activation. * BiasAdd fuses Add bias and residual. * Reduce Transpose nodes by graph transformation. -These optimizations are firstly carried out on CUDA EP. They may not work on other EP. To show the impact of each optimization on latency and GPU memory, we did some experiments: +These optimizations are firstly carried out on CUDA EP. They may not work on other EP. -### Results on RTX 3060 GPU: +## Scripts: -| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | -| ---------------------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | -| Raw FP32 models | 25.6 | 10,667 | OOM | OOM | -| FP16 baseline | 10.2 | 10,709 | OOM | OOM | -| FP16 baseline + FMHA | 6.1 | 7,719 | 39.1 | 10,821 | -| FP16 baseline + FMHA + NhwcConv | 5.5 | 7,656 | 38.8 | 11,615 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm | 5.1 | 6,673 | 35.8 | 10,763 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 4.9 | 4,447 | 33.7 | 6,669 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV | 4.8 | 4,625 | 33.5 | 6,663 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV + BiasAdd | 4.7 | 4,480 | 33.3 | 6,499 | +| Script | Description | +| ---------------------------------------------- | ----------------------------------------------------------------------------------------- | +| [demo_txt2img_xl.py](./demo_txt2img_xl.py) | Demo of text to image generation using Stable Diffusion XL model. | +| [demo_txt2img.py](./demo_txt2img.py) | Demo of text to image generation using Stable Diffusion models except XL. | +| [optimize_pipeline.py](./optimize_pipeline.py) | Optimize Stable Diffusion ONNX models exported from Huggingface diffusers or optimum | +| [benchmark.py](./benchmark.py) | Benchmark latency and memory of OnnxRuntime, xFormers or PyTorch 2.0 on stable diffusion. | -FP16 baseline contains optimizations available in ONNX Runtime 1.13 including LayerNormalization, SkipLayerNormalization, Gelu and float16 conversion. -Here FMHA means Attention and MultiHeadAttention operators with Flash Attention and Memory Efficient Attention kernels but inputs are not packed. Packed QKV means the inputs are packed. +## Run demo with docker -The last two optimizations (Packed QKV and BiasAdd) are only available in nightly package. Compared to 1.14.1, nightly package has slight improvement in performance. +#### Clone the onnxruntime repository +``` +git clone https://github.com/microsoft/onnxruntime +cd onnxruntime +``` -### Results on MI250X with 1 GCD +#### Launch NVIDIA pytorch container -With runtime tuning enabled, we get following performance number on one GCD of a MI250X GPU: +Install nvidia-docker using [these instructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker). -| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | -| --------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | -| Raw FP32 models | 6.7 | 17,319 | 36.4 * | 33,787 | -| FP16 baseline | 4.1 | 8,945 | 24.0 * | 34,493 | -| FP16 baseline + FMHA | 2.6 | 4,886 | 15.0 | 10,146 | -| FP16 baseline + FMHA + NhwcConv | 2.4 | 4,952 | 14.8 | 9,632 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm | 2.3 | 4,906 | 13.6 | 9,774 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 2.2 | 4,910 | 12.5 | 9,646 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + BiasAdd | 2.2 | 4,910 | 12.5 | 9,778 | +``` +docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:23.10-py3 /bin/bash +``` -The entries marked with `*` produce suspicious output images. The might be numerical stability or correctness issue for the pipeline. The performance number is for reference only. +#### Build onnxruntime from source +After launching the docker, you can build and install onnxruntime-gpu wheel like the following. +``` +export CUDACXX=/usr/local/cuda-12.2/bin/nvcc +git config --global --add safe.directory '*' +sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_version 12.2 \ + --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/lib/x86_64-linux-gnu/ --build_wheel --skip_tests \ + --use_tensorrt --tensorrt_home /usr/src/tensorrt \ + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ + --allow_running_as_root +python3 -m pip install --upgrade pip +python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall +``` -## Scripts: +If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity. -| Script | Description | -| ---------------------------------------------- | ----------------------------------------------------------------------------------------- | -| [optimize_pipeline.py](./optimize_pipeline.py) | Optimize Stable Diffusion ONNX models | -| [benchmark.py](./benchmark.py) | Benchmark latency and memory of OnnxRuntime, xFormers or PyTorch 2.0 on stable diffusion. | +#### Install required packages +``` +cd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion +python3 -m pip install -r requirements-cuda12.txt +python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +``` -In below example, we run the scripts in source code directory. You can get source code like the following: +### Run Demo +You can review the usage of supported pipelines like the following: ``` -git clone https://github.com/microsoft/onnxruntime -cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion +python3 demo_txt2img.py --help +python3 demo_txt2img_xl.py --help ``` -## Example of Stable Diffusion 1.5 +For example: +`--engine {ORT_CUDA,ORT_TRT,TRT}` can be used to choose different backend engines including CUDA or TensorRT execution provider of ONNX Runtime, or TensorRT. +`--work-dir WORK_DIR` can be used to load or save models under the given directory. You can download the [optimized ONNX models of Stable Diffusion XL 1.0](https://huggingface.co/tlwu/stable-diffusion-xl-1.0-onnxruntime#usage-example) to save time in running the XL demo. + +#### Generate an image guided by a text prompt +```python3 demo_txt2img.py "astronaut riding a horse on mars"``` + +#### Generate an image with Stable Diffusion XL guided by a text prompt +```python3 demo_txt2img_xl.py "starry night over Golden Gate Bridge by van gogh"``` -Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. +If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration. + +## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum + +If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline). + +Below setup does not use docker. We'll use the environment to optimize ONNX models of Stable Diffusion exported by huggingface diffusers or optimum. +For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. It is recommended to create a Conda environment with Python 3.10 for the following setup: ``` @@ -78,7 +96,7 @@ conda create -n py310 python=3.10 conda activate py310 ``` -### Setup Environment (CUDA) +### Setup Environment (CUDA) without docker First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 8.5 or above, and [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine. @@ -86,12 +104,19 @@ First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeple In the Conda environment, install PyTorch 2.1 or above, and other required packages like the following: ``` -pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install torch --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda11.txt ``` -We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. +For Windows, install nvtx like the following: +``` +conda install -c conda-forge nvtx +``` + +We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. + +For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. Like `pip install tensorrt-8.6.1.6.windows10.x86_64.cuda-11.8\tensorrt-8.6.1.6\python\tensorrt-8.6.1-cp310-none-win_amd64.whl`. #### CUDA 12.*: The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CUDA 12.*, you will need [build onnxruntime from source](https://onnxruntime.ai/docs/build/inferencing.html). @@ -99,6 +124,7 @@ The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CU ``` git clone --recursive https://github.com/Microsoft/onnxruntime.git cd onnxruntime +pip install cmake pip install -r requirements-dev.txt ``` Follow [example script for A100 in Ubuntu](https://github.com/microsoft/onnxruntime/blob/26a7b63716e3125bfe35fe3663ba10d2d7322628/build_release.sh) @@ -106,7 +132,7 @@ or [example script for RTX 4090 in Windows](https://github.com/microsoft/onnxrun Then install other python packages like the following: ``` -pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 +pip install torch --index-url https://download.pytorch.org/whl/cu121 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda12.txt ``` @@ -182,7 +208,13 @@ Example to optimize the exported float32 ONNX models, and save to float16 models python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16 ``` -For SDXL model, it is recommended to use a machine with 32 GB or more memory to optimize. +In all examples below, we run the scripts in source code directory. You can get source code like the following: +``` +git clone https://github.com/microsoft/onnxruntime +cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion +``` + +For SDXL model, it is recommended to use a machine with 48 GB or more memory to optimize. ``` python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 ``` @@ -265,6 +297,44 @@ python benchmark.py -e tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 python benchmark.py -e onnxruntime -r tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph ``` +### Results on RTX 3060 GPU: + +To show the impact of each optimization on latency and GPU memory, we did some experiments: + +| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | +| ---------------------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | +| Raw FP32 models | 25.6 | 10,667 | OOM | OOM | +| FP16 baseline | 10.2 | 10,709 | OOM | OOM | +| FP16 baseline + FMHA | 6.1 | 7,719 | 39.1 | 10,821 | +| FP16 baseline + FMHA + NhwcConv | 5.5 | 7,656 | 38.8 | 11,615 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm | 5.1 | 6,673 | 35.8 | 10,763 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 4.9 | 4,447 | 33.7 | 6,669 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV | 4.8 | 4,625 | 33.5 | 6,663 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV + BiasAdd | 4.7 | 4,480 | 33.3 | 6,499 | + +FP16 baseline contains optimizations available in ONNX Runtime 1.13 including LayerNormalization, SkipLayerNormalization, Gelu and float16 conversion. + +Here FMHA means Attention and MultiHeadAttention operators with Flash Attention and Memory Efficient Attention kernels but inputs are not packed. Packed QKV means the inputs are packed. + +The last two optimizations (Packed QKV and BiasAdd) are only available in nightly package. Compared to 1.14.1, nightly package has slight improvement in performance. + +### Results on MI250X with 1 GCD + +With runtime tuning enabled, we get following performance number on one GCD of a MI250X GPU: + +| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | +| --------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | +| Raw FP32 models | 6.7 | 17,319 | 36.4 * | 33,787 | +| FP16 baseline | 4.1 | 8,945 | 24.0 * | 34,493 | +| FP16 baseline + FMHA | 2.6 | 4,886 | 15.0 | 10,146 | +| FP16 baseline + FMHA + NhwcConv | 2.4 | 4,952 | 14.8 | 9,632 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm | 2.3 | 4,906 | 13.6 | 9,774 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 2.2 | 4,910 | 12.5 | 9,646 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + BiasAdd | 2.2 | 4,910 | 12.5 | 9,778 | + +The entries marked with `*` produce suspicious output images. The might be numerical stability or correctness issue for the pipeline. The performance number is for reference only. + + ### Example Benchmark output Common settings for below test results: @@ -400,7 +470,8 @@ Results are from Standard_NC4as_T4_v3 Azure virtual machine: ### Credits -Some CUDA kernels (Flash Attention, GroupNorm, SplitGelu and BiasAdd etc.) were originally implemented in [TensorRT](https://github.com/nviDIA/TensorRT) by Nvidia. +Some CUDA kernels (TensorRT Fused Attention, GroupNorm, SplitGelu and BiasAdd etc.) and demo diffusion were originally implemented in [TensorRT](https://github.com/nviDIA/TensorRT) by Nvidia. +We use [Flash Attention v2](https://github.com/Dao-AILab/flash-attention) in Linux. We use Memory efficient attention from [CUTLASS](https://github.com/NVIDIA/cutlass). The kernels were developed by Meta xFormers. The ONNX export script and pipeline for stable diffusion was developed by Huggingface [diffusers](https://github.com/huggingface/diffusers) library. @@ -408,10 +479,8 @@ Most ROCm kernel optimizations are from [composable kernel](https://github.com/R Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' collaboration. ### Future Works - -There are other optimizations might improve the performance or reduce memory footprint: -* Export the whole pipeline into a single ONNX model. Currently, there are multiple ONNX models (CLIP, VAE and U-Net etc). Each model uses separated thread pool and memory allocator. Combine them into one model could share thread pool and memory allocator. The end result is more efficient and less memory footprint. -* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention in Windows to improve performance. -* Reduce GPU memory footprint by actively deleting buffers for intermediate results. -* Safety Checker Optimization -* Leverage FP8 in latest GPU +* Update demo to support inpainting, LoRA Weights and Control Net. +* Support flash attention in Windows. +* Integration with UI. +* Optimization for H100 GPU. +* Export the whole pipeline into a single ONNX model. This senario is mainly for mobile device. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index fb051ac1ed3b4..4636f139d4613 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -53,8 +53,31 @@ f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - pipeline_info = PipelineInfo(args.version) - pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. + # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. + # This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768. + min_image_size = 512 if args.engine != "ORT_CUDA" else 256 + max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 + pipeline_info = PipelineInfo(args.version, min_image_size=min_image_size, max_image_size=max_image_size) + + # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to + # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. + # In this demo, we optimize batch size 1 and image size 512x512 (or 768x768 for SD 2.0/2.1) for dynamic engine. + # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. + opt_batch_size = 1 if args.build_dynamic_batch else batch_size + opt_image_height = pipeline_info.default_image_size() if args.build_dynamic_shape else args.height + opt_image_width = pipeline_info.default_image_size() if args.build_dynamic_shape else args.width + + pipeline = init_pipeline( + Txt2ImgPipeline, + pipeline_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) if engine_type == EngineType.TRT: max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 16e776a08282c..4f9ecf6cbb152 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -29,17 +29,7 @@ from pipeline_txt2img_xl import Txt2ImgXLPipeline -def run_demo(): - """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image.""" - - args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") - - prompt, negative_prompt = repeat_prompt(args) - - # Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf). - image_height = args.height - image_width = args.width - +def load_pipelines(args, batch_size): # Register TensorRT plugins engine_type = get_engine_type(args.engine) if engine_type == EngineType.TRT: @@ -49,26 +39,65 @@ def run_demo(): max_batch_size = 16 if (engine_type in [EngineType.ORT_TRT, EngineType.TRT]) and ( - args.build_dynamic_shape or image_height > 512 or image_width > 512 + args.build_dynamic_shape or args.height > 512 or args.width > 512 ): max_batch_size = 4 - batch_size = len(prompt) if batch_size > max_batch_size: raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.") + # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. + # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. + # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024). + min_image_size = 832 if args.engine != "ORT_CUDA" else 512 + max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 + # No VAE decoder in base when it outputs latent instead of image. - base_info = PipelineInfo(args.version, use_vae=False) - base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) + base_info = PipelineInfo( + args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size + ) - refiner_info = PipelineInfo(args.version, is_refiner=True) - refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to + # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. + # In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine. + # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. + opt_batch_size = 1 if args.build_dynamic_batch else batch_size + opt_image_height = base_info.default_image_size() if args.build_dynamic_shape else args.height + opt_image_width = base_info.default_image_size() if args.build_dynamic_shape else args.width + + base = init_pipeline( + Txt2ImgXLPipeline, + base_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) + + refiner = None + if not args.disable_refiner: + refiner_info = PipelineInfo( + args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size + ) + refiner = init_pipeline( + Img2ImgXLPipeline, + refiner_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) if engine_type == EngineType.TRT: - max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory()) _, shared_device_memory = cudart.cudaMalloc(max_device_memory) base.backend.activate_engines(shared_device_memory) - refiner.backend.activate_engines(shared_device_memory) + if refiner: + refiner.backend.activate_engines(shared_device_memory) if engine_type == EngineType.ORT_CUDA: enable_vae_slicing = args.enable_vae_slicing @@ -76,10 +105,17 @@ def run_demo(): print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.") enable_vae_slicing = True if enable_vae_slicing: - refiner.backend.enable_vae_slicing() + (refiner or base).backend.enable_vae_slicing() + return base, refiner + +def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False): + image_height = args.height + image_width = args.width + batch_size = len(prompt) base.load_resources(image_height, image_width, batch_size) - refiner.load_resources(image_height, image_width, batch_size) + if refiner: + refiner.load_resources(image_height, image_width, batch_size) def run_base_and_refiner(warmup=False): images, time_base = base.run( @@ -91,8 +127,13 @@ def run_base_and_refiner(warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, - return_type="latent", + return_type="latent" if refiner else "image", ) + if refiner is None: + return images, time_base + + # Use same seed in base and refiner. + seed = base.get_current_seed() images, time_refiner = refiner.run( prompt, @@ -103,7 +144,7 @@ def run_base_and_refiner(warmup=False): warmup=warmup, denoising_steps=args.denoising_steps, guidance=args.guidance, - seed=args.seed, + seed=seed, ) return images, time_base + time_refiner @@ -112,10 +153,14 @@ def run_base_and_refiner(warmup=False): # inference once to get cuda graph _, _ = run_base_and_refiner(warmup=True) - print("[I] Warming up ..") + if args.num_warmup_runs > 0: + print("[I] Warming up ..") for _ in range(args.num_warmup_runs): _, _ = run_base_and_refiner(warmup=True) + if is_warm_up: + return + print("[I] Running StableDiffusion XL pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() @@ -123,14 +168,89 @@ def run_base_and_refiner(warmup=False): if args.nvtx_profile: cudart.cudaProfilerStop() - base.teardown() - print("|------------|--------------|") print("| {:^10} | {:>9.2f} ms |".format("e2e", latency)) print("|------------|--------------|") - refiner.teardown() + + +def run_demo(args): + """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image.""" + + prompt, negative_prompt = repeat_prompt(args) + batch_size = len(prompt) + base, refiner = load_pipelines(args, batch_size) + run_pipelines(args, base, refiner, prompt, negative_prompt) + base.teardown() + if refiner: + refiner.teardown() + + +def run_dynamic_shape_demo(args): + """Run demo of generating images with different settings with ORT CUDA provider.""" + args.engine = "ORT_CUDA" + args.disable_cuda_graph = True + base, refiner = load_pipelines(args, 1) + + prompts = [ + "starry night over Golden Gate Bridge by van gogh", + "beautiful photograph of Mt. Fuji during cherry blossom", + "little cute gremlin sitting on a bed, cinematic", + "cute grey cat with blue eyes, wearing a bowtie, acrylic painting", + "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation", + "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic", + ] + + # batch size, height, width, scheduler, steps, prompt, seed + configs = [ + (1, 832, 1216, "UniPC", 8, prompts[0], None), + (1, 1024, 1024, "DDIM", 24, prompts[1], None), + (1, 1216, 832, "UniPC", 16, prompts[2], None), + (1, 1344, 768, "DDIM", 24, prompts[3], None), + (2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712), + (2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906), + ] + + # Warm up each combination of (batch size, height, width) once before serving. + args.prompt = ["warm up"] + args.num_warmup_runs = 1 + for batch_size, height, width, _, _, _, _ in configs: + args.batch_size = batch_size + args.height = height + args.width = width + print(f"\nWarm up batch_size={batch_size}, height={height}, width={width}") + prompt, negative_prompt = repeat_prompt(args) + run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=True) + + # Run pipeline on a list of prompts. + args.num_warmup_runs = 0 + for batch_size, height, width, scheduler, steps, example_prompt, seed in configs: + args.prompt = [example_prompt] + args.batch_size = batch_size + args.height = height + args.width = width + args.scheduler = scheduler + args.denoising_steps = steps + args.seed = seed + base.set_scheduler(scheduler) + if refiner: + refiner.set_scheduler(scheduler) + print( + f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}" + ) + prompt, negative_prompt = repeat_prompt(args) + run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) + + base.teardown() + if refiner: + refiner.teardown() if __name__ == "__main__": coloredlogs.install(fmt="%(funcName)20s: %(message)s") - run_demo() + + args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0] + if no_prompt: + run_dynamic_shape_demo(args) + else: + run_demo(args) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index e65efd2c53839..39ee273a3130d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -78,13 +78,13 @@ def parse_arguments(is_xl: bool, description: str): help="Root Directory to store torch or ONNX models, built engines and output images etc.", ) - parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.") + parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.") parser.add_argument( "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." ) parser.add_argument( - "--repeat-prompt", + "--batch-size", type=int, default=1, choices=[1, 2, 4, 8, 16], @@ -145,6 +145,10 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + parser.add_argument( + "--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline." + ) + group = parser.add_argument_group("Options for ORT_CUDA engine only") group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") @@ -174,9 +178,9 @@ def parse_arguments(is_xl: bool, description: str): ) # Validate image dimensions - if args.height % 8 != 0 or args.width % 8 != 0: + if args.height % 64 != 0 or args.width % 64 != 0: raise ValueError( - f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}." ) if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: @@ -194,7 +198,7 @@ def parse_arguments(is_xl: bool, description: str): def repeat_prompt(args): if not isinstance(args.prompt, list): raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") - prompt = args.prompt * args.repeat_prompt + prompt = args.prompt * args.batch_size if not isinstance(args.negative_prompt, list): raise ValueError( @@ -209,7 +213,9 @@ def repeat_prompt(args): return prompt, negative_prompt -def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): +def init_pipeline( + pipeline_class, pipeline_info, engine_type, args, max_batch_size, opt_batch_size, opt_image_height, opt_image_width +): onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type ) @@ -234,9 +240,6 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si engine_dir=engine_dir, framework_model_dir=framework_model_dir, onnx_dir=onnx_dir, - opt_image_height=args.height, - opt_image_width=args.height, - opt_batch_size=batch_size, force_engine_rebuild=args.force_engine_build, device_id=torch.cuda.current_device(), ) @@ -247,14 +250,15 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si framework_model_dir, onnx_dir, args.onnx_opset, - opt_image_height=args.height, - opt_image_width=args.height, - opt_batch_size=batch_size, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, + opt_batch_size=opt_batch_size, force_engine_rebuild=args.force_engine_build, static_batch=not args.build_dynamic_batch, static_image_shape=not args.build_dynamic_shape, max_workspace_size=0, device_id=torch.cuda.current_device(), + timing_cache=timing_cache, ) elif engine_type == EngineType.TRT: # Load TensorRT engines and pytorch modules @@ -263,9 +267,9 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si framework_model_dir, onnx_dir, args.onnx_opset, - opt_batch_size=batch_size, - opt_image_height=args.height, - opt_image_width=args.height, + opt_batch_size=opt_batch_size, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, force_export=args.force_onnx_export, force_optimize=args.force_onnx_optimize, force_build=args.force_engine_build, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 8b7579653d1b5..514205d3b8945 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -82,12 +82,23 @@ def infer_shapes(self): class PipelineInfo: - def __init__(self, version: str, is_inpaint: bool = False, is_refiner: bool = False, use_vae=False): + def __init__( + self, + version: str, + is_inpaint: bool = False, + is_refiner: bool = False, + use_vae=False, + min_image_size=256, + max_image_size=1024, + use_fp16_vae=True, + ): self.version = version self._is_inpaint = is_inpaint self._is_refiner = is_refiner self._use_vae = use_vae - + self._min_image_size = min_image_size + self._max_image_size = max_image_size + self._use_fp16_vae = use_fp16_vae if is_refiner: assert self.is_xl() @@ -118,6 +129,13 @@ def stages(self) -> List[str]: def vae_scaling_factor(self) -> float: return 0.13025 if self.is_xl() else 0.18215 + def vae_torch_fallback(self) -> bool: + return self.is_xl() and not self._use_fp16_vae + + def custom_fp16_vae(self) -> Optional[str]: + # For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs + return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None + @staticmethod def supported_versions(is_xl: bool): return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] @@ -187,6 +205,19 @@ def unet_embedding_dim(self): else: raise ValueError(f"Invalid version {self.version}") + def min_image_size(self): + return self._min_image_size + + def max_image_size(self): + return self._max_image_size + + def default_image_size(self): + if self.is_xl(): + return 1024 + if self.version in ("2.0", "2.1"): + return 768 + return 512 + class BaseModel: def __init__( @@ -209,8 +240,8 @@ def __init__( self.min_batch = 1 self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_image_shape = pipeline_info.min_image_size() + self.max_image_shape = pipeline_info.max_image_size() self.min_latent_shape = self.min_image_shape // 8 self.max_latent_shape = self.max_image_shape // 8 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py index ec3041e134e75..26c8450c57de9 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -44,7 +44,6 @@ def __init__( alphas = 1.0 - betas self.alphas_cumprod = torch.cumprod(alphas, dim=0) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -71,7 +70,7 @@ def configure(self): self.variance = torch.from_numpy(variance).to(self.device) timesteps = self.timesteps.long().cpu() - self.alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) + self.filtered_alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device) def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor: @@ -124,9 +123,9 @@ def step( # - pred_prev_sample -> "x_t-1" prev_idx = idx + 1 - alpha_prod_t = self.alphas_cumprod[idx] + alpha_prod_t = self.filtered_alphas_cumprod[idx] alpha_prod_t_prev = ( - self.alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod + self.filtered_alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t @@ -179,15 +178,15 @@ def step( variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) - variance = variance ** (0.5) * eta * variance_noise + variance = std_dev_t * variance_noise prev_sample = prev_sample + variance return prev_sample def add_noise(self, init_latents, noise, idx, latent_timestep): - sqrt_alpha_prod = self.alphas_cumprod[idx] ** 0.5 - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[idx]) ** 0.5 + sqrt_alpha_prod = self.filtered_alphas_cumprod[idx] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.filtered_alphas_cumprod[idx]) ** 0.5 noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise return noisy_latents diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index dfdfa007d74eb..ace75bfbae7cb 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -60,15 +60,8 @@ def __init__( self.torch_device = torch.device(device, torch.cuda.current_device()) self.stages = pipeline_info.stages() - # TODO: use custom fp16 for ORT_TRT, and no need to fallback to torch. - self.vae_torch_fallback = self.pipeline_info.is_xl() and engine_type != EngineType.ORT_CUDA - - # For SD XL, use an VAE that modified to run in fp16 precision without generating NaNs. - self.custom_fp16_vae = ( - "madebyollin/sdxl-vae-fp16-fix" - if self.pipeline_info.is_xl() and self.engine_type == EngineType.ORT_CUDA - else None - ) + self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() + self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae() self.models = {} self.engines = {} diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 07c675b2ed990..a03ca7ce2912c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -159,9 +159,6 @@ def build_engines( framework_model_dir: str, onnx_dir: str, onnx_opset_version: int = 17, - opt_image_height: int = 512, - opt_image_width: int = 512, - opt_batch_size: int = 1, force_engine_rebuild: bool = False, device_id: int = 0, save_fp32_intermediate_model=False, @@ -209,7 +206,8 @@ def build_engines( with torch.inference_mode(): # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + # Export model with sample of batch size 1, image size 512 x 512 + inputs = model_obj.get_sample_input(1, 512, 512) torch.onnx.export( model, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index 8a39dc2ed63fc..d966833aba394 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -13,6 +13,7 @@ from diffusion_models import PipelineInfo from engine_builder import EngineBuilder, EngineType from ort_utils import CudaSession +from packaging import version import onnxruntime as ort @@ -20,7 +21,17 @@ class OrtTensorrtEngine(CudaSession): - def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): + def __init__( + self, + engine_path, + device_id, + onnx_path, + fp16, + input_profile, + workspace_size, + enable_cuda_graph, + timing_cache_path=None, + ): self.engine_path = engine_path self.ort_trt_provider_options = self.get_tensorrt_provider_options( input_profile, @@ -28,6 +39,7 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works fp16, device_id, enable_cuda_graph, + timing_cache_path=timing_cache_path, ) session_options = ort.SessionOptions() @@ -45,7 +57,9 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works device = torch.device("cuda", device_id) super().__init__(ort_session, device, enable_cuda_graph) - def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): + def get_tensorrt_provider_options( + self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph, timing_cache_path=None + ): trt_ep_options = { "device_id": device_id, "trt_fp16_enable": fp16, @@ -55,6 +69,9 @@ def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, dev "trt_engine_cache_path": self.engine_path, } + if version.parse(ort.__version__) > version.parse("1.16.2") and timing_cache_path is not None: + trt_ep_options["trt_timing_cache_path"] = timing_cache_path + if enable_cuda_graph: trt_ep_options["trt_cuda_graph_enable"] = True @@ -153,6 +170,7 @@ def build_engines( static_image_shape=True, max_workspace_size=0, device_id=0, + timing_cache=None, ): self.torch_device = torch.device("cuda", device_id) self.load_models(framework_model_dir) @@ -224,7 +242,6 @@ def build_engines( engine_path = self.get_engine_path(engine_dir, model_name, profile_id) onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) - if not self.has_engine_file(engine_path): logger.info( "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", @@ -251,6 +268,7 @@ def build_engines( input_profile=input_profile, workspace_size=self.get_work_space_size(model_name, max_workspace_size), enable_cuda_graph=self.use_cuda_graph, + timing_cache_path=timing_cache, ) built_engines[model_name] = engine diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 4b48396b6c783..28e79abb9f018 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -91,7 +91,10 @@ def optimize( if keep_outputs: m.prune_graph(outputs=keep_outputs) - use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF + model_size = m.model.ByteSize() + + # model size might be negative (overflow?) in Windows. + use_external_data_format = model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF # Note that ORT < 1.16 could not save model larger than 2GB. # This step is is optional since it has no impact on inference latency. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index e28db2b77105a..e675c9a7b3bf5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -102,35 +102,19 @@ def __init__( self.verbose = verbose self.nvtx_profile = nvtx_profile - # Scheduler options - sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} - if self.version in ("2.0", "2.1"): - sched_opts["prediction_type"] = "v_prediction" - else: - sched_opts["prediction_type"] = "epsilon" - - if scheduler == "DDIM": - self.scheduler = DDIMScheduler(device=self.device, **sched_opts) - elif scheduler == "EulerA": - self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) - elif scheduler == "UniPC": - self.scheduler = UniPCMultistepScheduler(device=self.device) - else: - raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") - self.stages = pipeline_info.stages() - self.vae_torch_fallback = self.pipeline_info.is_xl() - self.use_cuda_graph = use_cuda_graph self.tokenizer = None self.tokenizer2 = None - self.generator = None - self.denoising_steps = None + self.generator = torch.Generator(device="cuda") self.actual_steps = None + self.current_scheduler = None + self.set_scheduler(scheduler) + # backend engine self.engine_type = engine_type if engine_type == EngineType.TRT: @@ -162,10 +146,31 @@ def __init__( def is_backend_tensorrt(self): return self.engine_type == EngineType.TRT + def set_scheduler(self, scheduler: str): + if scheduler == self.current_scheduler: + return + + # Scheduler options + sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} + if self.version in ("2.0", "2.1"): + sched_opts["prediction_type"] = "v_prediction" + else: + sched_opts["prediction_type"] = "epsilon" + + if scheduler == "DDIM": + self.scheduler = DDIMScheduler(device=self.device, **sched_opts) + elif scheduler == "EulerA": + self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) + elif scheduler == "UniPC": + self.scheduler = UniPCMultistepScheduler(device=self.device) + else: + raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") + + self.current_scheduler = scheduler + self.denoising_steps = None + def set_denoising_steps(self, denoising_steps: int): - if self.denoising_steps != denoising_steps: - assert self.denoising_steps is None # TODO(tianleiwu): support changing steps in different runs - # Pre-compute latent input scales and linear multistep coefficients + if not (self.denoising_steps == denoising_steps and isinstance(self.scheduler, DDIMScheduler)): self.scheduler.set_timesteps(denoising_steps) self.scheduler.configure() self.denoising_steps = denoising_steps @@ -176,8 +181,13 @@ def load_resources(self, image_height, image_width, batch_size): self.backend.load_resources(image_height, image_width, batch_size) def set_random_seed(self, seed): - # Initialize noise generator. Usually, it is done before a batch of inference. - self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None + if isinstance(seed, int): + self.generator.manual_seed(seed) + else: + self.generator.seed() + + def get_current_seed(self): + return self.generator.initial_seed() def teardown(self): for e in self.events.values(): @@ -447,8 +457,18 @@ def save_images(self, images, pipeline, prompt): images = self.to_pil_image(images) random_session_id = str(random.randint(1000, 9999)) for i, image in enumerate(images): + seed = str(self.get_current_seed()) image_path = os.path.join( - self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png" + self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + "-" + seed + ".png" ) print(f"Saving image {i+1} / {len(images)} to: {image_path}") - image.save(image_path) + + from PIL import PngImagePlugin + + metadata = PngImagePlugin.PngInfo() + metadata.add_text("prompt", prompt[i]) + metadata.add_text("batch_size", str(len(images))) + metadata.add_text("denoising_steps", str(self.denoising_steps)) + metadata.add_text("actual_steps", str(self.actual_steps)) + metadata.add_text("seed", seed) + image.save(image_path, "PNG", pnginfo=metadata) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt index 5f908c4f5ff39..447cb54f98ed2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt @@ -1,7 +1,7 @@ -r requirements.txt # Official onnxruntime-gpu 1.16.1 is built with CUDA 11.8. -onnxruntime-gpu>=1.16.1 +onnxruntime-gpu>=1.16.2 py3nvml @@ -12,7 +12,8 @@ cuda-python==11.8.0 # For windows, cuda-python need the following pywin32; platform_system == "Windows" -nvtx +# For windows, run `conda install -c conda-forge nvtx` instead +nvtx; platform_system != "Windows" # Please install PyTorch 2.1 or above for CUDA 11.8 using one of the following commands: # pip3 install torch --index-url https://download.pytorch.org/whl/cu118 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt index e4e765831c1b3..1ff0e3c1cf5af 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt @@ -1,18 +1,19 @@ -r requirements.txt # For CUDA 12.*, you will need build onnxruntime-gpu from source and install the wheel. See README.md for detail. -# onnxruntime-gpu>=1.16.1 +# onnxruntime-gpu>=1.16.2 py3nvml # The version of cuda-python shall be compatible with installed CUDA version. # For example, if your CUDA version is 12.1, you can install cuda-python 12.1. -cuda-python==12.1.0 +cuda-python>=12.1.0 # For windows, cuda-python need the following pywin32; platform_system == "Windows" -nvtx +# For windows, run `conda install -c conda-forge nvtx` instead +nvtx; platform_system != "Windows" # Please install PyTorch 2.1 or above for 12.1 using one of the following commands: # pip3 install torch --index-url https://download.pytorch.org/whl/cu121 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index 9386a941fb323..a00e25ddd983f 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,8 +1,8 @@ -diffusers>=0.19.3 -transformers>=4.31.0 +diffusers==0.19.3 +transformers==4.31.0 numpy>=1.24.1 accelerate -onnx>=1.13.0 +onnx==1.14.0 coloredlogs packaging # Use newer version of protobuf might cause crash @@ -10,6 +10,8 @@ protobuf==3.20.3 psutil sympy # The following are for SDXL -optimum>=1.11.1 +optimum==1.13.1 safetensors invisible_watermark +# newer version of opencv-python migth encounter module 'cv2.dnn' has no attribute 'DictValue' error +opencv-python==4.8.0.74