diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 61e843a29a6399..c3bcc67f27b20b 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -79,7 +79,7 @@ def test_train_loop_only(tmpdir): # fit model result = trainer.fit(model) assert result == 1 - assert trainer.callback_metrics['loss'] < 0.50 + assert trainer.callback_metrics['loss'] < 0.6 def test_train_val_loop_only(tmpdir): @@ -102,7 +102,7 @@ def test_train_val_loop_only(tmpdir): # fit model result = trainer.fit(model) assert result == 1 - assert trainer.callback_metrics['loss'] < 0.65 + assert trainer.callback_metrics['loss'] < 0.6 def test_full_loop(tmpdir):