Skip to content

Commit

Permalink
Merge pull request #232 from huggingface/xrsrke/precommit-s3
Browse files Browse the repository at this point in the history
precommit
  • Loading branch information
zzhhjjj authored Sep 9, 2024
2 parents 38d64fb + 1f04626 commit 97c13b0
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 31 deletions.
8 changes: 4 additions & 4 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import os
from dataclasses import dataclass, fields
from pathlib import Path
from datasets.download.streaming_download_manager import xPath
from typing import List, Optional, Type, Union

import dacite
import torch
import yaml
from dacite import from_dict
from datasets.download.streaming_download_manager import xPath
from yaml.loader import SafeLoader

from nanotron.config.lighteval_config import LightEvalConfig
Expand Down Expand Up @@ -108,6 +108,7 @@ def __post_init__(self):
if isinstance(self.s5cmd_path, str):
self.s5cmd_path = xPath(self.s5cmd_path)


@dataclass
class NanosetDatasetsArgs:
dataset_folder: Union[str, List[str]]
Expand Down Expand Up @@ -151,7 +152,6 @@ class CheckpointsArgs:
checkpoints_path: where to save the checkpoints
checkpoint_interval: how often to save the checkpoints
resume_checkpoint_path: if you want to load from a specific checkpoint path
"""

checkpoints_path: Path
Expand Down Expand Up @@ -350,15 +350,15 @@ class Config:
data_stages: Optional[List[DatasetStageArgs]] = None
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None
s3_upload : Optional[S3UploadArgs] = None
s3_upload: Optional[S3UploadArgs] = None

@classmethod
def create_empty(cls):
cls_fields = fields(cls)
return cls(**{f.name: None for f in cls_fields})

def __post_init__(self):

if self.s3_upload is not None:
self.s3_upload.__post_init__()

Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def init_random_states(parallel_config: ParallelismArgs, tp_pg: ProcessGroup):
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=tp_pg)}
)
else:
# We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
# NOTE: We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
random_states = RandomStates({})
return random_states

Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/s3_checkpoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fsspec import check_path_is_local, fs_copy, fs_open
from .s3_mover import S3Mover

__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"]
__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"]
2 changes: 1 addition & 1 deletion src/nanotron/s3_checkpoints/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def fs_open(
file: Union[str, Path],
mode="r",
):
# TODO @thomasw21: pass storage options
# TODO @thomasw21: pass storage options.
fs, path = get_filesystem_and_path(file)
with fs.open(path, mode=mode) as f:
yield f
Expand Down
6 changes: 3 additions & 3 deletions src/nanotron/s3_checkpoints/s3_mover.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from datasets.download.streaming_download_manager import xPath
from filelock import FileLock, Timeout

from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
Expand All @@ -19,7 +20,7 @@


class S3Mover:
#TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading
# TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading
"""Take care of uploading a checkpoint to S3 in the background and remove it from the disk.
Args:
Expand Down Expand Up @@ -70,7 +71,6 @@ def __init__(
self,
local_path: xPath,
s3_path: xPath,
# duplicate_checkpoint_path: Optional[xPath] = None,
post_upload_callback: Optional[callable] = None,
remove_after_upload: Optional[bool] = True,
s5cmd_numworkers: Optional[int] = None,
Expand Down Expand Up @@ -219,7 +219,7 @@ def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None):
self._warning(
f"[S3] Waiting {self.state.value}: {all_saved} / {group.size()}. Stdout: {len(stdout_lines)} end: {stdout_lines[-1:]}",
)
# sync all our saves on NCCL we could do a dist barrier later but this helps us not loosing NCCL connections down the line
# sync all our saves on NCCL we could do a dist barrier later but this helps us not losing NCCL connections down the line
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
Expand Down
9 changes: 4 additions & 5 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from pathlib import Path
from typing import Optional, cast
from datasets.download.streaming_download_manager import xPath
import os

from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open
import torch
from datasets.download.streaming_download_manager import xPath
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR
Expand All @@ -13,11 +12,11 @@
from nanotron import logging
from nanotron import optim as optim
from nanotron.config import Config
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open
from nanotron.sanity_checks import (
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
Expand All @@ -43,7 +42,7 @@
Version 1:
- serialize -> dumps every process weights in individual files
- load -> assume topology is exactly the same
- load -> assume topology is exactly the same.
"""


Expand Down
32 changes: 16 additions & 16 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
cast,
)

from nanotron.s3_checkpoints import S3Mover, check_path_is_local
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -77,6 +76,7 @@
tie_parameters,
)
from nanotron.random import set_random_seed
from nanotron.s3_checkpoints import S3Mover, check_path_is_local
from nanotron.sanity_checks import (
after_optim_step_sanity_checks,
after_tbi_sanity_checks,
Expand Down Expand Up @@ -149,14 +149,12 @@ def __init__(
data_parallel_size=self.config.parallelism.dp,
expert_parallel_size=self.config.parallelism.expert_parallel_size,
)

self.pre_init()

# Set log levels
set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging)



# Log benchmark info
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1":
log_throughput(self.config, self.parallel_context)
Expand Down Expand Up @@ -263,12 +261,11 @@ def pre_init(self):
def post_init(self):
# S3 Mover and save initial state
if self.config.s3_upload is not None:
# Only local rank 0 should upload
# NOTE: Only local rank 0 should upload
dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0)
self.s3_mover = S3Mover(
local_path=self.config.checkpoints.checkpoints_path,
s3_path=self.config.s3_upload.upload_s3_path,
# duplicate_checkpoint_path=self.config.checkpoints.resume_checkpoint_path,
remove_after_upload=self.config.s3_upload.remove_after_upload,
s5cmd_numworkers=self.config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=self.config.s3_upload.s5cmd_concurrency,
Expand Down Expand Up @@ -307,7 +304,7 @@ def post_train_step(self):
def post_training(self):
if self.s3_mover is not None:
self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg)

def _print_training_plan(self):
if hasattr(self.config, "data_stages") and self.config.data_stages is not None:
stages_info = "".join(
Expand Down Expand Up @@ -464,10 +461,10 @@ def train(
self.save_checkpoint()

dist.barrier() # let's wait for everyone before leaving

if self.config.checkpoints.save_final_state:
self.save_checkpoint()

self.post_training()

def training_step(
Expand Down Expand Up @@ -711,17 +708,21 @@ def _init_model_instance(self) -> NanotronModel:
def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel:
unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model

# Load or initialize model weights
# Load or initialize model weights
reloaded_from_checkpoint = False
if self.init_checkpoint_path is not None:
# Load from a pre existing checkpoint
# Load from a pre existing checkpoint
if check_path_is_local(self.init_checkpoint_path):
# Reload from a training checkpoint
log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0)
# Reload from a training checkpoint
log_rank(
f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0
)
self.param_shard_metadata = load_weights(
model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
reloaded_from_checkpoint=True
reloaded_from_checkpoint = True
if not reloaded_from_checkpoint:
log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0)
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
Expand Down Expand Up @@ -865,7 +866,6 @@ def post_save_checkpoint(self):
if self.s3_mover is not None:
self.s3_mover.start_uploading()


def save_checkpoint(self) -> Path:
self.pre_save_checkpoint()

Expand Down

0 comments on commit 97c13b0

Please sign in to comment.