Skip to content

Commit

Permalink
Merge branch 'pytorch-shuffle-multiple-chunks' of github.com:chanzuck…
Browse files Browse the repository at this point in the history
…erberg/cell-census into pytorch-shuffle-multiple-chunks
  • Loading branch information
ebezzi committed Jun 3, 2024
2 parents 5a22e2f + 2e5cdda commit a97de8a
Showing 1 changed file with 26 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,17 @@ def __init__(
self.obs_column_names = obs_column_names
if shuffle_chunk_count:
assert shuffle_rng is not None
chunk_count = len(obs_joinids_chunked)
grouped_chunks_count = chunk_count // min(chunk_count, shuffle_chunk_count)

# At the start of this step, `obs_joinids_chunked` is a list of one dimensional
# numpy arrays. Each numpy array corresponds to a chunk of contiguous rows in `obs`.
# Critically, `obs_joinids_chunked` is randomly ordered where each chunk is
# from a random section of `obs`.
# We then take `shuffle_chunk_count` of these in order, concatenate them into
# a larger numpy array and shuffle this larger numpy array.
# The result is again a list of numpy arrays.
self.obs_joinids_chunks_iter = (
shuffle_rng.permutation(np.concatenate(grouped_chunks))
for grouped_chunks in np.array_split(obs_joinids_chunked, grouped_chunks_count)
for grouped_chunks in list_split(obs_joinids_chunked, shuffle_chunk_count)
)
else:
self.obs_joinids_chunks_iter = iter(obs_joinids_chunked)
Expand Down Expand Up @@ -185,6 +191,21 @@ def __next__(self) -> _SOMAChunk:
return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats)


def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]:
"""Splits a python list into a list of sublists where each sublist is of size `sublist_len`."""
i = 0
result = []
while i < len(arr_list):
if (i + sublist_len) >= len(arr_list):
result.append(arr_list[i:])
else:
result.append(arr_list[i : i + sublist_len])

i += sublist_len

return result


def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any]]: # noqa: D103
proc = psutil.Process(os.getpid())

Expand Down Expand Up @@ -455,7 +476,8 @@ def __init__(
parallel with client-side processing of the SOMA data, potentially improving overall performance at the
cost of doubling memory utilization. Defaults to ``True``.
shuffle_chunk_count:
TODO
The number contiguous blocks (chunks) of rows to read at random and then concatenated and shuffled.
Larger number for `shuffle_chunk_count` correspond to more randomness in the shuffling.
Lifecycle:
experimental
Expand Down

0 comments on commit a97de8a

Please sign in to comment.