Skip to content

Commit

Permalink
revert removing sequence_len
Browse files Browse the repository at this point in the history
  • Loading branch information
seungduk-yanolja committed Oct 11, 2023
1 parent d78cb50 commit 46c7027
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/axolotl/prompt_strategies/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def tokenize_prompt(self, prompt):
tokenized_full_prompt = self._tokenize(full_prompt)

for key, val in tokenized_full_prompt.items():
for i in range(0, len(val), self.max_length):
res[key].append(val[i : i + self.max_length])
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])

return dict(res)

Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/prompt_strategies/metharme.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _tokenize(
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
Expand All @@ -43,7 +43,7 @@ def _tokenize(

if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
for _ in range(num_eos_tokens):
if len(result["input_ids"]) < self.max_length:
if len(result["input_ids"]) < self.sequence_len:
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)

Expand Down
7 changes: 5 additions & 2 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(
self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
# TODO: Document how they are different.
self.sequence_len = sequence_len
self.max_length = sequence_len

@abc.abstractmethod
Expand Down Expand Up @@ -310,13 +313,13 @@ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.max_length
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
Expand Down

0 comments on commit 46c7027

Please sign in to comment.