Skip to content

Commit

Permalink
Change old padding-free and granite flags to use_dolomite
Browse files Browse the repository at this point in the history
Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
  • Loading branch information
Maxusmusti committed Oct 8, 2024
1 parent 9de3409 commit a6f4f22
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class TrainingArgs(BaseModel):
save_samples: int
learning_rate: float
warmup_steps: int
is_padding_free: bool
use_dolomite: bool
random_seed: int = 42
checkpoint_at_epoch: bool = True
accelerate_full_state_at_epoch: bool = True
Expand Down
36 changes: 18 additions & 18 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
add_noisy_embeddings,
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
ensure_loadable_dolomite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
Expand Down Expand Up @@ -105,13 +105,13 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
}
if not args.disable_flash_attn:
base_model_args["attn_implementation"] = "flash_attention_2"
elif args.is_granite:
elif args.use_dolomite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(
if args.use_dolomite:
with ensure_loadable_dolomite_checkpoint(
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
Expand Down Expand Up @@ -169,7 +169,7 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
"GraniteForCausalLM",
], f"Model class name: {model.__class__.__name__} is not supported."

model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite)
model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite)
model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)

# handling of gradient checkpointing
Expand Down Expand Up @@ -214,15 +214,15 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
target_modules=args.lora_target_modules,
)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
model, peft_config, gradient_checkpointing=not args.use_dolomite
)

elif not args.is_granite:
elif not args.use_dolomite:
model.gradient_checkpointing_enable()

# granite gradient checkpointing is handled uniformly
# for both lora and full here
if args.is_granite:
if args.use_dolomite:
block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
model,
Expand Down Expand Up @@ -387,7 +387,7 @@ def train(
torch.tensor([batch.pop("num_loss_counted_tokens")])
)
micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
if not args.is_granite:
if not args.use_dolomite:
for k in batch:
batch[k] = batch[k].to(local_rank)
output = model(
Expand Down Expand Up @@ -552,7 +552,7 @@ def main(args):
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not (args.is_granite or supports_flash_attention()),
is_padding=not (args.use_dolomite or supports_flash_attention()),
dataset=dataset,
seed=args.seed,
)
Expand All @@ -575,7 +575,7 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand All @@ -594,7 +594,7 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand Down Expand Up @@ -704,11 +704,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.mock_len:
command.append(f"--mock_len={train_args.mock_len}")

if train_args.is_padding_free:
command.append("--is_granite")
if train_args.use_dolomite:
command.append("--use_dolomite")

if train_args.disable_flash_attn:
if train_args.is_padding_free:
if train_args.use_dolomite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)
Expand Down Expand Up @@ -890,7 +890,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
default="SHARD_GRAD_OP",
help="Sharding strategy to be used for FSDP distributed training.",
)
parser.add_argument("--is_granite", action="store_true")
parser.add_argument("--use_dolomite", action="store_true")
parser.add_argument("--lora_r", type=int, default=0) # set to > 0 to activate lora
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.1)
Expand Down Expand Up @@ -979,7 +979,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
--save_samples=250000 \
--log_level="INFO" \
--fsdp_sharding_strategy="SHARD_GRAD_OP" \
--is_granite \
--use_dolomite \
--max_batch_len 70000 \
--seed=42
"""
4 changes: 2 additions & 2 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def setup_dataloader(
dataset: Dataset,
pad_token_id: int,
num_workers: int = 8,
is_granite=False,
use_dolomite=False,
max_batch_len=60000,
packing_max_batch_len=60000,
samples_per_gpu=None,
sampler="multipack",
seed=47,
) -> DataLoader:
collate_fn = make_collate_fn(
pad_token_id, is_granite=is_granite, max_batch_len=max_batch_len
pad_token_id, use_dolomite=use_dolomite, max_batch_len=max_batch_len
)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
Expand Down
18 changes: 9 additions & 9 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def supports_flash_attention(device_id=0):
return is_sm8x or is_sm90


def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000):
def make_collate_fn(pad_token_id, use_dolomite=False, max_batch_len=60000):
rank = int(os.environ["RANK"])
if is_granite:
if use_dolomite:

def pad_collate_fn(batch):
lens = np.array([len(item["input_ids"]) for item in batch])
Expand Down Expand Up @@ -246,11 +246,11 @@ def pad_collate_fn(batch):
return pad_collate_fn


def convert_loss_to_reduce_sum(model, is_granite=False):
def convert_loss_to_reduce_sum(model, use_dolomite=False):
"""
this is necessary because multipack changes the samples per gpu, which biases the gradients to be larger for batches with less samples but longer lengths.
"""
if is_granite:
if use_dolomite:

def get_autoregressive_language_modeling_loss(
lm_logits: torch.Tensor,
Expand Down Expand Up @@ -536,7 +536,7 @@ class UniversalCheckpointArgs:


@contextmanager
def ensure_loadable_granite_checkpoint(
def ensure_loadable_dolomite_checkpoint(
model_name_or_path: str,
tmpdir: str,
):
Expand Down Expand Up @@ -709,7 +709,7 @@ def save_hf_format_accelerate(
tokenizer,
accelerator: Accelerator,
samples_seen,
convert_granite=True,
convert_dolomite=True,
is_lora=False,
):
log_rank_0(
Expand All @@ -719,7 +719,7 @@ def save_hf_format_accelerate(
start = time.time()

final_output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}"
if args.is_granite and convert_granite:
if args.use_dolomite and convert_dolomite:
tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with
output_dir = Path(tmpdir.name)
else:
Expand All @@ -741,7 +741,7 @@ def _get_state_dict_patched(model, unwrap=False):
model_state = model.module.state_dict()

output_dir.mkdir(parents=True, exist_ok=True)
if not model.module.config.architectures and convert_granite:
if not model.module.config.architectures and convert_dolomite:
model.module.config.architectures = ["LlamaForCausalLM"]
warnings.warn(
f"Adding architectures to ckpt: {model.module.config.architectures}",
Expand All @@ -767,7 +767,7 @@ def _get_state_dict_patched(model, unwrap=False):
safe_serialization=True,
)

if args.is_granite and convert_granite and accelerator.is_main_process:
if args.use_dolomite and convert_dolomite and accelerator.is_main_process:
# export doesnt like the directory to exist
if final_output_dir.exists():
shutil.rmtree(final_output_dir)
Expand Down

0 comments on commit a6f4f22

Please sign in to comment.