Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 14, 2024
1 parent e956e81 commit 6810e4d
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 164 deletions.
1 change: 0 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def __init__(
setattr(self.generation_config, param_name, param_value)
setattr(self.config, param_name, None)

self.onnx_paths = [self.model_path]
self.use_merged = "use_cache_branch" in self.input_names
self.model_type = self.config.model_type

Expand Down
204 changes: 85 additions & 119 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,24 @@ def __init__(
# would end-up removing the directory containing the underlying ONNX model.
self._model_save_dir_tempdirectory_instance = model_save_dir
self.model_save_dir = Path(model_save_dir.name)
elif isinstance(model_save_dir, TemporaryDirectory):
self.model_save_dir = Path(model_save_dir.name)
elif isinstance(model_save_dir, (str, Path)):
self.model_save_dir = Path(model_save_dir)
else:
self.model_save_dir = Path(unet_session._model_path).parent

# because OptimizedModel requires it
self.preprocessors = kwargs.pop("preprocessors", [])

# TODO: Maybe move this to from_pretrained so that the pipeline class can be instantiated with ORTModel instances
self.unet = ORTModelUnet(unet_session, self) if unet_session is not None else None
self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) if vae_decoder_session is not None else None
self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) if vae_encoder_session is not None else None
self.unet = ORTModelUnet(unet_session, use_io_binding) if unet_session is not None else None
self.vae_decoder = (
ORTModelVaeDecoder(vae_decoder_session, use_io_binding) if vae_decoder_session is not None else None
)
self.vae_encoder = (
ORTModelVaeEncoder(vae_encoder_session, use_io_binding) if vae_encoder_session is not None else None
)
self.text_encoder = (
ORTModelTextEncoder(text_encoder_session, self) if text_encoder_session is not None else None
ORTModelTextEncoder(text_encoder_session, use_io_binding) if text_encoder_session is not None else None
)
self.text_encoder_2 = (
ORTModelTextEncoder(text_encoder_2_session, self) if text_encoder_2_session is not None else None
ORTModelTextEncoder(text_encoder_2_session, use_io_binding) if text_encoder_2_session is not None else None
)
# We wrap the VAE Decoder & Encoder in a single object to simulate diffusers API
self.vae = ORTWrapperVae(self.vae_encoder, self.vae_decoder)
Expand Down Expand Up @@ -167,13 +166,8 @@ def __init__(
diffusers_pipeline_args[key] = all_pipeline_init_args[key]
self.auto_model_class.__init__(self, **diffusers_pipeline_args)

if use_io_binding is None:
if self.provider == "CUDAExecutionProvider":
self.use_io_binding = True
else:
self.use_io_binding = False
else:
self.use_io_binding = use_io_binding
# Forced on every class inheriting from OptimizedModel
self.preprocessors = kwargs.pop("preprocessors", [])

@property
def components(self) -> Dict[str, Any]:
Expand All @@ -185,49 +179,53 @@ def components(self) -> Dict[str, Any]:
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
}
components = dict(filter(lambda x: x[1] is not None, components.items()))
components = {k: v for k, v in components.items() if v is not None}
return components

def _validate_same_attribute_value_across_components(self, attribute: str):
# The idea is that these attributes make sense for the pipeline as a whole only when they are the same across
# all components, so we do support these attributes but also allow the user to experiment with undefined behavior
# like having heterogeneous devices or io bindings across components.
attribute_values = {
name: getattr(component, attribute)
for name, component in self.components.items()
if hasattr(component, attribute)
}

if len(set(attribute_values.values())) > 1:
raise ValueError(f"Attribute {attribute} is not the same across components: {attribute_values}.")

return next(iter(attribute_values.values()))

@property
def device(self) -> torch.device:
return self.unet.device
return self._validate_same_attribute_value_across_components("device")

@property
def providers(self):
return self.unet.providers
def dtype(self) -> torch.dtype:
return self._validate_same_attribute_value_across_components("dtype")

@property
def proviers_options(self):
return self.unet.providers_options
def provider(self) -> str:
return self._validate_same_attribute_value_across_components("provider")

@property
def provider(self):
return self.unet.provider
def provider_options(self) -> Dict[str, Any]:
return self._validate_same_attribute_value_across_components("provider_options")

@property
def provider_options(self):
return self.unet.provider_options

def to(self, *args, device: Optional[Union[torch.device, int, str]] = None, dtype: Optional[torch.dtype] = None):
for arg in args:
if isinstance(arg, torch.device):
device = arg
elif isinstance(arg, (int, str)):
device = torch.device(arg)
elif isinstance(arg, torch.dtype):
dtype = arg

if dtype is not None and dtype != self.dtype:
raise NotImplementedError(
f"Cannot change the dtype of the pipeline from {self.dtype} to {dtype}. "
f"Please export the pipeline with the desired dtype."
)
def use_io_binding(self) -> bool:
return self._validate_same_attribute_value_across_components("use_io_binding")

if device is not None and device != self.device:
for component in self.components.values():
if component is not None:
component.to(device=device, dtype=dtype)
@use_io_binding.setter
def use_io_binding(self, value):
for component in self.components.values():
if hasattr(component, "use_io_binding"):
component.use_io_binding = value

def to(self, *args, **kwargs):
for component in self.components.values():
component.to(*args, **kwargs)
return self

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -454,36 +452,32 @@ def _save_config(self, save_directory):
class ORTPipelinePart(ConfigMixin):
config_name: str = CONFIG_NAME

def __init__(self, session: ort.InferenceSession, parent_pipeline: ORTDiffusionPipeline):
# config should be in the same directory as the onnx model
def __init__(self, session: ort.InferenceSession, use_io_binding: Optional[bool]):
# config is mandatory for the model part to be used for inference
config_file_path = Path(session._model_path).parent / self.config_name
if not config_file_path.is_file():
# config is mandatory for the model part to be used for inference
raise ValueError(f"Configuration file for {self.__class__.__name__} not found at {config_file_path}")
config_dict = self._dict_from_json_file(config_file_path)
self.register_to_config(**config_dict)
else:
self.register_to_config(**self._dict_from_json_file(config_file_path))

self.session = session
self.parent_pipeline = parent_pipeline
self.use_io_binding = use_io_binding or session.get_providers()[0] in ["CUDAExecutionProvider"]

self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}

self.input_shapes = {input_key.name: input_key.shape for input_key in self.session.get_inputs()}
self.input_dtypes = {input_key.name: input_key.type for input_key in self.session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()}

self.input_shapes = {input_key.name: input_key.shape for input_key in self.session.get_inputs()}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.output_shapes = {output_key.name: output_key.shape for output_key in self.session.get_outputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()}

self._known_symbols = {name: value for name, value in self.config.items() if isinstance(value, int)}
self._compiled_input_shapes = self._compile_shapes(self.input_shapes)
self._known_symbols = {k: v for k, v in self.config.items() if isinstance(v, int)}
self._compiled_output_shapes = self._compile_shapes(self.output_shapes)
self._compiled_input_shapes = self._compile_shapes(self.input_shapes)

self._providers = self.session.get_providers()
self._providers_options = self.session.get_provider_options()
self._device = get_device_for_provider(
provider=self._providers[0], provider_options=next(iter(self._providers_options.values()))
)
self._device = get_device_for_provider(provider=self.provider, provider_options=self.provider_options)

def _compile_shapes(self, shapes: Dict[str, Tuple[Union[int, str]]]) -> Dict[str, Tuple[sp.Basic]]:
compiled_shapes = {}
Expand All @@ -497,28 +491,16 @@ def _compile_shapes(self, shapes: Dict[str, Tuple[Union[int, str]]]) -> Dict[str
return compiled_shapes

@property
def providers(self):
# all providers
return self._providers
def device(self):
return self._device

@property
def provider(self):
# main provider
return self._providers[0]

@property
def providers_options(self):
# all provider options
return self._providers_options

@property
def provider_options(self):
# main provider options
return self.providers_options[self.provider]

@property
def device(self):
return self._device
return self._providers_options[self._providers[0]]

@property
def dtype(self):
Expand All @@ -534,11 +516,9 @@ def dtype(self):

return None

def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None):
def to(self, *args, device: Optional[Union[int, str, torch.device]] = None, dtype: Optional[torch.dtype] = None):
for arg in args:
if isinstance(arg, torch.device):
device = arg
elif isinstance(arg, (int, str)):
if isinstance(arg, (int, str, torch.device)):
device = torch.device(arg)
elif isinstance(arg, torch.dtype):
dtype = arg
Expand All @@ -556,19 +536,6 @@ def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtyp
provider = get_provider_for_device(device)
validate_provider_availability(provider)

if self.use_io_binding is False and provider == "CUDAExecutionProvider":
self.use_io_binding = True
logger.info(
"use_io_binding was set to False with a CUDAExecutionProvider, setting it to True as it can speed up inference. "
"It is possible to disable this feature manually by setting the use_io_binding attribute back to False."
)
elif self.use_io_binding is True and provider == "ROCMExecutionProvider":
self.use_io_binding = False
logger.warning(
"use_io_binding was set to True with a ROCMExecutionProvider, setting it to False as it is not supported. "
"It is possible to enable this feature manually by setting the use_io_binding attribute back to True."
)

self.session.set_providers([provider], provider_options=[provider_options])

self._providers = self.session.get_providers()
Expand All @@ -577,8 +544,8 @@ def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtyp
if self.provider != provider or self.provider_options != provider_options:
raise ValueError(
f"Failed to set the device to {device}. "
f"Requested provider: {provider}, Requested provider options: {provider_options}. "
f"Session provider: {self.provider}, Session provider options: {self.provider_options}"
f"Requested provider {provider} with options: {provider_options}, "
f"but got provider {self.provider} with options: {self.provider_options}."
)

self._device = device
Expand Down Expand Up @@ -642,7 +609,7 @@ def _get_io_binding_outputs(self, io_binding: ort.IOBinding) -> Dict[str, torch.

return model_outputs

def prepare_io_binding(self, model_inputs: torch.Tensor) -> Tuple[ort.IOBinding, Dict[str, torch.Tensor]]:
def _prepare_io_binding(self, model_inputs: torch.Tensor) -> Tuple[ort.IOBinding, Dict[str, torch.Tensor]]:
io_binding = self.session.io_binding()

for input_name in self.input_names.keys():
Expand Down Expand Up @@ -683,7 +650,6 @@ def prepare_io_binding(self, model_inputs: torch.Tensor) -> Tuple[ort.IOBinding,
buffer_ptr=output_tensor.data_ptr(),
shape=tuple(output_tensor.size()),
)

except Exception as e:
logger.error(
f"Failed to prepare IO binding for {self.__class__.__name__}. "
Expand All @@ -700,7 +666,7 @@ def prepare_io_binding(self, model_inputs: torch.Tensor) -> Tuple[ort.IOBinding,

return io_binding, model_outputs

def prepare_onnx_inputs(self, **inputs: Union[torch.Tensor, np.ndarray]) -> Dict[str, np.ndarray]:
def _prepare_onnx_inputs(self, **inputs: Union[torch.Tensor, np.ndarray]) -> Dict[str, np.ndarray]:
onnx_inputs = {}

for input_name in self.input_names.keys():
Expand All @@ -716,7 +682,7 @@ def prepare_onnx_inputs(self, **inputs: Union[torch.Tensor, np.ndarray]) -> Dict

return onnx_inputs

def prepare_onnx_outputs(self, *onnx_outputs: np.ndarray) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
def _prepare_onnx_outputs(self, *onnx_outputs: np.ndarray) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
model_outputs = {}

for output_name, idx in self.output_names.items():
Expand Down Expand Up @@ -773,15 +739,15 @@ def forward(
**(added_cond_kwargs or {}),
}

if self.parent_pipeline.use_io_binding:
io_binding, model_outputs = self.prepare_io_binding(model_inputs)
if self.use_io_binding:
io_binding, model_outputs = self._prepare_io_binding(model_inputs)
self.session.run_with_iobinding(io_binding)
if model_outputs is None:
model_outputs = self.get_io_binding_outputs(io_binding, model_outputs)
model_outputs = self._get_io_binding_outputs(io_binding, model_outputs)
else:
onnx_inputs = self.prepare_onnx_inputs(**model_inputs)
onnx_inputs = self._prepare_onnx_inputs(**model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(*onnx_outputs)
model_outputs = self._prepare_onnx_outputs(*onnx_outputs)

if return_dict:
return model_outputs
Expand All @@ -799,15 +765,15 @@ def forward(
):
model_inputs = {"input_ids": input_ids}

if self.parent_pipeline.use_io_binding:
io_binding, model_outputs = self.prepare_io_binding(model_inputs)
if self.use_io_binding:
io_binding, model_outputs = self._prepare_io_binding(model_inputs)
self.session.run_with_iobinding(io_binding)
if model_outputs is None:
model_outputs = self.get_io_binding_outputs(io_binding, model_outputs)
model_outputs = self._get_io_binding_outputs(io_binding, model_outputs)
else:
onnx_inputs = self.prepare_onnx_inputs(**model_inputs)
onnx_inputs = self._prepare_onnx_inputs(**model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(*onnx_outputs)
model_outputs = self._prepare_onnx_outputs(*onnx_outputs)

if output_hidden_states:
model_outputs["hidden_states"] = []
Expand Down Expand Up @@ -846,15 +812,15 @@ def forward(
):
model_inputs = {"sample": sample}

if self.parent_pipeline.use_io_binding:
io_binding, model_outputs = self.prepare_io_binding(model_inputs)
if self.use_io_binding:
io_binding, model_outputs = self._prepare_io_binding(model_inputs)
self.session.run_with_iobinding(io_binding)
if model_outputs is None:
model_outputs = self.get_io_binding_outputs(io_binding, model_outputs)
model_outputs = self._get_io_binding_outputs(io_binding, model_outputs)
else:
onnx_inputs = self.prepare_onnx_inputs(**model_inputs)
onnx_inputs = self._prepare_onnx_inputs(**model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(*onnx_outputs)
model_outputs = self._prepare_onnx_outputs(*onnx_outputs)

if "latent_sample" in model_outputs:
model_outputs["latents"] = model_outputs.pop("latent_sample")
Expand Down Expand Up @@ -892,15 +858,15 @@ def forward(
):
model_inputs = {"latent_sample": latent_sample}

if self.parent_pipeline.use_io_binding:
io_binding, model_outputs = self.prepare_io_binding(model_inputs)
if self.use_io_binding:
io_binding, model_outputs = self._prepare_io_binding(model_inputs)
self.session.run_with_iobinding(io_binding)
if model_outputs is None:
model_outputs = self.get_io_binding_outputs(io_binding, model_outputs)
model_outputs = self._get_io_binding_outputs(io_binding, model_outputs)
else:
onnx_inputs = self.prepare_onnx_inputs(**model_inputs)
onnx_inputs = self._prepare_onnx_inputs(**model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self.prepare_onnx_outputs(*onnx_outputs)
model_outputs = self._prepare_onnx_outputs(*onnx_outputs)

if "latent_sample" in model_outputs:
model_outputs["latents"] = model_outputs.pop("latent_sample")
Expand Down
Loading

0 comments on commit 6810e4d

Please sign in to comment.