Skip to content

Commit

Permalink
add token and char count to histogram stats
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jul 15, 2024
1 parent c279f26 commit 359d9fa
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/datatrove/pipeline/stats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ def extract_stats(self, doc: Document) -> dict[str, int | float]:
"""
raise NotImplementedError()

def get_kv(self, doc: Document, value: STAT_TYPE, group_name: GROUP) -> tuple[str, STAT_TYPE]:
def get_kv(
self, doc: Document, value: STAT_TYPE, group_name: GROUP
) -> tuple[str, STAT_TYPE | dict[str, STAT_TYPE]]:
if group_name == "histogram":
# Use rounding to reduce then number of values for histogram
return str(round(value, self.histogram_round_digits)), 1
return str(round(value, self.histogram_round_digits)), {
"": 1,
"chars": len(doc.text),
**({"tokens": doc.metadata["token_count"]} if "token_count" in doc.metadata else {}),
}
elif group_name == "summary":
return "summary", value
elif group_name == "fqdn":
Expand Down Expand Up @@ -96,7 +102,13 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do
for group, counters in groups_dicts.items():
for stat, value in doc_stats.items():
key, value = self.get_kv(doc, value, group)
counters[stat][key] += value
if not isinstance(value, dict):
counters[stat][key] += value
else:
# each key in this dictionary is a suffix for the main stat
for suffix, val in value.items():
stat_name = stat if not suffix else f"{stat}_{suffix}"
counters[stat_name][key] += val

doc.metadata.update(doc_stats)
yield doc
Expand Down

0 comments on commit 359d9fa

Please sign in to comment.