Skip to content

Commit

Permalink
debug trainer, datawrangler, modelwrangler
Browse files Browse the repository at this point in the history
Signed-off-by: James Kunstle <jkunstle@redhat.com>
  • Loading branch information
JamesKunstle committed Jun 14, 2024
1 parent c5fa2f0 commit d3be728
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
165 changes: 165 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
accelerate==0.31.0
aiohttp==3.9.5
aiosignal==1.3.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.4.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.15.0
beautifulsoup4==4.12.3
bleach==6.1.0
blessed==1.20.0
certifi==2024.6.2
cffi==1.16.0
charset-normalizer==3.3.2
comm==0.2.2
contourpy==1.2.1
cycler==0.12.1
datasets==2.20.0
debugpy==1.8.1
decorator==5.1.1
deepspeed==0.14.3
defusedxml==0.7.1
dill==0.3.8
dolomite-engine @ file:///home/jkunstle/dolomite-engine
einops==0.8.0
exceptiongroup==1.2.1
executing==2.0.1
fastjsonschema==2.19.1
filelock==3.13.1
flash-attn==2.5.6
fonttools==4.53.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.2.0
gpustat==1.1.1
h11==0.14.0
hjson==3.1.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.4
hydra-core==1.3.2
idna==3.7
importlib_metadata==7.1.0
importlib_resources==6.4.0
ipdb==0.13.13
ipykernel==6.29.4
ipython==8.18.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.3
json5==0.9.25
jsonpointer==3.0.0
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyter_server==2.14.1
jupyter_server_terminals==0.5.3
jupyterlab==4.2.2
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.2
kiwisolver==1.4.5
llvmlite==0.43.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
ninja==1.11.1.1
notebook_shim==0.2.4
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.7.0.84
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-ml-py==12.555.43
nvidia-nccl-cu11==2.19.3
nvidia-nvtx-cu11==11.8.86
omegaconf==2.3.0
overrides==7.7.0
packaging==24.1
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
peft==0.11.1
pexpect==4.9.0
pillow==10.3.0
platformdirs==4.2.2
prometheus_client==0.20.0
prompt_toolkit==3.0.47
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==16.1.0
pyarrow-hotfix==0.6
pycparser==2.22
pydantic==2.7.4
pydantic_core==2.18.4
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.0.3
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.1
safetensors==0.4.3
Send2Trash==1.8.3
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
terminado==0.18.1
tinycss2==1.3.0
tokenizers==0.15.2
tomli==2.0.1
torch==2.2.2+cu118
tornado==6.4.1
tqdm==4.66.4
traitlets==5.14.3
transformers==4.39.3
triton==2.2.0
types-python-dateutil==2.9.0.20240316
typing_extensions==4.9.0
tzdata==2024.1
uri-template==1.3.0
urllib3==2.2.1
wcwidth==0.2.13
webcolors==24.6.0
webencodings==0.5.1
websocket-client==1.8.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.19.2
16 changes: 16 additions & 0 deletions startup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
torchrun --standalone \
--rdzv_endpoint="localhost:8889" \
--nproc_per_node=gpu ./trainer.py \
--model_name_or_path="/home/jkunstle/model" \
--data_path="/home/jkunstle/data/data.jsonl" \
--output_dir="/home/jkunstle/out" \
--num_epochs=100 \
--learning_rate=1e-05 \
--num_warmup_steps=400 \
--effective_batch_size=128 \
--save_samples=250000 \
--log_level="INFO" \
--sharding_strategy="HYBRID_SHARD" \
--is_granite \
--max_batch_len 10000 \
--seed=42
19 changes: 11 additions & 8 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DataWrapper:

def __init__(self, _args: argparse.ArgumentParser):

self._args = _args.copy()
self._args = _args
self.dataset = setup_dataset(data_path=_args.data_path)
self.tokenizer = setup_tokenizer(_args.model_name_or_path)
self.packing_max_batch_len, self.grad_accum = (
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self, _args: argparse.ArgumentParser):
class DSModelWrapper:

def __init__(self, _args: argparse.ArgumentParser, dataw: DataWrapper):
self._args = _args.copy()
self._args = _args
self.dataw = dataw
self._ds_config = self._get_ds_config(
world_size=self._args.world_size,
Expand Down Expand Up @@ -269,7 +269,7 @@ class DeepSpeedTrainer:
def __init__(
self, _args: argparse.ArgumentParser, modelw: DSModelWrapper, dataw: DataWrapper
):
self._args = _args.copy()
self._args = _args
self.dataw = dataw
self.modelw = modelw
self.model = self.modelw.model
Expand All @@ -288,6 +288,9 @@ def __init__(
print(
f"\033[93mNumber of samples per DS save: {self.save_samples_ds}\033[0m"
)
else:
self.save_samples_ds = None

if self.local_rank == 0:
print(f"\033[93mNumber of samples per save: {self.save_samples}\033[0m")

Expand Down Expand Up @@ -353,7 +356,7 @@ def _try_save_checkpoint(
if self.local_rank == 0:
elapsed_time = time.time() - start
overall_throughput = (
self._args.samples_per_gpu * self.world_size / elapsed_time
self.dataw.samples_per_gpu * self.world_size / elapsed_time
)
current_lr = self.model.lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
Expand All @@ -370,12 +373,12 @@ def _try_save_checkpoint(
f"total loss: {aggregated_values[2]/num_loss_counted_tokens}"
)

if self.global_step * self.batch_size % args.save_samples == 0:
if self.global_step * self.batch_size % self._args.save_samples == 0:
save_hf_format_ds(
args,
self.model,
self.dataw.tokenizer,
self.global_step * self._args.samples_per_gpu * self.world_size,
self.global_step * self.dataw.samples_per_gpu * self.world_size,
)

if (
Expand All @@ -386,7 +389,7 @@ def _try_save_checkpoint(
self._args,
self.model,
self.dataw.tokenizer,
self.global_step * self._args.samples_per_gpu * self.world_size,
self.global_step * self.dataw.samples_per_gpu * self.world_size,
)

def train(self):
Expand Down Expand Up @@ -428,7 +431,6 @@ def main(_args: argparse.ArgumentParser):
f"grad_accum: {dataw.grad_accum}\n"
f"num batches: {len(dataw.train_loader)}\n"
f"avg_samples_per_batch: {len(dataw.dataset)/len(dataw.train_loader)}\n"
f"samples_per_gpu: {_args.samples_per_gpu}\033[0m"
)

modelw = DSModelWrapper(_args=_args, dataw=dataw)
Expand Down Expand Up @@ -480,5 +482,6 @@ def main(_args: argparse.ArgumentParser):
parser.add_argument("--lora_target_modules", nargs="+", default=None)
parser.add_argument("--max_batch_len", type=int, default=60000)
args = parser.parse_args()
args.world_size = int(os.environ["WORLD_SIZE"])
set_random_seed(args.seed)
main(args)

0 comments on commit d3be728

Please sign in to comment.