Skip to content

Commit

Permalink
fixes for merger
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jul 15, 2024
1 parent 359d9fa commit 1687ba3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/datatrove/pipeline/stats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do
else:
# each key in this dictionary is a suffix for the main stat
for suffix, val in value.items():
stat_name = stat if not suffix else f"{stat}_{suffix}"
stat_name = stat if not suffix else f"{stat}__{suffix}"
counters[stat_name][key] += val

doc.metadata.update(doc_stats)
Expand Down
2 changes: 2 additions & 0 deletions src/datatrove/pipeline/stats/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do

with self.output_folder.open(f"{folder}/{STATS_MERGED_NAME}", "wt") as f:
group_name = Path(folder).parent.name
if "__" in group_name:
group_name = group_name.split("__")[0]
if group_name in self.top_k_config.top_k_groups:
top_k_keys = heapq.nlargest(self.top_k_config.top_k, stat, key=lambda x: stat.get(x).n)
stat = MetricStatsDict(init={s: stat.get(s) for s in top_k_keys})
Expand Down

0 comments on commit 1687ba3

Please sign in to comment.