Skip to content

Commit

Permalink
fixing linting errors
Browse files Browse the repository at this point in the history
Signed-off-by: aldo pareja-cardona <aldo.pareja@ibm.com>
  • Loading branch information
aldo-pareja committed Sep 18, 2024
1 parent fc6db23 commit 0b4d516
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 36 deletions.
16 changes: 4 additions & 12 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def check_valid_sample(
system_tk: int,
assistant_tk: int,
user_tk: int,
eos_tk: int,
eos_tk: list[int],
max_len: int = 1024,
):
if len(whole_sentence_tk) >= max_len or len(whole_sentence_tk) < 20:
Expand Down Expand Up @@ -136,17 +136,14 @@ def find_longest_match(start_idx, sequences):
in_pretraining = True
i += 1
continue
elif sentence_tk[i] == pretrain_end_token:
if sentence_tk[i] == pretrain_end_token:
in_pretraining = False
i += 1
continue

match = find_longest_match(i, [user_tokens, assist_tokens, system_tokens])
if match:
if match == assist_tokens:
unmasking = True
else:
unmasking = False
unmasking = match == assist_tokens
i += len(match)
continue

Expand Down Expand Up @@ -185,8 +182,7 @@ def find_longest_match(start_idx, sequences):

# 3. The labels have to be aligned with the sentence_tk unless they are masked
assert all(
label == -100 or label == token
for label, token in zip(final_labels, final_sentence_tk)
label in (token, -100) for label, token in zip(final_labels, final_sentence_tk)
), "Labels are not aligned with sentence tokens"

return {"labels": final_labels, "input_ids": final_sentence_tk}
Expand Down Expand Up @@ -429,16 +425,12 @@ def main(args: DataProcessArgs):
print_masked_samples(
data_with_labels,
tokenizer,
pad_tk,
SPECIAL_TOKENS.pad,
is_pretrain=True,
num_proc=NUM_PROC,
)
print_masked_samples(
data_with_labels,
tokenizer,
pad_tk,
SPECIAL_TOKENS.pad,
is_pretrain=False,
num_proc=NUM_PROC,
)
Expand Down
16 changes: 5 additions & 11 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

# Standard
from copy import deepcopy
from datetime import timedelta
from functools import partial
from pathlib import Path
import argparse
import math
Expand All @@ -13,31 +11,29 @@
import time

# Third Party
from accelerate import Accelerator
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.runtime.zero.utils import ZeRORuntimeException
import torch.distributed

# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from torch.distributed import ReduceOp, all_reduce
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
from accelerate import Accelerator
import torch
import torch.distributed

# First Party
from instructlab.training import config
from instructlab.training.setup_accelerator import setup_accelerator
from instructlab.training.async_logger import AsyncStructuredLogger
from instructlab.training.config import (
from instructlab.training.config import ( # DeepSpeedOptions,
DataProcessArgs,
# DeepSpeedOptions,
TorchrunArgs,
TrainingArgs,
)
from instructlab.training.multipack_sampler import (
find_packing_max_batch_len_and_grad_accum,
)
from instructlab.training.setup_accelerator import setup_accelerator
from instructlab.training.token_dataset import setup_dataloader, setup_dataset
from instructlab.training.tokenizer_utils import setup_tokenizer
from instructlab.training.utils import (
Expand All @@ -46,7 +42,6 @@
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
patch_target_module,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
retrieve_chat_template,
Expand Down Expand Up @@ -353,7 +348,6 @@ def train(
torch.tensor([batch.pop("num_loss_counted_tokens")])
)
micro_batch_size = float(len(batch["input_ids"]))
samples_seen += micro_batch_size
if not args.is_granite:
for k in batch:
batch[k] = batch[k].to(local_rank)
Expand All @@ -375,7 +369,7 @@ def train(
reduction="sum",
),
)
samples_seen += micro_batch_size
samples_seen += int(micro_batch_size)

# num_loss_counted_tokens = aggregated_values[0]
loss = (
Expand Down
20 changes: 11 additions & 9 deletions src/instructlab/training/setup_accelerator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
# Standard
from functools import partial
import torch
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
from torch.distributed.fsdp import (
# FullyShardedDataParallel as FSDP,
MixedPrecision,

# Third Party
from accelerate import Accelerator
from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP,
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import torch

from accelerate import Accelerator

# First Party
from instructlab.training.config import DeepSpeedOptions
from instructlab.training.utils import get_module_class_from_name, patch_target_module


def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions):
# Third Party
from accelerate.utils import DeepSpeedPlugin

ds_config = {
Expand Down Expand Up @@ -51,6 +52,7 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption


def get_fsdp_config(args, model):
# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin

block_name = model._no_split_modules[0]
Expand Down
3 changes: 0 additions & 3 deletions src/instructlab/training/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from instructlab.training.utils import log_rank_0


from dataclasses import dataclass, field


@dataclass
class TokenInfo:
token: str
Expand Down
1 change: 0 additions & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@ def apply_gradient_checkpointing(
model: torch.nn.Module,
**kwargs,
) -> None:

def block_checkpointing(
model: torch.nn.Module,
block_name: str,
Expand Down

0 comments on commit 0b4d516

Please sign in to comment.