diff --git a/setup.py b/setup.py index ed689439f9a..cf38c8b3d0e 100644 --- a/setup.py +++ b/setup.py @@ -62,17 +62,17 @@ _deepsparse_ent_deps = [f"deepsparse-ent~={version_nm_deps}"] _onnxruntime_deps = ["onnxruntime>=1.0.0"] -supported_torch_version = "torch>=1.7.0,<1.14" +supported_torch_version = "torch>=1.7.0,<=2.0" _pytorch_deps = [ supported_torch_version, "gputils", ] _pytorch_all_deps = _pytorch_deps + [ - "torchvision>=0.3.0,<0.15", - "torchaudio<=0.13", + "torchvision>=0.3.0,<=0.15.1", + "torchaudio<=2.0.1", ] _pytorch_vision_deps = _pytorch_deps + [ - "torchvision>=0.3.0,<0.15", + "torchvision>=0.3.0,<=0.15.1", "opencv-python<=4.6.0.66", ] _transformers_deps = _pytorch_deps + [ diff --git a/src/sparseml/pytorch/base.py b/src/sparseml/pytorch/base.py index bfacaaab6f4..a2dafde673a 100644 --- a/src/sparseml/pytorch/base.py +++ b/src/sparseml/pytorch/base.py @@ -49,7 +49,7 @@ _TORCH_MIN_VERSION = "1.0.0" -_TORCH_MAX_VERSION = "1.13.100" # set bug to 100 to support all future 1.9.X versions +_TORCH_MAX_VERSION = "2.0.100" def check_torch_install(