Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OSX error and re-merge unhandled exceptions handling #14138

Merged
merged 5 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,16 @@ cc_test(
],
)

cc_test(
name = "memory_store_test",
srcs = ["src/ray/core_worker/test/memory_store_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "direct_actor_transport_test",
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
Expand Down
25 changes: 22 additions & 3 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,20 @@ cdef void delete_spilled_objects_handler(
job_id=None)


cdef void unhandled_exception_handler(const CRayObject& error) nogil:
with gil:
worker = ray.worker.global_worker
data = None
metadata = None
if error.HasData():
data = Buffer.make(error.GetData())
if error.HasMetadata():
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
object_ids = [None]
worker.raise_errors([(data, metadata)], object_ids)


# This function introduces ~2-7us of overhead per call (i.e., it can be called
# up to hundreds of thousands of times per second).
cdef void get_py_stack(c_string* stack_out) nogil:
Expand Down Expand Up @@ -843,6 +857,7 @@ cdef class CoreWorker:
options.spill_objects = spill_objects_handler
options.restore_spilled_objects = restore_spilled_objects_handler
options.delete_spilled_objects = delete_spilled_objects_handler
options.unhandled_exception_handler = unhandled_exception_handler
options.get_lang_stack = get_py_stack
options.ref_counting_enabled = True
options.is_local_mode = local_mode
Expand Down Expand Up @@ -1453,9 +1468,13 @@ cdef class CoreWorker:
object_ref.native())

def remove_object_ref_reference(self, ObjectRef object_ref):
# Note: faster to not release GIL for short-running op.
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
object_ref.native())
cdef:
CObjectID c_object_id = object_ref.native()
# We need to release the gil since object destruction may call the
# unhandled exception handler.
with nogil:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)

def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
Expand Down
1 change: 1 addition & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(void(
const c_vector[c_string]&,
CWorkerType) nogil) delete_spilled_objects
(void(const CRayObject&) nogil) unhandled_exception_handler
(void(c_string *stack_out) nogil) get_lang_stack
c_bool ref_counting_enabled
c_bool is_local_mode
Expand Down
46 changes: 46 additions & 0 deletions python/ray/tests/test_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@
get_error_message, Semaphore)


def test_unhandled_errors(ray_start_regular):
@ray.remote
def f():
raise ValueError()

@ray.remote
class Actor:
def f(self):
raise ValueError()

a = Actor.remote()
num_exceptions = 0

def interceptor(e):
nonlocal num_exceptions
num_exceptions += 1

# Test we report unhandled exceptions.
ray.worker._unhandled_error_handler = interceptor
x1 = f.remote()
x2 = a.f.remote()
del x1
del x2
wait_for_condition(lambda: num_exceptions == 2)

# Test we don't report handled exceptions.
x1 = f.remote()
x2 = a.f.remote()
with pytest.raises(ray.exceptions.RayError) as err: # noqa
ray.get([x1, x2])
del x1
del x2
time.sleep(1)
assert num_exceptions == 2, num_exceptions

# Test suppression with env var works.
try:
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
x1 = f.remote()
del x1
time.sleep(1)
assert num_exceptions == 2, num_exceptions
finally:
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]


def test_failed_task(ray_start_regular, error_pubsub):
@ray.remote
def throw_exception_fct1():
Expand Down
79 changes: 19 additions & 60 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
import os
import redis
from six.moves import queue
import sys
import threading
import time
Expand Down Expand Up @@ -69,6 +68,12 @@
logger = logging.getLogger(__name__)


# Visible for testing.
def _unhandled_error_handler(e: Exception):
logger.error("Unhandled error (suppress with "
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))


class Worker:
"""A class used to define the control flow of a worker process.

Expand Down Expand Up @@ -277,6 +282,14 @@ def put_object(self, value, object_ref=None):
self.core_worker.put_serialized_object(
serialized_value, object_ref=object_ref))

def raise_errors(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
out = context.deserialize_objects(data_metadata_pairs, object_refs)
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
return
for e in out:
_unhandled_error_handler(e)

def deserialize_objects(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
return context.deserialize_objects(data_metadata_pairs, object_refs)
Expand Down Expand Up @@ -854,13 +867,6 @@ def custom_excepthook(type, value, tb):

sys.excepthook = custom_excepthook

# The last time we raised a TaskError in this process. We use this value to
# suppress redundant error messages pushed from the workers.
last_task_error_raise_time = 0

# The max amount of seconds to wait before printing out an uncaught error.
UNCAUGHT_ERROR_GRACE_PERIOD = 5


def print_logs(redis_client, threads_stopped, job_id):
"""Prints log messages from workers on all of the nodes.
Expand Down Expand Up @@ -1011,51 +1017,14 @@ def color_for(data: Dict[str, str]) -> str:
file=print_file)


def print_error_messages_raylet(task_error_queue, threads_stopped):
"""Prints message received in the given output queue.

This checks periodically if any un-raised errors occurred in the
background.

Args:
task_error_queue (queue.Queue): A queue used to receive errors from the
thread that listens to Redis.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""

while True:
# Exit if we received a signal that we should stop.
if threads_stopped.is_set():
return

try:
error, t = task_error_queue.get(block=False)
except queue.Empty:
threads_stopped.wait(timeout=0.01)
continue
# Delay errors a little bit of time to attempt to suppress redundant
# messages originating from the worker.
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
threads_stopped.wait(timeout=1)
if threads_stopped.is_set():
break
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
logger.debug(f"Suppressing error from worker: {error}")
else:
logger.error(f"Possible unhandled error from worker: {error}")


def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
def listen_error_messages_raylet(worker, threads_stopped):
"""Listen to error messages in the background on the driver.

This runs in a separate thread on the driver and pushes (error, time)
tuples to the output queue.

Args:
worker: The worker class that this thread belongs to.
task_error_queue (queue.Queue): A queue used to communicate with the
thread that prints the errors found by this thread.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
Expand Down Expand Up @@ -1094,8 +1063,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):

error_message = error_data.error_message
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
else:
logger.warning(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
Expand Down Expand Up @@ -1258,19 +1228,12 @@ def connect(node,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
q = queue.Queue()
worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, q, worker.threads_stopped))
worker.printer_thread = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(q, worker.threads_stopped))
args=(worker, worker.threads_stopped))
worker.listener_thread.daemon = True
worker.listener_thread.start()
worker.printer_thread.daemon = True
worker.printer_thread.start()
if log_to_driver:
global_worker_stdstream_dispatcher.add_handler(
"ray_print_logs", print_to_stdstream)
Expand Down Expand Up @@ -1323,8 +1286,6 @@ def disconnect(exiting_interpreter=False):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
worker.listener_thread.join()
if hasattr(worker, "printer_thread"):
worker.printer_thread.join()
if hasattr(worker, "logger_thread"):
worker.logger_thread.join()
worker.threads_stopped.clear()
Expand Down Expand Up @@ -1436,13 +1397,11 @@ def get(object_refs, *, timeout=None):
raise ValueError("'object_refs' must either be an object ref "
"or a list of object refs.")

global last_task_error_raise_time
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values, debugger_breakpoint = worker.get_objects(
object_refs, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
if isinstance(value, ray.exceptions.ObjectLostError):
worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
Expand Down
8 changes: 8 additions & 0 deletions src/ray/common/ray_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,20 @@ class RayObject {
/// large to return directly as part of a gRPC response).
bool IsInPlasmaError() const;

/// Mark this object as accessed before.
void SetAccessed() { accessed_ = true; };

/// Check if this object was accessed before.
bool WasAccessed() const { return accessed_; }

private:
std::shared_ptr<Buffer> data_;
std::shared_ptr<Buffer> metadata_;
const std::vector<ObjectID> nested_ids_;
/// Whether this class holds a data copy.
bool has_data_copy_;
/// Whether this object was accessed.
bool accessed_ = false;
};

} // namespace ray
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
return Status::OK();
},
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
options_.check_signals));
options_.check_signals, options_.unhandled_exception_handler));

auto check_node_alive_fn = [this](const NodeID &node_id) {
auto node = gcs_client_->Nodes().Get(node_id);
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct CoreWorkerOptions {
spill_objects(nullptr),
restore_spilled_objects(nullptr),
delete_spilled_objects(nullptr),
unhandled_exception_handler(nullptr),
get_lang_stack(nullptr),
kill_main(nullptr),
ref_counting_enabled(false),
Expand Down Expand Up @@ -146,6 +147,8 @@ struct CoreWorkerOptions {
/// Application-language callback to delete objects from external storage.
std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
delete_spilled_objects;
/// Function to call on error objects never retrieved.
std::function<void(const RayObject &error)> unhandled_exception_handler;
/// Language worker callback to get the current call stack.
std::function<void(std::string *)> get_lang_stack;
// Function that tries to interrupt the currently running Python thread.
Expand Down
Loading