Skip to content

Commit

Permalink
Upgrade wandb version (#340)
Browse files Browse the repository at this point in the history
* Upgrade wandb version

* mock wandb for tests

* mock wandb for tests

* patch _package_available

* change location of patch

* change patch location

* update patch
  • Loading branch information
ashwinvaidya17 authored and samet-akcay committed May 31, 2022
1 parent 12ea388 commit a84319b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 28 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pytorch-lightning>=1.6.0
torchmetrics>=0.8.0
torchvision>=0.9.1
torchtext>=0.9.1
wandb==0.12.9
wandb==0.12.17
matplotlib>=3.4.3
gradio>=2.9.4
61 changes: 34 additions & 27 deletions tests/pre_merge/utils/loggers/test_get_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from unittest.mock import patch

patch("pytorch_lightning.utilities.imports._package_available", False)
patch("pytorch_lightning.loggers.wandb.WandbLogger")

import pytest
from omegaconf import OmegaConf
from pytorch_lightning.loggers import CSVLogger
Expand All @@ -37,36 +42,38 @@ def test_get_experiment_logger():
}
)

# get no logger
logger = get_experiment_logger(config=config)
assert isinstance(logger, bool)
config.project.logger = False
logger = get_experiment_logger(config=config)
assert isinstance(logger, bool)
with patch("pytorch_lightning.loggers.wandb.wandb"):

# get tensorboard
config.project.logger = "tensorboard"
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], AnomalibTensorBoardLogger)
# get no logger
logger = get_experiment_logger(config=config)
assert isinstance(logger, bool)
config.project.logger = False
logger = get_experiment_logger(config=config)
assert isinstance(logger, bool)

# get wandb logger
config.project.logger = "wandb"
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], AnomalibWandbLogger)
# get tensorboard
config.project.logger = "tensorboard"
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], AnomalibTensorBoardLogger)

# get csv logger.
config.project.logger = "csv"
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], CSVLogger)
# get wandb logger
config.project.logger = "wandb"
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], AnomalibWandbLogger)

# get multiple loggers
config.project.logger = ["tensorboard", "wandb", "csv"]
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], AnomalibTensorBoardLogger)
assert isinstance(logger[1], AnomalibWandbLogger)
assert isinstance(logger[2], CSVLogger)
# get csv logger.
config.project.logger = "csv"
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], CSVLogger)

# raise unknown
with pytest.raises(UnknownLogger):
config.project.logger = "randomlogger"
# get multiple loggers
config.project.logger = ["tensorboard", "wandb", "csv"]
logger = get_experiment_logger(config=config)
assert isinstance(logger[0], AnomalibTensorBoardLogger)
assert isinstance(logger[1], AnomalibWandbLogger)
assert isinstance(logger[2], CSVLogger)

# raise unknown
with pytest.raises(UnknownLogger):
config.project.logger = "randomlogger"
logger = get_experiment_logger(config=config)

0 comments on commit a84319b

Please sign in to comment.