From 17482ca0ad68ffce0fe87d6a2265c59252342ce2 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 3 Apr 2024 09:13:43 -0700 Subject: [PATCH] fix remote file naming --- composer/callbacks/memory_snapshot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/callbacks/memory_snapshot.py b/composer/callbacks/memory_snapshot.py index 9ea21a0a69..fa2cd69137 100644 --- a/composer/callbacks/memory_snapshot.py +++ b/composer/callbacks/memory_snapshot.py @@ -76,7 +76,7 @@ def __init__( max_entries: int = 100000, folder: str = '{run_name}/torch_traces', filename: str = 'rank{rank}.{batch}.memory_snapshot', - remote_file_name: Optional[str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.memory_snapshot', + remote_file_name: Optional[str] = '{run_name}/torch_memory_traces', overwrite: bool = False, ) -> None: self.batches_left_to_skip = skip_batches @@ -179,7 +179,7 @@ def export_memory_snapshot(self, state: State, logger: Logger) -> None: if self.remote_path_in_bucket is not None: for f in [snapshot_file, trace_plot_file]: - remote_file_name = (self.remote_path_in_bucket + os.path.basename(f)).lstrip('/') + remote_file_name = os.path.join(self.remote_path_in_bucket, os.path.basename(f)).lstrip('/') log.info(f'Uploading memory snapshot to remote: {remote_file_name} from {f}') try: logger.upload_file(remote_file_name=remote_file_name, file_path=f, overwrite=self.overwrite)