Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid duplicate logging of hyperparameters with Wandb #13899

Closed
remivanbel opened this issue Jul 28, 2022 · 3 comments · Fixed by #17574
Closed

Avoid duplicate logging of hyperparameters with Wandb #13899

remivanbel opened this issue Jul 28, 2022 · 3 comments · Fixed by #17574
Labels
bug Something isn't working logger: wandb Weights & Biases won't fix This will not be worked on

Comments

@remivanbel
Copy link

remivanbel commented Jul 28, 2022

🚀 Feature

Change the flattening format of the hyperparameter dicts for the WandbLogger, such that it is the same as the one used by wandb for logging its configuration.

Motivation

At the moment, I have the problem that certain hyperparameters are logged twice in wandb.
image

Because I want to use the hyperparameter sweep capabilities of wandb, I have to first get my configs from wandb, and then feed this
to my lightning trainer. However, due to this setup, the hyperparameters are saved twice: once during initialization of the WandbLogger, and once during initialization of the Trainer (save_hyperparameters):

from parameters import config
wandb_logger = WandbLogger(project='my-project', config=config)
config = wandb.config.as_dict()
model = MyTrainer(**config)

with:

class MyTrainer(pl.LightningModule)
  def __init__(self, model_name, model_hparams, **config):
		  """
		  Inputs:
			  model_name - Name of the model/CNN to run. Used for creating the model (see function below)
			  model_hparams - Hyperparameters for the model, as dictionary.
		  """
		  super().__init__()
		  # Create model
		  self.model = create_model(model_name, model_hparams)
		  # Create loss module		  
		  self.loss_module = nn.MSELoss()

		  self.save_hyperparameters()

Pitch

I think the most elegent fix would be to make the delimiter a '.' when flattening the hyperparameters. Therefore the keys should be the same as the ones already used by Wandb.

params = _flatten_dict(params, delimiter=".")

https://github.com/Lightning-AI/lightning/blob/511875e5675b0543d89e2aae3950a7834b35238e/src/pytorch_lightning/loggers/wandb.py#L374-L379

Alternatives

The alternative would be not to log your hyperparameters to the logger using self.save_hyperparameters(logger=False) because all hyperparameters should already be present in the config provided to the WandbLogger. However, it is more fail save to log any additional hyperparameters present.

cc @awaelchli @morganmcg1 @borisdayma @scottire @parambharat @AyushExel @manangoel99

@remivanbel remivanbel added the needs triage Waiting to be triaged by maintainers label Jul 28, 2022
@akihironitta akihironitta added bug Something isn't working logger: wandb Weights & Biases and removed needs triage Waiting to be triaged by maintainers labels Jul 29, 2022
@manangoel99
Copy link
Contributor

Hi @remivanbel ! Engineer from W&B here. I think adding a flatten option like this could possibly lead to some unexpected behaviour and might cause changes to other workflows because we would be manually changing the config dict keys.

I think for your use case, adding trainer.logger.experiment.config.update(config) might work

@remivanbel
Copy link
Author

Hi @manangoel99 , it is an honor that my issue got noticed by a W&B engineer ;)

I agree that letting wandb handle the flattening is a more elegant solution (also if for example in some distant future wandb would use an other delimiter). This does make me wonder why lighting still does all this conversions to params before feeding it to self.experiment.config.update(). Would it be possible to directly feed the parameters to the logger with something like this? (I don't know if converting and sanitizing is handled by wandb)

@rank_zero_only 
 def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 
     self.experiment.config.update(params, allow_val_change=True) 

Whether allow_val_change should be True of False, I have no opinion, I think False would be safer but I guess there is a good reason to make it True.

The reason I would like to see a change in the log_hyperparameters method, is because even with this solution I would still need to call

self.save_hyperparameters(logger=False)
self.logger.experiment.config.update(config)

if I also want to have access to self.hparams besides logging these hparams.

Many thanks already btw for your suggested solution, its quick and simple which is always nice!

@stale
Copy link

stale bot commented Apr 15, 2023

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working logger: wandb Weights & Biases won't fix This will not be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants