Skip to content

Commit

Permalink
Merge pull request #698 from tclose/hash-change-guards
Browse files Browse the repository at this point in the history
Hash change guards
  • Loading branch information
tclose authored Mar 8, 2024
2 parents 0e66136 + ff281aa commit ff01e4c
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 92 deletions.
111 changes: 90 additions & 21 deletions pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from pathlib import Path
import typing as ty
from copy import deepcopy
from copy import deepcopy, copy
from uuid import uuid4
from filelock import SoftFileLock
import shutil
Expand Down Expand Up @@ -281,13 +281,15 @@ def checksum_states(self, state_index=None):
"""
if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING:
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
self.inputs._graph_checksums = {
nd.name: nd.checksum for nd in self.graph_sorted
}

if state_index is not None:
inputs_copy = deepcopy(self.inputs)
inputs_copy = copy(self.inputs)
for key, ind in self.state.inputs_ind[state_index].items():
val = self._extract_input_el(
inputs=inputs_copy, inp_nm=key.split(".")[1], ind=ind
inputs=self.inputs, inp_nm=key.split(".")[1], ind=ind
)
setattr(inputs_copy, key.split(".")[1], val)
# setting files_hash again in case it was cleaned by setting specific element
Expand Down Expand Up @@ -462,13 +464,25 @@ def __call__(
return res

def _modify_inputs(self):
"""Update and preserve a Task's original inputs"""
"""This method modifies the inputs of the task ahead of its execution:
- links/copies upstream files and directories into the destination tasks
working directory as required select state array values corresponding to
state index (it will try to leave them where they are unless specified or
they are on different file systems)
- resolve template values (e.g. output_file_template)
- deepcopy all inputs to guard against in-place changes during the task's
execution (they will be replaced after the task's execution with the
original inputs to ensure the tasks checksums are consistent)
"""
orig_inputs = {
k: deepcopy(v) for k, v in attr.asdict(self.inputs, recurse=False).items()
k: v
for k, v in attr.asdict(self.inputs, recurse=False).items()
if not k.startswith("_")
}
map_copyfiles = {}
for fld in attr_fields(self.inputs):
value = getattr(self.inputs, fld.name)
input_fields = attr.fields(type(self.inputs))
for name, value in orig_inputs.items():
fld = getattr(input_fields, name)
copy_mode, copy_collation = parse_copyfile(
fld, default_collation=self.DEFAULT_COPY_COLLATION
)
Expand All @@ -483,12 +497,22 @@ def _modify_inputs(self):
supported_modes=self.SUPPORTED_COPY_MODES,
)
if value is not copied_value:
map_copyfiles[fld.name] = copied_value
map_copyfiles[name] = copied_value
modified_inputs = template_update(
self.inputs, self.output_dir, map_copyfiles=map_copyfiles
)
if modified_inputs:
self.inputs = attr.evolve(self.inputs, **modified_inputs)
assert all(m in orig_inputs for m in modified_inputs), (
"Modified inputs contain fields not present in original inputs. "
"This is likely a bug."
)
for name, orig_value in orig_inputs.items():
try:
value = modified_inputs[name]
except KeyError:
# Ensure we pass a copy not the original just in case inner
# attributes are modified during execution
value = deepcopy(orig_value)
setattr(self.inputs, name, value)
return orig_inputs

def _populate_filesystem(self, checksum, output_dir):
Expand Down Expand Up @@ -548,13 +572,14 @@ def _run(self, rerun=False, environment=None, **kwargs):
save(output_dir, result=result, task=self)
# removing the additional file with the checksum
(self.cache_dir / f"{self.uid}_info.json").unlink()
# # function etc. shouldn't change anyway, so removing
orig_inputs = {
k: v for k, v in orig_inputs.items() if not k.startswith("_")
}
self.inputs = attr.evolve(self.inputs, **orig_inputs)
# Restore original values to inputs
for field_name, field_value in orig_inputs.items():
setattr(self.inputs, field_name, field_value)
os.chdir(cwd)
self.hooks.post_run(self, result)
# Check for any changes to the input hashes that have occurred during the execution
# of the task
self._check_for_hash_changes()
return result

def _collect_outputs(self, output_dir):
Expand Down Expand Up @@ -816,8 +841,8 @@ def result(self, state_index=None, return_inputs=False):
Returns
-------
result :
result : Result
the result of the task
"""
# TODO: check if result is available in load_result and
# return a future if not
Expand Down Expand Up @@ -884,6 +909,47 @@ def _reset(self):
for task in self.graph.nodes:
task._reset()

def _check_for_hash_changes(self):
hash_changes = self.inputs.hash_changes()
details = ""
for changed in hash_changes:
field = getattr(attr.fields(type(self.inputs)), changed)
val = getattr(self.inputs, changed)
field_type = type(val)
if issubclass(field.type, FileSet):
details += (
f"- {changed}: value passed to the {field.type} field is of type "
f"{field_type} ('{val}'). If it is intended to contain output data "
"then the type of the field in the interface class should be changed "
"to `pathlib.Path`. Otherwise, if the field is intended to be an "
"input field but it gets altered by the task in some way, then the "
"'copyfile' flag should be set to 'copy' in the field metadata of "
"the task interface class so copies of the files/directories in it "
"are passed to the task instead.\n"
)
else:
details += (
f"- {changed}: the {field_type} object passed to the {field.type}"
f"field appears to have an unstable hash. This could be due to "
"a stochastic/non-thread-safe attribute(s) of the object\n\n"
f"The {field.type}.__bytes_repr__() method can be implemented to "
"bespoke hashing methods based only on the stable attributes for "
f"the `{field_type.__module__}.{field_type.__name__}` type. "
f"See pydra/utils/hash.py for examples. Value: {val}\n"
)
if hash_changes:
raise RuntimeError(
f"Input field hashes have changed during the execution of the "
f"'{self.name}' {type(self).__name__}.\n\n{details}"
)
logger.debug(
"Input values and hashes for '%s' %s node:\n%s\n%s",
self.name,
type(self).__name__,
self.inputs,
self.inputs._hashes,
)

SUPPORTED_COPY_MODES = FileSet.CopyMode.any
DEFAULT_COPY_COLLATION = FileSet.CopyCollation.any

Expand Down Expand Up @@ -1076,7 +1142,9 @@ def checksum(self):
"""
# if checksum is called before run the _graph_checksums is not ready
if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING:
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
self.inputs._graph_checksums = {
nd.name: nd.checksum for nd in self.graph_sorted
}

input_hash = self.inputs.hash
if not self.state:
Expand Down Expand Up @@ -1256,8 +1324,9 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
(self.cache_dir / f"{self.uid}_info.json").unlink()
os.chdir(cwd)
self.hooks.post_run(self, result)
if result is None:
raise Exception("This should never happen, please open new issue")
# Check for any changes to the input hashes that have occurred during the execution
# of the task
self._check_for_hash_changes()
return result

async def _run_task(self, submitter, rerun=False):
Expand Down
45 changes: 31 additions & 14 deletions pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
import pydra
from .helpers_file import template_update_single
from ..utils.hash import hash_function
from ..utils.hash import hash_function, Cache

# from ..utils.misc import add_exc_note

Expand Down Expand Up @@ -73,21 +73,22 @@ class SpecInfo:
class BaseSpec:
"""The base dataclass specs for all inputs and outputs."""

# def __attrs_post_init__(self):
# self.files_hash = {
# field.name: {}
# for field in attr_fields(
# self, exclude_names=("_graph_checksums", "bindings", "files_hash")
# )
# if field.metadata.get("output_file_template") is None
# }

def collect_additional_outputs(self, inputs, output_dir, outputs):
"""Get additional outputs."""
return {}

@property
def hash(self):
hsh, self._hashes = self._compute_hashes()
return hsh

def hash_changes(self):
"""Detects any changes in the hashed values between the current inputs and the
previously calculated values"""
_, new_hashes = self._compute_hashes()
return [k for k, v in new_hashes.items() if v != self._hashes[k]]

def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
"""Compute a basic hash for any given set of fields."""
inp_dict = {}
for field in attr_fields(
Expand All @@ -101,10 +102,13 @@ def hash(self):
if "container_path" in field.metadata:
continue
inp_dict[field.name] = getattr(self, field.name)
inp_hash = hash_function(inp_dict)
hash_cache = Cache({})
field_hashes = {
k: hash_function(v, cache=hash_cache) for k, v in inp_dict.items()
}
if hasattr(self, "_graph_checksums"):
inp_hash = hash_function((inp_hash, self._graph_checksums))
return inp_hash
field_hashes["_graph_checksums"] = self._graph_checksums
return hash_function(sorted(field_hashes.items())), field_hashes

def retrieve_values(self, wf, state_index: ty.Optional[int] = None):
"""Get values contained by this spec."""
Expand Down Expand Up @@ -984,8 +988,21 @@ def get_value(
if result is None:
raise RuntimeError(
f"Could not find results of '{node.name}' node in a sub-directory "
f"named '{node.checksum}' in any of the cache locations:\n"
f"named '{node.checksum}' in any of the cache locations.\n"
+ "\n".join(str(p) for p in set(node.cache_locations))
+ f"\n\nThis is likely due to hash changes in '{self.name}' node inputs. "
f"Current values and hashes: {self.inputs}, "
f"{self.inputs._hashes}\n\n"
"Set loglevel to 'debug' in order to track hash changes "
"throughout the execution of the workflow.\n\n "
"These issues may have been caused by `bytes_repr()` methods "
"that don't return stable hash values for specific object "
"types across multiple processes (see bytes_repr() "
'"singledispatch "function in pydra/utils/hash.py).'
"You may need to implement a specific `bytes_repr()` "
'"singledispatch overload"s or `__bytes_repr__()` '
"dunder methods to handle one or more types in "
"your interface inputs."
)
_, split_depth = TypeParser.strip_splits(self.type)

Expand Down
65 changes: 48 additions & 17 deletions pydra/engine/submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,55 @@ async def expand_workflow(self, wf, rerun=False):
# don't block the event loop!
await asyncio.sleep(1)
if ii > 60:
blocked = _list_blocked_tasks(graph_copy)
# get_runnable_tasks(graph_copy) # Uncomment to debug `get_runnable_tasks`
raise Exception(
"graph is not empty, but not able to get more tasks "
"- something may have gone wrong when retrieving the results "
"of predecessor tasks. This could be caused by a file-system "
"error or a bug in the internal workflow logic, but is likely "
"to be caused by the hash of an upstream node being unstable."
" \n\nHash instability can be caused by an input of the node being "
"modified in place, or by psuedo-random ordering of `set` or "
"`frozenset` inputs (or nested attributes of inputs) in the hash "
"calculation. To ensure that sets are hashed consistently you can "
"you can try set the environment variable PYTHONHASHSEED=0 for "
"all processes, but it is best to try to identify where the set "
"objects are occurring and manually hash their sorted elements. "
"(or use list objects instead)"
"\n\nBlocked tasks\n-------------\n" + "\n".join(blocked)
msg = (
f"Graph of '{wf}' workflow is not empty, but not able to get "
"more tasks - something has gone wrong when retrieving the "
"results predecessors:\n\n"
)
# Get blocked tasks and the predecessors they are waiting on
outstanding = {
t: [
p for p in graph_copy.predecessors[t.name] if not p.done
]
for t in graph_copy.sorted_nodes
}

hashes_have_changed = False
for task, waiting_on in outstanding.items():
if not waiting_on:
continue
msg += f"- '{task.name}' node blocked due to\n"
for pred in waiting_on:
if (
pred.checksum
!= wf.inputs._graph_checksums[pred.name]
):
msg += (
f" - hash changes in '{pred.name}' node inputs. "
f"Current values and hashes: {pred.inputs}, "
f"{pred.inputs._hashes}\n"
)
hashes_have_changed = True
elif pred not in outstanding:
msg += (
f" - undiagnosed issues in '{pred.name}' node, "
"potentially related to file-system access issues "
)
msg += "\n"
if hashes_have_changed:
msg += (
"Set loglevel to 'debug' in order to track hash changes "
"throughout the execution of the workflow.\n\n "
"These issues may have been caused by `bytes_repr()` methods "
"that don't return stable hash values for specific object "
"types across multiple processes (see bytes_repr() "
'"singledispatch "function in pydra/utils/hash.py).'
"You may need to implement a specific `bytes_repr()` "
'"singledispatch overload"s or `__bytes_repr__()` '
"dunder methods to handle one or more types in "
"your interface inputs."
)
raise RuntimeError(msg)
for task in tasks:
# grab inputs if needed
logger.debug(f"Retrieving inputs for {task}")
Expand Down
4 changes: 2 additions & 2 deletions pydra/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ def command_args(self, root=None):
raise NotImplementedError

modified_inputs = template_update(self.inputs, output_dir=self.output_dir)
if modified_inputs is not None:
self.inputs = attr.evolve(self.inputs, **modified_inputs)
for field_name, field_value in modified_inputs.items():
setattr(self.inputs, field_name, field_value)

pos_args = [] # list for (position, command arg)
self._positions_provided = []
Expand Down
Loading

0 comments on commit ff01e4c

Please sign in to comment.