Skip to content

Commit

Permalink
Feature/sg 000 fix import of onnx graphsurgeon (#1359)
Browse files Browse the repository at this point in the history
* Update readme

* Fix small bug in __repr__ implementation of KeypointsImageToTensor

* Test

* Test

* Test

* Test

* Test

* Test

* Make graphsurgeon an optional

* Make graphsurgeon an optional

* Properly handle imports of optional packages

* Added empty __init__.py files

* Do imports of gs inside the export call

* Do imports of gs inside the export call

* Fix DEKR's missing HasPredict interface

* Update notebook & example doc to reflect changes in imports & function names

* Update readme

* Put back images

* Install onnx_graphsurgeon in CI

* Install onnx_graphsurgeon in CI

* Fix version of ONNX-GS installed in CI and installed on-demand

* Fix arange_cpu not implemented for Half

* Fix arange_cpu not implemented for Half

* Fix graph merging for old pytorch (1.12) that crashed because of nodes with duplicate names
  • Loading branch information
BloodAxe authored and Louis-Dupont committed Aug 15, 2023
1 parent 85b2335 commit 72bd969
Show file tree
Hide file tree
Showing 19 changed files with 111 additions and 49 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ jobs:
command: |
. venv/bin/activate
python3 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnx_graphsurgeon==0.3.27 --extra-index-url https://pypi.ngc.nvidia.com
- run:
name: run tests with coverage
Expand Down
39 changes: 21 additions & 18 deletions documentation/source/models_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ pred_classes, pred_classes.shape
For sake of this tutorial we will use a simple visualization function that is tailored for batch_size=1 only.
You can use it as a starting point for your own visualization code.


```python
from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
from super_gradients.training.utils.detection_utils import DetectionVisualization
Expand All @@ -266,8 +267,7 @@ def show_predictions_from_batch_format(image, predictions):
class_names = COCO_DETECTION_CLASSES_LIST
color_mapping = DetectionVisualization._generate_color_mapping(len(class_names))

for (x1, y1, x2, y2, class_score, class_index) in zip(pred_boxes[:, 0], pred_boxes[:, 1], pred_boxes[:, 2],
pred_boxes[:, 3], pred_scores, pred_classes):
for (x1, y1, x2, y2, class_score, class_index) in zip(pred_boxes[:, 0], pred_boxes[:, 1], pred_boxes[:, 2], pred_boxes[:, 3], pred_scores, pred_classes):
image = DetectionVisualization.draw_box_title(
image_np=image,
x1=int(x1),
Expand Down Expand Up @@ -306,7 +306,7 @@ You can explicitly specify output format of the predictions by setting the `outp


```python
from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode
from super_gradients.conversion import DetectionOutputFormatMode

export_result = model.export("yolo_nas_s.onnx", output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT)
export_result
Expand Down Expand Up @@ -371,6 +371,9 @@ result[0].shape

(25, 7)




```python
def show_predictions_from_flat_format(image, predictions):
[flat_predictions] = predictions
Expand All @@ -382,17 +385,17 @@ def show_predictions_from_flat_format(image, predictions):
for (sample_index, x1, y1, x2, y2, class_score, class_index) in flat_predictions[flat_predictions[:, 0] == 0]:
class_index = int(class_index)
image = DetectionVisualization.draw_box_title(
image_np=image,
x1=int(x1),
y1=int(y1),
x2=int(x2),
y2=int(y2),
class_id=class_index,
class_names=class_names,
color_mapping=color_mapping,
box_thickness=2,
pred_conf=class_score,
)
image_np=image,
x1=int(x1),
y1=int(y1),
x2=int(x2),
y2=int(y2),
class_id=class_index,
class_names=class_names,
color_mapping=color_mapping,
box_thickness=2,
pred_conf=class_score,
)

plt.figure(figsize=(8, 8))
plt.imshow(image)
Expand Down Expand Up @@ -497,7 +500,7 @@ In the example below we use a dummy data-loader for sake of showing how to use t
```python
import torch
from torch.utils.data import DataLoader
from super_gradients.conversion.conversion_enums import ExportQuantizationMode
from super_gradients.conversion import ExportQuantizationMode

# THIS IS ONLY AN EXAMPLE. YOU SHOULD USE YOUR OWN DATA-LOADER HERE
dummy_calibration_dataset = [torch.randn((3, 640, 640), dtype=torch.float32) for _ in range(32)]
Expand All @@ -519,7 +522,7 @@ result = session.run(outputs, {inputs[0]: image_bchw})
show_predictions_from_flat_format(image, result)
```

25%|████████████████████████████████████████████████████████▎ | 4/16 [00:12<00:36, 3.08s/it]
25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.87s/it]



Expand Down Expand Up @@ -552,13 +555,13 @@ Therefore, ONNX Runtime backend is recommended for most use-cases and is used by
You can specify the desired execution backend by setting the `execution_backend` argument of `export()` method:

```python
from super_gradients.conversion.conversion_enums import ExportTargetBackend
from super_gradients.conversion import ExportTargetBackend

model.export(..., engine=ExportTargetBackend.ONNXRUNTIME)
```

```python
from super_gradients.conversion.conversion_enums import ExportTargetBackend
from super_gradients.conversion import ExportTargetBackend

model.export(..., engine=ExportTargetBackend.TENSORRT)
```
Expand Down
Binary file modified documentation/source/models_export_files/models_export_18_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified documentation/source/models_export_files/models_export_24_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified documentation/source/models_export_files/models_export_26_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified documentation/source/models_export_files/models_export_28_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified documentation/source/models_export_files/models_export_30_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,3 @@ numpy<=1.23
rapidfuzz
json-tricks==3.16.1
onnx-simplifier>=0.3.6,<1.0
--extra-index-url https://pypi.ngc.nvidia.com
onnx_graphsurgeon>=0.3.8,<0.4
3 changes: 3 additions & 0 deletions src/super_gradients/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .conversion_enums import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode

__all__ = ["ExportQuantizationMode", "DetectionOutputFormatMode", "ExportTargetBackend"]
22 changes: 22 additions & 0 deletions src/super_gradients/conversion/gs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def import_onnx_graphsurgeon_or_fail_with_instructions():
try:
import onnx_graphsurgeon as gs
except ImportError:
raise ImportError(
"onnx-graphsurgeon is required to use export API. "
"Please install it with pip install onnx_graphsurgeon==0.3.27 --extra-index-url https://pypi.ngc.nvidia.com"
)
return gs


def import_onnx_graphsurgeon_or_install():
try:
import onnx_graphsurgeon as gs

return gs
except ImportError:
import pip

pip.main(["install", "onnx_graphsurgeon==0.3.27", "--extra-index-url", "https://pypi.ngc.nvidia.com"])

return import_onnx_graphsurgeon_or_fail_with_instructions()
Empty file.
6 changes: 4 additions & 2 deletions src/super_gradients/conversion/onnx/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
import numpy as np
import onnx
import onnx.shape_inference
import onnx_graphsurgeon as gs
import torch
from onnx import TensorProto
from torch import nn, Tensor

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode
from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions
from super_gradients.conversion.onnx.utils import append_graphs

logger = get_logger(__name__)

gs = import_onnx_graphsurgeon_or_fail_with_instructions()


class PickNMSPredictionsAndReturnAsBatchedResult(nn.Module):
__constants__ = ("batch_size", "max_predictions_per_image")
Expand Down Expand Up @@ -53,7 +55,7 @@ def forward(self, pred_boxes: Tensor, pred_scores: Tensor, selected_indexes: Ten

batch_predictions = torch.zeros((self.batch_size, self.max_predictions_per_image, 6), dtype=predictions.dtype, device=predictions.device)

batch_indexes = torch.arange(start=0, end=self.batch_size, step=1, device=predictions.device, dtype=predictions.dtype)
batch_indexes = torch.arange(start=0, end=self.batch_size, step=1, device=predictions.device).to(dtype=predictions.dtype)
masks = batch_indexes.view(-1, 1).eq(predictions[:, 0].view(1, -1)) # [B, N]

num_predictions = torch.sum(masks, dim=1).long()
Expand Down
24 changes: 22 additions & 2 deletions src/super_gradients/conversion/onnx/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
import onnx_graphsurgeon as gs
from onnx import shape_inference

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions

logger = get_logger(__name__)

gs = import_onnx_graphsurgeon_or_fail_with_instructions()

def append_graphs(graph1: gs.Graph, graph2: gs.Graph) -> gs.Graph:

def append_prefix_to_graph(graph: gs.Graph, prefix: str) -> gs.Graph:
"""
Append a prefix to all nodes and outputs in the graph to avoid name collisions.
:param graph: The graph to rename.
:param prefix: The prefix to add to all nodes and outputs in the graph.
:return: The modified new instance of graph.
"""
for node in graph.nodes:
node.name = prefix + node.name

for output in graph.outputs:
output.name = prefix + output.name

return graph


def append_graphs(graph1: gs.Graph, graph2: gs.Graph, prefix: str = "graph2_") -> gs.Graph:
"""
Append one graph to another. This function modify graph1 in place.
Outputs from the first graph will be connected to inputs of the second graph.
:param graph1: The first graph. Will be modified in place.
:param graph2: The second graph to attach to the first graph.
:param prefix: The prefix to add to all nodes and outputs in the second graph to avoid name collisions.
:return: The first graph, with the second graph appended to it.
"""
if len(graph1.outputs) != len(graph2.inputs):
raise ValueError(f"Number of outputs ({len(graph1.outputs)}) does not match number of inputs ({len(graph2.inputs)})")

merged_graph = graph1
graph2 = append_prefix_to_graph(graph2, prefix)

for node in graph2.nodes:
merged_graph.nodes.append(node)
Expand Down
Empty file.
8 changes: 4 additions & 4 deletions src/super_gradients/conversion/tensorrt/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

import numpy as np
import onnx
import onnx_graphsurgeon as gs
import torch
from torch import nn, Tensor

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode
from super_gradients.conversion.conversion_utils import numpy_dtype_to_torch_dtype
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_fail_with_instructions
from super_gradients.conversion.onnx.utils import append_graphs

logger = get_logger(__name__)

gs = import_onnx_graphsurgeon_or_fail_with_instructions()


class ConvertTRTFormatToFlatTensor(nn.Module):
__constants__ = ["batch_size", "max_predictions_per_image"]
Expand All @@ -40,9 +42,7 @@ def forward(self, num_predictions: Tensor, pred_boxes: Tensor, pred_scores: Tens
) # [B, max_predictions_per_image]

preds_indexes = (
torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device, dtype=pred_scores.dtype)
.view(1, -1, 1)
.repeat(self.batch_size, 1, 1)
torch.arange(start=0, end=self.max_predictions_per_image, step=1, device=num_predictions.device).view(1, -1, 1).repeat(self.batch_size, 1, 1)
) # [B, max_predictions_per_image, 1]

flat_predictions = torch.cat(
Expand Down
12 changes: 6 additions & 6 deletions src/super_gradients/examples/model_export/models_export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@
" color_mapping = DetectionVisualization._generate_color_mapping(len(class_names))\n",
"\n",
" for (x1, y1, x2, y2, class_score, class_index) in zip(pred_boxes[:, 0], pred_boxes[:, 1], pred_boxes[:, 2], pred_boxes[:, 3], pred_scores, pred_classes):\n",
" image = DetectionVisualization._draw_box_title(\n",
" image = DetectionVisualization.draw_box_title(\n",
" image_np=image,\n",
" x1=int(x1),\n",
" y1=int(y1),\n",
Expand Down Expand Up @@ -450,7 +450,7 @@
}
],
"source": [
"from super_gradients.conversion.conversion_enums import DetectionOutputFormatMode\n",
"from super_gradients.conversion import DetectionOutputFormatMode\n",
"\n",
"export_result = model.export(\"yolo_nas_s.onnx\", output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT)\n",
"export_result"
Expand Down Expand Up @@ -514,7 +514,7 @@
"\n",
" for (sample_index, x1, y1, x2, y2, class_score, class_index) in flat_predictions[flat_predictions[:, 0] == 0]:\n",
" class_index = int(class_index)\n",
" image = DetectionVisualization._draw_box_title(\n",
" image = DetectionVisualization.draw_box_title(\n",
" image_np=image,\n",
" x1=int(x1),\n",
" y1=int(y1),\n",
Expand Down Expand Up @@ -715,7 +715,7 @@
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from super_gradients.conversion.conversion_enums import ExportQuantizationMode\n",
"from super_gradients.conversion import ExportQuantizationMode\n",
"\n",
"# THIS IS ONLY AN EXAMPLE. YOU SHOULD USE YOUR OWN DATA-LOADER HERE\n",
"dummy_calibration_dataset = [torch.randn((3, 640, 640), dtype=torch.float32) for _ in range(32)]\n",
Expand Down Expand Up @@ -778,13 +778,13 @@
"You can specify the desired execution backend by setting the `execution_backend` argument of `export()` method:\n",
"\n",
"```python\n",
"from super_gradients.conversion.conversion_enums import ExportTargetBackend\n",
"from super_gradients.conversion import ExportTargetBackend\n",
"\n",
"model.export(..., engine=ExportTargetBackend.ONNXRUNTIME)\n",
"```\n",
"\n",
"```python\n",
"from super_gradients.conversion.conversion_enums import ExportTargetBackend\n",
"from super_gradients.conversion import ExportTargetBackend\n",
"\n",
"model.export(..., engine=ExportTargetBackend.TENSORRT)\n",
"```"
Expand Down
21 changes: 13 additions & 8 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,11 @@
from torch.utils.data import DataLoader

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.conversion.conversion_enums import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode
from super_gradients.conversion.conversion_utils import torch_dtype_to_numpy_dtype
from super_gradients.conversion.onnx.nms import attach_onnx_nms
from super_gradients.conversion.preprocessing_modules import CastTensorTo
from super_gradients.conversion.tensorrt.nms import attach_tensorrt_nms
from super_gradients.conversion import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install
from super_gradients.training.utils.export_utils import infer_format_from_file_name, infer_image_shape_from_model, infer_image_input_channels
from super_gradients.training.utils.quantization.fix_pytorch_quantization_modules import patch_pytorch_quantization_modules_if_needed
from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules
from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator


logger = get_logger(__name__)
Expand Down Expand Up @@ -128,7 +123,7 @@ def export(
nms_threshold: Optional[float] = None,
engine: Optional[ExportTargetBackend] = None,
quantization_mode: ExportQuantizationMode = Optional[None],
selective_quantizer: Optional[SelectiveQuantizer] = None,
selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
calibration_loader: Optional[DataLoader] = None,
calibration_method: str = "percentile",
calibration_batches: int = 16,
Expand Down Expand Up @@ -213,10 +208,20 @@ def export(
:param num_pre_nms_predictions: (int) Number of predictions to keep before NMS.
:return:
"""

# Do imports here to avoid raising error of missing onnx_graphsurgeon package if it is not needed.
import_onnx_graphsurgeon_or_install()
from super_gradients.conversion.conversion_utils import torch_dtype_to_numpy_dtype
from super_gradients.conversion.onnx.nms import attach_onnx_nms
from super_gradients.conversion.preprocessing_modules import CastTensorTo
from super_gradients.conversion.tensorrt.nms import attach_tensorrt_nms

usage_instructions = []

try:
from pytorch_quantization import nn as quant_nn
from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer

patch_pytorch_quantization_modules_if_needed()
except ImportError:
Expand Down
16 changes: 10 additions & 6 deletions src/super_gradients/training/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
import inspect
import math
import os
import random
import re
import tarfile
import time
import inspect
import typing
import warnings
from functools import lru_cache, wraps
from importlib import import_module
from itertools import islice

from pathlib import Path
from typing import Mapping, Optional, Tuple, Union, List, Dict, Any, Iterable
from zipfile import ZipFile

from pytorch_quantization.nn.modules._utils import QuantMixin
from super_gradients.training.utils.quantization.core import SGQuantMixin
from torch.nn.parallel import DistributedDataParallel

import numpy as np
import torch
import torch.nn as nn
from PIL import Image, ExifTags
from jsonschema import validate
from torch.nn.parallel import DistributedDataParallel

from super_gradients.common.abstractions.abstract_logger import get_logger

Expand Down Expand Up @@ -674,6 +670,14 @@ def check_model_contains_quantized_modules(model: nn.Module) -> bool:
:param model: Model to check.
:return: True if the model contains any quantized modules, False otherwise.
"""
try:
from pytorch_quantization.nn.modules._utils import QuantMixin
except ImportError:
# If pytorch_quantization is not installed then by definition the model cannot contain any quantized modules
return False

from super_gradients.training.utils.quantization.core import SGQuantMixin

model = unwrap_model(model)
for m in model.modules():
if isinstance(m, (QuantMixin, SGQuantMixin)):
Expand Down
Loading

0 comments on commit 72bd969

Please sign in to comment.