Skip to content

Commit

Permalink
script/adjust: preserve all comments in requirements (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 10, 2023
1 parent 52181b5 commit 68ebf4b
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 35 deletions.
6 changes: 3 additions & 3 deletions .github/actions/unittesting/action.yml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
name: Unittesting and coverage
name: Unittest and coverage
description: pull data samples -> unittests

inputs:
python-version:
description: Python version
required: true
pkg-name:
description: package name for coverage collctions
description: package name for coverage collections
required: true
requires:
description: define oldest or latest
required: false
default: ""
dirs:
description: Testing folders per domains
description: Testing folders per domains, space separated string
required: false
default: "."
pytest-args:
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ jobs:
- name: Print 🖨️ dependencies
uses: ./.github/actions/pip-list

- name: Unittesting and coverage
- name: Unittest and coverage
uses: ./.github/actions/unittesting
with:
python-version: ${{ matrix.python-version }}
pytorch-version: ${{ matrix.pytorch-version }}
dirs: "unittests"
pkg-name: "lightning_utilities"
pytest-args: "--timeout=120"

Expand All @@ -72,6 +73,10 @@ jobs:
name: codecov-umbrella
fail_ci_if_error: false

- name: test Scripts
working-directory: ./tests
run: python -m pytest scripts --durations=50 --timeout=120

testing-guardian:
runs-on: ubuntu-latest
needs: pytester
Expand Down
75 changes: 44 additions & 31 deletions scripts/adjust-torch-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import re
import sys
from typing import Dict, Optional
from typing import Dict, List, Optional

VERSIONS = [
{"torch": "2.2.0", "torchvision": "0.17.0", "torchtext": "0.17.0", "torchaudio": "2.2.0"}, # nightly
Expand Down Expand Up @@ -46,48 +46,61 @@ def find_latest(ver: str) -> Dict[str, str]:
raise ValueError(f"Missing {ver} in {VERSIONS}")


def adjust(requires: str, torch_version: Optional[str] = None) -> str:
def adjust(requires: List[str], pytorch_version: Optional[str] = None) -> List[str]:
"""Adjust the versions to be paired within pytorch ecosystem."""
if not torch_version:
if not pytorch_version:
import torch

torch_version = torch.__version__
if not torch_version:
raise ValueError(f"invalid torch: {torch_version}")
pytorch_version = torch.__version__
if not pytorch_version:
raise ValueError(f"invalid torch: {pytorch_version}")

# remove comments and strip whitespace
requires = re.sub(rf"\s*#.*{os.linesep}", os.linesep, requires).strip()

options = find_latest(torch_version)
requires_ = []
options = find_latest(pytorch_version)
logging.debug(f"determined ecosystem alignment: {options}")
for lib, version in options.items():
replace = f"{lib}=={version}" if version else ""
requires = re.sub(rf"\b{lib}(?![-_\w]).*", replace, requires)

return requires


def _offset_print(req: str, offset: str = "\t|\t") -> str:
for req in requires:
req_split = req.strip().split("#", maxsplit=1)
# anything before fst # shall be requirements
req = req_split[0].strip()
# anything after # in the line is comment
comment = "" if len(req_split) < 2 else " #" + req_split[1]
if not req:
# if only comment make it short
requires_.append(comment.strip())
continue
for lib, version in options.items():
replace = f"{lib}=={version}" if version else ""
req = re.sub(rf"\b{lib}(?![-_\w]).*", replace, req)
requires_.append(req + comment.rstrip())

return requires_


def _offset_print(reqs: List[str], offset: str = "\t|\t") -> str:
"""Adding offset to each line for the printing requirements."""
reqs = req.split(os.linesep)
reqs = [offset + r for r in reqs]
return os.linesep.join(reqs)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

if len(sys.argv) == 3:
requirements_path, torch_version = sys.argv[1:]
else:
requirements_path, torch_version = sys.argv[1], None

with open(requirements_path) as fp:
requirements = fp.read()
def main(requirements_path: str, torch_version: Optional[str] = None) -> None:
"""The main entry point with mapping to the CLI for positional arguments only."""
# rU - universal line ending - https://stackoverflow.com/a/2717154/4521646
with open(requirements_path, encoding="utf8") as fopen:
requirements = fopen.readlines()
requirements = adjust(requirements, torch_version)
logging.info(
f"requirements_path='{requirements_path}' with arg torch_version='{torch_version}' >>\n"
f"{_offset_print(requirements)}"
)
with open(requirements_path, "w") as fp:
fp.write(requirements)
with open(requirements_path, "w", encoding="utf8") as fopen:
fopen.writelines([r + os.linesep for r in requirements])


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
try:
from fire import Fire

Fire(main)
except (ModuleNotFoundError, ImportError):
main(*sys.argv[1:])
5 changes: 5 additions & 0 deletions tests/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os

_PATH_HERE = os.path.dirname(__file__)
_PATH_ROOT = os.path.dirname(os.path.dirname(_PATH_HERE))
_PATH_SCRIPTS = os.path.join(_PATH_ROOT, "scripts")
47 changes: 47 additions & 0 deletions tests/scripts/test_adjust_torch_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import platform
import subprocess
import sys

from scripts import _PATH_SCRIPTS

REQUIREMENTS_SAMPLE = """
# This is sample requirements file
# with multi line comments
torchvision >=0.13.0, <0.16.0 # sample # comment
gym[classic,control] >=0.17.0, <0.27.0
ipython[all] <8.15.0 # strict
torchmetrics >=0.10.0, <1.3.0
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
"""
REQUIREMENTS_EXPECTED = """
# This is sample requirements file
# with multi line comments
torchvision==0.11.1 # sample # comment
gym[classic,control] >=0.17.0, <0.27.0
ipython[all] <8.15.0 # strict
torchmetrics >=0.10.0, <1.3.0
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
"""


def test_adjust_torch_versions_call(tmp_path) -> None:
path_script = os.path.join(_PATH_SCRIPTS, "adjust-torch-versions.py")
path_req_file = str(tmp_path / "requirements.txt")
with open(path_req_file, "w", encoding="utf8") as fopen:
fopen.write(REQUIREMENTS_SAMPLE)

return_code = subprocess.call([sys.executable, path_script, path_req_file, "1.10.0"]) # noqa: S603
assert return_code == 0

with open(path_req_file, encoding="utf8") as fopen:
req_result = fopen.read()
# ToDO: no idea why parsing lines on windows leave extra line after each line
# tried strip, regex, hard-coded replace but none worked... so adjusting tests
if platform.system() == "Windows":
req_result = req_result.replace("\n\n", "\n")
assert req_result == REQUIREMENTS_EXPECTED

0 comments on commit 68ebf4b

Please sign in to comment.