Skip to content

Commit

Permalink
feature: fetch latest weights in distributed learning (#148)
Browse files Browse the repository at this point in the history
This commit allows a new trainer to fetch the latest global weight of distributed learning. The weight can either be from the ring all-reduce algorithm or from a single trainer if there is only one trainer.
  • Loading branch information
GaoxiangLuo committed Jun 28, 2022
1 parent 75d960a commit 77ed492
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
51 changes: 34 additions & 17 deletions lib/python/flame/mode/distributed/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@

TAG_RING_ALLREDUCE = 'ring_allreduce'

# the number of copies to save parameters and model artifact
# this is for redundancy
COMMIT_COUNT = 3


class Trainer(Role, metaclass=ABCMeta):
"""Trainer implements an ML training role."""

Expand All @@ -63,6 +58,7 @@ def internal_init(self) -> None:
if base_model and base_model.name != "" and base_model.version > 0:
self.model = self.registry_client.load_model(
base_model.name, base_model.version)
self.ring_weights = None # store the latest model weights from ring all-reduce

self.registry_client.setup_run(mlflow_runname(self.config))
self.metrics = dict()
Expand All @@ -71,7 +67,7 @@ def internal_init(self) -> None:
self._rounds = self.config.hyperparameters['rounds']
self._work_done = False

self.is_commiter = False
self.is_committer = False

self.framework = get_ml_framework_in_use()
if self.framework == MLFramework.UNKNOWN:
Expand Down Expand Up @@ -105,7 +101,6 @@ def _ring_allreduce(self, tag: str) -> None:
logger.debug(f"total_data_count: {total_data_count}")
if not success:
# members don't agree, we can't do ring-allreduce
self.is_commiter = True
return

self._update_weights()
Expand Down Expand Up @@ -164,8 +159,8 @@ def _do_ring_allreduce(self, channel):

recv_chunk_idx = (rank - i - 1 + size) % size
msg = channel.recv(ends[recv_from])
if MessageType.WEIGHTS not in msg:
logger.error(f"end_id: {ends[recv_from]}, msg: {msg}")
while MessageType.WEIGHTS not in msg:
msg = channel.recv(ends[recv_from])
recv_chunk = msg[MessageType.WEIGHTS]

logger.debug(
Expand Down Expand Up @@ -289,6 +284,11 @@ def _handle_member_check(self, channel, end, digest) -> tuple[bool, int]:

msg = channel.recv(end)

# check if a new trainer needs the latest weights
if MessageType.NEW_TRAINER in msg and self.is_committer:
logger.debug(f"{channel.get_backend_id()} sending weights to the new trainer {end}")
channel.send(end, {MessageType.RING_WEIGHTS: self.ring_weights})

logger.debug(f"end_id: {end}, msg: {msg}")
if MessageType.MEMBER_DIGEST not in msg:
logger.debug("no member digest found")
Expand All @@ -309,6 +309,16 @@ def _handle_member_check(self, channel, end, digest) -> tuple[bool, int]:

dataset_size = msg[MessageType.DATASET_SIZE]

# new trainer fetches the latest weights from the committer
if msg[MessageType.IS_COMMITTER]:
while self.ring_weights is None:
msg = channel.recv(end)
logger.debug(f"new trainer {channel.get_backend_id()} fetching weights from {end} ")
if MessageType.RING_WEIGHTS in msg:
self.weights = msg[MessageType.RING_WEIGHTS]
self._update_model()
break

if channel.is_rxq_empty(end):
break

Expand All @@ -320,13 +330,20 @@ def _member_check(self, channel) -> tuple[bool, int]:
# This means that there is no ends in the channel,
# so there is no point of broadcasting digest.
# If the empty digest is broadcast, it can cause a bug
self.is_committer = True
self._update_weights()
self._update_model()
return False, 0

msg = {
MessageType.MEMBER_DIGEST: digest,
MessageType.DATASET_SIZE: self.dataset_size,
MessageType.ROUND: self._round
MessageType.ROUND: self._round,
MessageType.IS_COMMITTER: self.is_committer
}
if self.ring_weights is None:
logger.debug("Sending arrival message...")
msg[MessageType.NEW_TRAINER] = True
logger.debug(f"member check msg = {msg}")
channel.broadcast(msg)

Expand All @@ -343,9 +360,8 @@ def _member_check(self, channel) -> tuple[bool, int]:
my_taskid = channel.get_backend_id()
ends.append(my_taskid)
ends.sort()
# if my taskid is in the first "COMMIT_COUNT" ends,
# then it's selected as a commiter
self.is_commiter = True if my_taskid in ends[:COMMIT_COUNT] else False
# if my taskid is in the first ends, then it's selected as a committer
self.is_committer = True if my_taskid in ends[:1] else False

# check if work is done by others
# if work is done, then no further distributed learning needed
Expand Down Expand Up @@ -380,6 +396,7 @@ def _update_model(self):
self.model.load_state_dict(self.weights)
elif self.framework == MLFramework.TENSORFLOW:
self.model.set_weights(self.weights)
self.ring_weights = self.weights

def _update_weights(self):
"""Save weights from model."""
Expand All @@ -405,14 +422,14 @@ def increment_round(self):

def save_params(self):
"""Save hyperparamets in a model registry."""
logger.debug(f"saving params: is_commiter: {self.is_commiter}")
if self.config.hyperparameters and self.is_commiter:
logger.debug(f"saving params: is_commiter: {self.is_committer}")
if self.config.hyperparameters and self.is_committer:
self.registry_client.save_params(self.config.hyperparameters)

def save_model(self):
"""Save model in a model registry."""
logger.debug(f"saving model: is_commiter: {self.is_commiter}")
if self.model and self.is_commiter:
logger.debug(f"saving model: is_commiter: {self.is_committer}")
if self.model and self.is_committer:
model_name = f"{self.config.job.name}-{self.config.job.job_id}"
self.registry_client.save_model(model_name, self.model)

Expand Down
3 changes: 3 additions & 0 deletions lib/python/flame/mode/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ class MessageType(Enum):

# a digest of all the workers in distributed learning
MEMBER_DIGEST = 5
RING_WEIGHTS = 6 # global model weights in distributed learning
NEW_TRAINER = 7 # sending message for the arrival of a new trainer
IS_COMMITTER = 8 # is a trainer responsible to send weights to a new trainer in distributed learning

0 comments on commit 77ed492

Please sign in to comment.