Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix use of jsonargparse avoiding reliance on non-public internal logic #1620

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ torchmetrics >0.7.0, <0.11.0 # strict
pytorch-lightning >1.8.0, <2.0.0 # strict
pyDeprecate >0.2.0
pandas >1.1.0, <=1.5.2
jsonargparse[signatures] >4.0.0, <=4.9.0
jsonargparse[signatures] >=4.22.0, <4.23.0
click >=7.1.2, <=8.1.3
protobuf <=3.20.1
fsspec[http] >=2022.5.0,<=2023.6.0
Expand Down
18 changes: 13 additions & 5 deletions src/flash/core/utilities/flash_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

import pytorch_lightning as pl
from jsonargparse import ArgumentParser
from jsonargparse.signatures import get_class_signature_functions
from jsonargparse import ArgumentParser, class_from_function
from lightning_utilities.core.overrides import is_overridden
from pytorch_lightning import LightningModule, Trainer

Expand All @@ -31,7 +30,6 @@
LightningArgumentParser,
LightningCLI,
SaveConfigCallback,
class_from_function,
)
from flash.core.utilities.stability import beta

Expand Down Expand Up @@ -107,6 +105,16 @@
return wrapper


def get_class_signature_functions(classes):
Borda marked this conversation as resolved.
Show resolved Hide resolved
signatures = []
for num, cls in enumerate(classes):
if cls.__new__ is not object.__new__ and not any(cls.__new__ is c.__new__ for c in classes[num + 1 :]):
signatures.append((cls, cls.__new__))
if not any(cls.__init__ is c.__init__ for c in classes[num + 1 :]):
signatures.append((cls, cls.__init__))
return signatures

Check warning on line 115 in src/flash/core/utilities/flash_cli.py

View check run for this annotation

Codecov / codecov/patch

src/flash/core/utilities/flash_cli.py#L109-L115

Added lines #L109 - L115 were not covered by tests


def get_overlapping_args(func_a, func_b) -> Set[str]:
func_a = get_class_signature_functions([func_a])[0][1]
func_b = get_class_signature_functions([func_b])[0][1]
Expand Down Expand Up @@ -214,7 +222,7 @@
def add_subcommand_from_function(self, subcommands, function, function_name=None):
subcommand = ArgumentParser()
if get_kwarg_name(function) == "data_module_kwargs":
datamodule_function = class_from_function(function, return_type=self.local_datamodule_class)
datamodule_function = class_from_function(function, self.local_datamodule_class)
subcommand.add_class_arguments(
datamodule_function,
fail_untyped=False,
Expand All @@ -233,7 +241,7 @@
},
)
else:
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
datamodule_function = class_from_function(drop_kwargs(function), self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
subcommand_name = function_name or function.__name__
subcommands.add_subcommand(subcommand_name, subcommand)
Expand Down
41 changes: 4 additions & 37 deletions src/flash/core/utilities/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
import os
import warnings
from argparse import Namespace
from functools import wraps
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import torch
from jsonargparse import ActionConfigFile, ArgumentParser, set_config_read_mode
from jsonargparse.signatures import ClassFromFunctionBase
from jsonargparse.typehints import ClassType
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand All @@ -25,46 +22,16 @@
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]


def class_from_function(
func: Callable[..., ClassType],
return_type: Optional[Type[ClassType]] = None,
) -> Type[ClassType]:
"""Creates a dynamic class which if instantiated is equivalent to calling func.

Args:
func: A function that returns an instance of a class. It must have a return type annotation.
"""

@wraps(func)
def __new__(cls, *args, **kwargs):
return func(*args, **kwargs)

if return_type is None:
return_type = inspect.signature(func).return_annotation

if isinstance(return_type, str):
raise RuntimeError("Classmethod instantiation is not supported when the return type annotation is a string.")

class ClassFromFunction(return_type, ClassFromFunctionBase): # type: ignore
pass

ClassFromFunction.__new__ = __new__ # type: ignore
ClassFromFunction.__doc__ = func.__doc__
ClassFromFunction.__name__ = func.__name__

return ClassFromFunction


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""

def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
Borda marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.

For full details of accepted arguments see
`ArgumentParser.__init__ <https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.__init__>`_.
"""
super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs)
super().__init__(*args, **kwargs)
self.add_argument(
"--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
)
Expand Down Expand Up @@ -95,7 +62,7 @@ def add_lightning_class_args(

if inspect.isclass(lightning_class) and issubclass(
cast(type, lightning_class),
(Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase),
(Trainer, LightningModule, LightningDataModule, Callback),
):
if issubclass(cast(type, lightning_class), Callback):
self.callback_keys.append(nested_key)
Expand Down
28 changes: 14 additions & 14 deletions tests/core/utilities/test_lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_default_args(mock_argparse, tmpdir):
"""Tests default argument parser for Trainer."""
mock_argparse.return_value = Namespace(**Trainer.default_attributes())

parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
args = parser.parse_args([])

args.max_epochs = 5
Expand All @@ -54,7 +54,7 @@ def test_default_args(mock_argparse, tmpdir):
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--default_root_dir=./"], []])
def test_add_argparse_args_redefined(cli_args):
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)

args = parser.parse_args(cli_args)
Expand All @@ -79,19 +79,19 @@ def test_add_argparse_args_redefined(cli_args):
("--auto_lr_find=True --auto_scale_batch_size=power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}),
(
"--auto_lr_find any_string --auto_scale_batch_size ON",
{"auto_lr_find": "any_string", "auto_scale_batch_size": True},
Borda marked this conversation as resolved.
Show resolved Hide resolved
{"auto_lr_find": "any_string", "auto_scale_batch_size": "ON"},
),
("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": True}),
("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": False}),
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": False}),
("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": "On"}),
("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": "No"}),
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": "FALSE"}),
("--limit_train_batches=100", {"limit_train_batches": 100}),
("--limit_train_batches 0.8", {"limit_train_batches": 0.8}),
],
)
def test_parse_args_parsing(cli_args, expected):
"""Test parsing simple types and None optionals not modified."""
cli_args = cli_args.split(" ") if cli_args else []
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
Expand All @@ -112,7 +112,7 @@ def test_parse_args_parsing(cli_args, expected):
)
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
"""Test parsing complex types."""
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
Expand All @@ -137,7 +137,7 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu):
"""Test parsing of gpus and instantiation of Trainer."""
mocker.patch("lightning_lite.utilities.device_parser._get_all_available_gpus", return_value=[0, 1])
cli_args = cli_args.split(" ") if cli_args else []
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
Expand Down Expand Up @@ -310,8 +310,8 @@ def test_lightning_cli_args(tmpdir):
config = yaml.safe_load(f.read())
assert "model" not in config
assert "model" not in cli.config
assert config["data"] == cli.config["data"]
assert config["trainer"] == cli.config["trainer"]
assert config["data"] == cli.config["data"].as_dict()
assert config["trainer"] == cli.config["trainer"].as_dict()


@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
Expand Down Expand Up @@ -363,9 +363,9 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
assert os.path.isfile(config_path)
with open(config_path) as f:
config = yaml.safe_load(f.read())
assert config["model"] == cli.config["model"]
assert config["data"] == cli.config["data"]
assert config["trainer"] == cli.config["trainer"]
assert config["model"] == cli.config["model"].as_dict()
assert config["data"] == cli.config["data"].as_dict()
assert config["trainer"] == cli.config["trainer"].as_dict()


def any_model_any_data_cli():
Expand Down