Skip to content

Commit

Permalink
move process_data arg into TrainingArgs
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Clifford <mcliffor@redhat.com>
  • Loading branch information
MichaelClifford committed Oct 5, 2024
1 parent aefde0e commit f97dca3
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ run_training(

```

If the machines above have shared storage, users can preprocess the training dataset a single time so that it can then be distributed to each machine with the following update:
## Example training with separate data pre-processing

If the machines in the example above have shared storage, users can pre-process the training dataset a single time so that it can then be distributed to each machine by making the following updates.

```python
from instructlab.training import (
Expand All @@ -295,6 +297,25 @@ from instructlab.training import (
data_process as dp
)

training_args = TrainingArgs(
# define data-specific arguments
model_path = "ibm-granite/granite-7b-base",
data_path = "path/to/dataset.jsonl",
ckpt_output_dir = "data/saved_checkpoints",
data_output_dir = "data/outputs",

# define model-trianing parameters
max_seq_len = 4096,
max_batch_len = 60000,
num_epochs = 10,
effective_batch_size = 3840,
save_samples = 250000,
learning_rate = 2e-6,
warmup_steps = 800,
is_padding_free = True, # set this to true when using Granite-based models
random_seed = 42,
process_data = True,
)
...

data_process_args = DataProcessArgs(
Expand All @@ -309,6 +330,5 @@ dp.main(data_process_args)
run_training(
torch_args=torchrun_args,
train_args=training_args,
process_data = False
)
```
8 changes: 2 additions & 6 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,9 @@


# defer import of main_ds
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""Wrapper around the main training job that calls torchrun."""
# Local
from .main_ds import run_training

return run_training(
torch_args=torch_args, train_args=train_args, process_data=process_data
)
return run_training(torch_args=torch_args, train_args=train_args)
3 changes: 3 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,6 @@ class TrainingArgs(BaseModel):
# https://github.com/instructlab/training/issues/28
# quantize_dtype: QuantizeDataType = QuantizeDataType.NONE
lora: LoraOptions | None = None

# This field defines whether or not data processing will occur inside of `run_training()`
process_data: Optional[bool] = True
6 changes: 2 additions & 4 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,7 @@ def main(args):


# public API
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
Expand All @@ -647,7 +645,7 @@ def run_training(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if process_data:
if train_args.process_data:
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
Expand Down

0 comments on commit f97dca3

Please sign in to comment.