Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support checkpointing for moe #242

Merged
merged 25 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bagua/torch_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@
from . import contrib # noqa: F401
from . import communication # noqa: F401
from . import algorithms # noqa: F401
from . import checkpointing # noqa: F401
from .model_parallel import moe # noqa: E402,F401
265 changes: 265 additions & 0 deletions bagua/torch_api/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright (c) 2021 Kuaishou AI Platform & DS3 Lab
#
# All rights reserved.

import logging
import os
import random
import re
import sys
import torch
import torch.distributed as dist
import bagua.torch_api as bagua
import numpy as np
from collections import defaultdict, OrderedDict


def _ensure_directory_exists(filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)


def _get_optimizer_ckpt_name(
checkpoints_path, iteration, expp_rank, mp_rank=0, release=False
):
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)
ckpt_name = os.path.join(
checkpoints_path,
directory,
f"expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt",
)
return ckpt_name


def _get_expert_ckpt_name(
checkpoints_path, expert_id, iteration, mp_rank=0, release=False
):
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)
ckpt_name = os.path.join(
checkpoints_path,
directory,
f"expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt",
)
return ckpt_name


def _get_model_ckpt_name(checkpoints_path, iteration, mp_rank=0, release=False):
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)
return os.path.join(
checkpoints_path, directory, f"mp_rank_{mp_rank:02d}_model_states.pt"
)


def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")


def read_metadata(tracker_filename):
iteration = 0
release = False
with open(tracker_filename, "r") as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == "release"
if not release:
logging.error(
"Invalid metadata file {}. Exiting".format(tracker_filename)
)
sys.exit()
assert iteration >= 0 or release, "error parsing metadata file {}".format(
tracker_filename
)

return iteration, release


def save_checkpoint(
iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None
):
logging.info(
"saving checkpoint at iterration {:7d} to {}".format(
iteration, checkpoints_path
)
)

if model.has_moe_layers:
_save_moe_checkpoint(
iteration, checkpoints_path, model, optimizer, lr_scheduler
)
else:
_save_checkpoint(iteration, checkpoints_path, model, optimizer, lr_scheduler)

logging.info(
"successfully saved checkpoint at iterration {:7d} to {}".format(
iteration, checkpoints_path
)
)

# update the latest iteration
if not dist.is_initialized() or dist.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(checkpoints_path)
with open(tracker_filename, "w") as f:
f.write(str(iteration))


def _save_checkpoint(
iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None
):
if not dist.is_initialized() or dist.get_rank() == 0:
state_dict = {}
state_dict["iteration"] = iteration
state_dict["model"] = model.state_dict()
if optimizer is not None:
state_dict["optimizer"] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict["lr_scheduler"] = lr_scheduler.state_dict()

checkpoint_name = _get_model_ckpt_name(checkpoints_path, iteration)
_ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)

if dist.is_initialized():
dist.barrier()


def _save_moe_checkpoint(
iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None
):
world_size = 1 if not dist.is_initialized() else dist.get_world_size()
expp_rank = 1 if not dist.is_initialized() else dist.get_rank()
num_local_experts = model.num_experts // world_size
experts_state_dict, model_state_dict = _get_moe_state_dict(
model.state_dict(), num_local_experts, expp_rank
)

# Each rank saves its local experts
for global_expert_id, expert_state_dict in experts_state_dict.items():
expert_save_dir = _get_expert_ckpt_name(
checkpoints_path, global_expert_id, iteration
)
logging.info(
f"Saving model expert {global_expert_id} checkpoint: {expert_save_dir}"
)
_ensure_directory_exists(expert_save_dir)
torch.save(expert_state_dict, expert_save_dir)

# Save optimizer states. They are different across each exp parallel rank.
optimizer_state = {"optimizer": optimizer.state_dict() if optimizer else None}
torch.save(
optimizer_state,
_get_optimizer_ckpt_name(checkpoints_path, iteration, expp_rank),
)

if expp_rank == 0:
state_dict = {}
state_dict["iteration"] = iteration
state_dict["model"] = model_state_dict
if lr_scheduler is not None:
state_dict["lr_scheduler"] = lr_scheduler.state_dict()

# Save.
checkpoint_name = _get_model_ckpt_name(checkpoints_path, iteration)
_ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)


def _get_moe_state_dict(full_state_dict, num_local_experts, expp_rank):
experts_state_dict, moe_state_dict = defaultdict(dict), {}
for key in list(full_state_dict.keys()):
if "expert" in key and "moe.gate.wg.weight" not in key:
moe_state_dict[key] = full_state_dict.pop(key)
non_moe_state_dict = full_state_dict

moe_str_prefix = '.bagua_moe.experts.bagua_experts.'
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
for key in list(moe_state_dict.keys()):
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
local_expert_id = None
if not m:
logging.warning(f'No expert found in key {key}.')
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
else:
local_expert_id = m.group(1)

global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',
f'{moe_str_prefix}{global_expert_id}')
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
experts_state_dict[str(global_expert_id)][expert_key] = moe_state_dict.pop(key)

return experts_state_dict, non_moe_state_dict


def load_checkpoint(checkpoints_path, model, optimizer=None, lr_scheduler=None, strict=True):
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""

tracker_filename = get_checkpoint_tracker_filename(checkpoints_path)

# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
logging.warning(f"could not find checkpoint metadata file {tracker_filename}")
return 0

iteration, release = read_metadata(tracker_filename)
logging.info(f"loading checkpoint from {checkpoints_path} at iteration {iteration}")

_load_checkpoint(iteration, checkpoints_path, model, optimizer, lr_scheduler)

if dist.is_initialized():
dist.barrier()

logging.info(f"successfully loaded checkpoint from {checkpoints_path} at {iteration}")
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
return iteration


def _load_checkpoint(iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None, strict=True):
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
expp_rank = 1 if not dist.is_initialized() else dist.get_rank()
checkpoint_name = _get_model_ckpt_name(checkpoints_path, iteration)

model_checkpoint = torch.load(checkpoint_name, map_location='cpu')
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
if model.has_moe_layers:
num_local_experts = model.num_experts // dist.get_world_size()
_load_moe_state_dict(checkpoints_path, iteration, num_local_experts, expp_rank, state_dict=model_checkpoint['model'])
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

if model.has_moe_layers and optimizer is not None:
optim_load_path = _get_optimizer_ckpt_name(checkpoints_path, iteration, expp_rank)
optim_checkpoint = torch.load(optim_load_path, map_location=torch.device('cpu'))
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
else:
optim_checkpoint = model_checkpoint


model.load_state_dict(model_checkpoint['model'], strict=strict)
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

# Optimizer.
if optimizer is not None:
optimizer.load_state_dict(optim_checkpoint['optimizer'])
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
if lr_scheduler is not None:
lr_scheduler.load_state_dict(model_checkpoint['lr_scheduler'])
liuhatry marked this conversation as resolved.
Show resolved Hide resolved


def _load_moe_state_dict(checkpoint_path, iteration, num_local_experts, expp_rank, state_dict):
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = torch.load(_get_expert_ckpt_name(
checkpoint_path, global_expert_id, iteration), map_location=torch.device('cpu'))
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

# Updating global -> local expert ids
moe_str_prefix = '.bagua_moe.experts.bagua_experts.'
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
for key in list(expert_state_dict.keys()):
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
9 changes: 9 additions & 0 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn
import itertools
from typing import List, Tuple
from bagua.torch_api.model_parallel.moe import MoE


@gorilla.patches(torch.nn.Module, filter=lambda name, obj: "bagua" in name)
Expand Down Expand Up @@ -382,6 +383,14 @@ def record_speed_metrics_event(self, _):
self._bagua_autotune_client = get_hyperparameters_service_client()

self._bagua_init_algorithm()

self.has_moe_layers = False
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
self.num_experts = 0
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
for name, module in self.named_modules():
if isinstance(module, MoE):
self.has_moe_layers = True
self.num_experts = module.num_experts
break
return self

def _bagua_autotune_register_tensors(self):
Expand Down
6 changes: 3 additions & 3 deletions bagua/torch_api/model_parallel/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1):
super(Experts, self).__init__()

self.deepspeed_experts = torch.nn.ModuleList(
self.bagua_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)]
)
self.num_local_experts = num_local_experts

# TODO: revisit allreduce for moe.gate...
for expert in self.deepspeed_experts:
for expert in self.bagua_experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for name, param in expert.named_parameters():
param.allreduce = False

def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
for chunk, expert in zip(chunks, self.bagua_experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0] # Ignore the bias term for now
Expand Down
6 changes: 3 additions & 3 deletions bagua/torch_api/model_parallel/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
)

experts = Experts(expert, num_local_experts)
self.deepspeed_moe = MOELayer(
self.bagua_moe = MOELayer(
TopKGate(
hidden_size,
self.num_experts,
Expand Down Expand Up @@ -103,6 +103,6 @@ def forward(self, hidden_states, used_token=None):

* exp_counts (int): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
output = self.bagua_moe(hidden_states, used_token)
output = self.dropout(output)
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts
return output, self.bagua_moe.l_aux, self.bagua_moe.exp_counts