Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer.from_argparse_args with additional kwargs causes model to not be saved #1714

Closed
crvineeth97 opened this issue May 3, 2020 · 7 comments
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@crvineeth97
Copy link

crvineeth97 commented May 3, 2020

🐛 Bug

When using Trainer.from_argparse_args() to initialize the trainer, there will be some specific arguments that we would like to keep constant and not send as part of hparams. If the extra arguments turn out to be an object, such as a TensorBoardLogger or a ModelCheckpoint object, the model will not be saved because these objects get added to hparams

To Reproduce

Code sample

import os
from argparse import ArgumentParser

import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


class LitModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {"train_loss": loss}
        return {"loss": loss, "log": tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        dataset = MNIST(
            os.getcwd(), train=True, download=True, transform=transforms.ToTensor()
        )
        loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)
        return loader


def main(hparams):
    logger = TensorBoardLogger(save_dir=os.getenv("HOME"), name="logs")
    net = LitModel(hparams)
    trainer = Trainer.from_argparse_args(
        hparams, logger=logger, checkpoint_callback=True, overfit_pct=0.01,
    )
    trainer.fit(net)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--sample", type=int, default=42, help="Sample Argument")
    hparams = parser.parse_args()
    main(hparams)

Error

TypeError: cannot pickle '_thread.lock' object

Expected behavior

The model to be saved without considering the extra arguments sent as hparams

Environment

  • CUDA:
    • GPU:
      • GeForce GTX 1050 Ti
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.18.1
    • pyTorch_debug: False
    • pyTorch_version: 1.4.0
    • pytorch-lightning: 0.7.5
    • tensorboard: 2.2.1
    • tqdm: 4.45.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.2
    • version: 32-Ubuntu SMP Wed Apr 22 17:40:10 UTC 2020

Additional context

Possible Fix

I believe that a possible fix is to change the from_argparse_args method
from

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

        params = vars(args)
        params.update(**kwargs)

        return cls(**params)

to

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

        params = vars(args)

        return cls(**params, **kwargs)

This ensures that the **kwargs are not added to the hparams of the model and the model gets saved successfully, but I'm not sure of the impact of this change

@crvineeth97 crvineeth97 added bug Something isn't working help wanted Open to be worked on labels May 3, 2020
@github-actions
Copy link
Contributor

github-actions bot commented May 3, 2020

Hi! thanks for your contribution!, great first issue!

@elias-ramzi
Copy link
Contributor

I had the same issue. I fixed it using the following "import copy; self.hparams=copy.deepcopy(hparams)". This is a quick fix, but maybe it can help you.

@Uroc327
Copy link

Uroc327 commented May 12, 2020

I usually have some lines like

mparams = Namespace(**vars(params))
del mparams.logger

for this after I've created the trainer object with params.

@xiadingZ
Copy link

mparams = Namespace(**vars(params))
del mparams.logger

I try your solution, but it makes trainer can't save model. Here is my code:

def main(hparams):

    checkpoint_callback = ModelCheckpoint(
        filepath=os.getcwd(),
        monitor='val_loss',
        mode='min',
    )

    if hparams.ddp == 1:
        hparams.distributed_backend = 'ddp'
    if hparams.early_stop == 1:
        hparams.early_stop_callback=True
    hparams.img_size = 256

    trainer = pl.Trainer.from_argparse_args(hparams, checkpoint_callback=checkpoint_callback)
    mparams = Namespace(**vars(hparams))
    del mparams.checkpoint_callback
    hparams = mparams
    if hparams.model == 'base':
        model = PVM_Baseline(hparams)
    elif hparams.model == 'gcn':
        model = PVM_GCN(hparams)
    # Run learning rate finder

    if hparams.lr == 0:
        # find best lr
        hparams.lr = 0.001
        hparams.auto_lr_find='lr'
    if hparams.test == 0:
        trainer.fit(model)
    else:
        checkpoint = osp.join(hparams.checkpoint, 'best_model.ckpt')
        tags_csv = osp.join(hparams.checkpoint, 'meta_tags.csv')
        model = model.load_from_metrics(
            weights_path=checkpoint,
            tags_csv=tags_csv,
            on_gpu=True,
            map_location=None
        )
        trainer.test(model)

Also I have tried self.hparams=copy.deepcopy(hparams) in dataset.py, but it throws error:

  File "/Users/xd/code/PVM/models/baseline.py", line 207, in val_dataloader
    dataset = PVMDataset(self.hparams, 'val', self.word_dict, transform)
  File "/Users/xd/code/PVM/dataset/pvm_dataset.py", line 21, in __init__
    self.hparams = deepcopy(hparams)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 248, in _deepcopy_method
    return type(x)(x.__func__, deepcopy(x.__self__, memo))
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 169, in deepcopy
    rv = reductor(4)
TypeError: can't pickle _thread.lock objects

Is there any solutions to this problem?

@adriantre
Copy link
Contributor

I had the same issue. I fixed it using the following "import copy; self.hparams=copy.deepcopy(hparams)". This is a quick fix, but maybe it can help you.

This solved an issue I had that occurred only when returning val_loss from validation_epoch_end. The error was the same, although different traceback.
TypeError: can't pickle _thread.lock objects

Show traceback

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/lightning_run.py", line 116, in <module>
    model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 793, in fit
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 913, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 347, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 452, in run_training_epoch
    self.call_checkpoint_callback()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 789, in call_checkpoint_callback
    self.checkpoint_callback.on_validation_end(self, self.get_model())
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py", line 10, in wrapped_fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 231, in on_validation_end
    self._do_check_save(filepath, current, epoch)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 265, in _do_check_save
    self._save_model(filepath)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 142, in _save_model
    self.save_function(filepath)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 253, in save_checkpoint
    self._atomic_save(checkpoint, filepath)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 244, in _atomic_save
    torch.save(checkpoint, tmp_path)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 328, in save
    _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 401, in _legacy_save
    pickler.dump(obj)
TypeError: can't pickle _thread.lock objects

@crvineeth97
Copy link
Author

crvineeth97 commented May 14, 2020

Possible Fix

I believe that a possible fix is to change the from_argparse_args method
from

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

        params = vars(args)
        params.update(**kwargs)

        return cls(**params)

to

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

        params = vars(args)

        return cls(**params, **kwargs)

This ensures that the **kwargs are not added to the hparams of the model and the model gets saved successfully, but I'm not sure of the impact of this change

Changing just this small part of the source code worked for me

Another workaround would be to use for eg. Trainer(gpus = hparams.gpus ...), which is not suggested but will definitely work

@acxz
Copy link

acxz commented Jun 10, 2020

Just confirmed that #2029 fixes the issue for me.

@Borda Borda closed this as completed Jun 10, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

No branches or pull requests

8 participants