Skip to content

Commit

Permalink
added minhash tests and bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jul 25, 2023
1 parent b480195 commit cb8cf37
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 22 deletions.
1 change: 0 additions & 1 deletion src/datatrove/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def open(self, mode: str = "w", gzip: bool = False, overwrite: bool = False):
if not self.file_handler or overwrite:
print(self.local_path, os.path.dirname(self.local_path))
os.makedirs(os.path.dirname(self.local_path), exist_ok=True)
self.file_handler = open(self.local_path, mode) if not gzip else gzip_lib.open(self.local_path, mode)
return self
Expand Down
50 changes: 29 additions & 21 deletions src/datatrove/pipeline/dedup/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
_max_hash = np.uint64((1 << 32) - 1)
_hash_range = 1 << 32

DEFAULT_NR_BUCKETS = 20
DEFAULT_PER_BUCKET = 20
"""
n_grams -> roughly nr of words (this should be small enough to catch fuzzy matches but big enough to not have each shingle be too common)
threshold is (1/8)^(1/14)~0.72
threshold is real minhash similarity cutoff for high probability inclusion by LSH minhash
probability of inclusion for s=0.8: 1-(1-0.8^8)^14=0.924
"""

DEFAULT_NR_BUCKETS = 14
DEFAULT_PER_BUCKET = 8
DEFAULT_N_GRAMS = 5


@dataclass
class HashSig:
sig: list[int]
sig: tuple[int]
doc_id: int
file_id: int

Expand Down Expand Up @@ -79,20 +86,23 @@ def get_signature(self, shingles):
def set_up_dl_locks(self, dl_lock, up_lock):
self.output_folder.set_lock(up_lock)

def get_shingles(self, text):
return np.array(
[
[sha1_hash32(" ".join(x).encode("utf-8"))]
for x in ngrams(word_tokenize(simplify_content(text)), self.n_grams)
],
dtype=np.uint64,
)

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
buckets = [
self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="wb")
for bi in range(self.num_buckets)
]
for doc_idx, doc in enumerate(data):
self.stat_update(StatHints.total)
shingles = np.array(
[
[sha1_hash32(" ".join(x).encode("utf-8"))]
for x in ngrams(word_tokenize(simplify_content(doc.content)), self.n_grams)
],
dtype=np.uint64,
)
shingles = self.get_shingles(doc.content)
if shingles.size != 0:
sig = self.get_signature(shingles)
for bi, (bucket, bucket_sig) in enumerate(zip(buckets, sig)):
Expand Down Expand Up @@ -132,17 +142,18 @@ def read_sigs(self, file: InputDataFile, file_id: int) -> Generator:
n = self.hashes_per_bucket + 1
with file.open(binary=True) as f:
while True:
data = f.read(n * 4)
data = f.read(n * struct.calcsize("I"))
if not data:
return
data = struct.unpack("<%sI" % n, n)[0]
data = struct.unpack("<%sI" % n, data)
yield HashSig(sig=data[:-1], doc_id=data[-1], file_id=file_id)

def set_up_dl_locks(self, dl_lock, up_lock):
self.input_folder.set_lock(dl_lock)
self.output_folder.set_lock(up_lock)

def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1):
assert data is None, "You should not use an input block before MinhashDedupBuckets"
assert world_size == self.num_buckets, "You must run exactly one task per bucket"
sig_files = self.input_folder.list_files(suffix=f"bucket_{bucket:03d}")
sig_readers = [self.read_sigs(file, file_i) for file_i, file in enumerate(sig_files)]
Expand Down Expand Up @@ -197,12 +208,13 @@ def parent(x):

for dup_file in dup_files:
with dup_file.open(binary=True) as df:
while data := df.read(4 * 4):
f1, d1, f2, d2 = struct.unpack("<4I", data)[0]
while data := df.read(4 * struct.calcsize("I")):
f1, d1, f2, d2 = struct.unpack("<4I", data)
a, b = (f1, d1), (f2, d2)
union_set[parent(a)] = parent(b)

for node, p in sorted(union_set.items()):
for node in sorted(union_set.keys()):
p = parent(node)
if node != p:
file, doc = node
self.output_folder.open(f"{file:06d}.remove", mode="wb").write(struct.pack("<I", doc))
Expand All @@ -216,14 +228,10 @@ class MinhashDedupFilter(PipelineStep):
def __init__(
self,
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
num_buckets: int = DEFAULT_NR_BUCKETS,
**kwargs,
):
super().__init__(**kwargs)
self.data_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets

def set_up_dl_locks(self, dl_lock, up_lock):
self.data_folder.set_lock(dl_lock)
Expand All @@ -235,9 +243,9 @@ def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
with remove_data[0].open_binary() as f:

def get_next():
data = f.read(4)
data = f.read(struct.calcsize("I"))
if data:
return struct.unpack("<I", data)
return struct.unpack("<I", data)[0]

next_removal = get_next()
for idx, doc in enumerate(data):
Expand Down
157 changes: 157 additions & 0 deletions tests/pipeline/test_minhash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
import shutil
import struct
import tempfile
import unittest
from collections import defaultdict
from math import floor

from datatrove.data import Document
from datatrove.io import LocalInputDataFolder, LocalOutputDataFolder
from datatrove.pipeline.dedup.minhash import (
MinhashDedupBuckets,
MinhashDedupCluster,
MinhashDedupFilter,
MinhashDedupSignature,
)


lorem_ipsum = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam euismod vel ante vitae rhoncus. Curabitur eu lectus et magna maximus facilisis eu non magna. Maecenas sed velit vitae est ornare placerat. Vestibulum quis consectetur nunc, a feugiat lorem. Cras in ipsum fringilla, vestibulum urna sit amet, viverra tortor. Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Morbi euismod vestibulum elit id placerat. Fusce malesuada ultricies condimentum. Cras tincidunt eget lorem nec hendrerit. Aenean mattis arcu dolor, id semper velit ullamcorper malesuada. Aliquam non ipsum et eros venenatis aliquet. Proin eleifend interdum scelerisque. Interdum et malesuada fames ac ante ipsum primis in faucibus. Mauris nunc sapien, molestie eget convallis at, maximus nec ipsum. Morbi quam diam, blandit ut mollis at, varius eu tellus. Maecenas sem justo, porttitor at odio nec, interdum posuere ex.
Aliquam pretium ac nulla et porttitor. Nunc quis felis posuere, lobortis magna quis, imperdiet nulla. Maecenas tempor, mi vel vestibulum tempus, arcu elit scelerisque erat, eu molestie velit eros id metus. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Proin fringilla viverra urna eu dictum. Suspendisse interdum, leo non sagittis pulvinar, mauris est euismod ante, et luctus nisi est id odio. Proin purus nunc, feugiat id condimentum eu, efficitur at diam. Quisque aliquet felis non risus rutrum, eu porttitor justo venenatis. Donec ut est felis. Proin risus augue, gravida fermentum elementum eu, varius eget tortor. Sed a sem quis tortor ultrices elementum. Nullam tempor rutrum ipsum id eleifend. Donec a elit tincidunt, sagittis magna a, pretium felis.
Sed ut faucibus dui, a sodales nunc. Praesent fermentum diam quis augue porttitor vulputate. Mauris pretium ipsum ut erat hendrerit eleifend. Maecenas consectetur elit a ligula auctor porta. Duis vitae diam ac velit porttitor tempor. Quisque at arcu a augue dictum molestie et sed erat. Donec fermentum pulvinar elit. Morbi orci nisi, condimentum non tincidunt sit amet, hendrerit non justo. Ut fringilla dolor neque, ut porttitor ante congue vitae. In aliquam augue a sapien sodales ornare. In in maximus nunc. Quisque turpis nibh, commodo non semper non, hendrerit a massa.
Suspendisse potenti. Ut feugiat nibh ex. Nunc eget ligula ut massa tempus pretium vitae et mauris. Suspendisse potenti. Vivamus euismod ipsum est, id consectetur lorem suscipit non. Integer ac felis egestas risus ornare luctus nec vel massa. Donec scelerisque enim eu nulla commodo fringilla. Donec et pulvinar dolor, sit amet tristique risus. Cras et est id leo malesuada sollicitudin.
Quisque et aliquet diam. Aenean euismod efficitur enim, non semper eros. Nullam molestie vehicula eros, nec porttitor justo feugiat nec. Maecenas fringilla eleifend augue, eu mollis arcu vulputate ac. Quisque ullamcorper turpis sed tristique dapibus. Etiam imperdiet pulvinar fringilla. Nulla sed est eget odio dictum pretium. Cras ultricies nibh libero, efficitur consequat neque semper id. Donec porttitor lacus nunc, vitae gravida lorem consectetur sit amet. Pellentesque mollis, dui nec molestie consectetur, massa enim tempus ipsum, quis pretium felis massa congue felis. Donec efficitur pretium diam, quis elementum felis eleifend quis. Nullam vehicula tortor et quam eleifend, maximus dignissim nisi feugiat. """


class TestMinhash(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.test_dir = tempfile.mkdtemp()

def tearDown(self):
# Remove the directory after the test
shutil.rmtree(self.test_dir)

def test_signatures(self):
minhash = MinhashDedupSignature(
output_folder=LocalOutputDataFolder(os.path.join(self.test_dir, "signatures1")),
)
shingles = minhash.get_shingles(lorem_ipsum)
sig = minhash.get_signature(shingles)

minhash2 = MinhashDedupSignature(
output_folder=LocalOutputDataFolder(os.path.join(self.test_dir, "signatures2"))
)
# check consistency
assert sig == minhash2.get_signature(shingles)

# check correct number of outputs
assert len(sig) == minhash.num_buckets
assert all([len(x) == minhash.hashes_per_bucket for x in sig])

# check similarity approximation
for pctd in range(0, 100, 5):
dec = pctd / 100
endp = floor(len(lorem_ipsum) * dec)
textd = lorem_ipsum[:endp] + lorem_ipsum[len(lorem_ipsum) - 1 : endp : -1]
sigd = minhash.get_signature(minhash.get_shingles(textd))
simil = sum([1 if a == b else 0 for ba, bb in zip(sig, sigd) for a, b in zip(ba, bb)]) / minhash.num_hashes
assert dec - 0.2 < simil < dec + 0.2

# check output file format and order
samples = [Document(f"sample {i}, {lorem_ipsum[i:: 10]}", data_id="test") for i in range(100)]
minhash(samples)
for bi in range(minhash.num_buckets):
with open(os.path.join(minhash.output_folder.path, f"bucket_{bi:03d}", "00000.minhash.sig"), "rb") as f:
prev = None
doc_ids = set()
S = struct.calcsize("I")
for di in range(100):
data = struct.unpack("<%sI" % minhash.hashes_per_bucket, f.read(minhash.hashes_per_bucket * S))
doc_id = struct.unpack("<I", f.read(S))[0]
# ensure sorted order
assert prev is None or data >= prev
prev = data
assert 0 <= doc_id < 100
doc_ids.add(doc_id)
assert len(doc_ids) == 100

def test_buckets_and_cluster(self):
sigs_folder = os.path.join(self.test_dir, "b_signatures")
buckets_folder = os.path.join(self.test_dir, "b_buckets")
clusters_folder = os.path.join(self.test_dir, "b_clusters")

signatures_block = MinhashDedupSignature(output_folder=LocalOutputDataFolder(sigs_folder))
buckets_block = MinhashDedupBuckets(
input_folder=LocalInputDataFolder(sigs_folder),
output_folder=LocalOutputDataFolder(buckets_folder),
)

clusters = [[0, 20, 50], [400, 420], [800, 810, 820, 840, 860], [1200, 1215, 1225, 1245], [1600], [2000]]

cluster_samples = [
Document(content=lorem_ipsum[x : x + 300], data_id=f"{ci}_{xi}", metadata={"ci": ci, "xi": xi})
for ci, cluster in enumerate(clusters)
for xi, x in enumerate(cluster)
]

signatures_block(cluster_samples)
# test file read
for fi, file in enumerate(buckets_block.input_folder.list_files()):
last = None
for sig in buckets_block.read_sigs(file, fi):
assert 0 <= sig.doc_id < 100
assert last is None or sig.sig >= last
assert len(sig.sig) == buckets_block.hashes_per_bucket
last = sig.sig

# test duplicate pairs
for b in range(buckets_block.num_buckets):
buckets_block(None, bucket=b, world_size=buckets_block.num_buckets)
bucket_results_folder = LocalInputDataFolder(buckets_folder)
dup_files = bucket_results_folder.list_files(extension=".dups")
pairs = defaultdict(set)
for dup_file in dup_files:
with dup_file.open(binary=True) as df:
while data := df.read(4 * struct.calcsize("I")):
f1, d1, f2, d2 = struct.unpack("<4I", data)
assert f1 == f2 == 0
assert cluster_samples[d1].metadata["ci"] == cluster_samples[d2].metadata["ci"]
pairs[d1].add(d2)
pairs[d2].add(d1)
doc_id = 0
for cluster in clusters:
for a in range(doc_id, doc_id + len(cluster)):
assert len(cluster) < 2 or any(a in pairs[b] for b in range(doc_id, doc_id + len(cluster)) if a != b)
doc_id += len(cluster)

# clustering
cluster_block = MinhashDedupCluster(bucket_results_folder, LocalOutputDataFolder(clusters_folder))
cluster_block(None)

cluster_results_folder = LocalInputDataFolder(clusters_folder)
remove_ids = set()
with cluster_results_folder.list_files()[0].open_binary() as df:
while data := df.read(struct.calcsize("I")):
remove_ids.add(struct.unpack("<I", data)[0])
doc_id = 0
kept = set()
for ci, cluster in enumerate(clusters):
to_remove = 0
for xi, a in enumerate(range(doc_id, doc_id + len(cluster))):
if a in remove_ids:
to_remove += 1
else:
kept.add(f"{ci}_{xi}")
doc_id += len(cluster)
assert to_remove == len(cluster) - 1

# filtering
filter_block = MinhashDedupFilter(cluster_results_folder)
filtered = filter_block(cluster_samples)
filtered_ids = {x.data_id for x in filtered}
assert filtered_ids == kept

0 comments on commit cb8cf37

Please sign in to comment.