Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of Stable Diffusion scripts #17138

Merged
merged 1 commit into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# Copyright 2023 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Models used in Stable diffusion.
"""
import logging

import onnx
import onnx_graphsurgeon as gs
import torch
from onnx import shape_inference
from ort_optimizer import OrtStableDiffusionOptimizer
from polygraphy.backend.onnx.loader import fold_constants

logger = logging.getLogger(__name__)


class TrtOptimizer:
def __init__(self, onnx_graph):
self.graph = gs.import_onnx(onnx_graph)

def cleanup(self):
self.graph.cleanup().toposort()

def get_optimized_onnx_graph(self):
return gs.export_onnx(self.graph)

def select_outputs(self, keep, names=None):
self.graph.outputs = [self.graph.outputs[o] for o in keep]
if names:
for i, name in enumerate(names):
self.graph.outputs[i].name = name

def fold_constants(self):
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
self.graph = gs.import_onnx(onnx_graph)

def infer_shapes(self):
onnx_graph = gs.export_onnx(self.graph)
if onnx_graph.ByteSize() > 2147483648:
raise TypeError("ERROR: model size exceeds supported 2GB limit")
else:
onnx_graph = shape_inference.infer_shapes(onnx_graph)

self.graph = gs.import_onnx(onnx_graph)


class BaseModel:
def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77):
self.model = model
self.name = name
self.fp16 = fp16
self.device = device

self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
self.min_latent_shape = self.min_image_shape // 8
self.max_latent_shape = self.max_image_shape // 8

self.embedding_dim = embedding_dim
self.text_maxlen = text_maxlen

self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae"
self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type)

def get_model(self):
return self.model

def get_input_names(self):
pass

def get_output_names(self):
pass

def get_dynamic_axes(self):
return None

def get_sample_input(self, batch_size, image_height, image_width):
pass

def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
"""For TensorRT EP"""
(
min_batch,
max_batch,
min_image_height,
max_image_height,
min_image_width,
max_image_width,
_,
_,
_,
_,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)

profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"

if self.name != "CLIP":
if static_image_shape:
profile_id += f"_h_{image_height}_w_{image_width}"
else:
profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"

return profile_id

def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
"""For TensorRT"""
return None

def get_shape_dict(self, batch_size, image_height, image_width):
return None

def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True):
self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16)

def optimize_trt(self, input_onnx_path, optimized_onnx_path):
onnx_graph = onnx.load(input_onnx_path)
opt = TrtOptimizer(onnx_graph)
opt.cleanup()
opt.fold_constants()
opt.infer_shapes()
opt.cleanup()
onnx_opt_graph = opt.get_optimized_onnx_graph()
onnx.save(onnx_opt_graph, optimized_onnx_path)

def check_dims(self, batch_size, image_height, image_width):
assert batch_size >= self.min_batch and batch_size <= self.max_batch
assert image_height % 8 == 0 or image_width % 8 == 0
latent_height = image_height // 8
latent_width = image_width // 8
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
return (latent_height, latent_width)

def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
latent_height = image_height // 8
latent_width = image_width // 8
min_image_height = image_height if static_image_shape else self.min_image_shape
max_image_height = image_height if static_image_shape else self.max_image_shape
min_image_width = image_width if static_image_shape else self.min_image_shape
max_image_width = image_width if static_image_shape else self.max_image_shape
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
return (
min_batch,
max_batch,
min_image_height,
max_image_height,
min_image_width,
max_image_width,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
)


class CLIP(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model,
name="CLIP",
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
)

def get_input_names(self):
return ["input_ids"]

def get_output_names(self):
return ["text_embeddings"]

def get_dynamic_axes(self):
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}

def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
batch_size, image_height, image_width, static_batch, static_image_shape
)
return {
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
}

def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return {
"input_ids": (batch_size, self.text_maxlen),
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
}

def get_sample_input(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)

def optimize_trt(self, input_onnx_path, optimized_onnx_path):
onnx_graph = onnx.load(input_onnx_path)
opt = TrtOptimizer(onnx_graph)
opt.select_outputs([0]) # delete graph output#1
opt.cleanup()
opt.fold_constants()
opt.infer_shapes()
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
opt.cleanup()
onnx_opt_graph = opt.get_optimized_onnx_graph()
onnx.save(onnx_opt_graph, optimized_onnx_path)


class UNet(BaseModel):
def __init__(
self,
model,
device="cuda",
fp16=False, # used by TRT
max_batch_size=16,
embedding_dim=768,
text_maxlen=77,
unet_dim=4,
):
super().__init__(
model=model,
name="UNet",
device=device,
fp16=fp16,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
text_maxlen=text_maxlen,
)
self.unet_dim = unet_dim

def get_input_names(self):
return ["sample", "timestep", "encoder_hidden_states"]

def get_output_names(self):
return ["latent"]

def get_dynamic_axes(self):
return {
"sample": {0: "2B", 2: "H", 3: "W"},
"encoder_hidden_states": {0: "2B"},
"latent": {0: "2B", 2: "H", 3: "W"},
}

def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
(
min_batch,
max_batch,
_,
_,
_,
_,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
return {
"sample": [
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
(2 * batch_size, self.unet_dim, latent_height, latent_width),
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
],
"encoder_hidden_states": [
(2 * min_batch, self.text_maxlen, self.embedding_dim),
(2 * batch_size, self.text_maxlen, self.embedding_dim),
(2 * max_batch, self.text_maxlen, self.embedding_dim),
],
}

def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
"timestep": [1],
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
"latent": (2 * batch_size, 4, latent_height, latent_width),
}

def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
dtype = torch.float16 if self.fp16 else torch.float32
return (
torch.randn(
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
),
torch.tensor([1.0], dtype=torch.float32, device=self.device),
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
)


class VAE(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model,
name="VAE Decoder",
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
)

def get_input_names(self):
return ["latent"]

def get_output_names(self):
return ["images"]

def get_dynamic_axes(self):
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}

def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
(
min_batch,
max_batch,
_,
_,
_,
_,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
return {
"latent": [
(min_batch, 4, min_latent_height, min_latent_width),
(batch_size, 4, latent_height, latent_width),
(max_batch, 4, max_latent_height, max_latent_width),
]
}

def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"latent": (batch_size, 4, latent_height, latent_width),
"images": (batch_size, 3, image_height, image_width),
}

def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
Loading
Loading