Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support checkpointing for moe #242

Merged
merged 25 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .buildkite/scripts/benchmark_master.sh
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,16 @@ CUDA_VISIBLE_DEVICES=0,1 python -m bagua.distributed.launch \
--set-deterministic \
2>&1 | tee ${logfile}
check_moe_log ${logfile} 0.000293

# 4. test moe checkpoint
logfile=$(mktemp /tmp/bagua_moe_checkpoint.XXXXXX.log)
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m bagua.distributed.launch \
--nproc_per_node 4 \
${MOE_SCRIPT} \
--algorithm gradient_allreduce \
--epochs 5 \
--num-local-experts 2 \
--set-deterministic \
--save-model \
2>&1 | tee ${logfile}
check_moe_log ${logfile} 0.000293
1 change: 1 addition & 0 deletions bagua/torch_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@
from . import contrib # noqa: F401
from . import communication # noqa: F401
from . import algorithms # noqa: F401
from . import checkpoint # noqa: E402,F401
from .model_parallel import moe # noqa: E402,F401
2 changes: 2 additions & 0 deletions bagua/torch_api/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .checkpointing import load_checkpoint # noqa: F401
from .checkpointing import save_checkpoint # noqa: F401
310 changes: 310 additions & 0 deletions bagua/torch_api/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# Copyright (c) 2021 Kuaishou AI Platform & DS3 Lab
#
# All rights reserved.

import logging
import os
import re
import sys
import torch
import torch.distributed as dist
from collections import defaultdict


def _has_moe_layers(model):
bagua_has_moe_layers = False
bagua_moe_num_experts = 0
for name, module in model.named_modules():
if isinstance(module, MoE):
bagua_has_moe_layers = True
bagua_moe_num_experts = module.num_experts
break
return bagua_has_moe_layers, bagua_moe_num_experts


def _ensure_directory_exists(filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)


def _get_optimizer_ckpt_name(
checkpoints_path, iteration, expp_rank, mp_rank=0, release=False
):
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)
ckpt_name = os.path.join(
checkpoints_path,
directory,
f"expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt",
)
return ckpt_name


def _get_expert_ckpt_name(
checkpoints_path, expert_id, iteration, mp_rank=0, release=False
):
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)
ckpt_name = os.path.join(
checkpoints_path,
directory,
f"expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt",
)
return ckpt_name


def _get_model_ckpt_name(checkpoints_path, iteration, mp_rank=0, release=False):
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)
return os.path.join(
checkpoints_path, directory, f"mp_rank_{mp_rank:02d}_model_states.pt"
)


def _get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")


def _read_metadata(tracker_filename):
iteration = 0
release = False
with open(tracker_filename, "r") as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == "release"
if not release:
logging.error(
"Invalid metadata file {}. Exiting".format(tracker_filename)
)
sys.exit()
assert iteration >= 0 or release, "error parsing metadata file {}".format(
tracker_filename
)

return iteration, release


def save_checkpoint(
iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lack type annotation, similar for other functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

):
"""Save model checkpoint.

Args:
iteration: Training Iteration.
checkpoints_path: Path of checkpoints.
model: The model to save.
optimizer(optional): The optimizer to save. Default: ``None``.
lr_scheduler(optional): The LR scheduler to save. Default: ``None``.
"""
logging.info(
"saving checkpoint at iterration {:7d} to {}".format(
iteration, checkpoints_path
)
)

bagua_has_moe_layers, bagua_moe_num_experts = _has_moe_layers(model)
if bagua_has_moe_layers:
_save_moe_checkpoint(
iteration, checkpoints_path, bagua_moe_num_experts, model, optimizer, lr_scheduler
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
)
else:
_save_checkpoint(iteration, checkpoints_path, model, optimizer, lr_scheduler)

logging.info(
"successfully saved checkpoint at iterration {:7d} to {}".format(
iteration, checkpoints_path
)
)

# update the latest iteration
if not dist.is_initialized() or dist.get_rank() == 0:
tracker_filename = _get_checkpoint_tracker_filename(checkpoints_path)
with open(tracker_filename, "w") as f:
f.write(str(iteration))

if dist.is_initialized():
dist.barrier()


def _save_checkpoint(
iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None
):
if not dist.is_initialized() or dist.get_rank() == 0:
state_dict = {}
state_dict["iteration"] = iteration
state_dict["model"] = model.state_dict()
if optimizer is not None:
state_dict["optimizer"] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict["lr_scheduler"] = lr_scheduler.state_dict()

checkpoint_name = _get_model_ckpt_name(checkpoints_path, iteration)
_ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)


def _save_moe_checkpoint(
iteration, checkpoints_path, num_experts, model, optimizer=None, lr_scheduler=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lack type annotation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still lack type annotations on _xxx functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

):
world_size = 1 if not dist.is_initialized() else dist.get_world_size()
expp_rank = 1 if not dist.is_initialized() else dist.get_rank()
num_local_experts = num_experts // world_size
experts_state_dict, model_state_dict = _get_moe_state_dict(
model.state_dict(), num_local_experts, expp_rank
)

# Each rank saves its local experts
for global_expert_id, expert_state_dict in experts_state_dict.items():
expert_save_dir = _get_expert_ckpt_name(
checkpoints_path, global_expert_id, iteration
)
logging.info(
f"Saving model expert {global_expert_id} checkpoint: {expert_save_dir}"
)
_ensure_directory_exists(expert_save_dir)
torch.save(expert_state_dict, expert_save_dir)

# Save optimizer states. They are different across each exp parallel rank.
optimizer_state = {"optimizer": optimizer.state_dict() if optimizer else None}
torch.save(
optimizer_state,
_get_optimizer_ckpt_name(checkpoints_path, iteration, expp_rank),
)

if expp_rank == 0:
state_dict = {}
state_dict["iteration"] = iteration
state_dict["model"] = model_state_dict
if lr_scheduler is not None:
state_dict["lr_scheduler"] = lr_scheduler.state_dict()

# Save.
checkpoint_name = _get_model_ckpt_name(checkpoints_path, iteration)
_ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)


def _get_moe_state_dict(full_state_dict, num_local_experts, expp_rank):
experts_state_dict, moe_state_dict = defaultdict(dict), {}
for key in list(full_state_dict.keys()):
if "expert" in key and "moe.gate.wg.weight" not in key:
moe_state_dict[key] = full_state_dict.pop(key)
non_moe_state_dict = full_state_dict

moe_str_prefix = ".bagua_moe.experts.bagua_experts."
for key in list(moe_state_dict.keys()):
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
local_expert_id = None
if not m:
logging.warning(f"No expert found in key {key}.")
else:
local_expert_id = m.group(1)

global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
expert_key = key.replace(
f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}"
)
experts_state_dict[str(global_expert_id)][expert_key] = moe_state_dict.pop(key)

return experts_state_dict, non_moe_state_dict


def load_checkpoint(
checkpoints_path, model, optimizer=None, lr_scheduler=None, strict=True
):
"""Load a model checkpoint and return the iteration.

Args:
checkpoints_path: Path of checkpoints.
model: The model to load on.
optimizer(optional): The optimizer to load on. Default: ``None``.
lr_scheduler(optional): The LR scheduler to load on. Default: ``None``.
strict (bool, optional): whether to strictly enforce that the keys in
``state_dict`` of the checkpoint match the keys returned by this module's
state_dict() function. Default: ``True``.
"""

tracker_filename = _get_checkpoint_tracker_filename(checkpoints_path)

# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
logging.warning(f"could not find checkpoint metadata file {tracker_filename}")
return 0

iteration, release = _read_metadata(tracker_filename)
logging.info(f"loading checkpoint from {checkpoints_path} at iteration {iteration}")

_load_checkpoint(iteration, checkpoints_path, model, optimizer, lr_scheduler)

if dist.is_initialized():
dist.barrier()

logging.info(
f"successfully loaded checkpoint from {checkpoints_path} at {iteration}"
)
return iteration


def _load_checkpoint(
iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None, strict=True
):
expp_rank = 1 if not dist.is_initialized() else dist.get_rank()
checkpoint_name = _get_model_ckpt_name(checkpoints_path, iteration)

model_checkpoint = torch.load(checkpoint_name, map_location="cpu")
bagua_has_moe_layers, bagua_moe_num_experts = _has_moe_layers(model)
if bagua_has_moe_layers:
num_local_experts = bagua_moe_num_experts // dist.get_world_size()
_load_moe_state_dict(
checkpoints_path,
iteration,
num_local_experts,
expp_rank,
state_dict=model_checkpoint["model"],
)

if model._bagua_has_moe_layers and optimizer is not None:
optim_load_path = _get_optimizer_ckpt_name(
checkpoints_path, iteration, expp_rank
)
optim_checkpoint = torch.load(optim_load_path, map_location=torch.device("cpu"))
else:
optim_checkpoint = model_checkpoint

model.load_state_dict(model_checkpoint["model"], strict=strict)

# Optimizer.
if optimizer is not None:
optimizer.load_state_dict(optim_checkpoint["optimizer"])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(model_checkpoint["lr_scheduler"])


def _load_moe_state_dict(
checkpoint_path, iteration, num_local_experts, expp_rank, state_dict
):
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = torch.load(
_get_expert_ckpt_name(checkpoint_path, global_expert_id, iteration),
map_location=torch.device("cpu"),
)

# Updating global -> local expert ids
moe_str_prefix = ".bagua_moe.experts.bagua_experts."
for key in list(expert_state_dict.keys()):
local_key = key.replace(
f"{moe_str_prefix}{global_expert_id}",
f"{moe_str_prefix}{local_expert_id}",
)
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
2 changes: 1 addition & 1 deletion bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn
import itertools
from typing import List, Tuple
from bagua.torch_api.model_parallel.moe import MoE


@gorilla.patches(torch.nn.Module, filter=lambda name, obj: "bagua" in name)
Expand Down Expand Up @@ -382,7 +383,6 @@ def record_speed_metrics_event(self, _):
self._bagua_autotune_client = get_hyperparameters_service_client()

self._bagua_init_algorithm()
return self

def _bagua_autotune_register_tensors(self):
"""
Expand Down
6 changes: 3 additions & 3 deletions bagua/torch_api/model_parallel/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1):
super(Experts, self).__init__()

self.deepspeed_experts = torch.nn.ModuleList(
self.bagua_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)]
)
self.num_local_experts = num_local_experts

# TODO: revisit allreduce for moe.gate...
for expert in self.deepspeed_experts:
for expert in self.bagua_experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for name, param in expert.named_parameters():
param.allreduce = False

def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
for chunk, expert in zip(chunks, self.bagua_experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0] # Ignore the bias term for now
Expand Down
6 changes: 3 additions & 3 deletions bagua/torch_api/model_parallel/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
)

experts = Experts(expert, num_local_experts)
self.deepspeed_moe = MOELayer(
self.bagua_moe = MOELayer(
TopKGate(
hidden_size,
self.num_experts,
Expand Down Expand Up @@ -103,6 +103,6 @@ def forward(self, hidden_states, used_token=None):

* exp_counts (int): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
output = self.bagua_moe(hidden_states, used_token)
output = self.dropout(output)
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts
return output, self.bagua_moe.l_aux, self.bagua_moe.exp_counts
Loading