Skip to content

Commit

Permalink
Merge pull request #1515 from mikel-brostrom/refactor-exporters
Browse files Browse the repository at this point in the history
move each format exporter to individual file
  • Loading branch information
mikel-brostrom committed Jul 12, 2024
2 parents b0dc822 + 099e130 commit 71214c8
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 313 deletions.
30 changes: 30 additions & 0 deletions boxmot/appearance/exporters/base_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging
import torch
from pathlib import Path
from boxmot.utils.checks import RequirementsChecker
from boxmot.utils import logger as LOGGER


class BaseExporter:
def __init__(self, model, im, file, optimize=False, dynamic=False, half=False, simplify=False):
self.model = model
self.im = im
self.file = Path(file)
self.optimize = optimize
self.dynamic = dynamic
self.half = half
self.simplify = simplify
self.checker = RequirementsChecker()

@staticmethod
def file_size(path):
path = Path(path)
if path.is_file():
return path.stat().st_size / 1e6
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / 1e6
else:
return 0.0

def export(self):
raise NotImplementedError("Export method must be implemented in subclasses.")
58 changes: 58 additions & 0 deletions boxmot/appearance/exporters/onnx_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import onnx
from boxmot.appearance.exporters.base_exporter import BaseExporter
from boxmot.utils import logger as LOGGER


class ONNXExporter(BaseExporter):
def export(self):
try:
self.checker.check_packages(("onnx==1.14.0",))
f = self.file.with_suffix(".onnx")
LOGGER.info(f"\nStarting export with onnx {onnx.__version__}...")

dynamic = {"images": {0: "batch"}, "output": {0: "batch"}} if self.dynamic else None

torch.onnx.export(
self.model.cpu() if self.dynamic else self.model,
self.im.cpu() if self.dynamic else self.im,
f,
verbose=False,
opset_version=12,
do_constant_folding=True,
input_names=["images"],
output_names=["output"],
dynamic_axes=dynamic,
)

model_onnx = onnx.load(f)
onnx.checker.check_model(model_onnx)
onnx.save(model_onnx, f)

if self.simplify:
self.simplify_model(model_onnx, f)

LOGGER.info(f"Export success, saved as {f} ({self.file_size(f):.1f} MB)")
return f
except Exception as e:
LOGGER.error(f"Export failure: {e}")

def simplify_model(self, model_onnx, f):
try:
cuda = torch.cuda.is_available()
self.checker.check_packages(
(
"onnxruntime-gpu" if cuda else "onnxruntime",
"onnx-simplifier>=0.4.1",
)
)
import onnxsim

LOGGER.info(
f"Simplifying with onnx-simplifier {onnxsim.__version__}..."
)
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, "assert check failed"
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.error(f"Simplifier failure: {e}")
29 changes: 29 additions & 0 deletions boxmot/appearance/exporters/openvino_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from pathlib import Path
import openvino.runtime as ov
from openvino.tools import mo
from boxmot.appearance.exporters.base_exporter import BaseExporter
from boxmot.utils import logger as LOGGER


class OpenVINOExporter(BaseExporter):
def export(self):
self.checker.check_packages(
("openvino-dev>=2023.0",)
)
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
f_onnx = self.file.with_suffix(".onnx")
f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
try:
LOGGER.info(f"\nStarting export with openvino {ov.__version__}...")
ov_model = mo.convert_model(
f_onnx,
model_name=self.file.with_suffix(".xml"),
framework="onnx",
compress_to_fp16=self.half,
)
ov.serialize(ov_model, f_ov)
except Exception as e:
LOGGER.error(f"Export failure: {e}")
LOGGER.info(f"Export success, saved as {f_ov} ({self.file_size(f_ov):.1f} MB)")
return f
75 changes: 75 additions & 0 deletions boxmot/appearance/exporters/tensorrt_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import platform
import torch
from boxmot.appearance.exporters.base_exporter import BaseExporter
from boxmot.utils import logger as LOGGER


class EngineExporter(BaseExporter):
def export(self):
try:
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. `python export.py --device 0`"
try:
import tensorrt as trt
except ImportError:
if platform.system() == "Linux":
self.checker.check_packages(
["nvidia-tensorrt"],
cmds=("-U --index-url https://pypi.ngc.nvidia.com",),
)
import tensorrt as trt

onnx_file = self.export_onnx()
LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...")
assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}"
f = self.file.with_suffix(".engine")
logger = trt.Logger(trt.Logger.INFO)
if self.verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = self.workspace * 1 << 30

flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx_file)):
raise RuntimeError(f"Failed to load ONNX file: {onnx_file}")

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
LOGGER.info("Network Description:")
for inp in inputs:
LOGGER.info(f'\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
for out in outputs:
LOGGER.info(f'\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')

if self.dynamic:
if self.im.shape[0] <= 1:
LOGGER.warning("WARNING: --dynamic model requires maximum --batch-size argument")
profile = builder.create_optimization_profile()
for inp in inputs:
if self.half:
inp.dtype = trt.float16
profile.set_shape(
inp.name,
(1, *self.im.shape[1:]),
(max(1, self.im.shape[0] // 2), *self.im.shape[1:]),
self.im.shape,
)
config.add_optimization_profile(profile)

LOGGER.info(f"Building FP{16 if builder.platform_has_fast_fp16 and self.half else 32} engine in {f}")
if builder.platform_has_fast_fp16 and self.half:
config.set_flag(trt.BuilderFlag.FP16)
config.default_device_type = trt.DeviceType.GPU
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
t.write(engine.serialize())
LOGGER.info(f"Export success, saved as {f} ({self.file_size(f):.1f} MB)")
return f
except Exception as e:
LOGGER.error(f"\nExport failure: {e}")

def export_onnx(self):
onnx_exporter = ONNXExporter(self.model, self.im, self.file, self.optimize, self.dynamic, self.half, self.simplify)
return onnx_exporter.export()
22 changes: 22 additions & 0 deletions boxmot/appearance/exporters/tflite_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import subprocess
from boxmot.appearance.exporters.base_exporter import BaseExporter
from boxmot.utils import logger as LOGGER


class TFLiteExporter(BaseExporter):
def export(self):
try:
self.checker.check_packages(
("onnx2tf>=1.15.4", "tensorflow", "onnx_graphsurgeon>=0.3.26", "sng4onnx>=1.0.1"),
cmds='--extra-index-url https://pypi.ngc.nvidia.com'
)
import onnx2tf

LOGGER.info(f"\nStarting {self.file} export with onnx2tf {onnx2tf.__version__}")
f = str(self.file).replace(".onnx", f"_saved_model{os.sep}")
cmd = f"onnx2tf -i {self.file} -o {f} -osd -coion --non_verbose"
subprocess.check_output(cmd.split())
LOGGER.info(f"Export success, results saved in {f} ({self.file_size(f):.1f} MB)")
return f
except Exception as e:
LOGGER.error(f"\nExport failure: {e}")
20 changes: 20 additions & 0 deletions boxmot/appearance/exporters/torchscript_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from boxmot.appearance.exporters.base_exporter import BaseExporter
from boxmot.utils import logger as LOGGER


class TorchScriptExporter(BaseExporter):
def export(self):
try:
LOGGER.info(f"\nStarting export with torch {torch.__version__}...")
f = self.file.with_suffix(".torchscript")
ts = torch.jit.trace(self.model, self.im, strict=False)
if self.optimize:
torch.utils.mobile_optimizer.optimize_for_mobile(ts)._save_for_lite_interpreter(str(f))
else:
ts.save(str(f))

LOGGER.info(f"Export success, saved as {f} ({self.file_size(f):.1f} MB)")
return f
except Exception as e:
LOGGER.error(f"Export failure: {e}")
Loading

0 comments on commit 71214c8

Please sign in to comment.