Skip to content

Commit

Permalink
chore: suppress unresolved forwardrefs type hint warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 29, 2022
1 parent ce349ca commit c8d71f5
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 33 deletions.
22 changes: 21 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import logging
import os
import pathlib
import sys
Expand All @@ -43,6 +44,24 @@ def get_version() -> str:
return version.__version__


try:
import sphinx_autodoc_typehints
except ImportError:
pass
else:

class RecursiveForwardRefFilter(logging.Filter):
def filter(self, record):
if (
"name 'TensorTree' is not defined" in record.getMessage()
or "name 'OptionalTensorTree' is not defined" in record.getMessage()
):
return False
return super().filter(record)

sphinx_autodoc_typehints._LOGGER.logger.addFilter(RecursiveForwardRefFilter())


# -- Project information -----------------------------------------------------

project = 'TorchOpt'
Expand Down Expand Up @@ -75,7 +94,7 @@ def get_version() -> str:
'sphinxcontrib.bibtex',
'sphinxcontrib.katex',
'sphinx_autodoc_typehints',
'myst_nb', # This is used for the .ipynb notebooks
'myst_nb', # this is used for the .ipynb notebooks
]

if not os.getenv('READTHEDOCS', None):
Expand Down Expand Up @@ -120,6 +139,7 @@ def get_version() -> str:
'exclude-members': '__module__, __dict__, __repr__, __str__, __weakref__',
}
autoclass_content = 'both'
simplify_optional_unions = False

# -- Options for bibtex -----------------------------------------------------

Expand Down
5 changes: 2 additions & 3 deletions torchopt/alias/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr
from torchopt.combine import chain_flat
from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam
from torchopt.typing import Params # pylint: disable=unused-import
from torchopt.typing import GradientTransformation, ScalarOrSchedule
from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule


__all__ = ['adamw']
Expand All @@ -51,7 +50,7 @@ def adamw(
weight_decay: float = 1e-2,
*,
eps_root: float = 0.0,
mask: Optional[Union[Any, Callable[['Params'], Any]]] = None,
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
moment_requires_grad: bool = False,
maximize: bool = False,
use_accelerated_op: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions torchopt/diff/implicit/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torchopt.nn
from torchopt import pytree
from torchopt.diff.implicit.decorator import custom_root
from torchopt.typing import TensorTree, TupleOfTensors # pylint: disable=unused-import
from torchopt.typing import TensorTree, TupleOfTensors
from torchopt.utils import extract_module_containers


Expand Down Expand Up @@ -228,7 +228,7 @@ def solve(self, batch, labels):
raise NotImplementedError # update parameters

# pylint: disable-next=redefined-builtin
def residual(self, *input, **kwargs) -> 'TensorTree':
def residual(self, *input, **kwargs) -> TensorTree:
r"""Computes the optimality residual.
This method stands for the residual to the optimal parameters after solving the inner
Expand Down
4 changes: 2 additions & 2 deletions torchopt/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from torchopt import alias
from torchopt.optim.base import Optimizer
from torchopt.typing import Params, ScalarOrSchedule # pylint: disable=unused-import
from torchopt.typing import Params, ScalarOrSchedule


__all__ = ['AdamW']
Expand All @@ -44,7 +44,7 @@ def __init__(
weight_decay: float = 1e-2,
*,
eps_root: float = 0.0,
mask: Optional[Union[Any, Callable[['Params'], Any]]] = None,
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
maximize: bool = False,
use_accelerated_op: bool = False,
) -> None:
Expand Down
13 changes: 4 additions & 9 deletions torchopt/optim/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
import torch

from torchopt import pytree
from torchopt.typing import ( # pylint: disable=unused-import
GradientTransformation,
OptState,
Params,
TupleOfTensors,
)
from torchopt.typing import GradientTransformation, OptState, Params, TupleOfTensors
from torchopt.update import apply_updates


Expand Down Expand Up @@ -84,11 +79,11 @@ def f(p):

pytree.tree_map(f, self.param_groups) # type: ignore[arg-type]

def state_dict(self) -> Tuple['OptState', ...]:
def state_dict(self) -> Tuple[OptState, ...]:
"""Returns the state of the optimizer."""
return tuple(self.state_groups)

def load_state_dict(self, state_dict: Sequence['OptState']) -> None:
def load_state_dict(self, state_dict: Sequence[OptState]) -> None:
"""Loads the optimizer state.
Args:
Expand Down Expand Up @@ -121,7 +116,7 @@ def f(p):

return loss

def add_param_group(self, params: 'Params') -> None:
def add_param_group(self, params: Params) -> None:
"""Add a param group to the optimizer's :attr:`param_groups`."""
flat_params, params_treespec = pytree.tree_flatten(params)
flat_params: TupleOfTensors = tuple(flat_params)
Expand Down
6 changes: 3 additions & 3 deletions torchopt/optim/func/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from torchopt.base import GradientTransformation
from torchopt.typing import OptState, Params # pylint: disable=unused-import
from torchopt.typing import OptState, Params
from torchopt.update import apply_updates


Expand Down Expand Up @@ -61,9 +61,9 @@ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> No
def step(
self,
loss: torch.Tensor,
params: 'Params',
params: Params,
inplace: Optional[bool] = None,
) -> 'Params':
) -> Params:
r"""Compute the gradients of loss to the network parameters and update network parameters.
Graph of the derivative will be constructed, allowing to compute higher order derivative
Expand Down
4 changes: 2 additions & 2 deletions torchopt/optim/meta/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from torchopt import alias
from torchopt.optim.meta.base import MetaOptimizer
from torchopt.typing import Params, ScalarOrSchedule # pylint: disable=unused-import
from torchopt.typing import Params, ScalarOrSchedule


__all__ = ['MetaAdamW']
Expand All @@ -44,7 +44,7 @@ def __init__(
weight_decay: float = 1e-2,
*,
eps_root: float = 0.0,
mask: Optional[Union[Any, Callable[['Params'], Any]]] = None,
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
moment_requires_grad: bool = False,
maximize: bool = False,
use_accelerated_op: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions torchopt/optim/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def add_param_group(self, net: nn.Module) -> None:
self.param_containers_groups.append(params_container)
self.state_groups.append(optimizer_state)

def state_dict(self) -> Tuple['OptState', ...]:
def state_dict(self) -> Tuple[OptState, ...]:
"""Extract the references of the optimizer states.
Note that the states are references, so any in-place operations will change the states
inside :class:`MetaOptimizer` at the same time.
"""
return tuple(self.state_groups)

def load_state_dict(self, state_dict: Sequence['OptState']) -> None:
def load_state_dict(self, state_dict: Sequence[OptState]) -> None:
"""Load the references of the optimizer states."""
self.state_groups[:] = list(state_dict)
4 changes: 2 additions & 2 deletions torchopt/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
"""Helper functions for applying updates."""

from torchopt import pytree
from torchopt.typing import Params, Updates # pylint: disable=unused-import
from torchopt.typing import Params, Updates


__all__ = ['apply_updates']


def apply_updates(params: 'Params', updates: 'Updates', *, inplace: bool = True) -> 'Params':
def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> Params:
"""Applies an update to the corresponding parameters.
This is a utility functions that applies an update to a set of parameters, and then returns the
Expand Down
14 changes: 7 additions & 7 deletions torchopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import torch.nn as nn

from torchopt import pytree
from torchopt.typing import OptState, TensorTree # pylint: disable=unused-import
from torchopt.typing import OptState, TensorTree


if TYPE_CHECKING:
Expand Down Expand Up @@ -64,7 +64,7 @@ class ModuleState(NamedTuple):
CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone']


def stop_gradient(target: Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']) -> None:
def stop_gradient(target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']) -> None:
"""Stop the gradient for the input object.
Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the
Expand Down Expand Up @@ -123,7 +123,7 @@ def extract_state_dict(
with_buffers: bool = True,
enable_visual: bool = False,
visual_prefix: str = '',
) -> Tuple['OptState', ...]:
) -> Tuple[OptState, ...]:
...


Expand All @@ -137,7 +137,7 @@ def extract_state_dict(
detach_buffers: bool = False,
enable_visual: bool = False,
visual_prefix: str = '',
) -> Union[ModuleState, Tuple['OptState', ...]]:
) -> Union[ModuleState, Tuple[OptState, ...]]:
"""Extract target state.
Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the
Expand Down Expand Up @@ -312,7 +312,7 @@ def update_container(container, items):

def recover_state_dict(
target: Union[nn.Module, 'MetaOptimizer'],
state: Union[ModuleState, Sequence['OptState']],
state: Union[ModuleState, Sequence[OptState]],
) -> None:
"""Recover state.
Expand Down Expand Up @@ -478,8 +478,8 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:


def module_detach_(
target: Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']
) -> Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']:
target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']
) -> Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']:
"""Detach a module from the computation graph.
Args:
Expand Down

0 comments on commit c8d71f5

Please sign in to comment.