Skip to content

Commit

Permalink
test cloudpickle (#2105)
Browse files Browse the repository at this point in the history
* cloudpickle

* ci tests
  • Loading branch information
Borda committed Jun 9, 2020
1 parent de15759 commit 16a7326
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)"
python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)"
# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pickle

import cloudpickle
import pytest
import torch

Expand Down Expand Up @@ -273,3 +274,4 @@ def test_model_saving_loading(tmpdir):
def test_model_pickle(tmpdir):
model = EvalModelTemplate()
pickle.dumps(model)
cloudpickle.dumps(model)
2 changes: 1 addition & 1 deletion tests/requirements-devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# extended list of dependencies dor development and run lint and tests
-r ./requirements.txt

cloudpickle
cloudpickle>=1.2
21 changes: 17 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import glob
import math
import os
import pickle
import types
from argparse import Namespace

import cloudpickle
import pytest
import torch

Expand Down Expand Up @@ -671,10 +673,12 @@ def _optimizer_step(*args, **kwargs):
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)

trainer = Trainer(max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_root_dir=tmpdir)
trainer = Trainer(
max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_root_dir=tmpdir
)

# for the test
model.optimizer_step = _optimizer_step
Expand Down Expand Up @@ -824,3 +828,12 @@ def __init__(self, **kwargs):
# when we pass in an unknown arg, the base class should complain
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'abcdefg'"):
TrainerSubclass(abcdefg='unknown_arg')


def test_trainer_pickle(tmpdir):
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir
)
pickle.dumps(trainer)
cloudpickle.dumps(trainer)

0 comments on commit 16a7326

Please sign in to comment.