diff --git a/README.md b/README.md index 40869a5..645583f 100644 --- a/README.md +++ b/README.md @@ -283,7 +283,9 @@ run_training( ``` -If the machines above have shared storage, users can preprocess the training dataset a single time so that it can then be distributed to each machine with the following update: +## Example training with separate data pre-processing + +If the machines in the example above have shared storage, users can pre-process the training dataset a single time so that it can then be distributed to each machine by making the following updates. ```python from instructlab.training import ( @@ -295,6 +297,25 @@ from instructlab.training import ( data_process as dp ) +training_args = TrainingArgs( + # define data-specific arguments + model_path = "ibm-granite/granite-7b-base", + data_path = "path/to/dataset.jsonl", + ckpt_output_dir = "data/saved_checkpoints", + data_output_dir = "data/outputs", + + # define model-trianing parameters + max_seq_len = 4096, + max_batch_len = 60000, + num_epochs = 10, + effective_batch_size = 3840, + save_samples = 250000, + learning_rate = 2e-6, + warmup_steps = 800, + is_padding_free = True, # set this to true when using Granite-based models + random_seed = 42, + process_data = True, +) ... data_process_args = DataProcessArgs( @@ -309,6 +330,5 @@ dp.main(data_process_args) run_training( torch_args=torchrun_args, train_args=training_args, - process_data = False ) ``` diff --git a/src/instructlab/training/__init__.py b/src/instructlab/training/__init__.py index 7b4bf5e..a2ed292 100644 --- a/src/instructlab/training/__init__.py +++ b/src/instructlab/training/__init__.py @@ -28,13 +28,9 @@ # defer import of main_ds -def run_training( - torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True -) -> None: +def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """Wrapper around the main training job that calls torchrun.""" # Local from .main_ds import run_training - return run_training( - torch_args=torch_args, train_args=train_args, process_data=process_data - ) + return run_training(torch_args=torch_args, train_args=train_args) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 03e963f..05fe479 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -199,3 +199,6 @@ class TrainingArgs(BaseModel): # https://github.com/instructlab/training/issues/28 # quantize_dtype: QuantizeDataType = QuantizeDataType.NONE lora: LoraOptions | None = None + + # This field defines whether or not data processing will occur inside of `run_training()` + process_data: Optional[bool] = True diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 710fcd5..c5cdb2b 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -635,9 +635,7 @@ def main(args): # public API -def run_training( - torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True -) -> None: +def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ Wrapper around the main training job that calls torchrun. """ @@ -647,7 +645,7 @@ def run_training( f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}" ) - if process_data: + if train_args.process_data: dp.main( DataProcessArgs( # XXX(osilkin): make a decision here, either: