diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml index 2b9cde4e8..3b1ab5b3e 100644 --- a/.github/workflows/cubed.yml +++ b/.github/workflows/cubed.yml @@ -30,4 +30,4 @@ jobs: - name: Test with pytest run: | - pytest -v sgkit/tests/test_aggregation.py -k "test_count_call_alleles" --use-cubed + pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index f76dbe691..d3bd48596 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -183,11 +183,14 @@ def count_variant_alleles( variables.validate(ds, {call_genotype: variables.call_genotype_spec}) n_alleles = ds.sizes["alleles"] n_variant = ds.sizes["variants"] - G = da.asarray(ds[call_genotype]).reshape((n_variant, -1)) + G = da.asarray(ds[call_genotype]) + G = da.reshape(G, (n_variant, -1)) shape = (G.chunks[0], n_alleles) # use uint64 dummy array to return uin64 counts array N = np.empty(n_alleles, dtype=np.uint64) - AC = da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=1, new_axis=1) + AC = da.map_blocks( + count_alleles, G, N, chunks=shape, dtype=np.uint64, drop_axis=1, new_axis=1 + ) AC = xr.DataArray(AC, dims=["variants", "alleles"]) else: options = {variables.call_genotype, variables.call_allele_count} @@ -692,22 +695,23 @@ def variant_stats( using=variables.call_genotype, # improved performance merge=False, )[variant_allele_count] - G = da.array(ds[call_genotype].data) + G = da.asarray(ds[call_genotype].data) H = xr.DataArray( da.map_blocks( - count_hom, + lambda *args: count_hom(*args)[:, np.newaxis, :], G, np.zeros(3, np.uint64), - drop_axis=(1, 2), - new_axis=1, + drop_axis=2, + new_axis=2, dtype=np.int64, - chunks=(G.chunks[0], 3), + chunks=(G.chunks[0], 1, 3), ), - dims=["variants", "categories"], + dims=["variants", "samples", "categories"], ) + H = H.sum(axis=1) _, n_sample, _ = G.shape n_called = H.sum(axis=-1) - call_rate = n_called / n_sample + call_rate = n_called.astype(float) / float(n_sample) n_hom_ref = H[:, 0] n_hom_alt = H[:, 1] n_het = H[:, 2] @@ -723,7 +727,8 @@ def variant_stats( variables.variant_n_non_ref: n_non_ref, variables.variant_allele_count: AC, variables.variant_allele_total: allele_total, - variables.variant_allele_frequency: AC / allele_total, + variables.variant_allele_frequency: AC.astype(float) + / allele_total.astype(float), } ) # for backwards compatible behavior @@ -798,7 +803,7 @@ def sample_stats( mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) if mixed_ploidy: raise ValueError("Mixed-ploidy dataset") - G = da.array(ds[call_genotype].data) + G = da.asarray(ds[call_genotype].data) H = xr.DataArray( da.map_blocks( count_hom, diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index e2719c1c1..f67d1e658 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -144,7 +144,7 @@ def test_count_variant_alleles__chunked(using): chunks={"variants": 5, "samples": 5} ) ac2 = count_variant_alleles(ds, using=using) - assert isinstance(ac2["variant_allele_count"].data, da.Array) + assert hasattr(ac2["variant_allele_count"].data, "chunks") xr.testing.assert_equal(ac1, ac2) @@ -786,13 +786,14 @@ def test_variant_stats__tetraploid(): ) -@pytest.mark.parametrize( - "chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1), (100, 10, 1)] -) -def test_variant_stats__chunks(chunks): +@pytest.mark.parametrize("precompute_variant_allele_count", [False, True]) +@pytest.mark.parametrize("chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1)]) +def test_variant_stats__chunks(precompute_variant_allele_count, chunks): ds = simulate_genotype_call_dataset( n_variant=1000, n_sample=30, missing_pct=0.01, seed=0 ) + if precompute_variant_allele_count: + ds = count_variant_alleles(ds) expect = variant_stats(ds, merge=False).compute() ds["call_genotype"] = ds["call_genotype"].chunk(chunks) actual = variant_stats(ds, merge=False).compute()