diff --git a/pytorch_lightning/profiler/profiler.py b/pytorch_lightning/profiler/profiler.py index 32f220897a9dc..ffecba5dff34b 100644 --- a/pytorch_lightning/profiler/profiler.py +++ b/pytorch_lightning/profiler/profiler.py @@ -163,8 +163,7 @@ def describe(self): self.recorded_stats = {} for action_name, pr in self.profiled_actions.items(): s = io.StringIO() - sortby = pstats.SortKey.CUMULATIVE - ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby) + ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) self.recorded_stats[action_name] = s.getvalue() if self.output_filename is not None: diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 1b26030de8ea4..410e452bca577 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,55 +1,96 @@ import time - import numpy as np +import pytest from pytorch_lightning.profiler import Profiler, AdvancedProfiler +PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 + + +@pytest.fixture +def simple_profiler(): + profiler = Profiler() + return profiler -def test_simple_profiler(): - p = Profiler() - with p.profile("a"): - time.sleep(3) +@pytest.fixture +def advanced_profiler(): + profiler = AdvancedProfiler() + return profiler - with p.profile("a"): - time.sleep(1) - with p.profile("b"): - time.sleep(2) +@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])]) +def test_simple_profiler_durations(simple_profiler, action, expected): + """ + ensure the reported durations are reasonably accurate + """ - with p.profile("c"): - time.sleep(1) + for duration in expected: + with simple_profiler.profile(action): + time.sleep(duration) # different environments have different precision when it comes to time.sleep() # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 - np.testing.assert_allclose(p.recorded_durations["a"], [3, 1], rtol=0.2) - np.testing.assert_allclose(p.recorded_durations["b"], [2], rtol=0.2) - np.testing.assert_allclose(p.recorded_durations["c"], [1], rtol=0.2) + np.testing.assert_allclose( + simple_profiler.recorded_durations[action], expected, rtol=0.2 + ) -def test_advanced_profiler(): - def _get_duration(profile): - return sum([x.totaltime for x in profile.getstats()]) +def test_simple_profiler_overhead(simple_profiler, n_iter=5): + """ + ensure that the profiler doesn't introduce too much overhead during training + """ + for _ in range(n_iter): + with simple_profiler.profile("no-op"): + pass + + durations = np.array(simple_profiler.recorded_durations["no-op"]) + assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE) - p = AdvancedProfiler() - with p.profile("a"): - time.sleep(3) +def test_simple_profiler_describe(simple_profiler): + """ + ensure the profiler won't fail when reporting the summary + """ + simple_profiler.describe() - with p.profile("a"): - time.sleep(1) - with p.profile("b"): - time.sleep(2) +@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])]) +def test_advanced_profiler_durations(advanced_profiler, action, expected): + def _get_total_duration(profile): + return sum([x.totaltime for x in profile.getstats()]) - with p.profile("c"): - time.sleep(1) + for duration in expected: + with advanced_profiler.profile(action): + time.sleep(duration) # different environments have different precision when it comes to time.sleep() # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 - a_duration = _get_duration(p.profiled_actions["a"]) - np.testing.assert_allclose(a_duration, [4], rtol=0.2) - b_duration = _get_duration(p.profiled_actions["b"]) - np.testing.assert_allclose(b_duration, [2], rtol=0.2) - c_duration = _get_duration(p.profiled_actions["c"]) - np.testing.assert_allclose(c_duration, [1], rtol=0.2) + recored_total_duration = _get_total_duration( + advanced_profiler.profiled_actions[action] + ) + expected_total_duration = np.sum(expected) + np.testing.assert_allclose( + recored_total_duration, expected_total_duration, rtol=0.2 + ) + + +def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): + """ + ensure that the profiler doesn't introduce too much overhead during training + """ + for _ in range(n_iter): + with advanced_profiler.profile("no-op"): + pass + + action_profile = advanced_profiler.profiled_actions["no-op"] + total_duration = sum([x.totaltime for x in action_profile.getstats()]) + average_duration = total_duration / n_iter + assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE + + +def test_advanced_profiler_describe(advanced_profiler): + """ + ensure the profiler won't fail when reporting the summary + """ + advanced_profiler.describe()