Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create trainer class(es) from main_ds.py #20

Closed
wants to merge 7 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 38 additions & 40 deletions src/instructlab/training/trainer.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
import argparse
from pathlib import Path
# Standard

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files

Check warning on line 1 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

R0801: Similar lines in 2 files
from datetime import timedelta
from pathlib import Path
from typing import Any
import argparse
import math
import os
import re
import time
import yaml
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler, MistralForCausalLM
from torch.distributed import (
ReduceOp,
all_reduce,
)

import deepspeed
# Third Party
from deepspeed.ops.adam import FusedAdam
from torch.distributed import ReduceOp, all_reduce
from tqdm import tqdm
from transformers import AutoModelForCausalLM, MistralForCausalLM, get_scheduler

Check warning on line 15 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

W0611: Unused MistralForCausalLM imported from transformers (unused-import)
import deepspeed
import torch
import yaml

# First Party
from instructlab.training.multipack_sampler import (
find_packing_max_batch_len_and_grad_accum,
)
from instructlab.training.token_dataset import setup_dataloader, setup_dataset
from instructlab.training.tokenizer_utils import setup_tokenizer
from instructlab.training.utils import (
convert_loss_to_reduce_sum,
save_hf_format_ds,
save_model_ds_native,
set_random_seed,
setup_logger,
convert_loss_to_reduce_sum,
)


Expand All @@ -48,7 +50,6 @@
is_padding_free: bool,
seed: int,
):

self.model_name_or_path: str = model_name_or_path
self.data_path: str = data_path
self.effective_batch_size: int = effective_batch_size
Expand Down Expand Up @@ -90,9 +91,9 @@
def __init__(
self,
model_name_or_path: str,
data_loader: any, # TODO: type the DataLoader obj.
data_loader: Any, # TODO: type the DataLoader obj.
output_dir: str,
tokenizer: any, # TODO: type the tokenizer obj.
tokenizer: Any, # TODO: type the tokenizer obj.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is PreTrainedTokenizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

effective_batch_size: int,
max_batch_len: int,
samples_per_gpu: int,
Expand All @@ -112,9 +113,9 @@
):
# NOTE: this is exhausive. In review, let's decide how to split this up into config groups.
self.model_name_or_path: str = model_name_or_path
self.data_loader: any = data_loader
self.data_loader: Any = data_loader
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does setting these as Any make it so that subsequent references are typed? Or will they appear as any.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because the variables seem naked without a type

self.output_dir: str = output_dir
self.tokenizer: any = tokenizer
self.tokenizer: Any = tokenizer
self.effective_batch_size: int = effective_batch_size
self.max_batch_len: int = max_batch_len
self.samples_per_gpu: int = samples_per_gpu
Expand Down Expand Up @@ -161,6 +162,7 @@
def _setup_model(self):
bnb_config = None
if self.lora_rank > 0 and self.lora_quant_bits == "nf4":
# Third Party
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
Expand All @@ -171,6 +173,7 @@
)

if self.is_padding_free:
# Third Party
from dolomite_engine.hf_models.models import GPTDolomiteForCausalLM

model = GPTDolomiteForCausalLM.from_pretrained(
Expand Down Expand Up @@ -212,8 +215,9 @@
# in the later stanza
if self.lora_rank > 0:
# if lora
# Third Party
from peft import LoraConfig
from utils import prepare_peft_model, patch_target_module
from utils import patch_target_module, prepare_peft_model

Check failure on line 220 in src/instructlab/training/trainer.py

View workflow job for this annotation

GitHub Actions / lint

E0401: Unable to import 'utils' (import-error)

if self.lora_target_modules is None:
self.lora_target_modules = [
Expand All @@ -236,9 +240,12 @@
)

# patch DS to work with quantized models
from deepspeed import DeepSpeedEngine
# Standard
from functools import partial

# Third Party
from deepspeed import DeepSpeedEngine

if self.lora_quant_bits is not None:
patch_target_module(
"deepspeed.DeepSpeedEngine",
Expand All @@ -250,10 +257,11 @@
# granite gradient checkpointing is handled uniformly
# for both lora and full here
if self.is_padding_free:
# Third Party
from dolomite_engine.enums import GradientCheckpointingMethod
from dolomite_engine.gradient_checkpointing import (
apply_gradient_checkpointing,
)
from dolomite_engine.enums import GradientCheckpointingMethod

block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
Expand Down Expand Up @@ -293,7 +301,6 @@
return model

def _maybe_resume_training(self):

model = self.model

local_rank = int(os.environ["LOCAL_RANK"])
Expand All @@ -303,9 +310,7 @@
# so we need to disable load_module_strict
# - load checkpoint will find the latest checkpoint
# - it will also load the optimizer and scheduler states by default
load_module_strict = (
self.lora_rank == 0
) # can only be true if lora is not used
load_module_strict = self.lora_rank == 0 # can only be true if lora is not used
output_dir = Path(self.output_dir) / "ds_native"
model.load_checkpoint(output_dir, load_module_strict=load_module_strict)

Expand Down Expand Up @@ -333,14 +338,13 @@


class DeepSpeedTrainer:

def __init__(
self,
_args: dict,
model: any, #TODO type model obj
data_loader: any, # TODO: type the DataLoader obj.
model: Any, # TODO type model obj
data_loader: Any, # TODO: type the DataLoader obj.
output_dir: str,
tokenizer: any, # TODO: type the tokenizer obj.
tokenizer: Any, # TODO: type the tokenizer obj.
effective_batch_size: int,
max_batch_len: int,
samples_per_gpu: int,
Expand All @@ -362,10 +366,10 @@
lora_target_modules: list = None,
):
# NOTE: this is exhausive. In review, let's decide how to split this up into config groups.
self.model: any = model
self.data_loader: any = data_loader
self.model: Any = model
self.data_loader: Any = data_loader
self.output_dir: str = output_dir
self.tokenizer: any = tokenizer
self.tokenizer: Any = tokenizer
self.effective_batch_size: int = effective_batch_size
self.max_batch_len: int = max_batch_len
self.samples_per_gpu: int = samples_per_gpu
Expand All @@ -386,9 +390,7 @@
self.world_size = world_size
self.global_step = 1
self.batch_size = effective_batch_size // grad_accum
self.save_samples = (
save_samples // self.batch_size
) * self.batch_size
self.save_samples = (save_samples // self.batch_size) * self.batch_size
self.last_step = last_step
self._args = _args
if save_samples_ds is not None:
Expand All @@ -405,7 +407,6 @@
print(f"\033[93mNumber of samples per save: {self.save_samples}\033[0m")

def _run_epoch(self, epoch: int):

self.data_loader.batch_sampler.set_epoch(epoch)

if self.local_rank == 0:
Expand Down Expand Up @@ -465,9 +466,7 @@
):
if self.local_rank == 0:
elapsed_time = time.time() - start
overall_throughput = (
self.samples_per_gpu * self.world_size / elapsed_time
)
overall_throughput = self.samples_per_gpu * self.world_size / elapsed_time
current_lr = self.model.lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
Expand Down Expand Up @@ -523,7 +522,6 @@


def main(_args: argparse.ArgumentParser):

if os.environ["LOCAL_RANK"] == "0":
print(f"\033[38;5;120m{yaml.dump(vars(_args), sort_keys=False)}\033[0m")

Expand All @@ -536,7 +534,7 @@
data_path=_args.data_path,
effective_batch_size=_args.effective_batch_size,
max_batch_len=_args.max_batch_len,
world_size=int(os.getenv("WORLD_SIZE")),
world_size= int(os.getenv("WORLD_SIZE")),
Copy link
Contributor

@fabianlim fabianlim Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general its better to get this as torch.distributed.get_world_size(). But this requires initializing the torch.distributed first. Right now we rely on deepspeed.initialize to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we initialize distributed a bit earlier, do you suggest swapping it out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, there isn't a torch.distribted.get_local_rank()- is it typical to just use

local_rank = torch.distributed.get_rank() // torch.distributed.get_world_size()

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesKunstle No that is not reliable. Actually torch added a new method get_node_local_rank but that is only available in bleeding edge https://github.com/pytorch/pytorch/pull/123992/files, so we are out of luck. My suggestion is:

  • follow the official implementation above.
  • wrap this logic in function or patch it to torch.distributed (the latter will be better to keep the api future proof)

is_padding_free=_args.is_granite,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this introduces some coupling into the DataWrapper, since it actually doesnt really need to know if the model is granite or not. it just needs some dependency injection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DataWrapper does need to know about whether the model is padding free because the MultipackDataSampler uses that information for batch allocation w.r.t padding. Do you intuit that there's a better split for this class? The sample could be initialized elsewhere.

Copy link
Contributor

@fabianlim fabianlim Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i misunderstood. the flag is is_padding_free.

But bear in mind that it is typically not canon (at least in the huggingface world) that a dataloader is prepared for a trainer, usually a dataset is passed to a trainer.

Aslo its not canon that a prepared model is passed to the trainer (at least in huggingface world). Usually a unprepared model (i.e. not yet wrapped with deepspeed) is passed to the trainer and it does it

Hence what im saying is that I feel you really only need one class DeepSpeedTrainer.

  • and then MultipackDataWrapper and DeepSpeedModelWrapper can be converted to internal method calls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ okay @fabianlim that makes a lot of sense.

seed=_args.seed,
)
Expand Down
Loading