Skip to content

Commit

Permalink
Rename obs_dim to dim
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzi committed Jul 13, 2023
1 parent 91d1df7 commit 4e66221
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def iterate() -> Generator[Tuple[npt.NDArray[np.int64], Any], None, None]:
with futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
for arrow_tbl in _EagerIterator(query.X(layer).tables(), pool=pool):
if axis == 1:
obs_dim = indexer.by_obs(arrow_tbl["soma_dim_0"])
dim = indexer.by_obs(arrow_tbl["soma_dim_0"])
else:
obs_dim = indexer.by_var(arrow_tbl["soma_dim_1"])
dim = indexer.by_var(arrow_tbl["soma_dim_1"])
data = arrow_tbl["soma_data"].to_numpy()
yield obs_dim, data
yield dim, data

joinids = query.obs_joinids() if axis == 1 else query.var_joinids()

Expand All @@ -88,16 +88,16 @@ def iterate() -> Generator[Tuple[npt.NDArray[np.int64], Any], None, None]:

if calculate_variance:
mvn = MeanVarianceAccumulator(n_batches, n_samples, n_dim_0)
for obs_dim, data in iterate():
mvn.update_single_batch(obs_dim, data)
for dim, data in iterate():
mvn.update_single_batch(dim, data)
_, _, all_u, all_var = mvn.finalize()
if calculate_mean:
result["mean"] = all_u
result["variance"] = all_var
else:
mn = MeanAccumulator(n_dim_1, n_dim_0)
for obs_dim, data in iterate():
mn.update(obs_dim, data)
for dim, data in iterate():
mn.update(dim, data)
all_u = mn.finalize()
result["mean"] = all_u

Expand Down

0 comments on commit 4e66221

Please sign in to comment.