Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

E2e device cuda #575

Merged
merged 3 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install -e .
pip3 install flash-attn
pip3 install -r requirements-tests.txt

- name: Run e2e tests
Expand Down
13 changes: 8 additions & 5 deletions src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pynvml
import torch
from pynvml.nvml import NVMLError


def gpu_memory_usage(device=0):
Expand All @@ -20,11 +21,13 @@ def gpu_memory_usage_smi(device=0):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
except NVMLError:
return 0.0


def log_gpu_memory_usage(log, msg, device):
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_device():
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
cfg.device_map = {"": torch.cuda.current_device()}
else:
cfg.device_map = {"": cfg.device}

Expand Down
42 changes: 42 additions & 0 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,45 @@ def test_lora(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

def test_lora_packing(self):
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)