Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add PyTorch Profiler. #5560

Merged
merged 43 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
ad00b97
add profiler
tchaton Jan 18, 2021
cfae67b
add profiler
tchaton Jan 18, 2021
2f1020c
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 18, 2021
5931c18
update
tchaton Jan 19, 2021
11d8c61
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 19, 2021
c85661a
resolve flake8
tchaton Jan 19, 2021
9a62eb8
update doc
tchaton Jan 19, 2021
6f54b69
update changelog
tchaton Jan 19, 2021
1bbe314
clean doc
tchaton Jan 19, 2021
bd035da
delete prof file
tchaton Jan 19, 2021
b0cfe7a
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 19, 2021
e689cda
merge pr codebase
tchaton Jan 21, 2021
803aaa2
update
tchaton Jan 21, 2021
2e91d9e
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 21, 2021
698b43a
update doc
tchaton Jan 21, 2021
991958f
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 21, 2021
da9a56d
update doc
tchaton Jan 21, 2021
3b119fd
update doc
tchaton Jan 21, 2021
75c966f
update on comments
tchaton Jan 22, 2021
c10ab8c
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 22, 2021
f6ae283
update docstring
tchaton Jan 22, 2021
29b9a58
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 22, 2021
f0aed96
update docstring
tchaton Jan 22, 2021
5dd2b4d
try
Borda Jan 22, 2021
03b3ea5
update test
tchaton Jan 22, 2021
5663b5f
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 22, 2021
1e6a953
Update pytorch_lightning/profiler/__init__.py
tchaton Jan 22, 2021
21ae2da
Update pytorch_lightning/profiler/__init__.py
tchaton Jan 22, 2021
f6f0d89
update on comments
tchaton Jan 22, 2021
7ca9b7c
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 22, 2021
2ea05de
remove old code
tchaton Jan 22, 2021
c8d24b8
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 24, 2021
c397603
add support for ddp
tchaton Jan 25, 2021
1db6e67
resolve flake8
tchaton Jan 25, 2021
4e9a86c
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 25, 2021
e9866bb
Update pytorch_lightning/profiler/__init__.py
tchaton Jan 25, 2021
d65beee
resolve tests
tchaton Jan 25, 2021
bb642ba
Merge branch 'feat/torch_profiler' of https://github.com/PyTorchLight…
tchaton Jan 25, 2021
8338c5e
resolve flake8
tchaton Jan 25, 2021
85f9aa2
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 25, 2021
e6263e6
Merge branch 'release/1.2-dev' into feat/torch_profiler
tchaton Jan 26, 2021
9ae56cc
resolve flake8
tchaton Jan 26, 2021
8d62f41
Merge branch 'release/1.2-dev' into feat/torch_profiler
mergify[bot] Jan 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,4 @@ pytorch\ lightning
test-reports/
wandb
.forked/
*.prof
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))


- Added `PyTorchProfiler` ([#5560](https://github.com/PyTorchLightning/pytorch-lightning/pull/5560))


### Changed

Expand Down
78 changes: 76 additions & 2 deletions pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@


Advanced Profiling
--------------------
------------------

If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
Expand Down Expand Up @@ -114,13 +114,87 @@ def custom_processing_step(self, data):
model = MyModel(profiler)
trainer = Trainer(profiler=profiler, max_epochs=1)


PyTorch Profiling
-----------------

Autograd includes a profiler that lets you inspect the cost of different operators
inside your model - both on the CPU and GPU.

Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html)

.. code-block:: python

trainer = Trainer(..., profiler="pytorch")

or

profiler = PyTorchProfiler(...)
trainer = Trainer(..., profiler=profiler)

The profiler's results will be printed on the completion of a training `fit()`. This profiler
report can be quite long, so you can also specify an `output_filename` to save the report instead
of logging it to the output in your terminal.

This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default.
The output below shows the profiling for the action `training_step_and_backward`.
The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions.

tchaton marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Profiler Report

Profile stats for: training_step_and_backward
--------------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg
--------------------- --------------- --------------- --------------- --------------- ---------------
t 62.10% 1.044ms 62.77% 1.055ms 1.055ms
addmm 32.32% 543.135us 32.69% 549.362us 549.362us
mse_loss 1.35% 22.657us 3.58% 60.105us 60.105us
mean 0.22% 3.694us 2.05% 34.523us 34.523us
div_ 0.64% 10.756us 1.90% 32.001us 16.000us
ones_like 0.21% 3.461us 0.81% 13.669us 13.669us
sum_out 0.45% 7.638us 0.74% 12.432us 12.432us
transpose 0.23% 3.786us 0.68% 11.393us 11.393us
as_strided 0.60% 10.060us 0.60% 10.060us 3.353us
to 0.18% 3.059us 0.44% 7.464us 7.464us
empty_like 0.14% 2.387us 0.41% 6.859us 6.859us
empty_strided 0.38% 6.351us 0.38% 6.351us 3.175us
fill_ 0.28% 4.782us 0.33% 5.566us 2.783us
expand 0.20% 3.336us 0.28% 4.743us 4.743us
empty 0.27% 4.456us 0.27% 4.456us 2.228us
copy_ 0.15% 2.526us 0.15% 2.526us 2.526us
broadcast_tensors 0.15% 2.492us 0.15% 2.492us 2.492us
size 0.06% 0.967us 0.06% 0.967us 0.484us
is_complex 0.06% 0.961us 0.06% 0.961us 0.481us
stride 0.03% 0.517us 0.03% 0.517us 0.517us
--------------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 1.681ms

When running with `PyTorchProfiler(emit_nvtx=True)`. You should run as following:

nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
tchaton marked this conversation as resolved.
Show resolved Hide resolved

To visualize the profiled operation, you can either:

* Use: nvvp trace_name.prof

* Use: python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))'
tchaton marked this conversation as resolved.
Show resolved Hide resolved

"""

from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.profiler.profilers import (
AdvancedProfiler,
BaseProfiler,
PassThroughProfiler,
PyTorchProfiler,
SimpleProfiler,
)

__all__ = [
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'PassThroughProfiler',
"PyTorchProfiler",
]
233 changes: 232 additions & 1 deletion pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
"""Profiler to check if there are any bottlenecks in your code."""

import cProfile
import inspect
import io
import os
import pstats
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class BaseProfiler(ABC):
Expand Down Expand Up @@ -282,3 +286,230 @@ def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class PyTorchProfiler(BaseProfiler):

PROFILER_OVERHEAD_MAX_TOLERANCE = 7.5e-4
tchaton marked this conversation as resolved.
Show resolved Hide resolved
PROFILED_FUNCTIONS = ["training_step_and_backward", "validation_step", "test_step"]
AVAILABLE_SORT_KEYS = [
"cpu_time", "cuda_time", "cpu_time_total",
"cuda_time_total", "cpu_memory_usage", "cuda_memory_usage",
"self_cpu_memory_usage", "self_cuda_memory_usage", "count"
]
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
output_filename: Optional[str] = None,
enabled: bool = True,
use_cuda: bool = False,
record_shapes: bool = False,
profile_memory: bool = False,
group_by_input_shapes: bool = False,
with_stack: bool = False,
use_kineto: bool = False,
use_cpu: bool = False,
emit_nvtx: bool = False,
export_to_chrome: bool = False,
path_to_export_trace: str = None,
row_limit: int = 20,
sort_by_key: Optional[str] = None,
profiled_functions: Optional[List] = None,
):
"""
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of
different operators inside your model - both on the CPU and GPU

Args:

output_filename: optionally save profile results to file instead of printing
to std out when training is finished.

enabled: Setting this to False makes this context manager a no-op.

use_cuda: Enables timing of CUDA events as well using the cudaEvent API.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Adds approximately 4us of overhead to each tensor operation.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

record_shapes: If shapes recording is set, information about input dimensions will be collected.

profile_memory: Whether to report memory usage, default: True (1.6.0)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

group_by_input_shapes: Include operator input shapes and group calls by shape.

with_stack: record source information (file and line number) for the ops (1.7.0)

use_kineto: experimental support for Kineto profiler (1.8.0)

use_cpu: use_kineto=True and can be used to lower the overhead for GPU-only profiling (1.8.0)

emit_nvtx: Context manager that makes every autograd operation emit an NVTX range
Run::

nvprof --profile-from-start off -o trace_name.prof -- <regular command here>

To visualize, you can either use::

nvvp trace_name.prof
torch.autograd.profiler.load_nvprof(path)

export_to_chrome: Wether to export the sequence of profiled operators for Chrome.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

path_to_export_trace: Directory path to export traces. By default, it will be save
tchaton marked this conversation as resolved.
Show resolved Hide resolved
where the file being is being run.

row_limit: Limit the number of rows in a table, `0` is a special value that
removes the limit completely.

sort_by_key: Keys to sort out profiled table

profiled_functions: list of profiled functions which will create a context manager on.
Any other will be pass through.
"""

self.profiled_actions = {}
# PyTorch Profiler doesn't seem to work with multiple processes
# todo: Try to find a solution
self.enabled = enabled and os.getenv("LOCAL_RANK", None) is None
self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS
self.use_cuda = use_cuda
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total")
self.with_stack = with_stack
self.group_by_input_shapes = group_by_input_shapes and record_shapes
self.use_kineto = use_kineto
self.use_cpu = use_cpu
self.row_limit = row_limit
self.emit_nvtx = emit_nvtx
self.export_to_chrome = export_to_chrome
self.path_to_export_trace = path_to_export_trace

if export_to_chrome and path_to_export_trace is None:
rank_zero_warn(
"The exported trace would be save locally as `path_to_export_trace` is empty"
"Note: Each functions will generate its own traced file. ")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if self.sort_by_key not in self.AVAILABLE_SORT_KEYS:
raise MisconfigurationException(
f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. ")

self.profiled_actions = {}
self.context_names = {}
self.running_stack = []
self.profiler = None

self.output_fname = output_filename
self.output_file = None

if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
# stop the running profiler if any
if action_name in self.profiled_functions:
if len(self.running_stack) > 0:
self._stop(self.running_stack[-1])
self.running_stack.append(action_name)

self.context_names[action_name] = "/".join(self.running_stack)

self._start(action_name)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def _start(self, action_name: str) -> None:
if self.emit_nvtx:
self._create_profiler(action_name, torch.cuda.profiler.profile, enter=False)
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
self._create_profiler(action_name, torch.autograd.profiler.profile)

def _create_profiler(self, action_name, profiler, enter=True):
init_args = inspect.signature(profiler.__init__).parameters
profiler_args = {
k: v for k, v in vars(self).items() if k in init_args
}
pr = profiler(**profiler_args)
if enter:
pr = pr.__enter__()
self.profiler = pr

def _stop(self, action_name: str) -> None:
if self.profiler is None:
return

self.profiler.__exit__(
exc_type=None,
exc_val=None,
exc_tb=None
)

function_events = self.profiler.function_events
self.profiler = None
for name in self.running_stack:
if name not in self.profiled_actions:
self.profiled_actions[name] = function_events
else:
self.profiled_actions[name] += function_events

def stop(self, action_name: str) -> None:
if action_name in self.profiled_functions:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if len(self.running_stack) == 0 or self.running_stack[-1] != action_name:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
self._stop(action_name)
self.running_stack.pop()
# restore running profiler
if len(self.running_stack) > 0:
self._start(self.running_stack[-1])

def summary(self) -> str:
recorded_stats = {}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
output_string = ''

if self.enabled:
for action_name, function_events in self.profiled_actions.items():

# next line is a workaround for a pytorch issue (fixed on master, still present
# on 1.7). Without it the code fails with `AssertionError: There is already a CPU
# parent event for detach`
function_events.populate_cpu_children = lambda: None

if self.export_to_chrome:
filename = f"{action_name}_trace.json"
path_to_trace = filename if self.path_to_export_trace is None \
else os.path.join(self.path_to_export_trace, filename)
function_events.export_chrome_trace(path_to_trace)

if self.emit_nvtx:
return output_string

else:
table = function_events.key_averages(
group_by_input_shapes=self.group_by_input_shapes).table(
sort_by=self.sort_by_key,
row_limit=self.row_limit)
recorded_stats[action_name] = table

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += (
f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"
)

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
9 changes: 8 additions & 1 deletion pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@

from typing import Union

from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.profiler import (
AdvancedProfiler,
BaseProfiler,
PassThroughProfiler,
PyTorchProfiler,
SimpleProfiler,
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PyTorchProfiler
}


Expand Down
Loading