-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
create trainer class(es) from main_ds.py
#20
Changes from 1 commit
c5fa2f0
6a479a5
be9edc3
9573786
c5702bf
a167c4c
85a3730
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,34 @@ | ||
import argparse | ||
from pathlib import Path | ||
# Standard | ||
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
Check warning on line 1 in src/instructlab/training/trainer.py GitHub Actions / lint
|
||
from datetime import timedelta | ||
from pathlib import Path | ||
from typing import Any | ||
import argparse | ||
import math | ||
import os | ||
import re | ||
import time | ||
import yaml | ||
import torch | ||
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM, get_scheduler, MistralForCausalLM | ||
from torch.distributed import ( | ||
ReduceOp, | ||
all_reduce, | ||
) | ||
|
||
import deepspeed | ||
# Third Party | ||
from deepspeed.ops.adam import FusedAdam | ||
from torch.distributed import ReduceOp, all_reduce | ||
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM, MistralForCausalLM, get_scheduler | ||
import deepspeed | ||
import torch | ||
import yaml | ||
|
||
# First Party | ||
from instructlab.training.multipack_sampler import ( | ||
find_packing_max_batch_len_and_grad_accum, | ||
) | ||
from instructlab.training.token_dataset import setup_dataloader, setup_dataset | ||
from instructlab.training.tokenizer_utils import setup_tokenizer | ||
from instructlab.training.utils import ( | ||
convert_loss_to_reduce_sum, | ||
save_hf_format_ds, | ||
save_model_ds_native, | ||
set_random_seed, | ||
setup_logger, | ||
convert_loss_to_reduce_sum, | ||
) | ||
|
||
|
||
|
@@ -48,7 +50,6 @@ | |
is_padding_free: bool, | ||
seed: int, | ||
): | ||
|
||
self.model_name_or_path: str = model_name_or_path | ||
self.data_path: str = data_path | ||
self.effective_batch_size: int = effective_batch_size | ||
|
@@ -90,9 +91,9 @@ | |
def __init__( | ||
self, | ||
model_name_or_path: str, | ||
data_loader: any, # TODO: type the DataLoader obj. | ||
data_loader: Any, # TODO: type the DataLoader obj. | ||
output_dir: str, | ||
tokenizer: any, # TODO: type the tokenizer obj. | ||
tokenizer: Any, # TODO: type the tokenizer obj. | ||
effective_batch_size: int, | ||
max_batch_len: int, | ||
samples_per_gpu: int, | ||
|
@@ -112,9 +113,9 @@ | |
): | ||
# NOTE: this is exhausive. In review, let's decide how to split this up into config groups. | ||
self.model_name_or_path: str = model_name_or_path | ||
self.data_loader: any = data_loader | ||
self.data_loader: Any = data_loader | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does setting these as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's because the variables seem naked without a type |
||
self.output_dir: str = output_dir | ||
self.tokenizer: any = tokenizer | ||
self.tokenizer: Any = tokenizer | ||
self.effective_batch_size: int = effective_batch_size | ||
self.max_batch_len: int = max_batch_len | ||
self.samples_per_gpu: int = samples_per_gpu | ||
|
@@ -161,6 +162,7 @@ | |
def _setup_model(self): | ||
bnb_config = None | ||
if self.lora_rank > 0 and self.lora_quant_bits == "nf4": | ||
# Third Party | ||
from transformers import BitsAndBytesConfig | ||
|
||
bnb_config = BitsAndBytesConfig( | ||
|
@@ -171,6 +173,7 @@ | |
) | ||
|
||
if self.is_padding_free: | ||
# Third Party | ||
from dolomite_engine.hf_models.models import GPTDolomiteForCausalLM | ||
|
||
model = GPTDolomiteForCausalLM.from_pretrained( | ||
|
@@ -212,8 +215,9 @@ | |
# in the later stanza | ||
if self.lora_rank > 0: | ||
# if lora | ||
# Third Party | ||
from peft import LoraConfig | ||
from utils import prepare_peft_model, patch_target_module | ||
from utils import patch_target_module, prepare_peft_model | ||
|
||
if self.lora_target_modules is None: | ||
self.lora_target_modules = [ | ||
|
@@ -236,9 +240,12 @@ | |
) | ||
|
||
# patch DS to work with quantized models | ||
from deepspeed import DeepSpeedEngine | ||
# Standard | ||
from functools import partial | ||
|
||
# Third Party | ||
from deepspeed import DeepSpeedEngine | ||
|
||
if self.lora_quant_bits is not None: | ||
patch_target_module( | ||
"deepspeed.DeepSpeedEngine", | ||
|
@@ -250,10 +257,11 @@ | |
# granite gradient checkpointing is handled uniformly | ||
# for both lora and full here | ||
if self.is_padding_free: | ||
# Third Party | ||
from dolomite_engine.enums import GradientCheckpointingMethod | ||
from dolomite_engine.gradient_checkpointing import ( | ||
apply_gradient_checkpointing, | ||
) | ||
from dolomite_engine.enums import GradientCheckpointingMethod | ||
|
||
block_name = model._no_split_modules[0] | ||
apply_gradient_checkpointing( | ||
|
@@ -293,7 +301,6 @@ | |
return model | ||
|
||
def _maybe_resume_training(self): | ||
|
||
model = self.model | ||
|
||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
|
@@ -303,9 +310,7 @@ | |
# so we need to disable load_module_strict | ||
# - load checkpoint will find the latest checkpoint | ||
# - it will also load the optimizer and scheduler states by default | ||
load_module_strict = ( | ||
self.lora_rank == 0 | ||
) # can only be true if lora is not used | ||
load_module_strict = self.lora_rank == 0 # can only be true if lora is not used | ||
output_dir = Path(self.output_dir) / "ds_native" | ||
model.load_checkpoint(output_dir, load_module_strict=load_module_strict) | ||
|
||
|
@@ -333,14 +338,13 @@ | |
|
||
|
||
class DeepSpeedTrainer: | ||
|
||
def __init__( | ||
self, | ||
_args: dict, | ||
model: any, #TODO type model obj | ||
data_loader: any, # TODO: type the DataLoader obj. | ||
model: Any, # TODO type model obj | ||
data_loader: Any, # TODO: type the DataLoader obj. | ||
output_dir: str, | ||
tokenizer: any, # TODO: type the tokenizer obj. | ||
tokenizer: Any, # TODO: type the tokenizer obj. | ||
effective_batch_size: int, | ||
max_batch_len: int, | ||
samples_per_gpu: int, | ||
|
@@ -362,10 +366,10 @@ | |
lora_target_modules: list = None, | ||
): | ||
# NOTE: this is exhausive. In review, let's decide how to split this up into config groups. | ||
self.model: any = model | ||
self.data_loader: any = data_loader | ||
self.model: Any = model | ||
self.data_loader: Any = data_loader | ||
self.output_dir: str = output_dir | ||
self.tokenizer: any = tokenizer | ||
self.tokenizer: Any = tokenizer | ||
self.effective_batch_size: int = effective_batch_size | ||
self.max_batch_len: int = max_batch_len | ||
self.samples_per_gpu: int = samples_per_gpu | ||
|
@@ -386,9 +390,7 @@ | |
self.world_size = world_size | ||
self.global_step = 1 | ||
self.batch_size = effective_batch_size // grad_accum | ||
self.save_samples = ( | ||
save_samples // self.batch_size | ||
) * self.batch_size | ||
self.save_samples = (save_samples // self.batch_size) * self.batch_size | ||
self.last_step = last_step | ||
self._args = _args | ||
if save_samples_ds is not None: | ||
|
@@ -405,7 +407,6 @@ | |
print(f"\033[93mNumber of samples per save: {self.save_samples}\033[0m") | ||
|
||
def _run_epoch(self, epoch: int): | ||
|
||
self.data_loader.batch_sampler.set_epoch(epoch) | ||
|
||
if self.local_rank == 0: | ||
|
@@ -465,9 +466,7 @@ | |
): | ||
if self.local_rank == 0: | ||
elapsed_time = time.time() - start | ||
overall_throughput = ( | ||
self.samples_per_gpu * self.world_size / elapsed_time | ||
) | ||
overall_throughput = self.samples_per_gpu * self.world_size / elapsed_time | ||
current_lr = self.model.lr_scheduler.get_last_lr()[0] | ||
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) | ||
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] | ||
|
@@ -523,7 +522,6 @@ | |
|
||
|
||
def main(_args: argparse.ArgumentParser): | ||
|
||
if os.environ["LOCAL_RANK"] == "0": | ||
print(f"\033[38;5;120m{yaml.dump(vars(_args), sort_keys=False)}\033[0m") | ||
|
||
|
@@ -536,7 +534,7 @@ | |
data_path=_args.data_path, | ||
effective_batch_size=_args.effective_batch_size, | ||
max_batch_len=_args.max_batch_len, | ||
world_size=int(os.getenv("WORLD_SIZE")), | ||
world_size= int(os.getenv("WORLD_SIZE")), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in general its better to get this as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we initialize distributed a bit earlier, do you suggest swapping it out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, there isn't a torch.distribted.get_local_rank()- is it typical to just use
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JamesKunstle No that is not reliable. Actually
|
||
is_padding_free=_args.is_granite, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this introduces some coupling into the DataWrapper, since it actually doesnt really need to know if the model is granite or not. it just needs some dependency injection There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The DataWrapper does need to know about whether the model is padding free because the MultipackDataSampler uses that information for batch allocation w.r.t padding. Do you intuit that there's a better split for this class? The sample could be initialized elsewhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok i misunderstood. the flag is But bear in mind that it is typically not canon (at least in the huggingface world) that a dataloader is prepared for a trainer, usually a dataset is passed to a trainer. Aslo its not canon that a prepared model is passed to the trainer (at least in huggingface world). Usually a unprepared model (i.e. not yet wrapped with deepspeed) is passed to the trainer and it does it Hence what im saying is that I feel you really only need one class
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ++ okay @fabianlim that makes a lot of sense. |
||
seed=_args.seed, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type is
PreTrainedTokenizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++