-
Notifications
You must be signed in to change notification settings - Fork 715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StableLM support #616
Add StableLM support #616
Conversation
StableLM support in Optimum is currently underway. |
I know it's still a draft, but I added a commit to get it into a working state :) Example usage: import { pipeline } from '@xenova/transformers';
const generator = await pipeline('text-generation', 'Xenova/tiny-random-StableLmForCausalLM')
const output = await generator('hi')
console.log(output); It's just a randomly initialized tiny model, so the output is gibberish. I will export some larger models too. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Awesome! 😎 |
Yes that is how you do it - but luckily in this case, we already have a tiny stablelm model: https://huggingface.co/hf-internal-testing/tiny-random-StableLmForCausalLM (check out the org for the full list of them) |
Here are some of the larger models: which we can test with (see all) |
Seems to work alright! 👍 import { pipeline } from '@xenova/transformers';
const generator = await pipeline('text-generation', 'Xenova/stablelm-2-zephyr-1_6b')
const prompt = [{'role': 'user', 'content': 'Tell me a joke'}];
const inputs = generator.tokenizer.apply_chat_template(prompt, { add_generation_prompt: true, tokenize: false });
const output = await generator(inputs, { max_new_tokens: 20 })
console.log(output[0].generated_text);
// "<|user|>\nTell me a joke\n<|assistant|>\nWhy did the tomato turn red?\n\nBecause it saw the salad dressing!" |
Really cool! Will turn this into a pr. |
Thanks! I think the last thing to do is just export and test with some sequence classifier models. Is that something you'd like to work on? |
Yes I will do that. Just to check, what do you mean by export? |
export / convert to onnx 👍 |
Looks like there aren't any stablelm text-classification models on the HF Hub (other than your test of course). So I think it will be a good idea to move that to a separate PR and get this merged? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: move stablelm text classification to separate PR
Wow, seems like you can read my mind 😆 Little update, since there is currently no stablelm model for text classification on the hub, I tried training my own. import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, TrainingArguments, Trainer
import evaluate
from huggingface_hub import login
login()
dataset = "zeroshot/twitter-financial-news-sentiment"
model = "stabilityai/stablelm-2-1_6b"
dataset = load_dataset(dataset)
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_dataset = dataset.map(preprocess_function, batched=True)
id2label = {0: "Bearish", 1: "Bullish", 2: "Neutral"}
label2id = {"Bearish": 0, "Bullish": 1, "Neutral": 2}
model = AutoModelForSequenceClassification.from_pretrained(
model,
num_labels=len(id2label),
id2label=id2label,
label2id=label2id,
)
model.config.pad_token_id = model.config.eos_token_id
training_args = TrainingArguments(
output_dir="stablelm-2-1_6b-sentiment",
learning_rate=2e-5,
# TODO: will it fit 16GB?
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
# TODO: will it fit 16GB?
gradient_accumulation_steps=4,
gradient_checkpointing=True,
optim="adamw_bnb_8bit",
load_best_model_at_end=True,
push_to_hub=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
trainer.push_to_hub() |
will be added in seperate PR Co-authored-by: Joshua Lochner <admin@xenova.com>
will be added in seperate PR Co-authored-by: Joshua Lochner <admin@xenova.com>
will be added in seperate pr Co-authored-by: Joshua Lochner <admin@xenova.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming nits
Merged! Thanks for this @D4ve-R! 🤗 Example: Text generation with import { pipeline } from '@xenova/transformers';
// Create text generation pipeline
const generator = await pipeline('text-generation', 'Xenova/stablelm-2-zephyr-1_6b');
// Define the prompt and list of messages
const prompt = "Tell me a funny joke."
const messages = [
{ "role": "system", "content": "You are a helpful assistant." },
{ "role": "user", "content": prompt },
]
// Apply chat template
const inputs = generator.tokenizer.apply_chat_template(messages, {
tokenize: false,
add_generation_prompt: true,
});
// Generate text
const output = await generator(inputs, { max_new_tokens: 20 });
console.log(output[0].generated_text);
// "<|system|>\nYou are a helpful assistant.\n<|user|>\nTell me a funny joke.\n<|assistant|>\nHere's a joke for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!" |
Adds support for
StableLMForCausalLM
StableLMForSequenceClassification
Closes #549