-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
test_wandb.py
69 lines (51 loc) · 2.15 KB
/
test_wandb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import pickle
from unittest import mock
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger(wandb):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""
logger = WandbLogger(anonymous=True, offline=True)
logger.log_metrics({'acc': 1.0})
wandb.init().log.assert_called_once_with({'acc': 1.0})
wandb.init().log.reset_mock()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
wandb.init().config.update.assert_called_once_with(
{'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
allow_val_change=True,
)
logger.watch('model', 'log', 10)
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)
assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_pickle(wandb, tmpdir):
"""Verify that pickling trainer with wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here.
"""
class Experiment:
id = 'the_id'
def project_name(self):
return 'the_project_name'
wandb.init.return_value = Experiment()
logger = WandbLogger(id='the_id', offline=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
)
# Access the experiment to ensure it's created
assert trainer.logger.experiment, 'missing experiment'
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
assert os.environ['WANDB_MODE'] == 'dryrun'
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
assert trainer2.logger.experiment, 'missing experiment'
wandb.init.assert_called()
assert 'id' in wandb.init.call_args[1]
assert wandb.init.call_args[1]['id'] == 'the_id'
del os.environ['WANDB_MODE']