Skip to content

Commit

Permalink
Ready
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 18, 2024
1 parent 9cfc5ea commit eed7bce
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
20 changes: 10 additions & 10 deletions examples/config_multilingual_nanoset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ data_stages:
dataset:
training_folder: datasets/c4-es/train
validation_folder: datasets/c4-es/validation
dataset_tokens:
- 15
lang_to_ids:
es: 128002
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
Expand All @@ -25,10 +25,10 @@ data_stages:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
dataset_tokens:
- 15
- 16
- 17
lang_to_ids:
es: 128002
en: 128003
fr: 128004
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
Expand All @@ -43,10 +43,10 @@ data_stages:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
dataset_tokens:
- 15
- 16
- 17
lang_to_ids:
es: 128002
en: 128003
fr: 128004

num_loading_workers: 1
seed: 42
Expand Down
11 changes: 8 additions & 3 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __post_init__(self):
class MultilingualNanosetDatasetsArgs:
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, List[str]]
dataset_tokens: List[int] # Set token for each language previously defined
lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order

def __post_init__(self):
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
Expand All @@ -125,8 +125,13 @@ def __post_init__(self):
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())

assert len(self.training_folder) == len(self.validation_folder)
assert len(self.training_folder) == len(self.dataset_tokens)
self.dataset_tokens = list(self.lang_to_ids.values())
assert len(self.training_folder) == len(
self.validation_folder
), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"
assert len(self.training_folder) == len(
self.dataset_tokens
), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})"


@dataclass
Expand Down

0 comments on commit eed7bce

Please sign in to comment.