Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log validation metrics before training #3625

Closed
bartonp2 opened this issue Sep 23, 2020 · 8 comments
Closed

Log validation metrics before training #3625

bartonp2 opened this issue Sep 23, 2020 · 8 comments
Labels
question Further information is requested won't fix This will not be worked on

Comments

@bartonp2
Copy link

❓ Questions and Help

Is there an easy way to run a full evaluation on the the validation set before starting training. I would like this as a kind of benchmark to see where I'm starting from and if the network learns anything at all.

While #1715 allows running the sanity check on the complete validation set, this does not log any metrics.

I tried the code as recommended:

class run_validation_on_start(Callback):
    def __init__(self):
        pass

    def on_train_start(self, trainer: Trainer, pl_module):
        return trainer.run_evaluation(test_mode=False)

Originally posted by @dvirginz in #1715 (comment)

but this gives me the following error:

Traceback (most recent call last):██████████████████████████████████████████████████████████████████████████████████████████████████████████████████           | 12/13 [00:00<00:00, 10.56it/s]
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/scratch/bartonp/avamap/trainer/train.py", line 95, in <module>
    main(hparams)
  File "/scratch/bartonp/avamap/trainer/train.py", line 52, in main
    trainer.fit(model, train_loader, val_loader)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1073, in fit
    results = self.accelerator_backend.train(model)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_backend.py", line 51, in train
    results = self.trainer.run_pretrain_routine(model)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1239, in run_pretrain_routine
    self.train()
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 363, in train
    self.on_train_start()
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 111, in on_train_start
    callback.on_train_start(self, self.get_model())
  File "/scratch/bartonp/avamap/trainer/train.py", line 17, in on_train_start
    return trainer.run_evaluation(test_mode=False)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 603, in run_evaluation
    self.on_validation_end()
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 176, in on_validation_end
    callback.on_validation_end(self, self.get_model())
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py", line 27, in wrapped_fn
    return fn(*args, **kwargs)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 357, in on_validation_end
    filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 253, in format_checkpoint_name
    groups = re.findall(r'(\{.*?)[:\}]', self.filename)
  File "/scratch/bartonp/miniconda/envs/eco/lib/python3.7/re.py", line 223, in findall
    return _compile(pattern, flags).findall(string)
TypeError: expected string or bytes-like object

Is there no simple way to run and log the validation set before training?

@bartonp2 bartonp2 added the question Further information is requested label Sep 23, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@oyj0594
Copy link

oyj0594 commented Sep 23, 2020

It's because model checkpointing occur with run_evaluation .
Just use ModelCheckpoint simultaneously, or you can bypass the issue with setting filename of checkpoint inside of your function.

@S-aiueo32
Copy link
Contributor

I faced this issue too. It seems to be an important functionality in transfer learning or fine-tuning,
I roughly think it should be done by a simple flag of Trainer like:

trainer = Trainer(
    ...,
    eval_before_training=True
)

How do you think about it? @williamFalcon and .@othercontributors

@hantoine
Copy link

hantoine commented Oct 7, 2020

I agree that it would be nice to have a flag like eval_before_training to do this easily.
In the meantime, I managed to make it work for me by calling the method on_train_start of the checkpoint callback:

        class ValidationOnStartCallback(pl.callbacks.Callback):
            def on_train_start(self, trainer, pl_module):
                trainer.checkpoint_callback.on_train_start(trainer, pl_module)
                return trainer.run_evaluation(test_mode=False)

@stale
Copy link

stale bot commented Nov 6, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Nov 6, 2020
@stale stale bot closed this as completed Nov 13, 2020
@michelbotros
Copy link

Agree with @S-aiueo32 that this would be a neat solution for logging metrics before training (very useful in transfer learning). Or is there since then another neat solution to do this?

@orm011
Copy link

orm011 commented Aug 26, 2021

At least as of 1.4.4 (but maybe before)

trainer = pl.Trainer(...., logger=[tb_logger, csv_logger], callbacks=[chk])
trainer.validate(model, dataloaders=[val_loader])
trainer.fit(model, train_loader, val_loader)

Does what you want, assuming the model weights are initialized to what you need.
The metrics get printed and then also get logged.
Not sure why this use case is hard to google and get answers for, seems very helpful for fine-tuning.

@Toekan
Copy link

Toekan commented Oct 28, 2023

Hey,

Thanks for all the helpful comments here, unfortunately they do not work for me. I'm using LightningCLI, so I can't run:

trainer = pl.Trainer(...., logger=[tb_logger, csv_logger], callbacks=[chk])
trainer.validate(model, dataloaders=[val_loader])
trainer.fit(model, train_loader, val_loader)

Myself as suggested above. I've been trying to get a callback to work as suggested by @hantoine , but trainer.run_evaluation does not exist anymore (I'm using PL 2.0.9), so I've tried to update it:

class ValidationOnStartCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        trainer.checkpoint_callback.on_train_start(trainer, pl_module)
        trainer.validate(pl_module, trainer.datamodule)

but I'm running in various problems. In the example given above I'm getting:

  File ".../pytorch_lightning/trainer/connectors/data_connector.py", line 484, in _process_dataloader
    raise RuntimeError("Unexpected state")
RuntimeError: Unexpected state

as running trainer.validate changed the state of my trainer. When hacking into the state in my callback:

class ValidationOnStartCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        trainer.checkpoint_callback.on_train_start(trainer, pl_module)
        trainer.validate(pl_module, trainer.datamodule)
        trainer.state = TrainerState(
            status=TrainerStatus.RUNNING,
            fn=TrainerFn.FITTING,
            stage=RunningStage.TRAINING
        )

I'm getting the following error:

File .../pytorch_lightning/strategies/ddp.py:332, in DDPStrategy.training_step(self, *args, **kwargs)
--> 332 return self.model(*args, **kwargs)

...

TypeError: LitImageClassification.forward() takes 2 positional arguments but 3 were given

Which seems to be caused because self.model on line 332 is a lightningmodule, while when running it without the callback, the class is: torch.nn.parallel.distributed.DistributedDataParallel which takes the args and kwargs in ddp.py fine.

Did anyone get the callback approach to work? I agree with the general sentiment here that this is very useful functionality when finetuning models (a lot of the gains are often made in the first epoch of training, so having a before and after in validation performance is important) and I'm surprised how hard this turns out to be.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

7 participants