Skip to content

Commit

Permalink
fix(python): fix ci random fail (#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
shjwudp authored Dec 24, 2021
1 parent 1e033d6 commit eff5ee4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 36 deletions.
27 changes: 27 additions & 0 deletions bagua/service/autotune_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def report_tensor_execution_order():

return json.dumps({})

@app.route("/api/v1/health_check", methods=["GET"])
def health_check():
return json.dumps({"status": "ok"})

# set secret-key
app.config.update(SECRET_KEY=os.urandom(24))

Expand All @@ -313,6 +317,16 @@ def __init__(
self.session = requests.Session()
self.proxies = proxies

import socket
from urllib3.connection import HTTPConnection

HTTPConnection.default_socket_options = HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), # Enables the feature
(socket.SOL_TCP, socket.TCP_KEEPIDLE, 45), # Overrides the time when the stack willl start sending KeppAlives after no data received on a Persistent Connection
(socket.SOL_TCP, socket.TCP_KEEPINTVL, 10), # Defines how often thoe KA will be sent between them
(socket.SOL_TCP, socket.TCP_KEEPCNT, 6), # How many attemps will your code try if the server goes down before droping the connection.
]

def report_metrics(
self,
model_name: str,
Expand Down Expand Up @@ -383,6 +397,19 @@ def report_tensor_execution_order(
)
return rsp

def health_check(self) -> bool:
try:
# get response will be ok
self.session.get(
"http://{}/api/v1/health_check".format(
self.autotune_service_addr
),
proxies=self.proxies
)

return True
except requests.exceptions.ConnectionError:
return False

if __name__ == "__main__":
import argparse
Expand Down
64 changes: 28 additions & 36 deletions bagua/torch_api/data_parallel/bagua_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class BaguaDistributedDataParallelStates:

self.module._bagua_states = BaguaDistributedDataParallelStates()
bagua_states = self.module._bagua_states
bagua_states._bagua_algorithm_hooks = []
bagua_states._bagua_autograd_hooks = []
bagua_states._bagua_framework_hooks = []

self._bagua_backend = get_backend(self.bagua_module_name)
Expand Down Expand Up @@ -345,7 +345,7 @@ def _bagua_autotune_step(self):
assert rsp.status_code == 200, "Unexpected rsp={}".format(rsp)

# update parameters
self._bagua_reset_algorithm_buckets()
self._reset_buckets()
self._bagua_autotune_last_report_time = time.time()

logging.debug("autotune overhead=%s", time.time() - start_time)
Expand Down Expand Up @@ -390,42 +390,33 @@ def _bagua_autotune_get_buckets(self):
)
return list(recommended_buckets)

def rebuild_buckets(self):
def _bagua_init_algorithm(self):
self._bagua_broadcast_parameters()

self._bagua_tensors = self.bagua_algorithm.init_tensors(self)
self._bagua_tensor_map = dict(
[(tensor.bagua_tensor_name, tensor) for tensor in self._bagua_tensors]
)
self._bagua_autotune_register_tensors()
self._bagua_reset_algorithm_buckets()
self.params_in_use = set([name for name, _ in self.bagua_build_params()])

def _bagua_init_algorithm(self):
self._bagua_cleanup_algorithm()
self._bagua_broadcast_parameters()
self.rebuild_buckets()

def _bagua_cleanup_algorithm(self):
bagua_states = self.module._bagua_states
for hook in bagua_states._bagua_algorithm_hooks:
hook.remove()
bagua_states._bagua_algorithm_hooks.clear()
self._reset_buckets()

self.bagua_buckets.clear()
self._register_autograd_hooks()
self._register_optimizer_hooks()

def delay_allreduce(self):
def _delay_allreduce(self):
for param_name, parameter in self.bagua_build_params():
self.bagua_algorithm.init_backward_hook(self)(param_name, parameter)
self.bagua_algorithm.init_post_backward_hook(self)()

def _bagua_reset_algorithm_buckets(self):
def _cleanup_autograd_hooks(self):
bagua_states = self.module._bagua_states
self._bagua_cleanup_algorithm()
raw_buckets = self._bagua_autotune_get_buckets()
self.bagua_buckets.extend(
self.bagua_algorithm.tensors_to_buckets(
raw_buckets, self.gradient_as_bucket_view
)
)
for hook in bagua_states._bagua_autograd_hooks:
hook.remove()
bagua_states._bagua_autograd_hooks.clear()

def _register_autograd_hooks(self):
bagua_states = self.module._bagua_states
self._cleanup_autograd_hooks()

for name, param in self.module.named_parameters():

Expand All @@ -447,13 +438,9 @@ def real_post_backward_hook(*unused):
)

if self.find_unused_parameters:
if (
set(self.autograd_graph_params.keys())
!= self.params_in_use
):
self.rebuild_buckets()
self.delay_allreduce()

if set(self.autograd_graph_params.keys()) != self.params_in_use:
self._reset_buckets()
self._delay_allreduce()
if not self._is_post_backward_callback_queued:
torch.autograd.Variable._execution_engine.queue_callback(
real_post_backward_hook
Expand All @@ -467,8 +454,9 @@ def real_post_backward_hook(*unused):
grad_acc = param_tmp.grad_fn.next_functions[0][0]
hook = grad_acc.register_hook(real_hook_factory(name, param))
hook.grad_acc = grad_acc
bagua_states._bagua_algorithm_hooks.append(hook)
bagua_states._bagua_autograd_hooks.append(hook)

def _register_optimizer_hooks(self):
optimizer_hook = self.bagua_algorithm.init_post_optimizer_step_hook(self)

from types import MethodType
Expand All @@ -488,6 +476,9 @@ def new_step(self, *args, **kwargs):

optimizer.step = new_step_factory(optimizer)

def _reset_buckets(self):
raw_buckets = self._bagua_autotune_get_buckets()
self.bagua_buckets = self.bagua_algorithm.tensors_to_buckets(raw_buckets, self.gradient_as_bucket_view)
for bucket in self.bagua_buckets:
self.bagua_algorithm.init_operations(
self,
Expand All @@ -496,12 +487,13 @@ def new_step(self, *args, **kwargs):
self._bagua_backend.register_ordered_buckets(
[bucket.backend_bucket for bucket in self.bagua_buckets]
)
self.params_in_use = set([name for name, _ in self.bagua_build_params()])

def _reset_algorithm_state(self):
bagua_states = self.module._bagua_states
if hasattr(bagua_states, "_bagua_framework_hooks"):
for hook in bagua_states._bagua_framework_hooks:
hook.remove()

if hasattr(bagua_states, "_bagua_algorithm_hooks"):
self._bagua_cleanup_algorithm()
if hasattr(bagua_states, "_bagua_autograd_hooks"):
self._cleanup_autograd_hooks()

0 comments on commit eff5ee4

Please sign in to comment.