Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add autograd profiler #1693

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,58 @@
on_train_end | 5.449e-06 | 5.449e-06


Autograd Profiling
------------------
If you would like to focus your profiling on the PyTorch-specific components, you should use the autograd
profiler. This leverages the native `torch.autograd.profiler`_ context manager to accurately measure time
spent on PyTorch ops on both the CPU and GPU.

.. _`torch.autograd.profiler`: https://pytorch.org/docs/stable/autograd.html#profiler

.. code-block:: python

profiler = AutogradProfiler()
trainer = Trainer(..., profiler=profiler)

The profiler's results will be printed at the completion of a training `fit()`. This profiler
report can be quite long, so you can also specify an `output_filename` to save the report instead
of logging it to the output in your terminal. The output below shows the profiling for the action
`model_forward`.

.. code-block:: python

Profiler Report

Profile stats for: model_forward
-------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls
-------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
batch_norm 0.05% 4.000us 65.88% 5.365ms 5.365ms 1
_batch_norm_impl_index 0.06% 5.000us 65.83% 5.361ms 5.361ms 1
native_batch_norm 65.75% 5.355ms 65.75% 5.355ms 5.355ms 1
addmm 13.03% 1.061ms 13.03% 1.061ms 530.500us 2
dropout 0.10% 8.000us 8.53% 695.000us 695.000us 1
mul 6.47% 527.000us 6.47% 527.000us 175.667us 3
tanh 4.44% 362.000us 4.44% 362.000us 362.000us 1
unsigned short 3.76% 306.000us 3.76% 306.000us 153.000us 2
log_softmax 0.04% 3.000us 3.25% 265.000us 265.000us 1
_log_softmax 3.22% 262.000us 3.22% 262.000us 262.000us 1
bernoulli_ 1.71% 139.000us 1.71% 139.000us 139.000us 1
div_ 0.52% 42.000us 0.52% 42.000us 42.000us 1
nll_loss 0.04% 3.000us 0.33% 27.000us 27.000us 1
view 0.32% 26.000us 0.32% 26.000us 26.000us 1
nll_loss_forward 0.29% 24.000us 0.29% 24.000us 24.000us 1
add 0.11% 9.000us 0.11% 9.000us 9.000us 1
empty 0.05% 4.000us 0.05% 4.000us 2.000us 2
empty_like 0.01% 1.000us 0.05% 4.000us 4.000us 1
detach 0.04% 3.000us 0.04% 3.000us 1.000us 3
-------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 8.144ms



Advanced Profiling
--------------------
------------------

If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
Expand Down Expand Up @@ -91,6 +141,11 @@
60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor)
60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__)



Profiling custom events
-----------------------

You can also reference this profiler in your LightningModule to profile specific actions of interest.
If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler`
which will allow you to skip profiling without having to make any code changes. Each profiler has a
Expand All @@ -116,11 +171,18 @@ def custom_processing_step(self, data):

"""

from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler
from pytorch_lightning.profiler.profilers import (
SimpleProfiler,
AdvancedProfiler,
AutogradProfiler,
PassThroughProfiler,
BaseProfiler,
)

__all__ = [
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'AutogradProfiler',
'PassThroughProfiler',
]
121 changes: 120 additions & 1 deletion pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import fsspec
import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -219,7 +220,8 @@ class AdvancedProfiler(BaseProfiler):
"""
This profiler uses Python's cProfiler to record more detailed information about
time spent in each function call recorded during a given action. The output is quite
verbose and you should only use this if you want very detailed reports.
verbose and you should only use this if you want very detailed reports. This profiler
is most helpful when trying to identify code bottlenecks outside of your neural network.
"""

def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0):
Expand Down Expand Up @@ -283,3 +285,120 @@ def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class AutogradProfiler(BaseProfiler):
"""
This profiler uses Pytorch's torch.autograd.profiler.profile as a backend. This profiler
can record PyTorch operations on both CPU and GPU, recording the events of functions
being executed under the hood in C++. This profiler is most useful when trying to optimize
the forward and backward pass of your neural network.

Caution! The `AutogradProfiler` doesn't currently work with dataloaders with workers
enabled (`num_workers > 0`). Please disable workers when using this profiler.
"""

def __init__(
self,
output_filename: str = None,
profile_memory: bool = True,
use_cuda: bool = False,
row_limit: int = 20,
group_by_input_shapes: bool = False
):
"""
Args:
output_filename: optionally save profile results to file instead of printing
to std out when training is finished.
profile_memory: Record memory usage.
use_cuda: Measure execution time of CUDA kernels.
row_limit: Limit the number of rows in a table, `0` is a special value that
removes the limit completely.
group_by_input_shapes: Include operator input shapes and group calls by shape.
"""
self.profile_memory = profile_memory
self.group_by_input_shapes = group_by_input_shapes
self.use_cuda = use_cuda
self.row_limit = row_limit
self.profiled_actions = {}
self.context_names = {}
self.profiler = None
# stack of currently running profilers
self.running_stack = []

self.output_fname = output_filename
self.output_file = open(self.output_fname, 'w') if self.output_fname else None

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def _start(self, action_name: str) -> None:
self.profiler = torch.autograd.profiler.profile(
profile_memory=self.profile_memory,
use_cuda=self.use_cuda,
record_shapes=self.group_by_input_shapes
)
self.profiler.__enter__()

def start(self, action_name: str) -> None:
# stop the running profiler if any
if len(self.running_stack) > 0:
self._stop(self.running_stack[-1])
self.running_stack.append(action_name)

self.context_names[action_name] = "/".join(self.running_stack)

self._start(action_name)

def _stop(self, action_name: str) -> None:
self.profiler.__exit__(
exc_type=None,
exc_val=None,
exc_tb=None
)
events = self.profiler.function_events
self.profiler = None
for name in self.running_stack:
if name not in self.profiled_actions:
self.profiled_actions[name] = events
else:
self.profiled_actions[name] += events

def stop(self, action_name: str) -> None:
if len(self.running_stack) == 0 or self.running_stack[-1] != action_name:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
self._stop(action_name)
self.running_stack.pop()
# restore running profiler
if len(self.running_stack) > 0:
self._start(self.running_stack[-1])

def summary(self) -> str:
output_string = f"{os.linesep}Profiler Report{os.linesep}"

for name, events in self.profiled_actions.items():
# next line is a workaround for a pytorch issue (fixed on master, still present
# on 1.7). Without it the code fails with `AssertionError: There is already a CPU
# parent event for detach`
events.populate_cpu_children = lambda: None

report = events.key_averages(group_by_input_shapes=self.group_by_input_shapes).table(
sort_by=("cuda_time_total" if self.use_cuda else "cpu_time_total"),
row_limit=self.row_limit
)
output_string += f"{os.linesep}Profile stats for {self.context_names[name]}{os.linesep}{report}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
69 changes: 62 additions & 7 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import numpy as np
import pytest
import torch

from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
from pytorch_lightning.profiler import AdvancedProfiler, AutogradProfiler, SimpleProfiler

PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005

Expand All @@ -14,6 +15,11 @@ def _get_python_cprofile_total_duration(profile):
return sum([x.inlinetime for x in profile.getstats()])


def _get_pytorch_profiler_total_duration(events):
total_time = sum([e.cpu_time + e.cuda_time for e in events])
return total_time / 1e6 # convert microseconds to seconds


def _sleep_generator(durations):
"""
the profile_iterable method needs an iterable in which we can ensure that we're
Expand All @@ -36,6 +42,15 @@ def advanced_profiler(tmpdir):
return profiler


@pytest.fixture
def autograd_profiler(tmpdir):
profiler = AutogradProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"))
return profiler


# =====================
# Simple Profiler
# =====================
@pytest.mark.parametrize(["action", "expected"], [
pytest.param("a", [3, 1]),
pytest.param("b", [2]),
Expand Down Expand Up @@ -105,12 +120,16 @@ def test_simple_profiler_value_errors(simple_profiler):
simple_profiler.stop(action)


# =====================
# Advanced Profiler
# =====================
@pytest.mark.parametrize(["action", "expected"], [
pytest.param("a", [3, 1]),
pytest.param("b", [2]),
pytest.param("c", [1])
])
def test_advanced_profiler_durations(advanced_profiler, action, expected):
"""Ensure the reported durations are reasonably accurate."""

for duration in expected:
with advanced_profiler.profile(action):
Expand Down Expand Up @@ -149,9 +168,7 @@ def test_advanced_profiler_iterable_durations(advanced_profiler, action, expecte


def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
"""
ensure that the profiler doesn't introduce too much overhead during training
"""
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with advanced_profiler.profile("no-op"):
pass
Expand All @@ -163,9 +180,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):


def test_advanced_profiler_describe(tmpdir, advanced_profiler):
"""
ensure the profiler won't fail when reporting the summary
"""
"""Ensure the profiler won't fail when reporting the summary."""
# record at least one event
with advanced_profiler.profile("test"):
pass
Expand All @@ -184,3 +199,43 @@ def test_advanced_profiler_value_errors(advanced_profiler):

advanced_profiler.start(action)
advanced_profiler.stop(action)


# =====================
# Autograd Profiler
# =====================

def test_autograd_profiler_overhead(autograd_profiler, n_iter=5):
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with autograd_profiler.profile("no-op"):
a = torch.ones(42)
b = torch.abs(a)
c = a + b

action_profile = autograd_profiler.profiled_actions["no-op"]
total_duration = _get_pytorch_profiler_total_duration(action_profile)
average_duration = total_duration / n_iter
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE


def test_autograd_profiler_describe(tmpdir, autograd_profiler):
"""Ensure the profiler won't fail when reporting the summary."""
with autograd_profiler.profile("test"):
pass

# log to stdout and print to file
autograd_profiler.describe()
data = Path(autograd_profiler.output_fname).read_text()
assert len(data) > 0


def test_autograd_profiler_value_errors(autograd_profiler):
"""Ensure errors are raised where expected."""

action = "test"
with pytest.raises(ValueError):
autograd_profiler.stop(action)

autograd_profiler.start(action)
autograd_profiler.stop(action)