diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 123b7544c9768..e728d554ecb3c 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -33,7 +33,7 @@ class WandbLogger(LightningLoggerBase): """ def __init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=False, - version=None, project=None, tags=None, experiment=None): + version=None, project=None, tags=None, experiment=None, entity=None): super().__init__() self._name = name self._save_dir = save_dir @@ -43,6 +43,7 @@ def __init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=F self._project = project self._experiment = experiment self._offline = offline + self._entity = entity def __getstate__(self): state = self.__dict__.copy() @@ -68,7 +69,7 @@ def experiment(self): os.environ["WANDB_MODE"] = "dryrun" self._experiment = wandb.init( name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, - id=self._id, resume="allow", tags=self._tags) + id=self._id, resume="allow", tags=self._tags, entity=self._entity) return self._experiment def watch(self, model, log="gradients", log_freq=100):