Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support device_map=sequential & max_memory config parameters #903

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ eval_sample_packing:
sample_packing_eff_est:
total_num_tokens:

# Passed through to transformers when loading the model when launched without accelerate
device_map:
winglian marked this conversation as resolved.
Show resolved Hide resolved
winglian marked this conversation as resolved.
Show resolved Hide resolved
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:

# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_device():

cfg.device = get_device()
if cfg.world_size == 1:
cfg.device_map = "auto"
cfg.device_map = cfg.device_map or "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def load_model(
model_kwargs = {}

model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype

if cfg.model_revision:
Expand Down