Skip to content

Commit

Permalink
fix test for profiler (#800)
Browse files Browse the repository at this point in the history
* fix test for profiler

* use allclose

* user relative tol
  • Loading branch information
Borda committed Feb 9, 2020
1 parent 5130841 commit fc0ad03
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
import time

import numpy as np

from pytorch_lightning.profiler import Profiler, AdvancedProfiler


def test_simple_profiler():
p = Profiler()
Expand All @@ -19,13 +21,14 @@ def test_simple_profiler():
time.sleep(1)

# different environments have different precision when it comes to time.sleep()
np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1)
np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1)
np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1)
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
np.testing.assert_allclose(p.recorded_durations["a"], [3, 1], rtol=0.2)
np.testing.assert_allclose(p.recorded_durations["b"], [2], rtol=0.2)
np.testing.assert_allclose(p.recorded_durations["c"], [1], rtol=0.2)


def test_advanced_profiler():
def get_duration(profile):
def _get_duration(profile):
return sum([x.totaltime for x in profile.getstats()])

p = AdvancedProfiler()
Expand All @@ -42,9 +45,11 @@ def get_duration(profile):
with p.profile("c"):
time.sleep(1)

a_duration = get_duration(p.profiled_actions["a"])
np.testing.assert_almost_equal(a_duration, [4], decimal=1)
b_duration = get_duration(p.profiled_actions["b"])
np.testing.assert_almost_equal(b_duration, [2], decimal=1)
c_duration = get_duration(p.profiled_actions["c"])
np.testing.assert_almost_equal(c_duration, [1], decimal=1)
# different environments have different precision when it comes to time.sleep()
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
a_duration = _get_duration(p.profiled_actions["a"])
np.testing.assert_allclose(a_duration, [4], rtol=0.2)
b_duration = _get_duration(p.profiled_actions["b"])
np.testing.assert_allclose(b_duration, [2], rtol=0.2)
c_duration = _get_duration(p.profiled_actions["c"])
np.testing.assert_allclose(c_duration, [1], rtol=0.2)

0 comments on commit fc0ad03

Please sign in to comment.