-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move each format exporter to individual file
- Loading branch information
Mikel Broström
committed
Jul 12, 2024
1 parent
b0dc822
commit 311e679
Showing
7 changed files
with
273 additions
and
314 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import logging | ||
import torch | ||
from pathlib import Path | ||
from boxmot.utils.checks import RequirementsChecker | ||
from boxmot.utils import logger as LOGGER | ||
|
||
checker = RequirementsChecker() | ||
|
||
class BaseExporter: | ||
def __init__(self, model, im, file, optimize=False, dynamic=False, half=False, simplify=False): | ||
self.model = model | ||
self.im = im | ||
self.file = Path(file) | ||
self.optimize = optimize | ||
self.dynamic = dynamic | ||
self.half = half | ||
self.simplify = simplify | ||
|
||
@staticmethod | ||
def file_size(path): | ||
path = Path(path) | ||
if path.is_file(): | ||
return path.stat().st_size / 1e6 | ||
elif path.is_dir(): | ||
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / 1e6 | ||
else: | ||
return 0.0 | ||
|
||
def export(self): | ||
raise NotImplementedError("Export method must be implemented in subclasses.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
import onnx | ||
from base_exporter import BaseExporter | ||
|
||
class ONNXExporter(BaseExporter): | ||
def export(self): | ||
try: | ||
checker.check_packages(("onnx==1.14.0",)) | ||
f = self.file.with_suffix(".onnx") | ||
LOGGER.info(f"\nStarting export with onnx {onnx.__version__}...") | ||
|
||
dynamic = {"images": {0: "batch"}, "output": {0: "batch"}} if self.dynamic else None | ||
|
||
torch.onnx.export( | ||
self.model.cpu() if self.dynamic else self.model, | ||
self.im.cpu() if self.dynamic else self.im, | ||
f, | ||
verbose=False, | ||
opset_version=12, | ||
do_constant_folding=True, | ||
input_names=["images"], | ||
output_names=["output"], | ||
dynamic_axes=dynamic, | ||
) | ||
|
||
model_onnx = onnx.load(f) | ||
onnx.checker.check_model(model_onnx) | ||
onnx.save(model_onnx, f) | ||
|
||
if self.simplify: | ||
self.simplify_model(model_onnx, f) | ||
|
||
LOGGER.info(f"Export success, saved as {f} ({self.file_size(f):.1f} MB)") | ||
return f | ||
except Exception as e: | ||
LOGGER.error(f"Export failure: {e}") | ||
|
||
def simplify_model(self, model_onnx, f): | ||
try: | ||
cuda = torch.cuda.is_available() | ||
checker.check_packages( | ||
( | ||
"onnxruntime-gpu" if cuda else "onnxruntime", | ||
"onnx-simplifier>=0.4.1", | ||
) | ||
) | ||
import onnxsim | ||
|
||
LOGGER.info( | ||
f"Simplifying with onnx-simplifier {onnxsim.__version__}..." | ||
) | ||
model_onnx, check = onnxsim.simplify(model_onnx) | ||
assert check, "assert check failed" | ||
onnx.save(model_onnx, f) | ||
except Exception as e: | ||
LOGGER.error(f"Simplifier failure: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
import openvino.runtime as ov | ||
from openvino.tools import mo | ||
from base_exporter import BaseExporter | ||
|
||
class OpenVINOExporter(BaseExporter): | ||
def export(self): | ||
checker.check_packages( | ||
("openvino-dev>=2023.0",) | ||
) | ||
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}") | ||
f_onnx = self.file.with_suffix(".onnx") | ||
f_ov = str(Path(f) / self.file.with_suffix(".xml").name) | ||
try: | ||
LOGGER.info(f"\nStarting export with openvino {ov.__version__}...") | ||
ov_model = mo.convert_model( | ||
f_onnx, | ||
model_name=self.file.with_suffix(".xml"), | ||
framework="onnx", | ||
compress_to_fp16=self.half, | ||
) | ||
ov.serialize(ov_model, f_ov) | ||
except Exception as e: | ||
LOGGER.error(f"Export failure: {e}") | ||
LOGGER.info(f"Export success, saved as {f_ov} ({self.file_size(f_ov):.1f} MB)") | ||
return f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import platform | ||
import torch | ||
from base_exporter import BaseExporter | ||
|
||
class EngineExporter(BaseExporter): | ||
def export(self): | ||
try: | ||
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. `python export.py --device 0`" | ||
try: | ||
import tensorrt as trt | ||
except ImportError: | ||
if platform.system() == "Linux": | ||
checker.check_packages( | ||
["nvidia-tensorrt"], | ||
cmds=("-U --index-url https://pypi.ngc.nvidia.com",), | ||
) | ||
import tensorrt as trt | ||
|
||
onnx_file = self.export_onnx() | ||
LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...") | ||
assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}" | ||
f = self.file.with_suffix(".engine") | ||
logger = trt.Logger(trt.Logger.INFO) | ||
if self.verbose: | ||
logger.min_severity = trt.Logger.Severity.VERBOSE | ||
|
||
builder = trt.Builder(logger) | ||
config = builder.create_builder_config() | ||
config.max_workspace_size = self.workspace * 1 << 30 | ||
|
||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | ||
network = builder.create_network(flag) | ||
parser = trt.OnnxParser(network, logger) | ||
if not parser.parse_from_file(str(onnx_file)): | ||
raise RuntimeError(f"Failed to load ONNX file: {onnx_file}") | ||
|
||
inputs = [network.get_input(i) for i in range(network.num_inputs)] | ||
outputs = [network.get_output(i) for i in range(network.num_outputs)] | ||
LOGGER.info("Network Description:") | ||
for inp in inputs: | ||
LOGGER.info(f'\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') | ||
for out in outputs: | ||
LOGGER.info(f'\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') | ||
|
||
if self.dynamic: | ||
if self.im.shape[0] <= 1: | ||
LOGGER.warning("WARNING: --dynamic model requires maximum --batch-size argument") | ||
profile = builder.create_optimization_profile() | ||
for inp in inputs: | ||
if self.half: | ||
inp.dtype = trt.float16 | ||
profile.set_shape( | ||
inp.name, | ||
(1, *self.im.shape[1:]), | ||
(max(1, self.im.shape[0] // 2), *self.im.shape[1:]), | ||
self.im.shape, | ||
) | ||
config.add_optimization_profile(profile) | ||
|
||
LOGGER.info(f"Building FP{16 if builder.platform_has_fast_fp16 and self.half else 32} engine in {f}") | ||
if builder.platform_has_fast_fp16 and self.half: | ||
config.set_flag(trt.BuilderFlag.FP16) | ||
config.default_device_type = trt.DeviceType.GPU | ||
with builder.build_engine(network, config) as engine, open(f, "wb") as t: | ||
t.write(engine.serialize()) | ||
LOGGER.info(f"Export success, saved as {f} ({self.file_size(f):.1f} MB)") | ||
return f | ||
except Exception as e: | ||
LOGGER.error(f"\nExport failure: {e}") | ||
|
||
def export_onnx(self): | ||
onnx_exporter = ONNXExporter(self.model, self.im, self.file, self.optimize, self.dynamic, self.half, self.simplify) | ||
return onnx_exporter.export() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import subprocess | ||
from base_exporter import BaseExporter | ||
|
||
class TFLiteExporter(BaseExporter): | ||
def export(self): | ||
try: | ||
checker.check_packages( | ||
("onnx2tf>=1.15.4", "tensorflow", "onnx_graphsurgeon>=0.3.26", "sng4onnx>=1.0.1"), | ||
cmds='--extra-index-url https://pypi.ngc.nvidia.com' | ||
) | ||
import onnx2tf | ||
|
||
LOGGER.info(f"\nStarting {self.file} export with onnx2tf {onnx2tf.__version__}") | ||
f = str(self.file).replace(".onnx", f"_saved_model{os.sep}") | ||
cmd = f"onnx2tf -i {self.file} -o {f} -osd -coion --non_verbose" | ||
subprocess.check_output(cmd.split()) | ||
LOGGER.info(f"Export success, results saved in {f} ({self.file_size(f):.1f} MB)") | ||
return f | ||
except Exception as e: | ||
LOGGER.error(f"\nExport failure: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
from base_exporter import BaseExporter | ||
|
||
class TorchScriptExporter(BaseExporter): | ||
def export(self): | ||
try: | ||
LOGGER.info(f"\nStarting export with torch {torch.__version__}...") | ||
f = self.file.with_suffix(".torchscript") | ||
ts = torch.jit.trace(self.model, self.im, strict=False) | ||
if self.optimize: | ||
torch.utils.mobile_optimizer.optimize_for_mobile(ts)._save_for_lite_interpreter(str(f)) | ||
else: | ||
ts.save(str(f)) | ||
|
||
LOGGER.info(f"Export success, saved as {f} ({self.file_size(f):.1f} MB)") | ||
return f | ||
except Exception as e: | ||
LOGGER.error(f"Export failure: {e}") |
Oops, something went wrong.