Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add throughput utilities to Fabric and the Trainer #18848

Merged
merged 33 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
106409a
Code
carmocca Oct 24, 2023
471f281
Tests
carmocca Oct 24, 2023
2a7667f
Protected APIs
carmocca Oct 24, 2023
fe52e82
Docstrings
carmocca Oct 24, 2023
b0798ea
Callback
carmocca Oct 24, 2023
9e519b5
Throughput
carmocca Oct 24, 2023
c46e366
mypy
carmocca Oct 24, 2023
a0cb727
Merge branch 'master' into carmocca/speed-monitor
carmocca Oct 24, 2023
cfdc1bf
CHANGELOG
carmocca Oct 24, 2023
374c975
Docs
carmocca Oct 24, 2023
8a08678
Fixes
carmocca Oct 24, 2023
d4ee1ca
Fixes
carmocca Oct 24, 2023
4841a38
Fixes
carmocca Oct 24, 2023
202d764
Update conf.py
carmocca Oct 24, 2023
a79a0bf
Merge branch 'master' into carmocca/speed-monitor
carmocca Oct 25, 2023
8ae5ce8
Review suggestions
carmocca Oct 25, 2023
27a1499
Pull loggers out of utility. No subclass
carmocca Oct 25, 2023
bd96595
Fabric refactor
carmocca Oct 25, 2023
f8b130d
Trainer refactor
carmocca Oct 25, 2023
82d8b25
Tests. Custom window class
carmocca Oct 26, 2023
c4817a6
file name is simply throughput in fabric
carmocca Oct 26, 2023
62f076d
Details
carmocca Oct 26, 2023
ffa823b
Update and compute a la torchmetrics
carmocca Oct 26, 2023
14d6210
Details
carmocca Oct 26, 2023
6fe5d1e
Support all trainer strages
carmocca Oct 26, 2023
29cc382
mypy
carmocca Oct 26, 2023
5c04bd6
utilities update
carmocca Oct 26, 2023
5fe8a65
Fix docs
carmocca Oct 26, 2023
94c62fd
Examples
carmocca Oct 26, 2023
8e931a6
Discussion comments
carmocca Oct 26, 2023
476f6b0
Merge branch 'master' into carmocca/speed-monitor
carmocca Oct 27, 2023
e2ee8e3
Discussion comments
carmocca Oct 26, 2023
c161ac4
Do not return self
carmocca Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source-fabric/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ lightning.fabric.utilities
.. autofunction:: lightning.fabric.utilities.distributed.is_shared_filesystem

.. autofunction:: lightning.fabric.utilities.warnings.disable_possible_user_warnings

.. autofunction:: lightning.fabric.utilities.throughput.measure_flops

.. autoclass:: lightning.fabric.utilities.throughput.ThroughputMonitor

.. autoclass:: lightning.fabric.utilities.throughput.Throughput
1 change: 1 addition & 0 deletions docs/source-fabric/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@
("py:class", "lightning.fabric.wrappers._FabricOptimizer"),
("py:class", "lightning.fabric.loggers.csv_logs._ExperimentWriter"),
("py:class", "lightning.fabric.strategies.strategy._Sharded"),
("py:class", "lightning.fabric.utilities.throughput.Throughput"),
# Nitpick does not see abstract API
("py:meth", "lightning.fabric.plugins.collectives.Collective.init_group"),
# These seem to be missing in reference generated API
Expand Down
4 changes: 4 additions & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ callbacks
RichModelSummary
RichProgressBar
StochasticWeightAveraging
SpikeDetection
ThroughputMonitor
Timer
TQDMProgressBar

Expand Down Expand Up @@ -248,3 +250,5 @@ utilities
rank_zero
seed
warnings

.. autofunction:: lightning.pytorch.utilities.measure_flops
7 changes: 5 additions & 2 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:class", "lightning.fabric.utilities.types.ReduceLROnPlateau"),
("py:class", "lightning.fabric.utilities.types.Steppable"),
("py:class", "lightning.fabric.wrappers._FabricOptimizer"),
("py:class", "lightning.fabric.utilities.throughput.Throughput"),
("py:func", "lightning.fabric.utilities.throughput.measure_flops"),
("py:class", "lightning.fabric.utilities.spike.SpikeDetection"),
("py:meth", "lightning.pytorch.Callback.on_exception"),
("py:class", "lightning.pytorch.LightningModule"),
("py:meth", "lightning.pytorch.LightningModule.on_train_epoch_end"),
Expand Down Expand Up @@ -450,7 +453,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:meth", "optimizer_step"),
("py:class", "out_dict"),
("py:meth", "prepare_data"),
("py:class", "pytorch_lightning.callbacks.device_stats_monitor.DeviceStatsMonitor"),
("py:class", "lightning.pytorch.callbacks.device_stats_monitor.DeviceStatsMonitor"),
("py:meth", "setup"),
("py:meth", "test_step"),
("py:meth", "toggle_optimizer"),
Expand Down Expand Up @@ -585,7 +588,7 @@ def package_list_from_file(file):
from lightning.pytorch import LightningDataModule, LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from lightning.pytorch.utilities import _TORCHVISION_AVAILABLE
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added `lightning.fabric.utilities.ThroughputMonitor` and `lightning.fabric.utilities.Throughput` to track throughput and log it ([#18848](https://github.com/Lightning-AI/lightning/pull/18848))


### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Generator, List, Optional, Union, cast

import torch
from lightning_utilities.core.rank_zero import rank_zero_info

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_info


class CUDAAccelerator(Accelerator):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
Expand Down Expand Up @@ -67,6 +66,7 @@
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp
Expand Down
28 changes: 22 additions & 6 deletions src/lightning/fabric/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,30 @@
# limitations under the License.
"""General utilities."""

from lightning.fabric.utilities.apply_func import move_data_to_device # noqa: F401
from lightning.fabric.utilities.data import suggested_max_num_workers # noqa: F401
from lightning.fabric.utilities.distributed import is_shared_filesystem # noqa: F401
from lightning.fabric.utilities.enums import LightningEnum # noqa: F401
from lightning.fabric.utilities.rank_zero import ( # noqa: F401
from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.data import suggested_max_num_workers
from lightning.fabric.utilities.distributed import is_shared_filesystem
from lightning.fabric.utilities.enums import LightningEnum
from lightning.fabric.utilities.rank_zero import (
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
from lightning.fabric.utilities.warnings import disable_possible_user_warnings # noqa: F401
from lightning.fabric.utilities.throughput import Throughput, ThroughputMonitor, measure_flops
from lightning.fabric.utilities.warnings import disable_possible_user_warnings

__all__ = [
"disable_possible_user_warnings",
"is_shared_filesystem",
"LightningEnum",
"measure_flops",
"move_data_to_device",
"rank_zero_deprecation",
"rank_zero_info",
"rank_zero_only",
"rank_zero_warn",
"suggested_max_num_workers",
"Throughput",
"ThroughputMonitor",
]
2 changes: 2 additions & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

_UTILITIES_GREATER_EQUAL_0_10 = compare_version("lightning_utilities", operator.ge, "0.10.0")
34 changes: 32 additions & 2 deletions src/lightning/fabric/utilities/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"""Utilities that can be used for calling functions on a particular rank."""
import logging
import os
from typing import Optional
from functools import wraps
from typing import Callable, Optional, TypeVar, overload

import lightning_utilities.core.rank_zero as rank_zero_module

Expand All @@ -25,11 +26,12 @@
rank_zero_debug,
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
from typing_extensions import ParamSpec

import lightning.fabric
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10

rank_zero_module.log = logging.getLogger(__name__)

Expand All @@ -50,6 +52,34 @@ def _get_rank(
return None


if not _UTILITIES_GREATER_EQUAL_0_10:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
T = TypeVar("T")
P = ParamSpec("P")

@overload
def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]:
...

@overload
def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]:
...

def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]:
@wraps(fn)
def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
rank = getattr(rank_zero_only, "rank", None)
if rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
if rank == 0:
return fn(*args, **kwargs)
return default

return wrapped_fn

rank_zero_module.rank_zero_only.rank = getattr(rank_zero_module.rank_zero_only, "rank", _get_rank() or 0)
else:
rank_zero_only = rank_zero_module.rank_zero_only # type: ignore[assignment]

# add the attribute to the function but don't overwrite in case Trainer has already set it
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0)

Expand Down
Loading