diff --git a/BUILD.bazel b/BUILD.bazel index c1745e4688526..7dbd8fadb5266 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -702,6 +702,16 @@ cc_test( ], ) +cc_test( + name = "memory_store_test", + srcs = ["src/ray/core_worker/test/memory_store_test.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "direct_actor_transport_test", srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"], diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 99cc6588ed427..57e9e2e770db2 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -734,6 +734,20 @@ cdef void delete_spilled_objects_handler( job_id=None) +cdef void unhandled_exception_handler(const CRayObject& error) nogil: + with gil: + worker = ray.worker.global_worker + data = None + metadata = None + if error.HasData(): + data = Buffer.make(error.GetData()) + if error.HasMetadata(): + metadata = Buffer.make(error.GetMetadata()).to_pybytes() + # TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors? + object_ids = [None] + worker.raise_errors([(data, metadata)], object_ids) + + # This function introduces ~2-7us of overhead per call (i.e., it can be called # up to hundreds of thousands of times per second). cdef void get_py_stack(c_string* stack_out) nogil: @@ -843,6 +857,7 @@ cdef class CoreWorker: options.spill_objects = spill_objects_handler options.restore_spilled_objects = restore_spilled_objects_handler options.delete_spilled_objects = delete_spilled_objects_handler + options.unhandled_exception_handler = unhandled_exception_handler options.get_lang_stack = get_py_stack options.ref_counting_enabled = True options.is_local_mode = local_mode @@ -1453,9 +1468,13 @@ cdef class CoreWorker: object_ref.native()) def remove_object_ref_reference(self, ObjectRef object_ref): - # Note: faster to not release GIL for short-running op. - CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference( - object_ref.native()) + cdef: + CObjectID c_object_id = object_ref.native() + # We need to release the gil since object destruction may call the + # unhandled exception handler. + with nogil: + CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference( + c_object_id) def serialize_and_promote_object_ref(self, ObjectRef object_ref): cdef: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 6114b9e7d58c0..2eb5f109bf654 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: (void( const c_vector[c_string]&, CWorkerType) nogil) delete_spilled_objects + (void(const CRayObject&) nogil) unhandled_exception_handler (void(c_string *stack_out) nogil) get_lang_stack c_bool ref_counting_enabled c_bool is_local_mode diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 0e4405092be39..35e7088d6a3d4 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -20,6 +20,52 @@ get_error_message, Semaphore) +def test_unhandled_errors(ray_start_regular): + @ray.remote + def f(): + raise ValueError() + + @ray.remote + class Actor: + def f(self): + raise ValueError() + + a = Actor.remote() + num_exceptions = 0 + + def interceptor(e): + nonlocal num_exceptions + num_exceptions += 1 + + # Test we report unhandled exceptions. + ray.worker._unhandled_error_handler = interceptor + x1 = f.remote() + x2 = a.f.remote() + del x1 + del x2 + wait_for_condition(lambda: num_exceptions == 2) + + # Test we don't report handled exceptions. + x1 = f.remote() + x2 = a.f.remote() + with pytest.raises(ray.exceptions.RayError) as err: # noqa + ray.get([x1, x2]) + del x1 + del x2 + time.sleep(1) + assert num_exceptions == 2, num_exceptions + + # Test suppression with env var works. + try: + os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1" + x1 = f.remote() + del x1 + time.sleep(1) + assert num_exceptions == 2, num_exceptions + finally: + del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] + + def test_failed_task(ray_start_regular, error_pubsub): @ray.remote def throw_exception_fct1(): diff --git a/python/ray/worker.py b/python/ray/worker.py index 57bca1858df8d..7239b80a982e0 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -9,7 +9,6 @@ import logging import os import redis -from six.moves import queue import sys import threading import time @@ -69,6 +68,12 @@ logger = logging.getLogger(__name__) +# Visible for testing. +def _unhandled_error_handler(e: Exception): + logger.error("Unhandled error (suppress with " + "RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e)) + + class Worker: """A class used to define the control flow of a worker process. @@ -277,6 +282,14 @@ def put_object(self, value, object_ref=None): self.core_worker.put_serialized_object( serialized_value, object_ref=object_ref)) + def raise_errors(self, data_metadata_pairs, object_refs): + context = self.get_serialization_context() + out = context.deserialize_objects(data_metadata_pairs, object_refs) + if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ: + return + for e in out: + _unhandled_error_handler(e) + def deserialize_objects(self, data_metadata_pairs, object_refs): context = self.get_serialization_context() return context.deserialize_objects(data_metadata_pairs, object_refs) @@ -854,13 +867,6 @@ def custom_excepthook(type, value, tb): sys.excepthook = custom_excepthook -# The last time we raised a TaskError in this process. We use this value to -# suppress redundant error messages pushed from the workers. -last_task_error_raise_time = 0 - -# The max amount of seconds to wait before printing out an uncaught error. -UNCAUGHT_ERROR_GRACE_PERIOD = 5 - def print_logs(redis_client, threads_stopped, job_id): """Prints log messages from workers on all of the nodes. @@ -1011,42 +1017,7 @@ def color_for(data: Dict[str, str]) -> str: file=print_file) -def print_error_messages_raylet(task_error_queue, threads_stopped): - """Prints message received in the given output queue. - - This checks periodically if any un-raised errors occurred in the - background. - - Args: - task_error_queue (queue.Queue): A queue used to receive errors from the - thread that listens to Redis. - threads_stopped (threading.Event): A threading event used to signal to - the thread that it should exit. - """ - - while True: - # Exit if we received a signal that we should stop. - if threads_stopped.is_set(): - return - - try: - error, t = task_error_queue.get(block=False) - except queue.Empty: - threads_stopped.wait(timeout=0.01) - continue - # Delay errors a little bit of time to attempt to suppress redundant - # messages originating from the worker. - while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): - threads_stopped.wait(timeout=1) - if threads_stopped.is_set(): - break - if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: - logger.debug(f"Suppressing error from worker: {error}") - else: - logger.error(f"Possible unhandled error from worker: {error}") - - -def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): +def listen_error_messages_raylet(worker, threads_stopped): """Listen to error messages in the background on the driver. This runs in a separate thread on the driver and pushes (error, time) @@ -1054,8 +1025,6 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): Args: worker: The worker class that this thread belongs to. - task_error_queue (queue.Queue): A queue used to communicate with the - thread that prints the errors found by this thread. threads_stopped (threading.Event): A threading event used to signal to the thread that it should exit. """ @@ -1094,8 +1063,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): error_message = error_data.error_message if (error_data.type == ray_constants.TASK_PUSH_ERROR): - # Delay it a bit to see if we can suppress it - task_error_queue.put((error_message, time.time())) + # TODO(ekl) remove task push errors entirely now that we have + # the separate unhandled exception handler. + pass else: logger.warning(error_message) except (OSError, redis.exceptions.ConnectionError) as e: @@ -1258,19 +1228,12 @@ def connect(node, # temporarily using this implementation which constantly queries the # scheduler for new error messages. if mode == SCRIPT_MODE: - q = queue.Queue() worker.listener_thread = threading.Thread( target=listen_error_messages_raylet, name="ray_listen_error_messages", - args=(worker, q, worker.threads_stopped)) - worker.printer_thread = threading.Thread( - target=print_error_messages_raylet, - name="ray_print_error_messages", - args=(q, worker.threads_stopped)) + args=(worker, worker.threads_stopped)) worker.listener_thread.daemon = True worker.listener_thread.start() - worker.printer_thread.daemon = True - worker.printer_thread.start() if log_to_driver: global_worker_stdstream_dispatcher.add_handler( "ray_print_logs", print_to_stdstream) @@ -1323,8 +1286,6 @@ def disconnect(exiting_interpreter=False): worker.import_thread.join_import_thread() if hasattr(worker, "listener_thread"): worker.listener_thread.join() - if hasattr(worker, "printer_thread"): - worker.printer_thread.join() if hasattr(worker, "logger_thread"): worker.logger_thread.join() worker.threads_stopped.clear() @@ -1436,13 +1397,11 @@ def get(object_refs, *, timeout=None): raise ValueError("'object_refs' must either be an object ref " "or a list of object refs.") - global last_task_error_raise_time # TODO(ujvl): Consider how to allow user to retrieve the ready objects. values, debugger_breakpoint = worker.get_objects( object_refs, timeout=timeout) for i, value in enumerate(values): if isinstance(value, RayError): - last_task_error_raise_time = time.time() if isinstance(value, ray.exceptions.ObjectLostError): worker.core_worker.dump_object_store_memory_usage() if isinstance(value, RayTaskError): diff --git a/src/ray/common/ray_object.h b/src/ray/common/ray_object.h index 633a5d787c7e0..c036550a86529 100644 --- a/src/ray/common/ray_object.h +++ b/src/ray/common/ray_object.h @@ -92,12 +92,20 @@ class RayObject { /// large to return directly as part of a gRPC response). bool IsInPlasmaError() const; + /// Mark this object as accessed before. + void SetAccessed() { accessed_ = true; }; + + /// Check if this object was accessed before. + bool WasAccessed() const { return accessed_; } + private: std::shared_ptr data_; std::shared_ptr metadata_; const std::vector nested_ids_; /// Whether this class holds a data copy. bool has_data_copy_; + /// Whether this object was accessed. + bool accessed_ = false; }; } // namespace ray diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index ac4e22ac23da1..111bf5c4c9343 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -422,7 +422,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ return Status::OK(); }, options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_, - options_.check_signals)); + options_.check_signals, options_.unhandled_exception_handler)); auto check_node_alive_fn = [this](const NodeID &node_id) { auto node = gcs_client_->Nodes().Get(node_id); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index ffa3f24969d87..863439a51ba7b 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -82,6 +82,7 @@ struct CoreWorkerOptions { spill_objects(nullptr), restore_spilled_objects(nullptr), delete_spilled_objects(nullptr), + unhandled_exception_handler(nullptr), get_lang_stack(nullptr), kill_main(nullptr), ref_counting_enabled(false), @@ -146,6 +147,8 @@ struct CoreWorkerOptions { /// Application-language callback to delete objects from external storage. std::function &, rpc::WorkerType)> delete_spilled_objects; + /// Function to call on error objects never retrieved. + std::function unhandled_exception_handler; /// Language worker callback to get the current call stack. std::function get_lang_stack; // Function that tries to interrupt the currently running Python thread. diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 6dad1b37be724..7897b6504e82c 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -93,6 +93,7 @@ void GetRequest::Set(const ObjectID &object_id, std::shared_ptr objec if (is_ready_) { return; // We have already hit the number of objects to return limit. } + object->SetAccessed(); objects_.emplace(object_id, object); if (objects_.size() == num_objects_ || (abort_if_any_object_is_exception_ && object->IsException() && @@ -106,6 +107,7 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { std::unique_lock lock(mutex_); auto iter = objects_.find(object_id); if (iter != objects_.end()) { + iter->second->SetAccessed(); return iter->second; } @@ -116,11 +118,13 @@ CoreWorkerMemoryStore::CoreWorkerMemoryStore( std::function store_in_plasma, std::shared_ptr counter, std::shared_ptr raylet_client, - std::function check_signals) + std::function check_signals, + std::function unhandled_exception_handler) : store_in_plasma_(store_in_plasma), ref_counter_(counter), raylet_client_(raylet_client), - check_signals_(check_signals) {} + check_signals_(check_signals), + unhandled_exception_handler_(unhandled_exception_handler) {} void CoreWorkerMemoryStore::GetAsync( const ObjectID &object_id, std::function)> callback) { @@ -136,6 +140,7 @@ void CoreWorkerMemoryStore::GetAsync( } // It's important for performance to run the callback outside the lock. if (ptr != nullptr) { + ptr->SetAccessed(); callback(ptr); } } @@ -146,6 +151,7 @@ std::shared_ptr CoreWorkerMemoryStore::GetOrPromoteToPlasma( auto iter = objects_.find(object_id); if (iter != objects_.end()) { auto obj = iter->second; + obj->SetAccessed(); if (obj->IsInPlasmaError()) { return nullptr; } @@ -210,6 +216,8 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ if (should_add_entry) { // If there is no existing get request, then add the `RayObject` to map. objects_.emplace(object_id, object_entry); + } else { + OnErase(object_entry); } } @@ -223,6 +231,7 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ // It's important for performance to run the callbacks outside the lock. for (const auto &cb : async_callbacks) { + object_entry->SetAccessed(); cb(object_entry); } @@ -257,6 +266,7 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, const auto &object_id = object_ids[i]; auto iter = objects_.find(object_id); if (iter != objects_.end()) { + iter->second->SetAccessed(); (*results)[i] = iter->second; if (remove_after_get) { // Note that we cannot remove the object_id from `objects_` now, @@ -426,6 +436,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set &object_i if (it->second->IsInPlasmaError()) { plasma_ids_to_delete->insert(object_id); } else { + OnErase(it->second); objects_.erase(it); } } @@ -435,7 +446,11 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set &object_i void CoreWorkerMemoryStore::Delete(const std::vector &object_ids) { absl::MutexLock lock(&mu_); for (const auto &object_id : object_ids) { - objects_.erase(object_id); + auto it = objects_.find(object_id); + if (it != objects_.end()) { + OnErase(it->second); + objects_.erase(it); + } } } @@ -451,6 +466,14 @@ bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id, bool *in_plasma) return false; } +void CoreWorkerMemoryStore::OnErase(std::shared_ptr obj) { + // TODO(ekl) note that this doesn't warn on errors that are stored in plasma. + if (obj->IsException() && !obj->IsInPlasmaError() && !obj->WasAccessed() && + unhandled_exception_handler_ != nullptr) { + unhandled_exception_handler_(*obj); + } +} + MemoryStoreStats CoreWorkerMemoryStore::GetMemoryStoreStatisticalData() { absl::MutexLock lock(&mu_); MemoryStoreStats item; diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 709227f65206d..0ca94ef6cc022 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -35,7 +35,8 @@ class CoreWorkerMemoryStore { std::function store_in_plasma = nullptr, std::shared_ptr counter = nullptr, std::shared_ptr raylet_client = nullptr, - std::function check_signals = nullptr); + std::function check_signals = nullptr, + std::function unhandled_exception_handler = nullptr); ~CoreWorkerMemoryStore(){}; /// Put an object with specified ID into object store. @@ -143,6 +144,9 @@ class CoreWorkerMemoryStore { std::vector> *results, bool abort_if_any_object_is_exception); + /// Called when an object is erased from the store. + void OnErase(std::shared_ptr obj); + /// Optional callback for putting objects into the plasma store. std::function store_in_plasma_; @@ -173,6 +177,9 @@ class CoreWorkerMemoryStore { /// Function passed in to be called to check for signals (e.g., Ctrl-C). std::function check_signals_; + + /// Function called to report unhandled exceptions. + std::function unhandled_exception_handler_; }; } // namespace ray