Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseSGLogger storage_location parameter is systematically overriden, why? #2024

Open
beatrice93 opened this issue Jun 24, 2024 · 0 comments
Open

Comments

@beatrice93
Copy link

beatrice93 commented Jun 24, 2024

馃挕 Your Question

Hi there,

Context:
I am trying to train a Yolo-NAS model and store my checkpoints remotely on an AWS s3 bucket.
Among others, I pass the following training parameters to the Trainer:

  "sg_logger": "base_sg_logger",
  "sg_logger_params": {
    "storage_location": "s3://my-bucket/models/",
    "save_checkpoints_remote": true
  },

...but to no effect. The checkpoints don't appear on my s3 bucket.

The culprit seems to be these few lines in the Trainer class definition (I removed some of the code for clarity):

def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
        """Initialize object that collect, write to disk, monitor and store remotely all training outputs"""
        sg_logger = core_utils.get_param(self.training_params, "sg_logger")

        # OVERRIDE SOME PARAMETERS TO MAKE SURE THEY MATCH THE TRAINING PARAMETERS
        general_sg_logger_params = {
            "experiment_name": self.experiment_name,
            "storage_location": "local",
            "resumed": self.load_checkpoint,
            "training_params": self.training_params,
            "checkpoints_dir_path": self.checkpoints_dir_path,
        }

        if isinstance(sg_logger, str):
            sg_logger_cls = SG_LOGGERS.get(sg_logger)
            sg_logger_params = core_utils.get_param(self.training_params, "sg_logger_params", {})
            if issubclass(sg_logger_cls, BaseSGLogger):
                sg_logger_params = {**sg_logger_params, **general_sg_logger_params}

It seems that whatever the user specifies, the storage_location parameter of the logger will be overriden to "local" whenever the logger class inherits from BaseSGLogger.

My question.
Why is this? And how do I setup my logger so it stores checkpoints and logs to s3?

Versions

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant