Skip to content

Commit

Permalink
feat: Elastic training (#31)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `baguaelastic/distributed/launch.py` now moved to `bagua/distributed/run.py`
  • Loading branch information
shjwudp authored Jun 18, 2021
1 parent e761cc6 commit 1a5964c
Show file tree
Hide file tree
Showing 7 changed files with 659 additions and 430 deletions.
59 changes: 5 additions & 54 deletions bagua/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,8 @@
import sys
import subprocess
import os
import json
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO, List, Any
from bagua.service import (
AutotuneService,
generate_and_broadcast_server_addr,
pick_n_free_ports,
)
from bagua.autotune import autotune_system_hyperparameters
from flask import Flask
import multiprocessing
import logging

node_local_rank_stdout_filename = "node_{}_local_rank_{}_stdout"
Expand Down Expand Up @@ -150,20 +141,10 @@ def parse_args():
return parser.parse_args()


def set_autotune_env(args, current_env):
if args.autotune_level == 1:
current_env.update(
{
"BAGUA_AUTOTUNE": "1",
}
)
elif args.autotune_level == 2:
recommend_env = autotune_system_hyperparameters(
args.host_list, args.nproc_per_node, args.ssh_port
)
current_env.update(recommend_env)
elif args.autotune_level == 3: # TODO
pass
def set_bagua_env(args, current_env):
current_env["BAGUA_SERVICE_PORT"] = str(args.bagua_service_port)
current_env["BAGUA_DEFAULT_BUCKET_SIZE"] = str(args.default_bucket_size)
current_env["BAGUA_AUTOTUNE"] = str(args.autotune_level)


def main():
Expand All @@ -184,37 +165,7 @@ def main():
current_env["MASTER_PORT"] = str(args.master_port)
current_env["WORLD_SIZE"] = str(dist_world_size)

# launch autotune server
server_addr = args.master_addr
server_port = args.bagua_service_port
current_env["AUTO_TUNE_SERVER_ADDR"] = "{}:{}".format(server_addr, server_port)

set_autotune_env(args, current_env)

print("server_addr={}, server_port={}".format(server_addr, server_port))
if args.node_rank == 0:
from logging.config import dictConfig

autotune_service = AutotuneService(
dist_world_size,
max_samples=args.autotune_max_samples,
sampling_confidence_time_s=args.autotune_sampling_confidence_time,
warmup_time_s=args.autotune_warmup_time,
autotune_level=args.autotune_level,
default_bucket_size=args.default_bucket_size,
)
app = Flask(__name__)
app = autotune_service.setup_app(app)
server = multiprocessing.Process(
target=app.run,
kwargs={
"host": "0.0.0.0",
"port": server_port,
"debug": False,
},
)
server.daemon = True
server.start()
set_bagua_env(args, current_env)

processes: List[Any] = []

Expand Down
Loading

0 comments on commit 1a5964c

Please sign in to comment.