Skip to content

Commit

Permalink
split completion text to sequence_len (#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 22, 2023
1 parent 2844eb2 commit 97d3776
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 59 deletions.
5 changes: 5 additions & 0 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,15 @@ def __init__( # pylint: disable=super-init-not-called
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
**map_kwargs,
)


Expand Down
78 changes: 75 additions & 3 deletions src/axolotl/prompt_strategies/completion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,81 @@
"""
Basic completion text
"""
from typing import Any, Dict, Optional
from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple

from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
from axolotl.prompters import CompletionPrompter
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy


class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""

_field: str = "text"

def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length is not None:
self.max_length = max_length

@property
def supports_batched(self):
return True

@property
def field(self) -> str:
return self._field

@field.setter
def field(self, new_field: str):
self._field = new_field

def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)

def tokenize_prompt(self, prompt):
res = defaultdict(lambda: [])
feature_names = list(prompt.keys())
for row in zip(*prompt.values()):
prompt_row = dict(zip(feature_names, row))
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt_row)

full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)

for key, val in tokenized_full_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])

return dict(res)

def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction, input, response)))


class CompletionPrompter:
"""
Prompter for completion
"""

def build_prompt(
self,
instruction: str,
input=None, # pylint: disable=redefined-builtin, unused-argument
output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
Expand All @@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]
Expand Down
49 changes: 7 additions & 42 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@ def __init__(
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
self.max_length = sequence_len

@abc.abstractmethod
def tokenize_prompt(self, prompt):
pass

@property
def supports_batched(self):
return False

@functools.lru_cache(maxsize=128)
def _get_user_token(self):
try:
Expand Down Expand Up @@ -77,7 +82,7 @@ def _tokenize(
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
max_length=self.max_length,
padding=False,
return_tensors=None,
)
Expand All @@ -86,7 +91,7 @@ def _tokenize(
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and len(result["input_ids"]) < self.max_length
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
Expand Down Expand Up @@ -247,46 +252,6 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
)


class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""

_field: str = "text"

@property
def field(self) -> str:
return self._field

@field.setter
def field(self, new_field: str):
self._field = new_field

def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)

def tokenize_prompt(self, prompt):
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt)

full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)

return tokenized_full_prompt

def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction, input, response)))


class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for Reflection prompts.
Expand Down
14 changes: 0 additions & 14 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,6 @@ def match_prompt_style(self):
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"


class CompletionPrompter:
"""
Prompter for completion
"""

def build_prompt(
self,
instruction: str,
input=None, # pylint: disable=redefined-builtin, unused-argument
output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction


class GPTeacherPrompter(AlpacaPrompter):
"""
Prompter for GPTeacher
Expand Down

0 comments on commit 97d3776

Please sign in to comment.