Skip to content

Commit

Permalink
Refactor and fix connected component tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 7, 2022
1 parent 2fa2ede commit 134f6a7
Show file tree
Hide file tree
Showing 14 changed files with 324 additions and 568 deletions.
8 changes: 3 additions & 5 deletions cluster_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .workflows import MulticutSegmentationWorkflow
from .workflows import LiftedMulticutSegmentationWorkflow
from .workflows import AgglomerativeClusteringWorkflow
from .workflows import SimpleStitchingWorkflow
from .thresholded_components import ThresholdedComponentsWorkflow, ThresholdAndWatershedWorkflow
from .workflows import (AgglomerativeClusteringWorkflow, LiftedMulticutSegmentationWorkflow,
MulticutSegmentationWorkflow, SimpleStitchingWorkflow)
from .connected_components import ConnectedComponentsWorkflow, ConnectedComponentsAndWatershedWorkflow

from .version import __version__
1 change: 1 addition & 0 deletions cluster_tools/connected_components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .connected_components_workflow import ConnectedComponentsWorkflow, ConnectedComponentsAndWatershedWorkflow
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
# Block-wise connected components tasks
#

class BlockComponentsBase(luigi.Task):
""" BlockComponents base class
class ConnectedComponentBlocksBase(luigi.Task):
""" ConnectedComponentBlocks base class
"""

task_name = 'block_components'
task_name = "connected_component_blocks"
src_file = os.path.abspath(__file__)
allow_retry = False

Expand All @@ -32,19 +32,19 @@ class BlockComponentsBase(luigi.Task):
output_key = luigi.Parameter()
# task that is required before running this task
dependency = luigi.TaskParameter()
threshold = luigi.FloatParameter()
threshold_mode = luigi.Parameter(default='greater')
mask_path = luigi.Parameter(default='')
mask_key = luigi.Parameter(default='')
threshold = luigi.FloatParameter(default=None)
threshold_mode = luigi.Parameter(default=None)
mask_path = luigi.Parameter(default="")
mask_key = luigi.Parameter(default="")
channel = luigi.Parameter(default=None)

threshold_modes = ('greater', 'less', 'equal', 'none')
threshold_modes = ("greater", "less", "equal", None)

@staticmethod
def default_task_config():
# we use this to get also get the common default config
config = LocalTask.default_task_config()
config.update({'sigma_prefilter': 0})
config.update({"sigma_prefilter": 0})
return config

def requires(self):
Expand All @@ -59,23 +59,22 @@ def run_impl(self):

assert self.threshold_mode in self.threshold_modes
config = self.get_task_config()
config.update({'input_path': self.input_path,
'input_key': self.input_key,
'output_path': self.output_path,
'output_key': self.output_key,
'block_shape': block_shape,
'tmp_folder': self.tmp_folder,
'threshold': self.threshold,
'threshold_mode': self.threshold_mode})
config.update({"input_path": self.input_path,
"input_key": self.input_key,
"output_path": self.output_path,
"output_key": self.output_key,
"block_shape": block_shape,
"threshold": self.threshold,
"threshold_mode": self.threshold_mode,
"tmp_folder": self.tmp_folder})

# check if we have a mask and add to the config if we do
if self.mask_path != '':
assert self.mask_key != ''
config.update({'mask_path': self.mask_path,
'mask_key': self.mask_key})
if self.mask_path != "":
assert self.mask_key != ""
config.update({"mask_path": self.mask_path, "mask_key": self.mask_key})

# get chunks
chunks = config.pop('chunks', None)
chunks = config.pop("chunks", None)
if chunks is None:
chunks = tuple(bs // 2 for bs in block_shape)

Expand All @@ -94,16 +93,15 @@ def run_impl(self):
assert all(isinstance(chan, int) for chan in self.channel)
assert shape[0] > max(self.channel), "%i, %i" % (shape[0], max(self.channel))
shape = shape[1:]
config.update({'channel': self.channel})
config.update({"channel": self.channel})

# clip chunks
chunks = tuple(min(ch, sh) for ch, sh in zip(chunks, shape))

# make output dataset
compression = config.pop('compression', 'gzip')
with vu.file_reader(self.output_path) as f:
f.require_dataset(self.output_key, shape=shape, dtype='uint64',
compression=compression, chunks=chunks)
compression = config.pop("compression", "gzip")
with vu.file_reader(self.output_path, "a") as f:
f.require_dataset(self.output_key, shape=shape, dtype="uint64", compression=compression, chunks=chunks)

block_list = vu.blocks_in_volume(shape, block_shape,
roi_begin, roi_end)
Expand All @@ -119,23 +117,23 @@ def run_impl(self):
self.check_jobs(n_jobs)


class BlockComponentsLocal(BlockComponentsBase, LocalTask):
class ConnectedComponentBlocksLocal(ConnectedComponentBlocksBase, LocalTask):
"""
BlockComponents on local machine
ConnectedComponentsBlocks on local machine
"""
pass


class BlockComponentsSlurm(BlockComponentsBase, SlurmTask):
class ConnectedComponentBlocksSlurm(ConnectedComponentBlocksBase, SlurmTask):
"""
BlockComponents on slurm cluster
ConnectedComponentsBlocks on slurm cluster
"""
pass


class BlockComponentsLSF(BlockComponentsBase, LSFTask):
class ConnectedComponentBlocksLSF(ConnectedComponentBlocksBase, LSFTask):
"""
BlockComponents on lsf cluster
ConnectedComponentsBlocks on lsf cluster
"""
pass

Expand All @@ -155,18 +153,18 @@ def _load_input(ds_in, bb, channel):


def _threshold_impl(input_, threshold, threshold_mode, sigma):
input_ = input_ if threshold_mode == 'none' else vu.normalize(input_)
if sigma > 0 and threshold_mode != 'none':
input_ = vu.apply_filter(input_, 'gaussianSmoothing', sigma)
input_ = input_ if threshold_mode is None else vu.normalize(input_)
if sigma > 0 and threshold_mode is not None:
input_ = vu.apply_filter(input_, "gaussianSmoothing", sigma)
input_ = vu.normalize(input_)

if threshold_mode == 'greater':
if threshold_mode == "greater":
input_ = input_ > threshold
elif threshold_mode == 'less':
elif threshold_mode == "less":
input_ = input_ < threshold
elif threshold_mode == 'equal':
elif threshold_mode == "equal":
input_ = input_ == threshold
elif threshold_mode == 'none':
elif threshold_mode is None:
pass
else:
raise RuntimeError("Thresholding Mode %s not supported" % threshold_mode)
Expand All @@ -175,7 +173,7 @@ def _threshold_impl(input_, threshold, threshold_mode, sigma):

def _cc_block(block_id, blocking,
ds_in, ds_out, threshold,
threshold_mode, channel, sigma):
threshold_mode, channel, sigma, tmp_folder):
fu.log("start processing block %i" % block_id)
block = blocking.getBlock(block_id)
bb = vu.block_to_bb(block)
Expand All @@ -187,21 +185,30 @@ def _cc_block(block_id, blocking,
return 0

components = label(input_)
# add global offset to make ids unique between blocks
offset = block_id * int(np.prod(blocking.blockShape))
assert offset < np.iinfo('uint64').max, "Id overflow"
components[components != 0] += offset

this_ids = np.unique(components)
if not len(this_ids) == 1 and this_ids[0] == 0:
id_path = os.path.join(tmp_folder, f"ids_{block_id}.npy")
np.save(id_path, this_ids)

ds_out[bb] = components
fu.log_block_success(block_id)
return int(components.max()) + 1


def _cc_block_with_mask(block_id, blocking,
ds_in, ds_out, threshold,
threshold_mode, mask,
channel, sigma):
channel, sigma, tmp_folder):
fu.log("start processing block %i" % block_id)
block = blocking.getBlock(block_id)
bb = vu.block_to_bb(block)

# get the mask and check if we have any pixels
in_mask = mask[bb].astype('bool')
in_mask = mask[bb].astype("bool")
if np.sum(in_mask) == 0:
fu.log_block_success(block_id)
return 0
Expand All @@ -214,38 +221,47 @@ def _cc_block_with_mask(block_id, blocking,
return 0

components = label(input_)
# add global offset to make ids unique between blocks
offset = block_id * int(np.prod(blocking.blockShape))
assert offset < np.iinfo('uint64').max, "Id overflow"
components[components != 0] += offset

this_ids = np.unique(components)
if not len(this_ids) == 1 and this_ids[0] == 0:
id_path = os.path.join(tmp_folder, f"ids_{block_id}.npy")
np.save(id_path, this_ids)

ds_out[bb] = components
fu.log_block_success(block_id)
return int(components.max()) + 1


def block_components(job_id, config_path):
def connected_components_block(job_id, config_path):

fu.log("start processing job %i" % job_id)
fu.log("reading config from %s" % config_path)

with open(config_path, 'r') as f:
with open(config_path, "r") as f:
config = json.load(f)
input_path = config['input_path']
input_key = config['input_key']
output_path = config['output_path']
output_key = config['output_key']
block_list = config['block_list']
tmp_folder = config['tmp_folder']
block_shape = config['block_shape']
threshold = config['threshold']
threshold_mode = config['threshold_mode']
input_path = config["input_path"]
input_key = config["input_key"]
output_path = config["output_path"]
output_key = config["output_key"]
block_list = config["block_list"]
block_shape = config["block_shape"]
threshold = config["threshold"]
threshold_mode = config["threshold_mode"]
tmp_folder = config["tmp_folder"]

sigma = config.get('sigma_prefilter', 0)
sigma = config.get("sigma_prefilter", 0)

mask_path = config.get('mask_path', '')
mask_key = config.get('mask_key', '')
mask_path = config.get("mask_path", "")
mask_key = config.get("mask_key", "")

channel = config.get('channel', None)
channel = config.get("channel", None)

fu.log("Applying threshold %f with mode %s" % (threshold, threshold_mode))

with vu.file_reader(input_path, 'r') as f_in, vu.file_reader(output_path) as f_out:
with vu.file_reader(input_path, "r") as f_in, vu.file_reader(output_path) as f_out:

ds_in = f_in[input_key]
ds_out = f_out[output_key]
Expand All @@ -257,28 +273,22 @@ def block_components(job_id, config_path):

blocking = nt.blocking([0, 0, 0], list(shape), block_shape)

if mask_path != '':
if mask_path != "":
mask = vu.load_mask(mask_path, mask_key, shape)
offsets = [_cc_block_with_mask(block_id, blocking,
ds_in, ds_out, threshold,
threshold_mode, mask, channel,
sigma) for block_id in block_list]
for block_id in block_list:
_cc_block_with_mask(block_id, blocking, ds_in, ds_out, threshold,
threshold_mode, mask, channel, sigma, tmp_folder)

else:
offsets = [_cc_block(block_id, blocking,
ds_in, ds_out, threshold,
threshold_mode, channel, sigma) for block_id in block_list]

offset_dict = {block_id: off for block_id, off in zip(block_list, offsets)}
save_path = os.path.join(tmp_folder,
'connected_components_offsets_%i.json' % job_id)
with open(save_path, 'w') as f:
json.dump(offset_dict, f)
for block_id in block_list:
_cc_block(block_id, blocking, ds_in, ds_out, threshold,
threshold_mode, channel, sigma, tmp_folder)

fu.log_job_success(job_id)


if __name__ == '__main__':
if __name__ == "__main__":
path = sys.argv[1]
assert os.path.exists(path), path
job_id = int(os.path.split(path)[1].split('.')[0].split('_')[-1])
block_components(job_id, path)
job_id = int(os.path.split(path)[1].split(".")[0].split("_")[-1])
connected_components_block(job_id, path)
Loading

0 comments on commit 134f6a7

Please sign in to comment.