Skip to content

Commit

Permalink
[LLM Runtime] Refine Python API (#665)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel authored Nov 15, 2023
1 parent d8799cb commit 91511d2
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 74 deletions.
82 changes: 67 additions & 15 deletions intel_extension_for_transformers/llm/runtime/graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,30 +98,43 @@ outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300, ctx_size

https://github.com/intel/intel-extension-for-transformers/assets/109187816/1698dcda-c9ec-4f44-b159-f4e9d67ab15b

Argument description of WeightOnlyQuantConfig:
| Argument | Type | Description |
| -------------- | ---------- | ----------------------------------------------------------------------- |
| compute_dtype | String | Data type of Gemm computation: int8/bf16/fp32 (default: int8) |
| weight_dtype | String | Data type of quantized weight: int4/int8 (default int4) |
| alg | String | Quantization algorithm: sym/asym (default sym) |
| group_size | Int | Group size: Int (default: 32) |
| scale_dtype | String | Data type of scales: fp32/bf16 (dafault fp32) |
| use_ggml | Bool | Enable ggml for quantization and inference (default: False) |
| not_quant | Bool | Determine whether or not the model will be quantized. (default: False) |
| use_cache | Bool | Use local quantized model if file exists (default: False) |

Argument description of generate function:
| Argument | Type | Description |
| -------------- | ---------- | ----------------------------------------------------------------------- |
| inputs | Lists[Int] | Input ids after tokenizer |
| streamer | Class | Streamer object that will be used to stream the generated sequences. (default: None) |
| interactive | Bool | Interactive mode, use history commands when True (default: False) |
| n_keep | Int | Number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| n_discard | Int | Number of tokens will be discarded (default: -1, -1 = half of tokens will be discarded) |
| shift_roped_k | Bool | Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False) |
| ignore_prompt | Bool | Generate outputs w/o prompt (default: False) |
| max_new_tokens | Int | Number of tokens to predict (default: -1, -1 = infinity) |
| batch_size | Int | Batch size for prompt processing (default: 512) |
| ctx_size | Int | Size of the prompt context (default: 512) |
| seed | Int | NG seed (default: -1, use random seed for < 0) |
| threads | Int | Number of threads to use during computation (default: 8) |
| repetition_penalty| Float | Penalize repeat sequence of tokens (default: 1.1, 1.0 = disabled) |
| num_beams | Int | Number of beams for beam_search (default: 1) |
| do_sample | Int | Whether or not to use sampling ; use greedy decoding otherwise. (default: False) |
| top_k | Int | Top-k sampling (default: 40, 0 = disabled) |
| top_p | Int | Top-p sampling (default: 0.95, 1.0 = disabled) |
| temperature | Float | Temperature (default: 0.8) |
| min_new_tokens | Int | The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. |
| length_penalty | Float | Exponential penalty to the length that is used with beam-based generation. |
| early_stopping | Bool | Controls the stopping condition for beam-based methods, like beam-search. |
| n_keep | Int | Number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| n_discard | Int | Number of tokens will be discarded (default: -1, -1 = half of tokens will be discarded) |
| shift_roped_k | Bool | Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False) |
| repetition_penalty| Float | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| num_beams | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| do_sample | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| top_k | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| top_p | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| temperature | Float | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| min_new_tokens | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| length_penalty | Float | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| early_stopping | Bool | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| max_new_tokens | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| streamer | Class | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| stopping_criteria | Class | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |

### 3. Multi-Round Chat

Expand All @@ -130,7 +143,8 @@ Chat with LLaMA2:
from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "meta-llama/Llama-2-7b-chat-hf" # or local path to model
# Please change to local path to model, llama2 does not support online conversion, currently.
model_name = "meta-llama/Llama-2-7b-chat-hf"
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
Expand Down Expand Up @@ -316,3 +330,41 @@ We support tensor parallelism strategy for distributed inference/training on mul
### 4. Contribution

You can consider adding your own models via [graph developer document](./developer_document.md).

### 5. Custom Stopping Criteria

You can customize the stopping criteria according to your own needs by processing the input_ids to determine if text generation needs to be stopped.
Here is a simple example, which requires a minimum generation length of 80 tokens. Once the `min_length` is met, encountering a terminator `eos_token_id` will end the generation.

```python
import torch
from typing import List
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnTokens(StoppingCriteria):
def __init__(self, min_length: int, start_length: int, stop_token_id: List[int]):
self.min_length = min_length
self.start_length = start_length
self.stop_token_id = stop_token_id

def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
if input_ids.shape[-1] - self.start_length > self.min_length:
for stop_id in self.stop_token_id:
if input_ids[0][input_ids.shape[-1] - 1] == stop_id:
return True
return False

stopping_criteria = StoppingCriteriaList(
[
StopOnTokens(
min_length=80,
start_length=inputs.shape[1],
stop_token_id=[tokenizer.eos_token_id],
)
]
)

outputs = model.generate(inputs, streamer=streamer, stopping_criteria=stopping_criteria)
```
101 changes: 63 additions & 38 deletions intel_extension_for_transformers/llm/runtime/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from transformers import AutoConfig
from transformers import AutoConfig, AutoTokenizer
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model
import torch
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"}
Expand Down Expand Up @@ -61,44 +61,62 @@ def __import_package(self, model_name):
raise TypeError("Unspported model type {}!".format(model_name))
self.module = cpp_model

def init(self, model_name, **kwargs):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model_type = model_maps.get(config.model_type, config.model_type)
if model_type == "chatglm" and "chatglm2" in config._name_or_path:
def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_type = model_maps.get(self.config.model_type, self.config.model_type)
if model_type == "chatglm" and "chatglm2" in self.config._name_or_path:
model_type = "chatglm2"
self.__import_package(model_type)

# 1. convert model
fp32_bin = "ne_{}_f32.bin".format(model_type)
# check cache and quantization
output_path = "runtime_outs"
if not os.path.exists(output_path):
os.makedirs(output_path)
fp32_bin = "{}/ne_{}_f32.bin".format(output_path, model_type)
quant_bin = "{}/ne_{}_q.bin".format(output_path, model_type)

if not_quant:
self.bin_file = fp32_bin
else:
self.bin_file = quant_bin
if use_cache and os.path.exists(self.bin_file):
return

convert_model(model_name, fp32_bin, "f32")
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"

# 2. quant model
quant_bin = "ne_{}_q.bin".format(model_type)
self.module.Model.quant_model(model_path = fp32_bin, out_path = quant_bin, **kwargs)
if not_quant:
print("FP32 model will be used.")
return
self.module.Model.quant_model(model_path = fp32_bin, out_path = quant_bin, **quant_kwargs)
assert os.path.exists(quant_bin), "Fail to quantize model"

self.model_type = model_type
self.bin_file = quant_bin

# clean
os.remove(fp32_bin)

def init_from_bin(self, model_name, model_path, **kwargs):
def init_from_bin(self, model_name, model_path, **generate_kwargs):
self.__import_package(model_name)
self.model = self.module.Model()
self.model.init_model(model_path, **kwargs)
if "threads" not in generate_kwargs:
threads = os.getenv("OMP_NUM_THREADS")
if threads is None:
generate_kwargs["threads"] = len(os.sched_getaffinity(0))
else:
generate_kwargs["threads"] = int(threads)
self.model.init_model(model_path, **generate_kwargs)

def quant_model(self, model_name, model_path, out_path, **kwargs):
def quant_model(self, model_name, model_path, out_path, **quant_kwargs):
self.__import_package(model_name)
self.module.Model.quant_model(model_path = model_path,
out_path = out_path, **kwargs)
out_path = out_path, **quant_kwargs)


def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, **kwargs):
def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs):
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, batch_size=input_ids.shape[0],
**kwargs)
**generate_kwargs)
self.generate_round = 0
elif not interactive:
self.model.reinit()
Expand All @@ -109,34 +127,41 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
ret = input_ids.tolist()

beam_search = False
if ("num_beams" in kwargs and kwargs["num_beams"] > 1) and not \
kwargs.get("do_sample", False):
if ("num_beams" in generate_kwargs and generate_kwargs["num_beams"] > 1) and not \
generate_kwargs.get("do_sample", False):
beam_search = True
if not beam_search:
# TODO support multi batch
assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids."

if streamer:
if beam_search:
print("ERROR, can not use streamer when use beam search for generation!")
import sys
sys.exit(1)
assert input_ids.shape[0] == 1, "Streamer only supports batch size 1."
assert beam_search == False, "ERROR, can not use streamer when use beam search for generation! \
Make sure that `num_beams` is set to 1."
if self.generate_round == 0 and not ignore_prompt:
streamer.put(input_ids)
if interactive:
self.model.reset_token_end()
while not self.is_token_end():
out = self.model.generate(input_ids = input_ids.tolist()[0])
if len(out) == 0:
break
streamer.put(torch.tensor([out]))
ret[0].extend(out)
streamer.end()
else:
response = self.model.generate_tokens(input_ids = input_ids.tolist())
assert (len(ret) == len(response))

if interactive:
self.model.reset_token_end()
out_count = 0
while True:
response = self.model.generate(input_ids = input_ids.tolist())
if len(response) == 0:
break
if streamer:
streamer.put(torch.tensor([response[0]]))
for i in range(len(response)):
ret[i].extend(response[i])

if stopping_criteria is not None:
if stopping_criteria(torch.tensor(ret), None):
break
elif ret[0][-1] == self.tokenizer.eos_token_id or \
(max_new_tokens != -1 and out_count > max_new_tokens):
break
out_count += 1
if streamer:
streamer.end()

self.generate_round += 1
return ret

Expand Down
Loading

0 comments on commit 91511d2

Please sign in to comment.