Skip to content

Commit

Permalink
[builder] ongoing refactoring for Dask (#1040)
Browse files Browse the repository at this point in the history
* first cut at fixed budget anndata handling

* memory

* refactor consolidate

* checkpoint refactoring for memory budget

* always have at least one worker

* smaller strides

* improve memory diagnostics

* autoupdate precommit modules

* fix bug in no-consolidate

* update test to match new manifest field requirements

* remove unused code

* further memory budget refinement and tuning

* add missing __len__ to AnnDataProxy

* further memory usage reduction

* preserve column ordering in dataframe loading

* comments and cleanup

* add extra verbose logging level

* back out parallel consolidation for now

* added a todo reminder

* a few more memory tuning tweaks

* simplify open_anndata interface

* pr review

* clean up logger

* lint

* snapshot initial dask explorations

* pr feedback

* additional dask refactoring

* fix empty slice bug

* additional refactoring to use dask

* refine async consolidator

* checkpoint progress

* additional X layer processing refinement

* fix pytest

* fix mocks in test

* update package deps for builder

* comment

* improve dataset shuffle

* tuning

* update to latest tiledb

* update to latest tiledb

* cleanup

* additional scale updates

* fix numpy cast error

* shorten step count for async consolidator

* additional cleanup

* update to latest cellxgene_census

* update tiledbsoma dep

* lint

* tune thread count cap

* update to latest tiledbsoma

* lint

* remove debugging code

* checkpoint partial refactoring

* second checkpoint

* clean up logging

* add docstring

* third checkpoint

* further refinement of validation refactoring

* dep update and cleanup

* lint

* fix builder test

* additional cleanup of heartbeat and exit hang

* cleanup minor detritus from refactoring

* additional cleanup

* fix tests and missing parameterization

* remove dead code

* remove dead code
  • Loading branch information
Bruce Martin authored Mar 12, 2024
1 parent b72dac1 commit 44c01f9
Show file tree
Hide file tree
Showing 17 changed files with 944 additions and 1,114 deletions.
12 changes: 6 additions & 6 deletions tools/cellxgene_census_builder/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
dependencies= [
"typing_extensions==4.9.0",
"typing_extensions==4.10.0",
"pyarrow==15.0.0",
"pandas[performance]==2.2.0",
"pandas[performance]==2.2.1",
"anndata==0.10.5.post1",
"numpy==1.24.4",
# IMPORTANT: consider TileDB format compat before advancing this version. It is important that
# IMPORTANT: the tiledbsoma version lag that used in cellxgene-census package.
"tiledbsoma==1.7.2",
"tiledbsoma==1.7.3",
"cellxgene-census==1.10.2",
"scipy==1.12.0",
"fsspec[http]==2024.2.0",
Expand All @@ -42,13 +42,13 @@ dependencies= [
"Cython", # required by owlready2
"wheel", # required by owlready2
"owlready2==0.44",
"gitpython==3.1.41",
"gitpython==3.1.42",
"attrs==23.2.0",
"psutil==5.9.8",
"pyyaml==6.0.1",
"numba==0.58.1",
"dask==2024.2.0",
"distributed==2024.2.0",
"dask==2024.2.1",
"distributed==2024.2.1",
]

[project.urls]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from .build_soma import build as build_a_soma
from .build_state import CENSUS_BUILD_CONFIG, CENSUS_BUILD_STATE, CensusBuildArgs, CensusBuildConfig, CensusBuildState
from .util import log_process_resource_status, process_init, start_resource_logger, urlcat
from .process_init import process_init
from .util import log_process_resource_status, start_resource_logger, urlcat

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import attrs

from ..build_state import CensusBuildArgs, CensusBuildConfig
from ..util import log_process_resource_status, process_init, start_resource_logger
from ..process_init import process_init
from ..util import log_process_resource_status, start_resource_logger

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,6 +49,7 @@ def main() -> int:
cc = validate(args)

log_process_resource_status(level=logging.INFO)
logger.info("Fini")
return cc


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from datetime import UTC, datetime
from typing import cast

import dask.distributed
import pandas as pd
import psutil
import tiledbsoma as soma

from ..build_state import CensusBuildArgs
from ..util import cpu_count
from ..util import clamp, cpu_count
from .census_summary import create_census_summary
from .consolidate import consolidate, start_async_consolidation, stop_async_consolidation
from .consolidate import submit_consolidate
from .datasets import Dataset, assign_dataset_soma_joinids, create_dataset_manifest
from .experiment_builder import (
ExperimentBuilder,
Expand All @@ -27,11 +28,11 @@
SOMA_TileDB_Context,
)
from .manifest import load_manifest
from .mp import create_dask_client
from .mp import create_dask_client, shutdown_dask_cluster
from .source_assets import stage_source_assets
from .summary_cell_counts import create_census_summary_cell_counts
from .util import get_git_commit_sha, is_git_repo_dirty, shuffle
from .validate_soma import validate as go_validate
from .util import get_git_commit_sha, is_git_repo_dirty
from .validate_soma import validate_consolidation, validate_soma

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,34 +73,30 @@ def build(args: CensusBuildArgs, *, validate: bool = True) -> int:

prepare_file_system(args)

with create_dask_client(args, n_workers=cpu_count(), threads_per_worker=2):
# Step 1 - get all source datasets
datasets = build_step1_get_source_datasets(args)
try:
with create_dask_client(args, n_workers=cpu_count(), threads_per_worker=1, memory_limit=0) as client:
# Step 1 - get all source datasets
datasets = build_step1_get_source_datasets(args)

# Step 2 - create root collection, and all child objects, but do not populate any dataframes or matrices
root_collection = build_step2_create_root_collection(args.soma_path.as_posix(), experiment_builders)
# Step 2 - create root collection, and all child objects, but do not populate any dataframes or matrices
root_collection = build_step2_create_root_collection(args.soma_path.as_posix(), experiment_builders)

# Step 3 - populate axes
filtered_datasets = build_step3_populate_obs_and_var_axes(
args.h5ads_path.as_posix(), datasets, experiment_builders, args
)
# Step 3 - populate axes
filtered_datasets = build_step3_populate_obs_and_var_axes(
args.h5ads_path.as_posix(), datasets, experiment_builders, args
)

# Constraining parallelism is critical at this step, as each worker utilizes ~128GiB+ of memory to
# process the X array (partitions are large to reduce TileDB fragment count).
#
# TODO: when global order writes are supported, processing of much smaller slices will be
# possible, and this budget should drop considerably. When that is implemented, n_workers should be
# be much larger (eg., use default value of #CPUs or some such).
# https://github.com/single-cell-data/TileDB-SOMA/issues/2054
MEM_BUDGET = 128 * 1024**3
total_memory = psutil.virtual_memory().total
n_workers = max(1, int(total_memory // MEM_BUDGET) - 1) # reserve one for main thread
with create_dask_client(args, n_workers=n_workers, threads_per_worker=1, memory_limit=0):
try:
if args.config.consolidate:
consolidator = start_async_consolidation(root_collection.uri)
else:
consolidator = None
# Constraining parallelism is critical at this step, as each worker utilizes (max) ~64GiB+ of memory to
# process the X array (partitions are large to reduce TileDB fragment count, which reduces consolidation time).
#
# TODO: when global order writes are supported, processing of much smaller slices will be
# possible, and this budget should drop considerably. When that is implemented, n_workers should be
# be much larger (eg., use default value of #CPUs or some such).
# https://github.com/single-cell-data/TileDB-SOMA/issues/2054
MEM_BUDGET = 64 * 1024**3
n_workers = clamp(int(psutil.virtual_memory().total // MEM_BUDGET), 1, args.config.max_worker_processes)
logger.info(f"Scaling cluster to {n_workers} workers.")
client.cluster.scale(n_workers)

# Step 4 - populate X layers
build_step4_populate_X_layers(args.h5ads_path.as_posix(), filtered_datasets, experiment_builders, args)
Expand All @@ -112,22 +109,31 @@ def build(args: CensusBuildArgs, *, validate: bool = True) -> int:
root_collection, experiment_builders, filtered_datasets, args.config.build_tag
)

finally:
if consolidator:
stop_async_consolidation(consolidator) # blocks until any running consolidate finishes
del consolidator

# Temporary work-around. Can be removed when single-cell-data/TileDB-SOMA#1969 fixed.
tiledb_soma_1969_work_around(root_collection.uri)
# Temporary work-around. Can be removed when single-cell-data/TileDB-SOMA#1969 fixed.
tiledb_soma_1969_work_around(root_collection.uri)

# TODO: consolidation and validation can be done in parallel. Goal: do this work
# when refactoring validation to use Dask.
# Scale the cluster up as we are no longer memory constrained in the following phases
n_workers = clamp(cpu_count(), 1, args.config.max_worker_processes)
logger.info(f"Scaling cluster to {n_workers} workers.")
client.cluster.scale(n=n_workers)

if args.config.consolidate:
consolidate(args, root_collection.uri)

if validate:
go_validate(args)
if args.config.consolidate:
for f in dask.distributed.as_completed(
submit_consolidate(root_collection.uri, pool=client, vacuum=True)
):
assert f.result()
if validate:
for f in dask.distributed.as_completed(validate_soma(args, client)):
assert f.result()
if args.config.consolidate and validate:
validate_consolidation(args)
logger.info("Validation & consolidation complete.")

shutdown_dask_cluster(client)

except TimeoutError:
# quiet tornado race conditions (harmless) on shutdown
pass

return 0

Expand Down Expand Up @@ -172,24 +178,18 @@ def build_step1_get_source_datasets(args: CensusBuildArgs) -> list[Dataset]:
logger.error("No H5AD files in the manifest (or we can't find the files)")
raise RuntimeError("No H5AD files in the manifest (or we can't find the files)")

# sort encourages (does not guarantee) largest files processed first
datasets = sorted(all_datasets, key=lambda d: d.asset_h5ad_filesize)

# Testing/debugging hook - hidden option
if args.config.test_first_n is not None and abs(args.config.test_first_n) > 0:
datasets = sorted(all_datasets, key=lambda d: d.asset_h5ad_filesize)
if args.config.test_first_n > 0:
# Process the N smallest datasets
datasets = datasets[: args.config.test_first_n]
else:
# Process the N largest datasets
datasets = datasets[args.config.test_first_n :]

else:
# Shuffle datasets by size.
# TODO: it is unclear if this shuffle has material impact. Needs more benchmarking.
# Nothing magical about a step of 16, other than the observation that the dataset
# distribution is disproportionately populated by small datasets.
datasets = shuffle(sorted(all_datasets, key=lambda d: d.asset_h5ad_filesize), step=16)
assert {d.dataset_id for d in all_datasets} == {d.dataset_id for d in datasets}

# Stage all files
stage_source_assets(datasets, args)

Expand Down
Loading

0 comments on commit 44c01f9

Please sign in to comment.