Skip to content

Commit

Permalink
Fix deadlock in unhandled exception handler and re-merge (#3) (ray-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Feb 19, 2021
1 parent 3ffe375 commit cc156f7
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 68 deletions.
10 changes: 10 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,16 @@ cc_test(
],
)

cc_test(
name = "memory_store_test",
srcs = ["src/ray/core_worker/test/memory_store_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "direct_actor_transport_test",
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
Expand Down
25 changes: 22 additions & 3 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,20 @@ cdef void delete_spilled_objects_handler(
job_id=None)


cdef void unhandled_exception_handler(const CRayObject& error) nogil:
with gil:
worker = ray.worker.global_worker
data = None
metadata = None
if error.HasData():
data = Buffer.make(error.GetData())
if error.HasMetadata():
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
object_ids = [None]
worker.raise_errors([(data, metadata)], object_ids)


# This function introduces ~2-7us of overhead per call (i.e., it can be called
# up to hundreds of thousands of times per second).
cdef void get_py_stack(c_string* stack_out) nogil:
Expand Down Expand Up @@ -843,6 +857,7 @@ cdef class CoreWorker:
options.spill_objects = spill_objects_handler
options.restore_spilled_objects = restore_spilled_objects_handler
options.delete_spilled_objects = delete_spilled_objects_handler
options.unhandled_exception_handler = unhandled_exception_handler
options.get_lang_stack = get_py_stack
options.ref_counting_enabled = True
options.is_local_mode = local_mode
Expand Down Expand Up @@ -1453,9 +1468,13 @@ cdef class CoreWorker:
object_ref.native())

def remove_object_ref_reference(self, ObjectRef object_ref):
# Note: faster to not release GIL for short-running op.
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
object_ref.native())
cdef:
CObjectID c_object_id = object_ref.native()
# We need to release the gil since object destruction may call the
# unhandled exception handler.
with nogil:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)

def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
Expand Down
1 change: 1 addition & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(void(
const c_vector[c_string]&,
CWorkerType) nogil) delete_spilled_objects
(void(const CRayObject&) nogil) unhandled_exception_handler
(void(c_string *stack_out) nogil) get_lang_stack
c_bool ref_counting_enabled
c_bool is_local_mode
Expand Down
46 changes: 46 additions & 0 deletions python/ray/tests/test_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@
get_error_message, Semaphore)


def test_unhandled_errors(ray_start_regular):
@ray.remote
def f():
raise ValueError()

@ray.remote
class Actor:
def f(self):
raise ValueError()

a = Actor.remote()
num_exceptions = 0

def interceptor(e):
nonlocal num_exceptions
num_exceptions += 1

# Test we report unhandled exceptions.
ray.worker._unhandled_error_handler = interceptor
x1 = f.remote()
x2 = a.f.remote()
del x1
del x2
wait_for_condition(lambda: num_exceptions == 2)

# Test we don't report handled exceptions.
x1 = f.remote()
x2 = a.f.remote()
with pytest.raises(ray.exceptions.RayError) as err: # noqa
ray.get([x1, x2])
del x1
del x2
time.sleep(1)
assert num_exceptions == 2, num_exceptions

# Test suppression with env var works.
try:
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
x1 = f.remote()
del x1
time.sleep(1)
assert num_exceptions == 2, num_exceptions
finally:
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]


def test_failed_task(ray_start_regular, error_pubsub):
@ray.remote
def throw_exception_fct1():
Expand Down
79 changes: 19 additions & 60 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
import os
import redis
from six.moves import queue
import sys
import threading
import time
Expand Down Expand Up @@ -69,6 +68,12 @@
logger = logging.getLogger(__name__)


# Visible for testing.
def _unhandled_error_handler(e: Exception):
logger.error("Unhandled error (suppress with "
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))


class Worker:
"""A class used to define the control flow of a worker process.
Expand Down Expand Up @@ -277,6 +282,14 @@ def put_object(self, value, object_ref=None):
self.core_worker.put_serialized_object(
serialized_value, object_ref=object_ref))

def raise_errors(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
out = context.deserialize_objects(data_metadata_pairs, object_refs)
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
return
for e in out:
_unhandled_error_handler(e)

def deserialize_objects(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
return context.deserialize_objects(data_metadata_pairs, object_refs)
Expand Down Expand Up @@ -865,13 +878,6 @@ def custom_excepthook(type, value, tb):

sys.excepthook = custom_excepthook

# The last time we raised a TaskError in this process. We use this value to
# suppress redundant error messages pushed from the workers.
last_task_error_raise_time = 0

# The max amount of seconds to wait before printing out an uncaught error.
UNCAUGHT_ERROR_GRACE_PERIOD = 5


def print_logs(redis_client, threads_stopped, job_id):
"""Prints log messages from workers on all of the nodes.
Expand Down Expand Up @@ -1022,51 +1028,14 @@ def color_for(data: Dict[str, str]) -> str:
file=print_file)


def print_error_messages_raylet(task_error_queue, threads_stopped):
"""Prints message received in the given output queue.
This checks periodically if any un-raised errors occurred in the
background.
Args:
task_error_queue (queue.Queue): A queue used to receive errors from the
thread that listens to Redis.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""

while True:
# Exit if we received a signal that we should stop.
if threads_stopped.is_set():
return

try:
error, t = task_error_queue.get(block=False)
except queue.Empty:
threads_stopped.wait(timeout=0.01)
continue
# Delay errors a little bit of time to attempt to suppress redundant
# messages originating from the worker.
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
threads_stopped.wait(timeout=1)
if threads_stopped.is_set():
break
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
logger.debug(f"Suppressing error from worker: {error}")
else:
logger.error(f"Possible unhandled error from worker: {error}")


def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
def listen_error_messages_raylet(worker, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
tuples to the output queue.
Args:
worker: The worker class that this thread belongs to.
task_error_queue (queue.Queue): A queue used to communicate with the
thread that prints the errors found by this thread.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
Expand Down Expand Up @@ -1105,8 +1074,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):

error_message = error_data.error_message
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
else:
logger.warning(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
Expand Down Expand Up @@ -1269,19 +1239,12 @@ def connect(node,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
q = queue.Queue()
worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, q, worker.threads_stopped))
worker.printer_thread = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(q, worker.threads_stopped))
args=(worker, worker.threads_stopped))
worker.listener_thread.daemon = True
worker.listener_thread.start()
worker.printer_thread.daemon = True
worker.printer_thread.start()
if log_to_driver:
global_worker_stdstream_dispatcher.add_handler(
"ray_print_logs", print_to_stdstream)
Expand Down Expand Up @@ -1334,8 +1297,6 @@ def disconnect(exiting_interpreter=False):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
worker.listener_thread.join()
if hasattr(worker, "printer_thread"):
worker.printer_thread.join()
if hasattr(worker, "logger_thread"):
worker.logger_thread.join()
worker.threads_stopped.clear()
Expand Down Expand Up @@ -1447,13 +1408,11 @@ def get(object_refs, *, timeout=None):
raise ValueError("'object_refs' must either be an object ref "
"or a list of object refs.")

global last_task_error_raise_time
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values, debugger_breakpoint = worker.get_objects(
object_refs, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
if isinstance(value, ray.exceptions.ObjectLostError):
worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
Expand Down
8 changes: 8 additions & 0 deletions src/ray/common/ray_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,20 @@ class RayObject {
/// large to return directly as part of a gRPC response).
bool IsInPlasmaError() const;

/// Mark this object as accessed before.
void SetAccessed() { accessed_ = true; };

/// Check if this object was accessed before.
bool WasAccessed() const { return accessed_; }

private:
std::shared_ptr<Buffer> data_;
std::shared_ptr<Buffer> metadata_;
const std::vector<ObjectID> nested_ids_;
/// Whether this class holds a data copy.
bool has_data_copy_;
/// Whether this object was accessed.
bool accessed_ = false;
};

} // namespace ray
7 changes: 6 additions & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,12 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
return Status::OK();
},
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
options_.check_signals));
options_.check_signals,
[this](const RayObject &obj) {
// Run this on the event loop to avoid calling back into the language runtime
// from the middle of user operations.
io_service_.post([this, obj]() { options_.unhandled_exception_handler(obj); });
}));

auto check_node_alive_fn = [this](const NodeID &node_id) {
auto node = gcs_client_->Nodes().Get(node_id);
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct CoreWorkerOptions {
spill_objects(nullptr),
restore_spilled_objects(nullptr),
delete_spilled_objects(nullptr),
unhandled_exception_handler(nullptr),
get_lang_stack(nullptr),
kill_main(nullptr),
ref_counting_enabled(false),
Expand Down Expand Up @@ -146,6 +147,8 @@ struct CoreWorkerOptions {
/// Application-language callback to delete objects from external storage.
std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
delete_spilled_objects;
/// Function to call on error objects never retrieved.
std::function<void(const RayObject &error)> unhandled_exception_handler;
/// Language worker callback to get the current call stack.
std::function<void(std::string *)> get_lang_stack;
// Function that tries to interrupt the currently running Python thread.
Expand Down
Loading

0 comments on commit cc156f7

Please sign in to comment.