Skip to content

Commit

Permalink
feat(python): add BAGUA_AUTOTUNE_SERVER_WAIT_TIME env (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangraying authored Jan 8, 2022
1 parent e2c048a commit 4cebc59
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
12 changes: 10 additions & 2 deletions bagua/torch_api/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_node_rank,
get_default_bucket_size,
get_bagua_service_port,
get_autotune_server_wait_time,
find_free_network_port,
)
from enum import IntEnum
Expand Down Expand Up @@ -501,15 +502,22 @@ def init_process_group(store: Optional[torch.distributed.Store] = None):
start_autotune_server(_autotune_service_port)

AUTOTUNE_SERVER_WAIT_TIME = 30
wait_time = get_autotune_server_wait_time()
# at least wait 30 seconds
if wait_time < AUTOTUNE_SERVER_WAIT_TIME:
wait_time = AUTOTUNE_SERVER_WAIT_TIME

start = time.time()
service_ready = False
while (time.time() - start) < AUTOTUNE_SERVER_WAIT_TIME:
while (time.time() - start) < wait_time:
client = get_hyperparameters_service_client()
service_ready = client.health_check()
if service_ready:
break
if not service_ready:
raise Exception("Warning! autotune service not ready after {} seconds.".format(AUTOTUNE_SERVER_WAIT_TIME))
raise Exception("Warning! autotune service not ready after {} seconds. "
"You can adjust this duration through "
"`BAGUA_AUTOTUNE_SERVER_WAIT_TIME` environment variable.".format(wait_time))

# TODO remove the dependency on torch process group
if not dist.is_initialized():
Expand Down
4 changes: 4 additions & 0 deletions bagua/torch_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def get_is_output_autotune_log() -> bool:
return bool(os.environ.get("BAGUA_IS_OUTPUT_AUTOTUNE_LOG", 0))


def get_autotune_server_wait_time() -> int:
return int(os.environ.get("BAGUA_AUTOTUNE_SERVER_WAIT_TIME", 300))


def find_free_network_port() -> int:
"""Finds a free port on localhost."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand Down

0 comments on commit 4cebc59

Please sign in to comment.