Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor base profilers 3/5 #6621

Merged
merged 36 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fbbb9a2
Refactor basic profilers
carmocca Mar 22, 2021
2a82e05
Fixes
carmocca Mar 22, 2021
ff125e2
Unused import
carmocca Mar 22, 2021
01a760e
Introduce setup
carmocca Mar 22, 2021
b31831e
Profile on all ranks. Print to stdout on 0
carmocca Mar 22, 2021
f8a8772
Introduce dirpath + filename
carmocca Mar 22, 2021
aa4b7dd
CHANGELOG
carmocca Mar 22, 2021
8e3034e
Add tests. Address comments
carmocca Mar 22, 2021
e4e0dd6
add `on_run_stage_setup`
tchaton Mar 22, 2021
d0fdbb9
add on_run_stage_setup function
tchaton Mar 22, 2021
1a16bb3
update
tchaton Mar 22, 2021
63b6988
update lightnng flow direction
tchaton Mar 22, 2021
a05acdd
remove trace
tchaton Mar 22, 2021
af72dff
Merge branch 'master' into refactor-base-profilers
tchaton Mar 22, 2021
59c941b
Undo code that should be in 3/4
carmocca Mar 22, 2021
da0f310
Multi-stage multi-rank
carmocca Mar 22, 2021
59c1b4c
2/5 changes
carmocca Mar 22, 2021
dd1dce0
Pass stage in __del__
carmocca Mar 22, 2021
12d014b
Merge branch 'master' into refactor-base-profilers
carmocca Mar 22, 2021
097a426
Remove TODOs
carmocca Mar 22, 2021
4d529fa
Describe on_evaluation_end. Add tests
carmocca Mar 22, 2021
58dcd4e
Typo
carmocca Mar 22, 2021
c37162f
Address comments
carmocca Mar 22, 2021
4c5f1f3
deepcopy tests
carmocca Mar 22, 2021
5ed73fb
Advanced teardown
carmocca Mar 22, 2021
897f8e5
Fix teardown test
carmocca Mar 22, 2021
e42be2a
Fix tests
carmocca Mar 22, 2021
32c301c
Minor change
carmocca Mar 22, 2021
af0c8ad
Update CHANGELOG.md
carmocca Mar 22, 2021
29a73c5
Fix test
carmocca Mar 22, 2021
cb756b8
Fix 6522
carmocca Mar 22, 2021
758b942
resolve ddp tests
tchaton Mar 23, 2021
fca4eb2
resolve tests
tchaton Mar 23, 2021
2919a39
resolve some tests
tchaton Mar 23, 2021
d7ca5fa
update tests
tchaton Mar 23, 2021
c7bb71b
update
tchaton Mar 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 97 additions & 87 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Optional, Union
from typing import Dict, Optional, Tuple

import numpy as np

Expand All @@ -30,22 +30,8 @@
log = logging.getLogger(__name__)


class BaseProfiler(ABC):
"""
If you wish to write a custom profiler, you should inhereit from this class.
"""

def __init__(self, output_streams: Optional[Union[list, tuple]] = None):
"""
Args:
output_streams: callable
"""
if output_streams:
if not isinstance(output_streams, (list, tuple)):
output_streams = [output_streams]
else:
output_streams = []
self.write_streams = output_streams
class AbstractProfiler(ABC):
"""Specification of a profiler."""

@abstractmethod
def start(self, action_name: str) -> None:
Expand All @@ -55,6 +41,49 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""

def teardown(self) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
carmocca marked this conversation as resolved.
Show resolved Hide resolved


class BaseProfiler(AbstractProfiler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton Why do we need BaseProfiler & AbstractProfiler both?

Copy link
Contributor

@tchaton tchaton Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is slightly cleaner.

"""
If you wish to write a custom profiler, you should inherit from this class.
"""

def __init__(
self,
output_filename: Optional[str] = None,
local_rank: Optional[int] = None,
log_dir: Optional[str] = None
) -> None:
self.output_fname = output_filename
self.output_file = None
self._file_prepared = False
self.write_streams = []
# the profiler can be used outside of lightning
# that's why we call `on_train_start` manually
self.on_train_start(local_rank=local_rank, log_dir=log_dir)

def on_train_start(self, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
"""
This function is used by the Trainer to inject local_rank with `DDP`
and `TensorBoardLogger` log_dir in the profiler.
"""
self.local_rank = local_rank
self.log_dir = log_dir

def _prepare_file(self) -> None:
if not self._file_prepared:
if self.output_fname and self.output_file is None:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.write_streams = [self.output_file.write] if self.output_file else [log.info]
self._file_prepared = True

@contextmanager
def profile(self, action_name: str) -> None:
"""
Expand Down Expand Up @@ -88,15 +117,41 @@ def profile_iterable(self, iterable, action_name: str) -> None:

def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
self._prepare_file()
for write in self.write_streams:
write(self.summary())
if self.output_file:
self.output_file.flush()
self.teardown()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def stats_to_str(self, stats: Dict[str, str]) -> str:
output = ["Profiler Report"]
for action, value in stats.items():
header = f"Profile stats for: {action}"
if getattr(self, "local_rank", None) is not None:
header += f" rank: {self.local_rank}"
output.append(header)
output.append(value)
return os.linesep.join(output)

def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
self.write_streams = []
self._file_prepared = False

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""
def __del__(self) -> None:
self.teardown()

def on_train_start(self, local_rank: Optional[int] = None):
self.local_rank = local_rank
def start(self, action_name: str) -> None:
raise NotImplementedError

def stop(self, action_name: str) -> None:
raise NotImplementedError

def summary(self) -> str:
raise NotImplementedError
ananthsub marked this conversation as resolved.
Show resolved Hide resolved


class PassThroughProfiler(BaseProfiler):
Expand All @@ -105,9 +160,6 @@ class PassThroughProfiler(BaseProfiler):
The Trainer uses this class by default.
"""

def __init__(self):
super().__init__(output_streams=None)

def start(self, action_name: str) -> None:
pass

Expand All @@ -117,14 +169,17 @@ def stop(self, action_name: str) -> None:
def summary(self) -> str:
return ""

def teardown(self) -> None:
pass


class SimpleProfiler(BaseProfiler):
"""
This profiler simply records the duration of actions (in seconds) and reports
the mean duration of each action and the total time spent over the entire training run.
"""

def __init__(self, output_filename: Optional[str] = None, extended=True):
def __init__(self, output_filename: Optional[str] = None, extended: bool = True) -> None:
"""
Args:
output_filename: optionally save profile results to file instead of printing
Expand All @@ -135,19 +190,11 @@ def __init__(self, output_filename: Optional[str] = None, extended=True):
If you attempt to start an action which has already started, or
if you attempt to stop recording an action which was never started.
"""
self.current_actions = {}
self.current_actions: Dict[str, float] = {}
self.recorded_durations = defaultdict(list)
self.extended = extended

self.output_fname = output_filename
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_filename=output_filename)
self.start_time = time.monotonic()
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name in self.current_actions:
Expand All @@ -162,31 +209,32 @@ def stop(self, action_name: str) -> None:
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)

def make_report(self):
def _make_report(self) -> Tuple[list, float]:
total_duration = time.monotonic() - self.start_time
report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()]
report.sort(key=lambda x: x[2], reverse=True)
return report, total_duration

def summary(self) -> str:
output_string = "\n\nProfiler Report\n"
sep = os.linesep
output_string = f"Profiler Report{sep}"

if self.extended:

if len(self.recorded_durations) > 0:
max_key = np.max([len(k) for k in self.recorded_durations.keys()])

def log_row(action, mean, num_calls, total, per):
row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|"
row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|"
row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
return row

output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
output_string_len = len(output_string)
output_string += f"{os.linesep}{'-' * output_string_len}"
report, total_duration = self.make_report()
output_string += f"{sep}{'-' * output_string_len}"
report, total_duration = self._make_report()
output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %")
output_string += f"{os.linesep}{'-' * output_string_len}"
output_string += f"{sep}{'-' * output_string_len}"
for action, durations, duration_per in report:
output_string += log_row(
action,
Expand All @@ -198,27 +246,16 @@ def log_row(action, mean, num_calls, total, per):
else:

def log_row(action, mean, total):
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}"

output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"{os.linesep}{'-' * 65}"
output_string += f"{sep}{'-' * 65}"

for action, durations in self.recorded_durations.items():
output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}")
output_string += os.linesep
output_string += sep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class AdvancedProfiler(BaseProfiler):
"""
Expand All @@ -227,7 +264,7 @@ class AdvancedProfiler(BaseProfiler):
verbose and you should only use this if you want very detailed reports.
"""

def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0):
def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0) -> None:
"""
Args:
output_filename: optionally save profile results to file instead of printing
Expand All @@ -240,18 +277,10 @@ def __init__(self, output_filename: Optional[str] = None, line_count_restriction
ValueError:
If you attempt to stop recording an action which was never started.
"""
super().__init__(output_filename=output_filename)
self.profiled_actions = {}
self.line_count_restriction = line_count_restriction

self.output_fname = output_filename
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
Expand All @@ -260,9 +289,7 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()

def summary(self) -> str:
Expand All @@ -272,21 +299,4 @@ def summary(self) -> str:
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
recorded_stats[action_name] = s.getvalue()

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
return self.stats_to_str(recorded_stats)
52 changes: 12 additions & 40 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,33 +159,21 @@ def __init__(
self.running_stack = []
self.profiler = None

self.output_fname = output_filename
self.output_file = None
if local_rank is not None:
self.on_train_start(local_rank=local_rank)
self.on_train_start = super().on_train_start
super().__init__(output_filename=output_filename, local_rank=local_rank)

def on_train_start(self, local_rank: Optional[str] = None):
self.local_rank = local_rank
def on_train_start(self, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
super().on_train_start(local_rank=local_rank, log_dir=log_dir)

# when logging to `log.info`, only perform profiling on rank 0
if local_rank != 0 and self.output_fname is None:
self.wrap_functions_into_rank_zero_only()

if self.output_fname:
if local_rank is not None:
if '.txt' not in self.output_fname:
raise MisconfigurationException("Log file should be .txt file.")

self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt")

fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")
# if the user didn't provide `path_to_export_trace`,
# set it as TensorBoardLogger log_dir if exists
if self.path_to_export_trace is None:
self.path_to_export_trace = log_dir

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)
# when logging to `log.info`, only perform profiling on rank 0
if local_rank is not None and local_rank > 0 and self.output_fname is None:
self._rank_zero_only_wrap()

def wrap_functions_into_rank_zero_only(self):
def _rank_zero_only_wrap(self) -> None:
self.start = rank_zero_only(self.start)
self.stop = rank_zero_only(self.stop)
self.summary = rank_zero_only(self.summary)
Expand Down Expand Up @@ -284,20 +272,4 @@ def summary(self) -> str:
table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit)
recorded_stats[action_name] = table

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}")

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
return self.stats_to_str(recorded_stats)
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ def call_teardown_hook(self, model: LightningModule) -> None:
else:
state = None

self.profiler.teardown()
self.teardown(stage=state)
model.teardown(stage=state)

Expand Down
Loading