Skip to content

Commit

Permalink
Merge pull request #220 from MichaelClifford/data_processing
Browse files Browse the repository at this point in the history
Make data processing optional in run_training()
  • Loading branch information
mergify[bot] authored Oct 7, 2024
2 parents 2273f08 + f97dca3 commit 99e833a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 17 deletions.
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,55 @@ run_training(
torchrun_args=torchrun_args,
training_args=training_args,
)

```

## Example training with separate data pre-processing

If the machines in the example above have shared storage, users can pre-process the training dataset a single time so that it can then be distributed to each machine by making the following updates.

```python
from instructlab.training import (
run_training,
TorchrunArgs,
TrainingArgs,
DeepSpeedOptions,
DataProcessArgs,
data_process as dp
)

training_args = TrainingArgs(
# define data-specific arguments
model_path = "ibm-granite/granite-7b-base",
data_path = "path/to/dataset.jsonl",
ckpt_output_dir = "data/saved_checkpoints",
data_output_dir = "data/outputs",

# define model-trianing parameters
max_seq_len = 4096,
max_batch_len = 60000,
num_epochs = 10,
effective_batch_size = 3840,
save_samples = 250000,
learning_rate = 2e-6,
warmup_steps = 800,
is_padding_free = True, # set this to true when using Granite-based models
random_seed = 42,
process_data = True,
)
...

data_process_args = DataProcessArgs(
data_output_path = training_args.data_output_dir,
model_path = training_args.model_path,
data_path = training_args.data_path,
max_seq_len = training_args.max_seq_len,
chat_tmpl_path = training_args.chat_tmpl_path
)

dp.main(data_process_args)
run_training(
torch_args=torchrun_args,
train_args=training_args,
)
```
3 changes: 3 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,6 @@ class TrainingArgs(BaseModel):
# https://github.com/instructlab/training/issues/28
# quantize_dtype: QuantizeDataType = QuantizeDataType.NONE
lora: LoraOptions | None = None

# This field defines whether or not data processing will occur inside of `run_training()`
process_data: Optional[bool] = True
2 changes: 2 additions & 0 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def get_masked_and_orig_text(sample):


def main(args: DataProcessArgs):
if not os.path.exists(args.data_output_path):
os.makedirs(args.data_output_path, exist_ok=True)
print("\033[92m data arguments are:\033[0m")
print("\033[36m" + args.model_dump_json() + "\033[0m")
NUM_PROC = args.num_cpu_procs
Expand Down
32 changes: 15 additions & 17 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,24 +645,22 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

# process the training data
if not os.path.exists(train_args.data_output_dir):
os.makedirs(train_args.data_output_dir, exist_ok=True)
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
# 1. the CLI is fully responsible for managing where the data is written
# 2. we never cache it and simply write it to a tmp file every time.
#
# An important reason for why #1 would be preferable is in the case of OpenShift/SELinux
# where the user has a defined place for new temporary data to be written.
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
if train_args.process_data:
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
# 1. the CLI is fully responsible for managing where the data is written
# 2. we never cache it and simply write it to a tmp file every time.
#
# An important reason for why #1 would be preferable is in the case of OpenShift/SELinux
# where the user has a defined place for new temporary data to be written.
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
)
)
)

if not os.path.exists(train_args.ckpt_output_dir):
os.makedirs(train_args.ckpt_output_dir, exist_ok=True)
Expand Down

0 comments on commit 99e833a

Please sign in to comment.