Skip to content

Commit

Permalink
Added explicit antialias=False to ensure we can export this operation…
Browse files Browse the repository at this point in the history
… to ONNX (#1824)
  • Loading branch information
BloodAxe committed Feb 9, 2024
1 parent d5a85fd commit 94cc2d6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/unit_tests/export_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
class TestModelsONNXExport(unittest.TestCase):
def test_models_onnx_export_with_deprecated_input_shape(self):
pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
preprocess = Compose([Resize(224, antialias=False), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "resnet18.onnx")
models.convert_to_onnx(model=pretrained_model, out_path=out_path, input_shape=(3, 256, 256), pre_process=preprocess)
self.assertTrue(os.path.exists(out_path))

def test_models_onnx_export(self):
pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
preprocess = Compose([Resize(224, antialias=False), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "resnet18.onnx")
models.convert_to_onnx(
Expand Down

0 comments on commit 94cc2d6

Please sign in to comment.