diff --git a/minillm/data_utils/lm_datasets.py b/minillm/data_utils/lm_datasets.py index 65c14b8b..3a810128 100644 --- a/minillm/data_utils/lm_datasets.py +++ b/minillm/data_utils/lm_datasets.py @@ -57,6 +57,10 @@ def _process_lm(self, i, samp, model_data, no_model_data, gen_data): prompt = None if 65535 in input_ids: source_len = np.where(input_ids==65535)[0][0] + prompt = input_ids[:source_len] #for uint16 (others) + input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0) + elif 4294967295 in input_ids: #for uint32 (qwen, gemma, and etc) + source_len = np.where(input_ids==4294967295)[0][0] prompt = input_ids[:source_len] input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0) input_ids = input_ids[:self.max_length]