Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint save dir is not correctly set when _save_dir is given by wandb logger #2527

Closed
pableeto opened this issue Jul 6, 2020 · 6 comments · Fixed by #2681
Closed

checkpoint save dir is not correctly set when _save_dir is given by wandb logger #2527

pableeto opened this issue Jul 6, 2020 · 6 comments · Fixed by #2681
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers

Comments

@pableeto
Copy link

pableeto commented Jul 6, 2020

🐛 Bug

When using ModelCheckpoint with default parameter and Wandb Logger with save_dir set to some dir,
The checkpoint is still dumped to os.getcwd()

To Reproduce

........
logger = WandbLogger(save_dir='/path/to/experiment')
trainer = Trainer.from_argparse_args(other_args, logger = logger)

Expected behavior

The checkpoint should be saved under /path/to/experiment defined by Wandb logger's save_dir argument.

Additional context

The pl version i am use is by pip install, i.e. 0.8.4

I think the problem is related to the logic in on_train_start() function in model_checkpoint.py,

        if trainer.logger is not None:
            # weights_save_path overrides anything
            if getattr(trainer, 'weights_save_path', None) is not None:
                save_dir = trainer.weights_save_path
            else:
                save_dir = (getattr(trainer.logger, 'save_dir', None)
                            or getattr(trainer.logger, '_save_dir', None)
                            or trainer.default_root_dir)

Unfortunately, the default of "weights_save_path" is not None; it is set to default_root_dir which is os.getcwd() (See pytorch_lightning/trainer/callback_config.py, line 57):

        # if weights_save_path is still none here, set to current working dir
        if self.weights_save_path is None:
            self.weights_save_path = self.default_root_dir

Thus, the ckpt_path is always set to weights_save_path instead of save_dir from logger.

Fix

A quick patch for this might be as follows:

        if trainer.logger is not None:
            # weights_save_path overrides anything
            # unless if it is os.getcwd() and we have a logger set its save_dir to other folder
            weights_save_path = getattr(trainer, 'weights_save_path', None)
            loggers_save_path = (getattr(trainer.logger, 'save_dir', None)
                            or getattr(trainer.logger, '_save_dir', None)
                            or trainer.default_root_dir)
            avoid_weights_save_path = (weight_save_path == trainer.default_root_dir and loggers_save_path != trainer.default_root_dir)

            if (weights_save_path is not None and not avoid_weights_save_path):
                save_dir = weights_save_path
            else:
                save_dir = loggers_save_path

I would be happy to fork the code and submit a PR, btw.

@pableeto pableeto added bug Something isn't working help wanted Open to be worked on labels Jul 6, 2020
@awaelchli
Copy link
Member

awaelchli commented Jul 6, 2020

@pableeto I don't think we want to make this part of the code even more complicated.
Wouldn't the real fix be to remove this if block here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/25ee51bc570503f331dceecc610d0eb355e22327/pytorch_lightning/trainer/callback_config.py#L57

I think it is a left-over from a recent refactor.

@pableeto
Copy link
Author

pableeto commented Jul 6, 2020

Well, I guess that would be fine :)
Not sure if remove that block will cause other issues, though.

@pableeto
Copy link
Author

pableeto commented Jul 6, 2020

@pableeto I don't think we want to make this part of the code even more complicated.
Wouldn't the real fix be to remove this if block here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/25ee51bc570503f331dceecc610d0eb355e22327/pytorch_lightning/trainer/callback_config.py#L57

I think it is a left-over from a recent refactor.

I just found if we remove this block, then it will made the pipeline crush at line 391 of pytorch_lightning/trainer/training_io.py:

        folderpath = self.weights_save_path              # This will become None, next line will crush
        if os.path.exists(folderpath):
            files = os.listdir(folderpath)

@williamFalcon
Copy link
Contributor

@pableeto try:

pip install pytorch-lightning==0.8.5rc1

We can re-open if it is not fixed

@pableeto
Copy link
Author

pableeto commented Jul 15, 2020

@pableeto try:

pip install pytorch-lightning==0.8.5rc1

We can re-open if it is not fixed

@williamFalcon Just tried 0.8.5rc1, problem still exists.

@awaelchli
Copy link
Member

Fixed here #2681

@Borda Borda added the logger Related to the Loggers label Aug 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers
Projects
None yet
4 participants