diff --git a/continuous_integration/environment-3.7.yaml b/continuous_integration/environment-3.7.yaml index bb7838fe20..bfb1c0aac6 100644 --- a/continuous_integration/environment-3.7.yaml +++ b/continuous_integration/environment-3.7.yaml @@ -40,6 +40,6 @@ dependencies: - zict - zstandard - pip: - - git+https://github.com/dask/dask + - git+https://github.com/madsbk/dask.git@formalization_of_computation - git+https://github.com/jcrist/crick # Only tested here - keras diff --git a/continuous_integration/environment-3.8.yaml b/continuous_integration/environment-3.8.yaml index 118d4231e4..f48dec6e3e 100644 --- a/continuous_integration/environment-3.8.yaml +++ b/continuous_integration/environment-3.8.yaml @@ -38,5 +38,5 @@ dependencies: - zict - zstandard - pip: - - git+https://github.com/dask/dask + - git+https://github.com/madsbk/dask.git@formalization_of_computation - keras diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index 0dcd972db5..ee8a33e230 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -45,7 +45,7 @@ dependencies: - zict # overridden by git tip below - zstandard - pip: - - git+https://github.com/dask/dask + - git+https://github.com/madsbk/dask.git@formalization_of_computation - git+https://github.com/dask/s3fs - git+https://github.com/dask/zict - git+https://github.com/intake/filesystem_spec diff --git a/distributed/client.py b/distributed/client.py index 39153bd622..712eabea3b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -68,7 +68,8 @@ from .metrics import time from .objects import HasWhat, SchedulerInfo, WhoHas from .protocol import to_serialize -from .protocol.pickle import dumps, loads +from .protocol.computation import PickledObject +from .protocol.pickle import dumps from .publish import Datasets from .pubsub import PubSubClientExtension from .security import Security @@ -1325,7 +1326,7 @@ def _handle_key_in_memory(self, key=None, type=None, workers=None): if state is not None: if type and not state.type: # Type exists and not yet set try: - type = loads(type) + type = PickledObject.deserialize(type) except Exception: type = None # Here, `type` may be a str if actual type failed diff --git a/distributed/diagnostics/eventstream.py b/distributed/diagnostics/eventstream.py index 34805e9085..0a12792c6a 100644 --- a/distributed/diagnostics/eventstream.py +++ b/distributed/diagnostics/eventstream.py @@ -1,7 +1,7 @@ import logging from ..core import coerce_to_address, connect -from ..worker import dumps_function +from ..protocol.computation import PickledCallable from .plugin import SchedulerPlugin logger = logging.getLogger(__name__) @@ -62,10 +62,10 @@ async def eventstream(address, interval): await comm.write( { "op": "feed", - "setup": dumps_function(EventStream), - "function": dumps_function(swap_buffer), + "setup": PickledCallable.serialize(EventStream), + "function": PickledCallable.serialize(swap_buffer), "interval": interval, - "teardown": dumps_function(teardown), + "teardown": PickledCallable.serialize(teardown), } ) return comm diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 57b0cb3839..7f274867b2 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -3,9 +3,9 @@ from tlz import merge, valmap from ..core import coerce_to_address, connect +from ..protocol.computation import PickledCallable from ..scheduler import Scheduler from ..utils import color_of, key_split -from ..worker import dumps_function from .progress import AllProgress logger = logging.getLogger(__name__) @@ -50,10 +50,10 @@ async def progress_stream(address, interval): await comm.write( { "op": "feed", - "setup": dumps_function(AllProgress), - "function": dumps_function(counts), + "setup": PickledCallable.serialize(AllProgress), + "function": PickledCallable.serialize(counts), "interval": interval, - "teardown": dumps_function(remove_plugin), + "teardown": PickledCallable.serialize(remove_plugin), } ) return comm diff --git a/distributed/diagnostics/progressbar.py b/distributed/diagnostics/progressbar.py index a85b72d4b0..c41226e4ff 100644 --- a/distributed/diagnostics/progressbar.py +++ b/distributed/diagnostics/progressbar.py @@ -8,12 +8,10 @@ from tlz import valmap from tornado.ioloop import IOLoop -import dask - from ..client import default_client, futures_of from ..core import CommClosedError, coerce_to_address, connect -from ..protocol.pickle import dumps -from ..utils import LoopRunner, is_kernel, key_split +from ..protocol.computation import PickledCallable +from ..utils import LoopRunner, is_kernel, key_split, parse_timedelta from .progress import MultiProgress, Progress, format_time logger = logging.getLogger(__name__) @@ -36,7 +34,7 @@ def __init__(self, keys, scheduler=None, interval="100ms", complete=True): break self.keys = {k.key if hasattr(k, "key") else k for k in keys} - self.interval = dask.utils.parse_timedelta(interval, default="s") + self.interval = parse_timedelta(interval, default="s") self.complete = complete self._start_time = default_timer() @@ -71,8 +69,8 @@ def function(scheduler, p): await self.comm.write( { "op": "feed", - "setup": dumps(setup), - "function": dumps(function), + "setup": PickledCallable.serialize(setup), + "function": PickledCallable.serialize(function), "interval": self.interval, }, serializers=self.client()._serializers if self.client else None, @@ -263,8 +261,8 @@ def function(scheduler, p): await self.comm.write( { "op": "feed", - "setup": dumps(setup), - "function": dumps(function), + "setup": PickledCallable.serialize(setup), + "function": PickledCallable.serialize(function), "interval": self.interval, } ) diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index e47c1bd5bc..7f605bf536 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -74,16 +74,14 @@ def record_display(*args): import re from operator import add -from tlz import valmap - from distributed.client import wait from distributed.diagnostics.progressbar import ( MultiProgressWidget, ProgressWidget, progress, ) +from distributed.protocol.computation import typeset_dask_graph from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws -from distributed.worker import dumps_task @gen_cluster(client=True) @@ -146,8 +144,7 @@ async def test_multi_progressbar_widget(c, s, a, b): @gen_cluster() async def test_multi_progressbar_widget_after_close(s, a, b): s.update_graph( - tasks=valmap( - dumps_task, + tasks=typeset_dask_graph( { "x-1": (inc, 1), "x-2": (inc, "x-1"), @@ -232,8 +229,7 @@ def test_progressbar_cancel(client): @gen_cluster() async def test_multibar_complete(s, a, b): s.update_graph( - tasks=valmap( - dumps_task, + tasks=typeset_dask_graph( { "x-1": (inc, 1), "x-2": (inc, "x-1"), diff --git a/distributed/protocol/computation.py b/distributed/protocol/computation.py new file mode 100644 index 0000000000..7280c5a684 --- /dev/null +++ b/distributed/protocol/computation.py @@ -0,0 +1,270 @@ +""" +This module implements graph computations based on the specification in Dask[1]: +> A computation may be one of the following: +> - Any key present in the Dask graph like `'x'` +> - Any other value like `1`, to be interpreted literally +> - A task like `(inc, 'x')` +> - A list of computations, like `[1, 'x', (inc, 'x')]` + +In order to support efficient and flexible task serialization, this module introduces +classes for computations, tasks, data, functions, etc. + +Notable Classes +--------------- + +- `PickledObject` - An object that are serialized using `protocol.pickle`. + This object isn't a computation by itself instead users can build pickled + computations that contains pickled objects. This object is automatically + de-serialized by the Worker before execution. + +- `Computation` - A computation that the Worker can execute. The Scheduler sees + this as a black box. A computation **cannot** contain pickled objects but it may + contain `Serialize` and/or `Serialized` objects, which will be de-serialize when + arriving on the Worker automatically. + +- `PickledComputation` - A computation that are serialized using `protocol.pickle`. + The class is derived from `Computation` but **can** contain pickled objects. + Pickled objects and itself will be de-serialize by the Worker before execution. + +Notable Functions +----------------- + +- `typeset_dask_graph()` - Use to typeset a Dask graph, which wrap computations in + either the `Data` or `Task` class. This should be done before communicating the graph. + Note, this replaces the old `tlz.valmap(dumps_task, dsk)` operation. + +[1] +""" + +import threading +import warnings +from typing import Any, Callable, Dict, Iterable, Mapping, MutableMapping, Tuple + +import tlz + +from dask.core import istask +from dask.utils import apply, format_bytes + +from ..utils import LRU +from . import pickle + + +def identity(x, *args_ignored): + return x + + +def execute_task(task, *args_ignored): + """Evaluate a nested task + + >>> inc = lambda x: x + 1 + >>> execute_task((inc, 1)) + 2 + >>> execute_task((sum, [1, 2, (inc, 3)])) + 7 + """ + if istask(task): + func, args = task[0], task[1:] + return func(*map(execute_task, args)) + elif isinstance(task, list): + return list(map(execute_task, task)) + else: + return task + + +class PickledObject: + _value: bytes + + def __init__(self, value: bytes): + self._value = value + + def __reduce__(self): + return (type(self), (self._value,)) + + @classmethod + def msgpack_decode(cls, state: Mapping): + return cls(state["value"]) + + def msgpack_encode(self) -> dict: + return { + f"__{type(self).__name__}__": True, + "value": self._value, + } + + @classmethod + def serialize(cls, obj) -> "PickledObject": + return cls(pickle.dumps(obj)) + + def deserialize(self): + return pickle.loads(self._value) + + def __eq__(self, other): + return isinstance(other, PickledObject) and self._value == other._value + + def __ne__(self, other): + return not (self == other) + + +class PickledCallable(PickledObject): + cache_dumps: MutableMapping[int, bytes] = LRU(maxsize=100) + cache_loads: MutableMapping[int, Callable] = LRU(maxsize=100) + cache_max_sized_obj = 1_000_000 + cache_dumps_lock = threading.Lock() + + @classmethod + def dumps_function(cls, func: Callable) -> bytes: + """Dump a function to bytes, cache functions""" + + try: + with cls.cache_dumps_lock: + ret = cls.cache_dumps[func] + except KeyError: + ret = pickle.dumps(func) + if len(ret) <= cls.cache_max_sized_obj: + with cls.cache_dumps_lock: + cls.cache_dumps[func] = ret + except TypeError: # Unhashable function + ret = pickle.dumps(func) + return ret + + @classmethod + def loads_function(cls, dumped_func: bytes): + """Load a function from bytes, cache bytes""" + if len(dumped_func) > cls.cache_max_sized_obj: + return pickle.loads(dumped_func) + + try: + ret = cls.cache_loads[dumped_func] + except KeyError: + cls.cache_loads[dumped_func] = ret = pickle.loads(dumped_func) + return ret + + @classmethod + def serialize(cls, func: Callable) -> "PickledCallable": + if isinstance(func, cls): + return func + else: + return cls(cls.dumps_function(func)) + + def deserialize(self) -> Callable: + return self.loads_function(self._value) + + def __call__(self, *args, **kwargs): + return self.deserialize()(*args, **kwargs) + + +class Computation: + def __init__(self, value, is_a_task: bool): + self._value = value + self._is_a_task = is_a_task + + @classmethod + def msgpack_decode(cls, state: Mapping): + return cls(state["value"], state["is_a_task"]) + + def msgpack_encode(self) -> dict: + return { + f"__{type(self).__name__}__": True, + "value": self._value, + "is_a_task": self._is_a_task, + } + + def get_func_and_args(self) -> Tuple[Callable, Iterable, Mapping]: + if self._is_a_task: + return (execute_task, (self._value,), {}) + else: + return (identity, (self._value,), {}) + + def get_computation(self) -> "Computation": + return self + + +class PickledComputation(Computation): + _size_warning_triggered: bool = False + _size_warning_limit: int = 1_000_000 + + @classmethod + def serialize(cls, value, is_a_task: bool): + data = pickle.dumps(value) + ret = cls(data, is_a_task) + if not cls._size_warning_triggered and len(data) > cls._size_warning_limit: + cls._size_warning_triggered = True + s = str(value) + if len(s) > 70: + s = s[:50] + " ... " + s[-15:] + warnings.warn( + "Large object of size %s detected in task graph: \n" + " %s\n" + "Consider scattering large objects ahead of time\n" + "with client.scatter to reduce scheduler burden and \n" + "keep data on workers\n\n" + " future = client.submit(func, big_data) # bad\n\n" + " big_future = client.scatter(big_data) # good\n" + " future = client.submit(func, big_future) # good" + % (format_bytes(len(data)), s) + ) + return ret + + def deserialize(self): + def inner_deserialize(obj): + if isinstance(obj, list): + return [inner_deserialize(o) for o in obj] + elif istask(obj): + return tuple(inner_deserialize(o) for o in obj) + elif isinstance(obj, PickledObject): + return obj.deserialize() + else: + return obj + + return inner_deserialize(pickle.loads(self._value)) + + def get_computation(self) -> Computation: + return Computation(self.deserialize(), self._is_a_task) + + def get_func_and_args(self) -> Tuple[Callable, Iterable, Mapping]: + return self.get_computation().get_func_and_args() + + +def typeset_computation(computation) -> Computation: + from .serialize import Serialize, Serialized, to_serialize + + if isinstance(computation, Computation): + return computation # Already a computation + + contain_pickled = [False] + contain_tasks = [False] + + def serialize_callables(obj): + if isinstance(obj, list): + return [serialize_callables(o) for o in obj] + elif istask(obj): + contain_tasks[0] = True + if obj[0] is apply: + return (apply, PickledCallable.serialize(obj[1])) + tuple( + map(serialize_callables, obj[2:]) + ) + else: + return (PickledCallable.serialize(obj[0]),) + tuple( + map(serialize_callables, obj[1:]) + ) + elif isinstance(obj, PickledObject): + contain_pickled[0] = True + return obj + else: + assert not isinstance(obj, (Serialize, Serialized)), obj + return obj + + computation = serialize_callables(computation) + if contain_tasks[0]: + # We found a task thus this is a (nested) task + return PickledComputation.serialize(computation, is_a_task=True) + elif contain_pickled[0]: + # We found no tasks but it contains pickled objects + return PickledComputation.serialize(computation, is_a_task=False) + else: + # No tasks or pickled objects found. Still, it might contain non-msgpack + # serializable objects thus we wrap it in a `to_serialize` + return Computation(to_serialize(computation), is_a_task=False) + + +def typeset_dask_graph(dsk: Mapping[str, Any]) -> Dict[str, Computation]: + return tlz.valmap(typeset_computation, dsk) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 51815677f5..10cc0157b5 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -12,6 +12,7 @@ from ..utils import ensure_bytes, has_keyword, typename from . import pickle from .compression import decompress, maybe_compress +from .computation import Computation, PickledCallable, PickledComputation, PickledObject from .utils import frame_split_size, msgpack_opts, pack_frames_prelude, unpack_frames lazy_registrations = {} @@ -118,13 +119,17 @@ def msgpack_decode_default(obj): if "__Set__" in obj: return set(obj["as-list"]) - if "__Serialized__" in obj: - # Notice, the data here is marked a Serialized rather than deserialized. This - # is because deserialization requires Pickle which the Scheduler cannot run - # because of security reasons. - # By marking it Serialized, the data is passed through to the workers that - # eventually will deserialize it. - return Serialized(*obj["data"]) + if "__PickledCallable__" in obj: + return PickledCallable.msgpack_decode(obj) + + if "__PickledObject__" in obj: + return PickledObject.msgpack_decode(obj) + + if "__PickledComputation__" in obj: + return PickledComputation.msgpack_decode(obj) + + if "__Computation__" in obj: + return Computation.msgpack_decode(obj) return obj @@ -148,6 +153,9 @@ def msgpack_encode_default(obj): if isinstance(obj, set): return {"__Set__": True, "as-list": list(obj)} + if isinstance(obj, (PickledObject, Computation)): + return obj.msgpack_encode() + return obj @@ -541,32 +549,21 @@ def nested_deserialize(x): {'op': 'update', 'data': 123} """ - def replace_inner(x): - if type(x) is dict: - x = x.copy() - for k, v in x.items(): - typ = type(v) - if typ is dict or typ is list: - x[k] = replace_inner(v) - elif typ is Serialize: - x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) - - elif type(x) is list: - x = list(x) - for k, v in enumerate(x): - typ = type(v) - if typ is dict or typ is list: - x[k] = replace_inner(v) - elif typ is Serialize: - x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) - + typ = type(x) + if typ is dict: + return {k: nested_deserialize(v) for k, v in x.items()} + elif typ is list or typ is tuple: + return typ(nested_deserialize(o) for o in x) + elif typ is Serialize: + return x.data + elif typ is Serialized: + return deserialize(x.header, x.frames) + elif isinstance(x, Computation): + x = x.get_computation() + x._value = nested_deserialize(x._value) + return x + else: return x - - return replace_inner(x) def serialize_bytelist(x, **kwargs): diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py index ec596bc461..12967803ca 100644 --- a/distributed/recreate_tasks.py +++ b/distributed/recreate_tasks.py @@ -3,9 +3,9 @@ from dask.utils import stringify from .client import futures_of, wait +from .protocol.computation import PickledComputation from .utils import sync from .utils_comm import pack_data -from .worker import _deserialize logger = logging.getLogger(__name__) @@ -83,12 +83,9 @@ async def _get_raw_components_from_future(self, future): key = future.key spec = await self.scheduler.get_runspec(key=key) deps, task = spec["deps"], spec["task"] - if isinstance(task, dict): - function, args, kwargs = _deserialize(**task) - return (function, args, kwargs, deps) - else: - function, args, kwargs = _deserialize(task=task) - return (function, args, kwargs, deps) + assert isinstance(task, PickledComputation) + function, args, kwargs = task.get_func_and_args() + return (function, args, kwargs, deps) async def _prepare_raw_components(self, raw_components): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 566f149fc3..59b1c8eb58 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -59,6 +59,7 @@ from .multi_lock import MultiLockExtension from .node import ServerNode from .proctitle import setproctitle +from .protocol.computation import Computation from .protocol.pickle import loads from .publish import PublishExtension from .pubsub import PubSubSchedulerExtension @@ -389,7 +390,7 @@ class WorkerState: .. attribute:: processing: {TaskState: cost} A dictionary of tasks that have been submitted to this worker. - Each task state is asssociated with the expected cost in seconds + Each task state is associated with the expected cost in seconds of running that task, summing both the task's expected computation time and the expected communication time of its result. @@ -404,7 +405,7 @@ class WorkerState: .. attribute:: executing: {TaskState: duration} A dictionary of tasks that are currently being run on this worker. - Each task state is asssociated with the duration in seconds which + Each task state is associated with the duration in seconds which the task has been running. .. attribute:: has_what: {TaskState} @@ -768,7 +769,7 @@ def ncores(self): @final @cclass -class Computation: +class _Computation: """ Collection tracking a single compute or persist call @@ -2132,7 +2133,7 @@ def __pdict__(self): @ccall @exceptval(check=False) def new_task( - self, key: str, spec: object, state: str, computation: Computation = None + self, key: str, spec: object, state: str, computation: _Computation = None ) -> TaskState: """Create a new task, and associated states""" ts: TaskState = TaskState(key, spec) @@ -4408,9 +4409,9 @@ def update_graph( if parent._total_occupancy > 1e-9 and parent._computations: # Still working on something. Assign new tasks to same computation - computation = cast(Computation, parent._computations[-1]) + computation = cast(_Computation, parent._computations[-1]) else: - computation = Computation() + computation = _Computation() parent._computations.append(computation) if code and code not in computation._code: # add new code blocks @@ -6715,12 +6716,6 @@ async def feed( interval = parse_timedelta(interval) with log_errors(): - if function: - function = pickle.loads(function) - if setup: - setup = pickle.loads(setup) - if teardown: - teardown = pickle.loads(teardown) state = setup(self) if setup else None if inspect.isawaitable(state): state = await state @@ -7729,11 +7724,15 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> if duration < 0: duration = state.get_task_duration(ts) + run_spec = ts._run_spec + assert run_spec is None or isinstance(run_spec, Computation) + msg: dict = { "op": "compute-task", "key": ts._key, "priority": ts._priority, "duration": duration, + "runspec": run_spec, } if ts._resource_restrictions: msg["resource_restrictions"] = ts._resource_restrictions @@ -7750,12 +7749,6 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> if state._validate: assert all(msg["who_has"].values()) - task = ts._run_spec - if type(task) is dict: - msg.update(task) - else: - msg["task"] = task - if ts._annotations: msg["annotations"] = ts._annotations return msg diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index e642e6f6ae..79e5c09881 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4660,8 +4660,7 @@ async def test_recreate_error_delayed(c, s, a, b): error_f = await c._get_errored_future(f) function, args, kwargs = await c._get_components_from_future(error_f) assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) + assert args == ((div, 1, 0),) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4680,8 +4679,7 @@ async def test_recreate_error_futures(c, s, a, b): error_f = await c._get_errored_future(f) function, args, kwargs = await c._get_components_from_future(error_f) assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) + assert args == ((div, 1, 0),) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4768,8 +4766,7 @@ async def test_recreate_task_delayed(c, s, a, b): function, args, kwargs = await c._get_components_from_future(f) assert f.status == "finished" - assert function.__name__ == "sum" - assert args == ([1, 1],) + assert args == ((sum, [1, 1]),) assert function(*args, **kwargs) == 2 @@ -4786,8 +4783,7 @@ async def test_recreate_task_futures(c, s, a, b): function, args, kwargs = await c._get_components_from_future(f) assert f.status == "finished" - assert function.__name__ == "sum" - assert args == ([1, 1],) + assert args == ((sum, [1, 1]),) assert function(*args, **kwargs) == 2 @@ -5581,7 +5577,7 @@ async def test_warn_when_submitting_large_values(c, s, a, b): assert "2.00 MB" in text or "1.91 MiB" in text assert "large" in text assert "..." in text - assert "'000" in text + assert "... 000" in text assert "000'" in text assert len(text) < 2000 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 536fd3fd45..0a764e6b06 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -14,17 +14,18 @@ import cloudpickle import psutil import pytest -from tlz import concat, first, frequencies, merge, valmap +from tlz import concat, first, frequencies, merge import dask from dask import delayed -from dask.utils import apply, parse_timedelta, stringify +from dask.utils import parse_timedelta, stringify from distributed import Client, Nanny, Worker, fire_and_forget, wait from distributed.comm import Comm from distributed.compatibility import LINUX, WINDOWS from distributed.core import ConnectionPool, Status, connect, rpc from distributed.metrics import time +from distributed.protocol.computation import PickledCallable, typeset_dask_graph from distributed.protocol.pickle import dumps from distributed.scheduler import MemoryState, Scheduler from distributed.utils import TimeoutError, tmpfile, typename @@ -43,7 +44,7 @@ tls_only_security, varying, ) -from distributed.worker import dumps_function, dumps_task, get_worker +from distributed.worker import get_worker if sys.version_info < (3, 8): try: @@ -296,7 +297,7 @@ async def test_retire_workers_empty(s): @gen_cluster() async def test_remove_client(s, a, b): s.update_graph( - tasks={"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, + tasks=typeset_dask_graph({"x": (inc, 1), "y": (inc, "x")}), dependencies={"x": [], "y": ["x"]}, keys=["y"], client="ident", @@ -323,7 +324,7 @@ async def test_server_listens_to_other_ops(s, a, b): async def test_remove_worker_from_scheduler(s, a, b): dsk = {("x-%d" % i): (inc, i) for i in range(20)} s.update_graph( - tasks=valmap(dumps_task, dsk), + tasks=typeset_dask_graph(dsk), keys=list(dsk), dependencies={k: set() for k in dsk}, ) @@ -389,7 +390,7 @@ async def test_add_worker(s, a, b): dsk = {("x-%d" % i): (inc, i) for i in range(10)} s.update_graph( - tasks=valmap(dumps_task, dsk), + tasks=typeset_dask_graph(dsk), keys=list(dsk), client="client", dependencies={k: set() for k in dsk}, @@ -409,7 +410,9 @@ def func(scheduler): return dumps(dict(scheduler.worker_info)) comm = await connect(s.address) - await comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) + await comm.write( + {"op": "feed", "function": PickledCallable.serialize(func), "interval": 0.01} + ) response = await comm.read() @@ -435,7 +438,9 @@ def func(scheduler): return dumps(dict(scheduler.worker_info)) comm = await connect(s.address) - await comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) + await comm.write( + {"op": "feed", "function": PickledCallable.serialize(func), "interval": 0.01} + ) for i in range(5): response = await comm.read() @@ -461,9 +466,9 @@ def teardown(scheduler, state): await comm.write( { "op": "feed", - "function": dumps(func), - "setup": dumps(setup), - "teardown": dumps(teardown), + "function": PickledCallable.serialize(func), + "setup": PickledCallable.serialize(setup), + "teardown": PickledCallable.serialize(teardown), "interval": 0.01, } ) @@ -490,7 +495,9 @@ def func(scheduler): return True comm = await connect(s.address) - await comm.write({"op": "feed", "function": dumps(func), "interval": 0.05}) + await comm.write( + {"op": "feed", "function": PickledCallable.serialize(func), "interval": 0.05} + ) for i in range(5): response = await comm.read() @@ -549,7 +556,7 @@ async def test_filtered_communication(s, a, b): await c.write( { "op": "update-graph", - "tasks": {"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, + "tasks": typeset_dask_graph({"x": (inc, 1), "y": (inc, "x")}), "dependencies": {"x": [], "y": ["x"]}, "client": "c", "keys": ["y"], @@ -559,10 +566,12 @@ async def test_filtered_communication(s, a, b): await f.write( { "op": "update-graph", - "tasks": { - "x": dumps_task((inc, 1)), - "z": dumps_task((operator.add, "x", 10)), - }, + "tasks": typeset_dask_graph( + { + "x": (inc, 1), + "z": (operator.add, "x", 10), + } + ), "dependencies": {"x": [], "z": ["x"]}, "client": "f", "keys": ["z"], @@ -576,37 +585,10 @@ async def test_filtered_communication(s, a, b): assert msg["key"] == "z" -def test_dumps_function(): - a = dumps_function(inc) - assert cloudpickle.loads(a)(10) == 11 - - b = dumps_function(inc) - assert a is b - - c = dumps_function(dec) - assert a != c - - -def test_dumps_task(): - d = dumps_task((inc, 1)) - assert set(d) == {"function", "args"} - - f = lambda x, y=2: x + y - d = dumps_task((apply, f, (1,), {"y": 10})) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert cloudpickle.loads(d["kwargs"]) == {"y": 10} - - d = dumps_task((apply, f, (1,))) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert set(d) == {"function", "args"} - - @gen_cluster() async def test_ready_remove_worker(s, a, b): s.update_graph( - tasks={"x-%d" % i: dumps_task((inc, i)) for i in range(20)}, + tasks=typeset_dask_graph({"x-%d" % i: (inc, i) for i in range(20)}), keys=["x-%d" % i for i in range(20)], client="client", dependencies={"x-%d" % i: [] for i in range(20)}, @@ -752,11 +734,13 @@ async def test_file_descriptors_dont_leak(s): @gen_cluster() async def test_update_graph_culls(s, a, b): s.update_graph( - tasks={ - "x": dumps_task((inc, 1)), - "y": dumps_task((inc, "x")), - "z": dumps_task((inc, 2)), - }, + tasks=typeset_dask_graph( + { + "x": (inc, 1), + "y": (inc, "x"), + "z": (inc, 2), + } + ), keys=["y"], dependencies={"y": "x", "x": [], "z": []}, client="client", diff --git a/distributed/worker.py b/distributed/worker.py index 21188d2141..444cde1610 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -10,7 +10,7 @@ import threading import warnings import weakref -from collections import defaultdict, deque, namedtuple +from collections import defaultdict, deque from collections.abc import MutableMapping from contextlib import suppress from datetime import timedelta @@ -22,9 +22,8 @@ from tornado.ioloop import IOLoop, PeriodicCallback import dask -from dask.core import istask from dask.system import CPU_COUNT -from dask.utils import apply, format_bytes, funcname, parse_bytes, parse_timedelta +from dask.utils import format_bytes, funcname, parse_bytes, parse_timedelta from . import comm, preloading, profile, system, utils from .batched import BatchedSend @@ -46,16 +45,21 @@ from .metrics import time from .node import ServerNode from .proctitle import setproctitle -from .protocol import pickle, to_serialize +from .protocol import pickle +from .protocol.computation import ( + Computation, + PickledCallable, + PickledComputation, + PickledObject, +) +from .protocol.serialize import to_serialize from .pubsub import PubSubWorkerExtension from .security import Security from .sizeof import safe_sizeof as sizeof from .threadpoolexecutor import ThreadPoolExecutor from .threadpoolexecutor import secede as tpe_secede from .utils import ( - LRU, TimeoutError, - _maybe_complex, get_ip, has_arg, import_file, @@ -96,8 +100,6 @@ dask.config.get("distributed.scheduler.default-data-size") ) -SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"]) - class TaskState: """Holds volatile state relating to an individual Dask task @@ -148,15 +150,13 @@ class TaskState: Parameters ---------- key: str - runspec: SerializedTask - A named tuple containing the ``function``, ``args``, ``kwargs`` and - ``task`` associated with this `TaskState` instance. This defaults to - ``None`` and can remain empty if it is a dependency that this worker + runspec: Computation, optional + This defaults to ``None`` and can remain empty if it is a dependency that this worker will receive from another worker. """ - def __init__(self, key, runspec=None): + def __init__(self, key, runspec: Computation = None): assert key is not None self.key = key self.runspec = runspec @@ -409,7 +409,7 @@ def __init__( lifetime_restart=None, **kwargs, ): - self.tasks = dict() + self.tasks: Dict[str, TaskState] = dict() self.waiting_for_data_count = 0 self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) @@ -1064,7 +1064,7 @@ def func(data): if load: try: import_file(out_filename) - cache_loads.data.clear() + PickledCallable.cache_loads.clear() except Exception as e: logger.exception(e) raise e @@ -1571,10 +1571,7 @@ def cancel_compute(self, key, reason): def add_task( self, key, - function=None, - args=None, - kwargs=None, - task=no_value, + runspec: Computation = None, who_has=None, nbytes=None, priority=None, @@ -1585,7 +1582,6 @@ def add_task( **kwargs2, ): try: - runspec = SerializedTask(function, args, kwargs, task) if key in self.tasks: ts = self.tasks[key] ts.scheduler_holds_ref = True @@ -1608,9 +1604,7 @@ def add_task( self.transition(ts, "waiting", runspec=runspec) else: self.log.append((key, "new")) - self.tasks[key] = ts = TaskState( - key=key, runspec=SerializedTask(function, args, kwargs, task) - ) + self.tasks[key] = ts = TaskState(key=key, runspec=runspec) self.transition(ts, "waiting") # TODO: move transition of `ts` to end of `add_task` # This will require a chained recommendation transition system like @@ -2203,7 +2197,7 @@ def send_task_state_to_scheduler(self, ts): typ = ts.type = type(value) del value try: - typ_serialized = dumps_function(typ) + typ_serialized = PickledObject.serialize(typ) except PicklingError: # Some types fail pickling (example: _thread.lock objects), # send their name as a best effort. @@ -2800,27 +2794,29 @@ def meets_resource_constraints(self, key): return True - async def _maybe_deserialize_task(self, ts): - if not isinstance(ts.runspec, SerializedTask): - return ts.runspec - try: - start = time() - # Offload deserializing large tasks - if sizeof(ts.runspec) > OFFLOAD_THRESHOLD: - function, args, kwargs = await offload(_deserialize, *ts.runspec) - else: - function, args, kwargs = _deserialize(*ts.runspec) - stop = time() + async def _maybe_deserialize_task(self, ts: TaskState) -> Computation: + if isinstance(ts.runspec, PickledComputation): + try: + start = time() + # Offload deserializing large tasks + if sizeof(ts.runspec) > OFFLOAD_THRESHOLD: + runspec = await offload(_deserialize, ts.runspec) + else: + runspec = _deserialize(ts.runspec) + stop = time() - if stop - start > 0.010: - ts.startstops.append( - {"action": "deserialize", "start": start, "stop": stop} - ) - return function, args, kwargs - except Exception: - logger.error("Could not deserialize task", exc_info=True) - self.log.append((ts.key, "deserialize-error")) - raise + if stop - start > 0.010: + ts.startstops.append( + {"action": "deserialize", "start": start, "stop": stop} + ) + return runspec + except Exception: + logger.error("Could not deserialize task", exc_info=True) + self.log.append((ts.key, "deserialize-error")) + raise + else: + assert isinstance(ts.runspec, Computation), ts.runspec + return ts.runspec async def ensure_computing(self): if self.paused: @@ -2880,7 +2876,7 @@ async def execute(self, key): assert ts.state == "executing" assert ts.runspec is not None - function, args, kwargs = await self._maybe_deserialize_task(ts) + function, args, kwargs = ts.runspec.get_func_and_args() args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) @@ -3677,130 +3673,9 @@ async def _get_data(): job_counter = [0] -cache_loads = LRU(maxsize=100) - - -def loads_function(bytes_object): - """Load a function from bytes, cache bytes""" - if len(bytes_object) < 100000: - try: - result = cache_loads[bytes_object] - except KeyError: - result = pickle.loads(bytes_object) - cache_loads[bytes_object] = result - return result - return pickle.loads(bytes_object) - - -def _deserialize(function=None, args=None, kwargs=None, task=no_value): - """Deserialize task inputs and regularize to func, args, kwargs""" - if function is not None: - function = loads_function(function) - if args and isinstance(args, bytes): - args = pickle.loads(args) - if kwargs and isinstance(kwargs, bytes): - kwargs = pickle.loads(kwargs) - - if task is not no_value: - assert not function and not args and not kwargs - function = execute_task - args = (task,) - - return function, args or (), kwargs or {} - - -def execute_task(task): - """Evaluate a nested task - - >>> inc = lambda x: x + 1 - >>> execute_task((inc, 1)) - 2 - >>> execute_task((sum, [1, 2, (inc, 3)])) - 7 - """ - if istask(task): - func, args = task[0], task[1:] - return func(*map(execute_task, args)) - elif isinstance(task, list): - return list(map(execute_task, task)) - else: - return task - - -cache_dumps = LRU(maxsize=100) - -_cache_lock = threading.Lock() - - -def dumps_function(func): - """Dump a function to bytes, cache functions""" - try: - with _cache_lock: - result = cache_dumps[func] - except KeyError: - result = pickle.dumps(func, protocol=4) - if len(result) < 100000: - with _cache_lock: - cache_dumps[func] = result - except TypeError: # Unhashable function - result = pickle.dumps(func, protocol=4) - return result - - -def dumps_task(task): - """Serialize a dask task - - Returns a dict of bytestrings that can each be loaded with ``loads`` - - Examples - -------- - Either returns a task as a function, args, kwargs dict - - >>> from operator import add - >>> dumps_task((add, 1)) # doctest: +SKIP - {'function': b'\x80\x04\x95\x00\x8c\t_operator\x94\x8c\x03add\x94\x93\x94.' - 'args': b'\x80\x04\x95\x07\x00\x00\x00K\x01K\x02\x86\x94.'} - - Or as a single task blob if it can't easily decompose the result. This - happens either if the task is highly nested, or if it isn't a task at all - - >>> dumps_task(1) # doctest: +SKIP - {'task': b'\x80\x04\x95\x03\x00\x00\x00\x00\x00\x00\x00K\x01.'} - """ - if istask(task): - if task[0] is apply and not any(map(_maybe_complex, task[2:])): - d = {"function": dumps_function(task[1]), "args": warn_dumps(task[2])} - if len(task) == 4: - d["kwargs"] = warn_dumps(task[3]) - return d - elif not any(map(_maybe_complex, task[1:])): - return {"function": dumps_function(task[0]), "args": warn_dumps(task[1:])} - return to_serialize(task) - - -_warn_dumps_warned = [False] - - -def warn_dumps(obj, dumps=pickle.dumps, limit=1e6): - """Dump an object to bytes, warn if those bytes are large""" - b = dumps(obj, protocol=4) - if not _warn_dumps_warned[0] and len(b) > limit: - _warn_dumps_warned[0] = True - s = str(obj) - if len(s) > 70: - s = s[:50] + " ... " + s[-15:] - warnings.warn( - "Large object of size %s detected in task graph: \n" - " %s\n" - "Consider scattering large objects ahead of time\n" - "with client.scatter to reduce scheduler burden and \n" - "keep data on workers\n\n" - " future = client.submit(func, big_data) # bad\n\n" - " big_future = client.scatter(big_data) # good\n" - " future = client.submit(func, big_future) # good" - % (format_bytes(len(b)), s) - ) - return b +def _deserialize(runspec: PickledComputation) -> Computation: + """Deserialize computation""" + return runspec.get_computation() def apply_function(