Skip to content

Commit

Permalink
Fix up region size check
Browse files Browse the repository at this point in the history
Closes #146

note
  • Loading branch information
jeromekelleher committed May 1, 2024
1 parent 2b0d1a4 commit 6038115
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 23 deletions.
70 changes: 54 additions & 16 deletions bio2zarr/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class VcfPartition:
num_records: int = -1


# TODO bump this before current PR is done!
ICF_METADATA_FORMAT_VERSION = "0.2"
ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
Expand Down Expand Up @@ -903,6 +904,40 @@ def num_columns(self):
return len(self.columns)


@dataclasses.dataclass
class IcfPartitionMetadata:
num_records: int
last_position: int
field_summaries: dict

def asdict(self):
return dataclasses.asdict(self)

def asjson(self):
return json.dumps(self.asdict(), indent=4)

@staticmethod
def fromdict(d):
md = IcfPartitionMetadata(**d)
for k, v in md.field_summaries.items():
md.field_summaries[k] = VcfFieldSummary.fromdict(v)
return md


def check_overlapping_partitions(partitions):
for i in range(1, len(partitions)):
prev_region = partitions[i - 1].region
current_region = partitions[i].region
if prev_region.contig == current_region.contig:
assert prev_region.end is not None
# Regions are *inclusive*
if prev_region.end >= current_region.start:
raise ValueError(
f"Overlapping VCF regions in partitions {i - 1} and {i}: "
f"{prev_region} and {current_region}"
)


class IntermediateColumnarFormatWriter:
def __init__(self, path):
self.path = pathlib.Path(path)
Expand Down Expand Up @@ -974,11 +1009,8 @@ def load_partition_summaries(self):
not_found = []
for j in range(self.num_partitions):
try:
with open(self.wip_path / f"p{j}_summary.json") as f:
summary = json.load(f)
for k, v in summary["field_summaries"].items():
summary["field_summaries"][k] = VcfFieldSummary.fromdict(v)
summaries.append(summary)
with open(self.wip_path / f"p{j}.json") as f:
summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
except FileNotFoundError:
not_found.append(j)
if len(not_found) > 0:
Expand All @@ -995,7 +1027,7 @@ def load_metadata(self):

def process_partition(self, partition_index):
self.load_metadata()
summary_path = self.wip_path / f"p{partition_index}_summary.json"
summary_path = self.wip_path / f"p{partition_index}.json"
# If someone is rewriting a summary path (for whatever reason), make sure it
# doesn't look like it's already been completed.
# NOTE to do this properly we probably need to take a lock on this file - but
Expand All @@ -1016,6 +1048,7 @@ def process_partition(self, partition_index):
else:
format_fields.append(field)

last_position = None
with IcfPartitionWriter(
self.metadata,
self.path,
Expand All @@ -1025,6 +1058,7 @@ def process_partition(self, partition_index):
num_records = 0
for variant in ivcf.variants(partition.region):
num_records += 1
last_position = variant.POS
tcw.append("CHROM", variant.CHROM)
tcw.append("POS", variant.POS)
tcw.append("QUAL", variant.QUAL)
Expand All @@ -1049,15 +1083,16 @@ def process_partition(self, partition_index):
f"flushing buffers"
)

partition_metadata = {
"num_records": num_records,
"field_summaries": {k: v.asdict() for k, v in tcw.field_summaries.items()},
}
partition_metadata = IcfPartitionMetadata(
num_records=num_records,
last_position=last_position,
field_summaries=tcw.field_summaries,
)
with open(summary_path, "w") as f:
json.dump(partition_metadata, f, indent=4)
f.write(partition_metadata.asjson())
logger.info(
f"Finish p{partition_index} {partition.vcf_path}__{partition.region}="
f"{num_records} records"
f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
f"{num_records} records last_pos={last_position}"
)

def explode(self, *, worker_processes=1, show_progress=False):
Expand Down Expand Up @@ -1099,8 +1134,9 @@ def finalise(self):
partition_summaries = self.load_partition_summaries()
total_records = 0
for index, summary in enumerate(partition_summaries):
partition_records = summary["num_records"]
partition_records = summary.num_records
self.metadata.partitions[index].num_records = partition_records
self.metadata.partitions[index].region.end = summary.last_position
total_records += partition_records
if not np.isinf(self.metadata.num_records):
# Note: this is just telling us that there's a bug in the
Expand All @@ -1110,9 +1146,11 @@ def finalise(self):
assert total_records == self.metadata.num_records
self.metadata.num_records = total_records

check_overlapping_partitions(self.metadata.partitions)

for field in self.metadata.fields:
for summary in partition_summaries:
field.summary.update(summary["field_summaries"][field.full_name])
field.summary.update(summary.field_summaries[field.full_name])

logger.info("Finalising metadata")
with open(self.path / "metadata.json", "w") as f:
Expand Down Expand Up @@ -1756,7 +1794,7 @@ def encode_partition(self, partition_index):
final_path = self.partition_path(partition_index)
logger.info(f"Finalising {partition_index} at {final_path}")
if final_path.exists():
logger.warning("Removing existing partition at {final_path}")
logger.warning(f"Removing existing partition at {final_path}")
shutil.rmtree(final_path)
os.rename(partition_path, final_path)

Expand Down
6 changes: 5 additions & 1 deletion bio2zarr/vcf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def read_bytes_as_tuple(f: IO[Any], fmt: str) -> Sequence[Any]:

@dataclass
class Region:
"""
A htslib style region, where coordinates are 1-based and inclusive.
"""

contig: str
start: Optional[int] = None
end: Optional[int] = None
Expand All @@ -86,7 +90,7 @@ def __post_init__(self):
assert self.start > 0
if self.end is not None:
self.end = int(self.end)
assert self.end > self.start
assert self.end >= self.start

def __str__(self):
s = f"{self.contig}"
Expand Down
18 changes: 14 additions & 4 deletions tests/test_icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_finalise_missing_partition_fails(self, tmp_path, partition):
def test_explode_partition(self, tmp_path, partition):
icf_path = tmp_path / "x.icf"
vcf.explode_init(icf_path, [self.data_path])
summary_file = icf_path / "wip" / f"p{partition}_summary.json"
summary_file = icf_path / "wip" / f"p{partition}.json"
assert not summary_file.exists()
vcf.explode_partition(icf_path, partition)
assert summary_file.exists()
Expand All @@ -156,12 +156,12 @@ def test_double_explode_partition(self, tmp_path):
partition = 1
icf_path = tmp_path / "x.icf"
vcf.explode_init(icf_path, [self.data_path])
summary_file = icf_path / "wip" / f"p{partition}_summary.json"
summary_file = icf_path / "wip" / f"p{partition}.json"
assert not summary_file.exists()
vcf.explode_partition(icf_path, partition, worker_processes=0)
vcf.explode_partition(icf_path, partition)
with open(summary_file) as f:
s1 = f.read()
vcf.explode_partition(icf_path, partition, worker_processes=0)
vcf.explode_partition(icf_path, partition)
with open(summary_file) as f:
s2 = f.read()
assert s1 == s2
Expand All @@ -173,6 +173,16 @@ def test_explode_partition_out_of_range(self, tmp_path, partition):
with pytest.raises(ValueError, match="Partition index must be in the range"):
vcf.explode_partition(icf_path, partition)

def test_explode_same_file_twice(self, tmp_path):
icf_path = tmp_path / "x.icf"
with pytest.raises(ValueError, match="Duplicate path provided"):
vcf.explode(icf_path, [self.data_path, self.data_path])

def test_explode_same_data_twice(self, tmp_path):
icf_path = tmp_path / "x.icf"
with pytest.raises(ValueError, match="Overlapping VCF regions"):
vcf.explode(icf_path, [self.data_path, "tests/data/vcf/sample.bcf"])


class TestGeneratedFieldsExample:
data_path = "tests/data/vcf/field_type_combos.vcf.gz"
Expand Down
6 changes: 4 additions & 2 deletions tests/test_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def test_call_GQ(self, schema):
[("1", 100, 200), ("1", 150, 250)],
# Overlap by one position
[("1", 100, 201), ("1", 200, 300)],
# End coord is *inclusive*
[("1", 100, 201), ("1", 201, 300)],
# Contained overlap
[("1", 100, 300), ("1", 150, 250)],
# Exactly equal
Expand All @@ -345,8 +347,8 @@ def test_check_overlap(regions):
vcf.VcfPartition("", region=vcf_utils.Region(contig, start, end))
for contig, start, end in regions
]
with pytest.raises(ValueError, match="Multiple VCFs have the region"):
vcf.check_overlap(partitions)
with pytest.raises(ValueError, match="Overlapping VCF regions"):
vcf.check_overlapping_partitions(partitions)


class TestVcfDescriptions:
Expand Down

0 comments on commit 6038115

Please sign in to comment.