diff --git a/anomalib/core/callbacks/compress.py b/anomalib/core/callbacks/compress.py index 81c2a87443..8015ca1855 100644 --- a/anomalib/core/callbacks/compress.py +++ b/anomalib/core/callbacks/compress.py @@ -28,7 +28,12 @@ def on_train_end(self, trainer, pl_module: LightningModule) -> None: # pylint: onnx_path = os.path.join(self.dirpath, self.filename + ".onnx") height, width = self.input_size torch.onnx.export( - pl_module.model, torch.zeros((1, 3, height, width)).to(pl_module.device), onnx_path, opset_version=11 + pl_module.model, + torch.zeros((1, 3, height, width)).to(pl_module.device), + onnx_path, + opset_version=11, + input_names=["input"], + output_names=["output"], ) optimize_command = "mo --input_model " + onnx_path + " --output_dir " + self.dirpath os.system(optimize_command)