Skip to content

Commit

Permalink
advanced profiler describe + cleaned up tests (#837)
Browse files Browse the repository at this point in the history
* add py36 compatibility

* add test case to capture previous bug

* clean up tests

* clean up tests
  • Loading branch information
jeremyjordan authored Feb 16, 2020
1 parent 06ca642 commit 4ae31cd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 34 deletions.
3 changes: 1 addition & 2 deletions pytorch_lightning/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def describe(self):
self.recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
s = io.StringIO()
sortby = pstats.SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby)
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
self.recorded_stats[action_name] = s.getvalue()
if self.output_filename is not None:
Expand Down
105 changes: 73 additions & 32 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,96 @@
import time

import numpy as np
import pytest

from pytorch_lightning.profiler import Profiler, AdvancedProfiler

PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001


@pytest.fixture
def simple_profiler():
profiler = Profiler()
return profiler

def test_simple_profiler():
p = Profiler()

with p.profile("a"):
time.sleep(3)
@pytest.fixture
def advanced_profiler():
profiler = AdvancedProfiler()
return profiler

with p.profile("a"):
time.sleep(1)

with p.profile("b"):
time.sleep(2)
@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_simple_profiler_durations(simple_profiler, action, expected):
"""
ensure the reported durations are reasonably accurate
"""

with p.profile("c"):
time.sleep(1)
for duration in expected:
with simple_profiler.profile(action):
time.sleep(duration)

# different environments have different precision when it comes to time.sleep()
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
np.testing.assert_allclose(p.recorded_durations["a"], [3, 1], rtol=0.2)
np.testing.assert_allclose(p.recorded_durations["b"], [2], rtol=0.2)
np.testing.assert_allclose(p.recorded_durations["c"], [1], rtol=0.2)
np.testing.assert_allclose(
simple_profiler.recorded_durations[action], expected, rtol=0.2
)


def test_advanced_profiler():
def _get_duration(profile):
return sum([x.totaltime for x in profile.getstats()])
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
"""
ensure that the profiler doesn't introduce too much overhead during training
"""
for _ in range(n_iter):
with simple_profiler.profile("no-op"):
pass

durations = np.array(simple_profiler.recorded_durations["no-op"])
assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE)

p = AdvancedProfiler()

with p.profile("a"):
time.sleep(3)
def test_simple_profiler_describe(simple_profiler):
"""
ensure the profiler won't fail when reporting the summary
"""
simple_profiler.describe()

with p.profile("a"):
time.sleep(1)

with p.profile("b"):
time.sleep(2)
@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_advanced_profiler_durations(advanced_profiler, action, expected):
def _get_total_duration(profile):
return sum([x.totaltime for x in profile.getstats()])

with p.profile("c"):
time.sleep(1)
for duration in expected:
with advanced_profiler.profile(action):
time.sleep(duration)

# different environments have different precision when it comes to time.sleep()
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
a_duration = _get_duration(p.profiled_actions["a"])
np.testing.assert_allclose(a_duration, [4], rtol=0.2)
b_duration = _get_duration(p.profiled_actions["b"])
np.testing.assert_allclose(b_duration, [2], rtol=0.2)
c_duration = _get_duration(p.profiled_actions["c"])
np.testing.assert_allclose(c_duration, [1], rtol=0.2)
recored_total_duration = _get_total_duration(
advanced_profiler.profiled_actions[action]
)
expected_total_duration = np.sum(expected)
np.testing.assert_allclose(
recored_total_duration, expected_total_duration, rtol=0.2
)


def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
"""
ensure that the profiler doesn't introduce too much overhead during training
"""
for _ in range(n_iter):
with advanced_profiler.profile("no-op"):
pass

action_profile = advanced_profiler.profiled_actions["no-op"]
total_duration = sum([x.totaltime for x in action_profile.getstats()])
average_duration = total_duration / n_iter
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE


def test_advanced_profiler_describe(advanced_profiler):
"""
ensure the profiler won't fail when reporting the summary
"""
advanced_profiler.describe()

0 comments on commit 4ae31cd

Please sign in to comment.