diff --git a/README.md b/README.md index f57937e..645583f 100644 --- a/README.md +++ b/README.md @@ -280,3 +280,55 @@ run_training( torchrun_args=torchrun_args, training_args=training_args, ) + +``` + +## Example training with separate data pre-processing + +If the machines in the example above have shared storage, users can pre-process the training dataset a single time so that it can then be distributed to each machine by making the following updates. + +```python +from instructlab.training import ( + run_training, + TorchrunArgs, + TrainingArgs, + DeepSpeedOptions, + DataProcessArgs, + data_process as dp +) + +training_args = TrainingArgs( + # define data-specific arguments + model_path = "ibm-granite/granite-7b-base", + data_path = "path/to/dataset.jsonl", + ckpt_output_dir = "data/saved_checkpoints", + data_output_dir = "data/outputs", + + # define model-trianing parameters + max_seq_len = 4096, + max_batch_len = 60000, + num_epochs = 10, + effective_batch_size = 3840, + save_samples = 250000, + learning_rate = 2e-6, + warmup_steps = 800, + is_padding_free = True, # set this to true when using Granite-based models + random_seed = 42, + process_data = True, +) +... + +data_process_args = DataProcessArgs( + data_output_path = training_args.data_output_dir, + model_path = training_args.model_path, + data_path = training_args.data_path, + max_seq_len = training_args.max_seq_len, + chat_tmpl_path = training_args.chat_tmpl_path +) + +dp.main(data_process_args) +run_training( + torch_args=torchrun_args, + train_args=training_args, +) +``` diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 03e963f..05fe479 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -199,3 +199,6 @@ class TrainingArgs(BaseModel): # https://github.com/instructlab/training/issues/28 # quantize_dtype: QuantizeDataType = QuantizeDataType.NONE lora: LoraOptions | None = None + + # This field defines whether or not data processing will occur inside of `run_training()` + process_data: Optional[bool] = True diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 2e6cd39..4bd7c78 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -221,6 +221,8 @@ def get_masked_and_orig_text(sample): def main(args: DataProcessArgs): + if not os.path.exists(args.data_output_path): + os.makedirs(args.data_output_path, exist_ok=True) print("\033[92m data arguments are:\033[0m") print("\033[36m" + args.model_dump_json() + "\033[0m") NUM_PROC = args.num_cpu_procs diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 07684a9..c5cdb2b 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -645,24 +645,22 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}" ) - # process the training data - if not os.path.exists(train_args.data_output_dir): - os.makedirs(train_args.data_output_dir, exist_ok=True) - dp.main( - DataProcessArgs( - # XXX(osilkin): make a decision here, either: - # 1. the CLI is fully responsible for managing where the data is written - # 2. we never cache it and simply write it to a tmp file every time. - # - # An important reason for why #1 would be preferable is in the case of OpenShift/SELinux - # where the user has a defined place for new temporary data to be written. - data_output_path=train_args.data_output_dir, - model_path=train_args.model_path, - data_path=train_args.data_path, - max_seq_len=train_args.max_seq_len, - chat_tmpl_path=train_args.chat_tmpl_path, + if train_args.process_data: + dp.main( + DataProcessArgs( + # XXX(osilkin): make a decision here, either: + # 1. the CLI is fully responsible for managing where the data is written + # 2. we never cache it and simply write it to a tmp file every time. + # + # An important reason for why #1 would be preferable is in the case of OpenShift/SELinux + # where the user has a defined place for new temporary data to be written. + data_output_path=train_args.data_output_dir, + model_path=train_args.model_path, + data_path=train_args.data_path, + max_seq_len=train_args.max_seq_len, + chat_tmpl_path=train_args.chat_tmpl_path, + ) ) - ) if not os.path.exists(train_args.ckpt_output_dir): os.makedirs(train_args.ckpt_output_dir, exist_ok=True)