Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor variable names for consistency #253

Merged
merged 3 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bio2zarr/vcf2zarr/icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,11 +853,11 @@ def __len__(self):

def summary_table(self):
data = []
for name, col in self.fields.items():
summary = col.vcf_field.summary
for name, icf_field in self.fields.items():
summary = icf_field.vcf_field.summary
d = {
"name": name,
"type": col.vcf_field.vcf_type,
"type": icf_field.vcf_field.vcf_type,
"chunks": summary.num_chunks,
"size": core.display_size(summary.uncompressed_size),
"compressed": core.display_size(summary.compressed_size),
Expand Down Expand Up @@ -1009,8 +1009,8 @@ def mkdirs(self):
self.path.mkdir()
self.wip_path.mkdir()
for field in self.metadata.fields:
col_path = get_vcf_field_path(self.path, field)
col_path.mkdir(parents=True)
field_path = get_vcf_field_path(self.path, field)
field_path.mkdir(parents=True)

def load_partition_summaries(self):
summaries = []
Expand Down
108 changes: 55 additions & 53 deletions bio2zarr/vcf2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def from_field(
num_samples,
variants_chunk_size,
samples_chunk_size,
variable_name=None,
array_name=None,
):
shape = [num_variants]
prefix = "variant_"
Expand All @@ -79,8 +79,8 @@ def from_field(
shape.append(num_samples)
chunks.append(samples_chunk_size)
dimensions.append("samples")
if variable_name is None:
variable_name = prefix + vcf_field.name
if array_name is None:
array_name = prefix + vcf_field.name
# TODO make an option to add in the empty extra dimension
if vcf_field.summary.max_number > 1:
shape.append(vcf_field.summary.max_number)
Expand All @@ -96,7 +96,7 @@ def from_field(
dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
return ZarrArraySpec.new(
vcf_field=vcf_field.full_name,
name=variable_name,
name=array_name,
dtype=vcf_field.smallest_dtype(),
shape=shape,
chunks=chunks,
Expand Down Expand Up @@ -226,14 +226,14 @@ def generate(icf, variants_chunk_size=None, samples_chunk_size=None):
f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
)

def spec_from_field(field, variable_name=None):
def spec_from_field(field, array_name=None):
return ZarrArraySpec.from_field(
field,
num_samples=n,
num_variants=m,
samples_chunk_size=samples_chunk_size,
variants_chunk_size=variants_chunk_size,
variable_name=variable_name,
array_name=array_name,
)

def fixed_field_spec(
Expand All @@ -249,10 +249,10 @@ def fixed_field_spec(
chunks=[variants_chunk_size],
)

alt_col = icf.fields["ALT"]
max_alleles = alt_col.vcf_field.summary.max_number + 1
alt_field = icf.fields["ALT"]
max_alleles = alt_field.vcf_field.summary.max_number + 1

colspecs = [
array_specs = [
fixed_field_spec(
name="variant_contig",
dtype=core.min_int_dtype(0, icf.metadata.num_contigs),
Expand Down Expand Up @@ -281,27 +281,29 @@ def fixed_field_spec(
name_map = {field.full_name: field for field in icf.metadata.fields}

# Only two of the fixed fields have a direct one-to-one mapping.
colspecs.extend(
array_specs.extend(
[
spec_from_field(name_map["QUAL"], variable_name="variant_quality"),
spec_from_field(name_map["POS"], variable_name="variant_position"),
spec_from_field(name_map["QUAL"], array_name="variant_quality"),
spec_from_field(name_map["POS"], array_name="variant_position"),
]
)
colspecs.extend([spec_from_field(field) for field in icf.metadata.info_fields])
array_specs.extend(
[spec_from_field(field) for field in icf.metadata.info_fields]
)

gt_field = None
for field in icf.metadata.format_fields:
if field.name == "GT":
gt_field = field
continue
colspecs.append(spec_from_field(field))
array_specs.append(spec_from_field(field))

if gt_field is not None:
ploidy = gt_field.summary.max_number - 1
shape = [m, n]
chunks = [variants_chunk_size, samples_chunk_size]
dimensions = ["variants", "samples"]
colspecs.append(
array_specs.append(
ZarrArraySpec.new(
vcf_field=None,
name="call_genotype_phased",
Expand All @@ -314,7 +316,7 @@ def fixed_field_spec(
)
shape += [ploidy]
dimensions += ["ploidy"]
colspecs.append(
array_specs.append(
ZarrArraySpec.new(
vcf_field=None,
name="call_genotype",
Expand All @@ -325,7 +327,7 @@ def fixed_field_spec(
description="",
)
)
colspecs.append(
array_specs.append(
ZarrArraySpec.new(
vcf_field=None,
name="call_genotype_mask",
Expand All @@ -341,7 +343,7 @@ def fixed_field_spec(
format_version=ZARR_SCHEMA_FORMAT_VERSION,
samples_chunk_size=samples_chunk_size,
variants_chunk_size=variants_chunk_size,
fields=colspecs,
fields=array_specs,
samples=icf.metadata.samples,
contigs=icf.metadata.contigs,
filters=icf.metadata.filters,
Expand Down Expand Up @@ -583,28 +585,28 @@ def encode_filter_id(self, root):
)
array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]

def init_array(self, root, variable, variants_dim_size):
def init_array(self, root, array_spec, variants_dim_size):
object_codec = None
if variable.dtype == "O":
if array_spec.dtype == "O":
object_codec = numcodecs.VLenUTF8()
shape = list(variable.shape)
shape = list(array_spec.shape)
# Truncate the variants dimension is max_variant_chunks was specified
shape[0] = variants_dim_size
a = root.empty(
variable.name,
array_spec.name,
shape=shape,
chunks=variable.chunks,
dtype=variable.dtype,
compressor=numcodecs.get_codec(variable.compressor),
filters=[numcodecs.get_codec(filt) for filt in variable.filters],
chunks=array_spec.chunks,
dtype=array_spec.dtype,
compressor=numcodecs.get_codec(array_spec.compressor),
filters=[numcodecs.get_codec(filt) for filt in array_spec.filters],
object_codec=object_codec,
dimension_separator=self.metadata.dimension_separator,
)
a.attrs.update(
{
"description": variable.description,
"description": array_spec.description,
# Dimension names are part of the spec in Zarr v3
"_ARRAY_DIMENSIONS": variable.dimensions,
"_ARRAY_DIMENSIONS": array_spec.dimensions,
}
)
logger.debug(f"Initialised {a}")
Expand Down Expand Up @@ -644,9 +646,9 @@ def encode_partition(self, partition_index):
self.encode_filters_partition(partition_index)
self.encode_contig_partition(partition_index)
self.encode_alleles_partition(partition_index)
for col in self.schema.fields:
if col.vcf_field is not None:
self.encode_array_partition(col, partition_index)
for array_spec in self.schema.fields:
if array_spec.vcf_field is not None:
self.encode_array_partition(array_spec, partition_index)
if self.has_genotypes():
self.encode_genotypes_partition(partition_index)

Expand All @@ -672,21 +674,21 @@ def init_partition_array(self, partition_index, name):
def finalise_partition_array(self, partition_index, name):
logger.debug(f"Encoded {name} partition {partition_index}")

def encode_array_partition(self, column, partition_index):
array = self.init_partition_array(partition_index, column.name)
def encode_array_partition(self, array_spec, partition_index):
array = self.init_partition_array(partition_index, array_spec.name)

partition = self.metadata.partitions[partition_index]
ba = core.BufferedArray(array, partition.start)
source_col = self.icf.fields[column.vcf_field]
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
source_field = self.icf.fields[array_spec.vcf_field]
sanitiser = source_field.sanitiser_factory(ba.buff.shape)

for value in source_col.iter_values(partition.start, partition.stop):
for value in source_field.iter_values(partition.start, partition.stop):
# We write directly into the buffer in the sanitiser function
# to make it easier to reason about dimension padding
j = ba.next_buffer_row()
sanitiser(ba.buff, j, value)
ba.flush()
self.finalise_partition_array(partition_index, column.name)
self.finalise_partition_array(partition_index, array_spec.name)

def encode_genotypes_partition(self, partition_index):
gt_array = self.init_partition_array(partition_index, "call_genotype")
Expand All @@ -700,8 +702,8 @@ def encode_genotypes_partition(self, partition_index):
gt_mask = core.BufferedArray(gt_mask_array, partition.start)
gt_phased = core.BufferedArray(gt_phased_array, partition.start)

source_col = self.icf.fields["FORMAT/GT"]
for value in source_col.iter_values(partition.start, partition.stop):
source_field = self.icf.fields["FORMAT/GT"]
for value in source_field.iter_values(partition.start, partition.stop):
j = gt.next_buffer_row()
icf.sanitise_value_int_2d(gt.buff, j, value[:, :-1])
j = gt_phased.next_buffer_row()
Expand All @@ -723,12 +725,12 @@ def encode_alleles_partition(self, partition_index):
alleles_array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
alleles = core.BufferedArray(alleles_array, partition.start)
ref_col = self.icf.fields["REF"]
alt_col = self.icf.fields["ALT"]
ref_field = self.icf.fields["REF"]
alt_field = self.icf.fields["ALT"]

for ref, alt in zip(
ref_col.iter_values(partition.start, partition.stop),
alt_col.iter_values(partition.start, partition.stop),
ref_field.iter_values(partition.start, partition.stop),
alt_field.iter_values(partition.start, partition.stop),
):
j = alleles.next_buffer_row()
alleles.buff[j, :] = constants.STR_FILL
Expand All @@ -744,9 +746,9 @@ def encode_id_partition(self, partition_index):
partition = self.metadata.partitions[partition_index]
vid = core.BufferedArray(vid_array, partition.start)
vid_mask = core.BufferedArray(vid_mask_array, partition.start)
col = self.icf.fields["ID"]
field = self.icf.fields["ID"]

for value in col.iter_values(partition.start, partition.stop):
for value in field.iter_values(partition.start, partition.stop):
j = vid.next_buffer_row()
k = vid_mask.next_buffer_row()
assert j == k
Expand All @@ -769,8 +771,8 @@ def encode_filters_partition(self, partition_index):
partition = self.metadata.partitions[partition_index]
var_filter = core.BufferedArray(array, partition.start)

col = self.icf.fields["FILTERS"]
for value in col.iter_values(partition.start, partition.stop):
field = self.icf.fields["FILTERS"]
for value in field.iter_values(partition.start, partition.stop):
j = var_filter.next_buffer_row()
var_filter.buff[j] = False
for f in value:
Expand All @@ -790,9 +792,9 @@ def encode_contig_partition(self, partition_index):
array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
contig = core.BufferedArray(array, partition.start)
col = self.icf.fields["CHROM"]
field = self.icf.fields["CHROM"]

for value in col.iter_values(partition.start, partition.stop):
for value in field.iter_values(partition.start, partition.stop):
j = contig.next_buffer_row()
# Note: because we are using the indexes to define the lookups
# and we always have an index, it seems that we the contig lookup
Expand Down Expand Up @@ -880,8 +882,8 @@ def get_max_encoding_memory(self):
Return the approximate maximum memory used to encode a variant chunk.
"""
max_encoding_mem = 0
for col in self.schema.fields:
max_encoding_mem = max(max_encoding_mem, col.variant_chunk_nbytes)
for array_spec in self.schema.fields:
max_encoding_mem = max(max_encoding_mem, array_spec.variant_chunk_nbytes)
gt_mem = 0
if self.has_genotypes:
gt_mem = sum(
Expand Down Expand Up @@ -921,9 +923,9 @@ def encode_all_partitions(
num_workers = min(max_num_workers, worker_processes)

total_bytes = 0
for col in self.schema.fields:
for array_spec in self.schema.fields:
# Open the array definition to get the total size
total_bytes += zarr.open(self.arrays_path / col.name).nbytes
total_bytes += zarr.open(self.arrays_path / array_spec.name).nbytes

progress_config = core.ProgressConfig(
total=total_bytes,
Expand Down
22 changes: 11 additions & 11 deletions tests/test_icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def test_mkschema(self, tmp_path, icf):

def test_summary_table(self, icf):
data = icf.summary_table()
cols = [d["name"] for d in data]
assert tuple(sorted(cols)) == self.fields
fields = [d["name"] for d in data]
assert tuple(sorted(fields)) == self.fields

def test_inspect(self, icf):
assert icf.summary_table() == vcf2zarr.inspect(icf.path)
Expand Down Expand Up @@ -111,10 +111,10 @@ def test_init_paths(self, tmp_path):
assert icf_path.exists()
wip_path = icf_path / "wip"
assert wip_path.exists()
for column in self.fields:
col_path = icf_path / column
assert col_path.exists()
assert col_path.is_dir()
for field_name in self.fields:
field_path = icf_path / field_name
assert field_path.exists()
assert field_path.is_dir()

def test_finalise_paths(self, tmp_path):
icf_path = tmp_path / "x.icf"
Expand Down Expand Up @@ -427,8 +427,8 @@ def test_partition_record_index(self, icf):
)

def test_pos_values(self, icf):
col = icf["POS"]
pos = np.array([v[0] for v in col.values])
field = icf["POS"]
pos = np.array([v[0] for v in field.values])
# Check the actual values here to make sure other tests make sense
actual = np.hstack([1 + np.arange(933) for _ in range(5)])
nt.assert_array_equal(pos, actual)
Expand Down Expand Up @@ -465,9 +465,9 @@ def test_pos_chunk_records(self, icf):
],
)
def test_slice(self, icf, start, stop):
col = icf["POS"]
pos = np.array(col.values)
pos_slice = np.array(list(col.iter_values(start, stop)))
field = icf["POS"]
pos = np.array(field.values)
pos_slice = np.array(list(field.iter_values(start, stop)))
nt.assert_array_equal(pos[start:stop], pos_slice)


Expand Down
6 changes: 3 additions & 3 deletions tests/test_vcf_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def test_no_genotypes(self, ds, tmp_path):
vcf2zarr.convert([path], out)
ds2 = sg.load_dataset(out)
assert len(ds2["sample_id"]) == 0
for col in ds:
if col != "sample_id" and not col.startswith("call_"):
xt.assert_equal(ds[col], ds2[col])
for field_name in ds:
if field_name != "sample_id" and not field_name.startswith("call_"):
xt.assert_equal(ds[field_name], ds2[field_name])

@pytest.mark.parametrize(
("variants_chunk_size", "samples_chunk_size", "y_chunks", "x_chunks"),
Expand Down
Loading
Loading