Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 18, 2024
1 parent 6759521 commit 26afda9
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions scripts/adjust-torch-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, List, Optional


def _determine_torchaudio(torch_version : str) -> str:
def _determine_torchaudio(torch_version: str) -> str:
"""Determine the torchaudio version based on the torch version.
>>> _determine_torchaudio("1.9.0")
Expand All @@ -19,6 +19,7 @@ def _determine_torchaudio(torch_version : str) -> str:
'2.4.1'
>>> _determine_torchaudio("1.8.2")
'0.9.1'
"""
_VERSION_EXCEPTIONS = {
"1.8.2": "0.9.1",
Expand All @@ -27,15 +28,15 @@ def _determine_torchaudio(torch_version : str) -> str:
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _VERSION_EXCEPTIONS:
return _VERSION_EXCEPTIONS[torch_ver]
ver_major, ver_minor, ver_bugfix= map(int, torch_ver.split("."))
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
ta_ver_array = [ver_major, ver_minor, ver_bugfix]
if ver_major == 1:
ta_ver_array[0] = 0
ta_ver_array[2] = ver_bugfix
return ".".join(map(str, ta_ver_array))


def _determine_torchtext(torch_version : str) -> str:
def _determine_torchtext(torch_version: str) -> str:
"""Determine the torchtext version based on the torch version.
>>> _determine_torchtext("1.9.0")
Expand All @@ -44,6 +45,7 @@ def _determine_torchtext(torch_version : str) -> str:
'0.18.0'
>>> _determine_torchtext("1.8.2")
'0.9.1'
"""
_VERSION_EXCEPTIONS = {
"2.0.1": "0.15.2",
Expand All @@ -54,7 +56,7 @@ def _determine_torchtext(torch_version : str) -> str:
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _VERSION_EXCEPTIONS:
return _VERSION_EXCEPTIONS[torch_ver]
ver_major, ver_minor, ver_bugfix= map(int, torch_ver.split("."))
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
tt_ver_array = [0, 0, 0]
if ver_major == 1:
tt_ver_array[1] = ver_minor + 1
Expand All @@ -70,7 +72,7 @@ def _determine_torchtext(torch_version : str) -> str:
return ".".join(map(str, tt_ver_array))


def _determine_torchvision(torch_version : str) -> str:
def _determine_torchvision(torch_version: str) -> str:
"""Determine the torchvision version based on the torch version.
>>> _determine_torchvision("1.9.0")
Expand All @@ -79,6 +81,7 @@ def _determine_torchvision(torch_version : str) -> str:
'0.19.1'
>>> _determine_torchvision("2.0.1")
'0.15.2'
"""
_VERSION_EXCEPTIONS = {
"2.0.1": "0.15.2",
Expand All @@ -92,7 +95,7 @@ def _determine_torchvision(torch_version : str) -> str:
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _VERSION_EXCEPTIONS:
return _VERSION_EXCEPTIONS[torch_ver]
ver_major, ver_minor, ver_bugfix= map(int, torch_ver.split("."))
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
tv_ver_array = [0, 0, 0]
if ver_major == 1:
tv_ver_array[1] = ver_minor + 1
Expand All @@ -113,6 +116,7 @@ def find_latest(ver: str) -> Dict[str, str]:
'torchaudio': '2.1.0',
'torchtext': '0.16.0',
'torchvision': '0.16.0'}
"""
# drop all except semantic version
ver = re.search(r"([\.\d]+)", ver).groups()[0]
Expand All @@ -138,6 +142,7 @@ def adjust(requires: List[str], pytorch_version: Optional[str] = None) -> List[s
'torchvision==0.16.0',
'torchtext==0.16.0',
'torchaudio==2.1.0']
"""
if not pytorch_version:
import torch
Expand Down Expand Up @@ -172,6 +177,7 @@ def _offset_print(reqs: List[str], offset: str = "\t|\t") -> str:
>>> _offset_print(["torch==2.1.0", "torchvision==0.16.0", "torchtext==0.16.0", "torchaudio==2.1.0"])
'\t|\ttorch==2.1.0\n\t|\ttorchvision==0.16.0\n\t|\ttorchtext==0.16.0\n\t|\ttorchaudio==2.1.0'
"""
reqs = [offset + r for r in reqs]
return os.linesep.join(reqs)
Expand Down

0 comments on commit 26afda9

Please sign in to comment.