From 386a8a4a9b9f81f27d20cf606ac08fbe82caec8b Mon Sep 17 00:00:00 2001 From: Jeremy Jordan Date: Thu, 7 May 2020 22:18:58 -0400 Subject: [PATCH] added some documentation and tests --- pytorch_lightning/profiler/__init__.py | 66 ++++++++++++++++++++++- pytorch_lightning/profiler/profilers.py | 9 ++-- tests/test_profiler.py | 69 ++++++++++++++++++++++--- 3 files changed, 132 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index fc684d143e4b82..81de8fd211682e 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -49,8 +49,58 @@ on_train_end | 5.449e-06 | 5.449e-06 +Autograd Profiling +------------------ +If you would like to focus your profiling on the PyTorch-specific components, you should use the autograd +profiler. This leverages the native `torch.autograd.profiler`_ context manager to accurately measure time +spent on PyTorch ops on both the CPU and GPU. + +.. _`torch.autograd.profiler`: https://pytorch.org/docs/stable/autograd.html#profiler + +.. code-block:: python + + profiler = AutogradProfiler() + trainer = Trainer(..., profiler=profiler) + +The profiler's results will be printed at the completion of a training `fit()`. This profiler +report can be quite long, so you can also specify an `output_filename` to save the report instead +of logging it to the output in your terminal. The output below shows the profiling for the action +`model_forward`. + +.. code-block:: python + + Profiler Report + + Profile stats for: model_forward + -------------------------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + -------------------------- --------------- --------------- --------------- --------------- --------------- --------------- + batch_norm 0.05% 4.000us 65.88% 5.365ms 5.365ms 1 + _batch_norm_impl_index 0.06% 5.000us 65.83% 5.361ms 5.361ms 1 + native_batch_norm 65.75% 5.355ms 65.75% 5.355ms 5.355ms 1 + addmm 13.03% 1.061ms 13.03% 1.061ms 530.500us 2 + dropout 0.10% 8.000us 8.53% 695.000us 695.000us 1 + mul 6.47% 527.000us 6.47% 527.000us 175.667us 3 + tanh 4.44% 362.000us 4.44% 362.000us 362.000us 1 + unsigned short 3.76% 306.000us 3.76% 306.000us 153.000us 2 + log_softmax 0.04% 3.000us 3.25% 265.000us 265.000us 1 + _log_softmax 3.22% 262.000us 3.22% 262.000us 262.000us 1 + bernoulli_ 1.71% 139.000us 1.71% 139.000us 139.000us 1 + div_ 0.52% 42.000us 0.52% 42.000us 42.000us 1 + nll_loss 0.04% 3.000us 0.33% 27.000us 27.000us 1 + view 0.32% 26.000us 0.32% 26.000us 26.000us 1 + nll_loss_forward 0.29% 24.000us 0.29% 24.000us 24.000us 1 + add 0.11% 9.000us 0.11% 9.000us 9.000us 1 + empty 0.05% 4.000us 0.05% 4.000us 2.000us 2 + empty_like 0.01% 1.000us 0.05% 4.000us 4.000us 1 + detach 0.04% 3.000us 0.04% 3.000us 1.000us 3 + -------------------------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 8.144ms + + + Advanced Profiling --------------------- +------------------ If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. @@ -87,6 +137,11 @@ 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) + + +Profiling custom events +----------------------- + You can also reference this profiler in your LightningModule to profile specific actions of interest. If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` which will allow you to skip profiling without having to make any code changes. Each profiler has a @@ -112,11 +167,18 @@ def custom_processing_step(self, data): """ -from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler +from pytorch_lightning.profiler.profilers import ( + SimpleProfiler, + AdvancedProfiler, + AutogradProfiler, + PassThroughProfiler, + BaseProfiler, +) __all__ = [ 'BaseProfiler', 'SimpleProfiler', 'AdvancedProfiler', + 'AutogradProfiler', 'PassThroughProfiler', ] diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 9483a48d7307d6..e8eb1387e6fea3 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -167,7 +167,8 @@ class AdvancedProfiler(BaseProfiler): """ This profiler uses Python's cProfiler to record more detailed information about time spent in each function call recorded during a given action. The output is quite - verbose and you should only use this if you want very detailed reports. + verbose and you should only use this if you want very detailed reports. This profiler + is most helpful when trying to identify code bottlenecks outside of your neural network. """ def __init__(self, output_filename: str = None, line_count_restriction: float = 1.0): @@ -230,8 +231,10 @@ def __del__(self): class AutogradProfiler(BaseProfiler): """ - This profiler uses Pytorch's torch.autograd.profiler.profile as a backend. It allows to - profile backend calls and optimize model forward/backward performance. + This profiler uses Pytorch's torch.autograd.profiler.profile as a backend. This profiler + can record PyTorch operations on both CPU and GPU, recording the events of functions + being executed under the hood in C++. This profiler is most useful when trying to optimize + the forward and backward pass of your neural network. """ def __init__(self, use_cuda: bool = False, output_filename: str = None, row_limit: int = 20): """ diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 3bce379c1115c2..f8438f5b19defb 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -3,9 +3,10 @@ from pathlib import Path import numpy as np +import torch import pytest -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler +from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, AutogradProfiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -14,6 +15,11 @@ def _get_python_cprofile_total_duration(profile): return sum([x.inlinetime for x in profile.getstats()]) +def _get_pytorch_profiler_total_duration(events): + total_time = sum([e.cpu_time + e.cuda_time for e in events]) + return total_time / 1e6 # convert microseconds to seconds + + def _sleep_generator(durations): """ the profile_iterable method needs an iterable in which we can ensure that we're @@ -36,6 +42,15 @@ def advanced_profiler(tmpdir): return profiler +@pytest.fixture +def autograd_profiler(tmpdir): + profiler = AutogradProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) + return profiler + + +# ===================== +# Simple Profiler +# ===================== @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), @@ -105,12 +120,16 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) +# ===================== +# Advanced Profiler +# ===================== @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), pytest.param("c", [1]) ]) def test_advanced_profiler_durations(advanced_profiler, action, expected): + """Ensure the reported durations are reasonably accurate.""" for duration in expected: with advanced_profiler.profile(action): @@ -149,9 +168,7 @@ def test_advanced_profiler_iterable_durations(advanced_profiler, action, expecte def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): - """ - ensure that the profiler doesn't introduce too much overhead during training - """ + """Ensure that the profiler doesn't introduce too much overhead during training.""" for _ in range(n_iter): with advanced_profiler.profile("no-op"): pass @@ -163,9 +180,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): def test_advanced_profiler_describe(tmpdir, advanced_profiler): - """ - ensure the profiler won't fail when reporting the summary - """ + """Ensure the profiler won't fail when reporting the summary.""" # record at least one event with advanced_profiler.profile("test"): pass @@ -184,3 +199,43 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.start(action) advanced_profiler.stop(action) + + +# ===================== +# Autograd Profiler +# ===================== + +def test_autograd_profiler_overhead(autograd_profiler, n_iter=5): + """Ensure that the profiler doesn't introduce too much overhead during training.""" + for _ in range(n_iter): + with autograd_profiler.profile("no-op"): + a = torch.ones(42) + b = torch.abs(a) + c = a + b + + action_profile = autograd_profiler.profiled_actions["no-op"] + total_duration = _get_pytorch_profiler_total_duration(action_profile) + average_duration = total_duration / n_iter + assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE + + +def test_autograd_profiler_describe(tmpdir, autograd_profiler): + """Ensure the profiler won't fail when reporting the summary.""" + with autograd_profiler.profile("test"): + pass + + # log to stdout and print to file + autograd_profiler.describe() + data = Path(autograd_profiler.output_fname).read_text() + assert len(data) > 0 + + +def test_autograd_profiler_value_errors(autograd_profiler): + """Ensure errors are raised where expected.""" + + action = "test" + with pytest.raises(ValueError): + autograd_profiler.stop(action) + + autograd_profiler.start(action) + autograd_profiler.stop(action)