Skip to content

Commit

Permalink
Set shuffle as default in pytorch, use new algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzi committed Jun 6, 2024
1 parent 5fc72e2 commit 95d8217
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,12 @@ def __init__(
var_query: Optional[soma.AxisQuery] = None,
obs_column_names: Sequence[str] = (),
batch_size: int = 1,
shuffle: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
return_sparse_X: bool = False,
soma_chunk_size: Optional[int] = None,
soma_chunk_size: Optional[int] = 64,
use_eager_fetch: bool = True,
shuffle_chunk_count: Optional[int] = None,
shuffle_chunk_count: Optional[int] = 2000,
) -> None:
r"""Construct a new ``ExperimentDataPipe``.
Expand All @@ -443,18 +443,14 @@ def __init__(
``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will
result in :class:`torch.Tensor`\ s of rank 2 (multiple rows).
shuffle:
Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``False`` (no shuffling).
For performance reasons, shuffling is performed in two steps: 1) a global shuffling, where contiguous
rows are grouped into chunks and the order of the chunks is randomized, and then 2) a local
shuffling, where the rows within each chunk are shuffled. Since this class must retrieve data
in chunks (to keep memory requirements to a fixed size), global shuffling ensures that a given row in
the shuffled result can originate from any position in the non-shuffled result ordering. If shuffling
only occurred within each chunk (i.e. "local" shuffling), the first chunk's rows would always be
returned first, the second chunk's rows would always be returned second, and so on. The chunk size is
determined by the ``soma_chunk_size`` parameter. Note that rows within a chunk will maintain
proximity, even after shuffling, so some experimentation may be required to ensure the shuffling is
sufficient for the model training process. To this end, the ``soma_chunk_size`` can be treated as a
hyperparameter that can be tuned.
Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``.
For performance reasons, shuffling is not performed globally across all rows, but rather in chunks.
More specifically, we select ``shuffle_chunk_count`` non contiguous chunks across all the dataset,
concatenate them and shuffle the resulting array.
The randomness of the shuffling is therefore determined by the
(``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have been determined
to yield a good trade-off between randomness and performance. Further tuning may be required for
different type of models.
seed:
The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using
:class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker
Expand All @@ -468,10 +464,8 @@ def __init__(
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of
this class's behavior: 1) The maximum memory utilization, with larger values providing
better read performance, but also requiring more memory; 2) The granularity of the global shuffling
step (see ``shuffle`` parameter for details). If not specified, the value is set to utilize ~1 GiB of
RAM per SOMA chunk read, based upon the number of ``var`` columns (cells/features) being requested
and assuming X data sparsity of 95%; the number of rows per chunk will depend on the number of
``var`` columns being read.
step (see ``shuffle`` parameter for details). The default value of 64 works well in conjunction
with the default ``shuffle_chunk_count`` value.
use_eager_fetch:
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
available for processing via the iterator. This allows network (or filesystem) requests to be made in
Expand All @@ -480,6 +474,7 @@ def __init__(
shuffle_chunk_count:
The number of contiguous blocks (chunks) of rows sampled to then concatenate and shuffle.
Larger numbers correspond to more randomness per training batch.
If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``.
Lifecycle:
experimental
Expand All @@ -499,7 +494,7 @@ def __init__(
self._encoders = None
self._obs_joinids = None
self._var_joinids = None
self._shuffle_chunk_count = (shuffle_chunk_count or 1) if shuffle else None
self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None
self._shuffle_rng = np.random.default_rng(seed) if shuffle else None
self._initialized = False

Expand Down
11 changes: 11 additions & 0 deletions api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
row_iter = iter(exp_data_pipe)
Expand All @@ -164,6 +165,7 @@ def test_batching__all_batches_full_size(soma_experiment: Experiment, use_eager_
X_name="raw",
obs_column_names=["label"],
batch_size=3,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand Down Expand Up @@ -214,6 +216,7 @@ def test_batching__partial_final_batch_size(soma_experiment: Experiment, use_eag
X_name="raw",
obs_column_names=["label"],
batch_size=3,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand All @@ -239,6 +242,7 @@ def test_batching__exactly_one_batch(soma_experiment: Experiment, use_eager_fetc
X_name="raw",
obs_column_names=["label"],
batch_size=3,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand Down Expand Up @@ -286,6 +290,7 @@ def test_sparse_output__non_batched(soma_experiment: Experiment, use_eager_fetch
X_name="raw",
obs_column_names=["label"],
return_sparse_X=True,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand All @@ -309,6 +314,7 @@ def test_sparse_output__batched(soma_experiment: Experiment, use_eager_fetch: bo
obs_column_names=["label"],
batch_size=3,
return_sparse_X=True,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
batch_iter = iter(exp_data_pipe)
Expand Down Expand Up @@ -350,6 +356,7 @@ def test_encoders(soma_experiment: Experiment) -> None:
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
shuffle=False,
batch_size=3,
)
batch_iter = iter(exp_data_pipe)
Expand Down Expand Up @@ -413,6 +420,7 @@ def test_distributed__returns_data_partition_for_rank(
X_name="raw",
obs_column_names=["label"],
soma_chunk_size=2,
shuffle=False,
)
full_result = list(iter(dp))

Expand Down Expand Up @@ -451,6 +459,7 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
X_name="raw",
obs_column_names=["label"],
soma_chunk_size=2,
shuffle=False,
)

full_result = list(iter(dp))
Expand All @@ -475,6 +484,7 @@ def test_experiment_dataloader__non_batched(soma_experiment: Experiment, use_eag
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
Expand All @@ -498,6 +508,7 @@ def test_experiment_dataloader__batched(soma_experiment: Experiment, use_eager_f
X_name="raw",
obs_column_names=["label"],
batch_size=3,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
Expand Down

0 comments on commit 95d8217

Please sign in to comment.