Skip to content

Commit

Permalink
switch instructlab dolomite
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Jun 20, 2024
1 parent c6bffd4 commit 5d57574
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ datasets>=2.15.0
numba
numpy
rich
dolomite-engine @ git+https://github.com/ibm-granite/dolomite-engine.git@main
instructlab-dolomite @ git+https://github.com/instructlab/GPTDolomite.git@initial

trl==0.9.4
peft
pydantic>=2.7.0
Expand Down
8 changes: 3 additions & 5 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
setup_logger,
)
import instructlab.training.data_process as dp
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from instructlab.dolomite.enums import GradientCheckpointingMethod
from instructlab.dolomite.gradient_checkpointing import apply_gradient_checkpointing


def get_ds_config(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions):
Expand Down Expand Up @@ -88,8 +91,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
)

if args.is_granite:
# Third Party
from dolomite_engine.hf_models.models import GPTDolomiteForCausalLM

model = GPTDolomiteForCausalLM.from_pretrained(
args.model_name_or_path,
Expand Down Expand Up @@ -201,9 +202,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
# granite gradient checkpointing is handled uniformly
# for both lora and full here
if args.is_granite:
# Third Party
from dolomite_engine.enums import GradientCheckpointingMethod
from dolomite_engine.gradient_checkpointing import apply_gradient_checkpointing

block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import warnings

# Third Party
from instructlab.dolomite.hf_models import export_to_huggingface
from rich.logging import RichHandler
from torch import distributed as dist
from torch.distributed import get_rank, is_initialized
Expand Down Expand Up @@ -539,7 +540,6 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True
from tempfile import TemporaryDirectory

# Third Party
from dolomite_engine.hf_models import export_to_huggingface
from safetensors.torch import save_file

with TemporaryDirectory("w") as tmpdir:
Expand Down

0 comments on commit 5d57574

Please sign in to comment.