Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

give "validation sanity check" flag for "validation_epoch_end" & "validation_step" #1391

Closed
davinnovation opened this issue Apr 6, 2020 · 12 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on

Comments

@davinnovation
Copy link
Contributor

davinnovation commented Apr 6, 2020

🚀 Feature

Motivation

When using some custom saver, logger in validation function (validation_epoch_end, validation_step), with Trainer.fit(), it always execute validation sanity check so mess log comes out.

Pitch

def validation_step(self, batch, batch_nb, sanity_check):
   if sanity_check:
      ...
def validation_epoch_end(self, outputs, sanity_check):
   if sanity_check:
      ...

or

def validation_step(self, batch, batch_nb):
   if self.sanity_check:
      ...
def validation_epoch_end(self, outputs):
   if self.sanity_check:
      ...

Alternatives

None

Additional context

None

@davinnovation davinnovation added feature Is an improvement or enhancement help wanted Open to be worked on labels Apr 6, 2020
@awaelchli
Copy link
Member

This could be addressed with the Trainer states: #1633

@stale
Copy link

stale bot commented Jul 3, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jul 3, 2020
@stale stale bot closed this as completed Jul 12, 2020
@ZhaofengWu
Copy link
Contributor

ZhaofengWu commented Dec 23, 2020

@awaelchli I'm looking at #1633 and the merged PR #2541. What's the training state that corresponds to this usage? Is it TrainerState.INITIALIZING? Though it looks like fit starts with TrainerState.RUNNING, while the sanity check happens within fit. Does this mean there's still no way to do this right now?

@ZhaofengWu
Copy link
Contributor

Nvm, found this flag trainer.running_sanity_check.

@noamzilo
Copy link
Contributor

Is there a way to disable .log and .logger for the sanity step using the framework in an elegant way?

@awaelchli
Copy link
Member

@noamzilo how about we provide logger.enable() and logger.disable() methods?
Then the user could call these for example in the LightningModule, to temporarily disable logging.
For example:

def on sanity_check_start(self):
    self.logger.disable()

def on_sanity_check_end(self):
    self.logger.enable() 

@noamzilo
Copy link
Contributor

noamzilo commented Jan 18, 2021

@noamzilo how about we provide logger.enable() and logger.disable() methods?
Then the user could call these for example in the LightningModule, to temporarily disable logging.
For example:

def on sanity_check_start(self):
    self.logger.disable()

def on_sanity_check_end(self):
    self.logger.enable() 

sounds great :)
what about .log?


In this opportunity, I would like to raise a concern I found, not sure if I am doing something wrong:

2
1

These are plotted using

tb.add_scalars("losses", {"val_loss": loss}, global_step=self.current_epoch)
tb.add_scalars("losses", {"train_loss": loss}, global_step=self.current_epoch)

(tb being self.logger.experiment)

with

trainer = Trainer(
            overfit_batches=True,
            logger=TensorBoardLogger(...)
    )

and a single sample per batch.

This configuration should mean that the training loss should always equal the validation loss.

However, as the graphs show (and I validated this with the debugger), the training loss lags by 1 epoch after the val loss, which appears as if the train loss is LARGER than the val loss, which raised my attention.
They are exactly equal, but with a lag of 1.

using tb.add_scalars("losses", {"train_loss": loss}, global_step=self.current_epoch - 1) solves this, but I doubt that was the original designers intention.

Am I doing something wrong? Did I find a bug?

Initially, I thought this was due to the sanity epoch, but it seems this isn't the case.

This happens on any data set I tried.

@AtomScott
Copy link

AtomScott commented Feb 8, 2021

@noamzilo @awaelchli has this been implemented?

def on sanity_check_start(self):
    self.logger.disable()

def on_sanity_check_end(self):
    self.logger.enable() 

If not we shouldn't reopen this issue? This is a feature that I would use daily and also something were I had to write my own workaround so I don't mind getting my own hands dirty.

@jmerkow
Copy link

jmerkow commented Sep 14, 2021

@noamzilo @awaelchli Bumping this issue....

@ZhaofengWu how do you access this flag from a PL module?

@ananthsub
Copy link
Contributor

ananthsub commented Sep 14, 2021

you can use if self.trainer.sanity_checking inside the LightningModule

@brendanartley
Copy link

brendanartley commented Jul 11, 2023

It seems that the trainer.sanity_checking variable is not accessible in a callback?

class CustomCallback(pl.Callback):
  def __init__(self):
    super().__init__()

  def on_validation_epoch_end(self, trainer, module):
    if not trainer.sanity_checking:
        return
    else:
        # do something here

This throws the following error.

AttributeError: 'Trainer' object has no attribute 'running_sanity_check'

@klieret
Copy link
Contributor

klieret commented Jul 24, 2023

Are you sure you didn't mistype trainer.running_sanity_check instead of trainer.sanity_checking? I can access the variable just as in your code snippet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

9 participants