Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StableLM support #3586

Merged
merged 33 commits into from
Nov 14, 2023
Merged

StableLM support #3586

merged 33 commits into from
Nov 14, 2023

Conversation

Galunid
Copy link
Collaborator

@Galunid Galunid commented Oct 11, 2023

  • Converting model
  • Add warning when trying to convert .safetensors model
  • Support .safetensors model conversion
  • Loading model
  • GPU support - partial (offloading whole model runs into
    GGML_ASSERT: ggml-cuda.cu:6402: ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet")
  • Text generation
  • Coherent text generation
  • Verify gpt2 tokenizer produces the same results as GPTNeoxFast from transformers
  • Code style fixes
  • Investigate why sometimes run crashes with std::unordered_map
  • Push added_tokens fixes to conversion script
  • Convert .bin models too

closes #3456

@cebtenzzre
Copy link
Collaborator

Why warn for safetensors? The current version of the falcon script supports both formats; it's not hard to do.

@goerch
Copy link
Collaborator

goerch commented Oct 11, 2023

Why warn for safetensors?

He might not have tested it. Fine for me.

@Galunid : tokenizers are bad, did you check?

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 11, 2023

Why warn for safetensors? The current version of the falcon script supports both formats; it's not hard to do.

I wasn't aware we support it already and don't need to pull any new dependencies. I just took a look at convert-gptneox-hf-to-gguf.py. Model is shipped in .safetensors format, so it would be great if we can directly support it, without first converting it to pytorch_model.bin.

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 11, 2023

@Galunid : tokenizers are bad, did you check?

@goerch: No, not yet, right now I can load the model and it starts generating gibberish. I wanted to get something that runs first. I'm going to look at the tokenizer now. Could you confirm whether the output below looks like a tokenizer issue to you?

system_info: n_threads = 6 / 12 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1,100000, presence_penalty = 0,000000, frequency_penalty = 0,000000, top_k = 40, tfs_z = 1,000000, top_p = 0,950000, typical_p = 1,000000, temp = 0,800000, mirostat = 0, mirostat_lr = 0,100000, mirostat_ent = 5,000000
generate: n_ctx = 512, n_batch = 512, n_predict = -1, n_keep = 0


User: Test\nAssistant:ittinailresholdew่edsHs sendingwedgeformationsÄ"_ordingnai ØFIighsinkingnoseighs fir Hann Hann mountain gnTraceeldgger MeasurementighturchpossatsßEvaluationoseättFALSEaughsetto Ladenichte))org)ose knitmatterschenrettoseesteiked ted lightningoseasper Hed<|endoftext|> [end of text]

llama_print_timings:        load time =    1088,18 ms
llama_print_timings:      sample time =      59,66 ms /    60 runs   (    0,99 ms per token,  1005,68 tokens per second)
llama_print_timings: prompt eval time =     335,30 ms /     8 tokens (   41,91 ms per token,    23,86 tokens per second)
llama_print_timings:        eval time =    8119,81 ms /    59 runs   (  137,62 ms per token,     7,27 tokens per second)
llama_print_timings:       total time =    8576,62 ms
Log end

@mmnga
Copy link
Contributor

mmnga commented Oct 12, 2023

After a little modification, the output was correct.

make -j && ./main -m 'ggml-model-q4_0.gguf' -p "My Best Music is" -n 64
Selected 3b model!My Best Music is a lot of music is all best song Best Song Best Music Best music music music. Kind Music with the least musical and songs in me! Best S B The F,
Best Good and Music of our is buying was. video buy Us Born Birth The Best Buy Me Best For St will best and the. It.
llama_print_timings: load time = 227.01 ms

This model has parameters like llama, but the logic is GPT-NEOX.

https://huggingface.co/stabilityai/stablelm-3b-4e1t
Library: GPT-NeoX

add attn-norm bias
add ffn-norm bias
rope gpt-neox mode

    struct ggml_tensor * inpSA = inpL;

    // norm
    {
        cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
        offload_func(cur);
        ggml_set_name(cur, "rms_norm_0");

        // cur = cur*attn_norm(broadcasted)
        cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
        offload_func(cur);
        ggml_set_name(cur, "attention_norm_0");

        // add bias
        cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
        offload_func(cur);
        ggml_set_name(cur, "attention_norm_0");
    }


        // mode 2 (gpt-neox)
        struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale);
        offload_func_kq(Kcur);
        ggml_set_name(Kcur, "Kcur");

        // mode 2 (gpt-neox)
        struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale);
        offload_func_kq(Qcur);
        ggml_set_name(Qcur, "Qcur");

        // norm
        {
            cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
            offload_func(cur);
            ggml_set_name(cur, "rms_norm_1");

            // cur = cur*ffn_norm(broadcasted)
            cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
            offload_func(cur);
            ggml_set_name(cur, "ffn_norm");

            // add bias
            // cur = cur*ffn_norm(broadcasted)
            cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b);
            offload_func(cur);
            ggml_set_name(cur, "ffn_norm");

        }

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 12, 2023

Thanks, I realized I forgot to add biases, but I'd have to spend quite a while to realize rope was set wrong!

@niranjanakella
Copy link

@Galunid Hello, just wanted to ask if there was any update on this? were you able to convert the stablelm-3b-4e1t model to GGUF? Thank you.

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 17, 2023

@niranjanakella There's no update. I had family emergency and I didn't have time to touch code. I got back today and I'm planning to look at this more tomorrow. You can download converted model here. You will not get anything useful running this though, just some random nonsense.

@Green-Sky
Copy link
Collaborator

@Galunid how did you convert 4e1t, i dont see any safetensor code in the conversion script.

@Green-Sky
Copy link
Collaborator

Green-Sky commented Oct 17, 2023

I tried hacking in safetensor loading (by copying from falconconvert), but it fails with a size mismatch, idk, just dumping my diff here for reference.

diff --git a/convert-stablelm-hf-to-gguf.py b/convert-stablelm-hf-to-gguf.py
index 4a6fc66a..e163bb87 100755
--- a/convert-stablelm-hf-to-gguf.py
+++ b/convert-stablelm-hf-to-gguf.py
@@ -4,6 +4,7 @@
 from __future__ import annotations
 
 import argparse
+import contextlib
 import json
 import os
 import struct
@@ -20,17 +21,16 @@ if 'NO_LOCAL_GGUF' not in os.environ:
 import gguf
 
 
-def count_model_parts(dir_model: Path) -> int:
+def count_model_parts(dir_model: Path, prefix: str) -> int:
     num_parts = 0
     for filename in os.listdir(dir_model):
-        if filename.startswith("pytorch_model-"):
+        if filename.startswith(prefix):
             num_parts += 1
 
     if num_parts > 0:
         print("gguf: found " + str(num_parts) + " model parts")
     return num_parts
 
-
 def parse_args() -> argparse.Namespace:
     parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
     parser.add_argument(
@@ -80,10 +80,17 @@ with open(dir_model / "config.json", "r", encoding="utf-8") as f:
 if hparams["architectures"][0] != "StableLMEpochForCausalLM":
     print("Model architecture not supported: " + hparams["architectures"][0])
 
-    sys.exit()
+    sys.exit(1)
 
 # get number of model parts
-num_parts = count_model_parts(dir_model)
+#num_parts = count_model_parts(dir_model, "model-00")
+#if num_parts:
+num_parts = 0
+is_safetensors = True
+from safetensors import safe_open
+#else:
+    #is_safetensors = False
+    #num_parts = count_model_parts(dir_model, "pytorch_model-")
 
 ARCH=gguf.MODEL_ARCH.STABLELM
 gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
@@ -140,13 +147,20 @@ special_vocab.add_to_gguf(gguf_writer)
 # TENSORS
 
 tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
-print(tensor_map)
+#print(tensor_map)
 
 # tensor info
 print("gguf: get tensor metadata")
 
 if num_parts == 0:
-    part_names = iter(("pytorch_model.bin",))
+    if is_safetensors:
+        part_names = iter(("model.safetensors",))
+    else:
+        part_names = iter(("pytorch_model.bin",))
+elif is_safetensors:
+    part_names = (
+        f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
+    )
 else:
     part_names = (
         f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
@@ -156,47 +170,55 @@ for part_name in part_names:
     if args.vocab_only:
         break
     print("gguf: loading model part '" + part_name + "'")
-    model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+    if is_safetensors:
+        ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
+    else:
+        ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu"))
+
+    #model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+
+    with ctx as model_part:
+        for name in model_part.keys():
+            #data = model_part[name]
+            data = model_part.get_tensor(name) if is_safetensors else model_part[name]
 
-    for name in model_part.keys():
-        data = model_part[name]
 
-        # we don't need these
-        if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
-            continue
+            # we don't need these
+            if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
+                continue
 
-        old_dtype = data.dtype
+            old_dtype = data.dtype
 
-        # convert any unsupported data types to float32
-        if data.dtype != torch.float16 and data.dtype != torch.float32:
-            data = data.to(torch.float32)
+            # convert any unsupported data types to float32
+            if data.dtype != torch.float16 and data.dtype != torch.float32:
+                data = data.to(torch.float32)
 
-        data = data.squeeze().numpy()
+            data = data.squeeze().numpy()
 
-        # map tensor names
-        new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
-        if new_name is None:
-            print("Can not map tensor '" + name + "'")
-            sys.exit()
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
+            if new_name is None:
+                print("Can not map tensor '" + name + "'")
+                sys.exit()
 
-        n_dims = len(data.shape)
-        data_dtype = data.dtype
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
 
-        # if f32 desired, convert any float16 to float32
-        if ftype == 0 and data_dtype == np.float16:
-            data = data.astype(np.float32)
+            # if f32 desired, convert any float16 to float32
+            if ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
 
-        # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
-        if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
-            data = data.astype(np.float32)
+            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+                data = data.astype(np.float32)
 
-        # if f16 desired, convert any float32 2-dim weight tensors to float16
-        if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
-            data = data.astype(np.float16)
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
 
-        print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
+            print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
 
-        gguf_writer.add_tensor(new_name, data)
+            gguf_writer.add_tensor(new_name, data)
 
 
 print("gguf: write header")

the error:

blk.24.attn_k.weight, n_dims = 2, torch.bfloat16 --> float16
blk.24.attn_output.weight, n_dims = 2, torch.bfloat16 --> float16
blk.24.attn_q.weight, n_dims = 2, torch.bfloat16 --> float16
Traceback (most recent call last):
  File "/home/green/workspace/llama.cpp/./convert-stablelm-hf-to-gguf.py", line 221, in <module>
    gguf_writer.add_tensor(new_name, data)
  File "/home/green/workspace/llama.cpp/gguf-py/gguf/gguf.py", line 825, in add_tensor
    tensor.tofile(self.temp_file)
OSError: 6553600 requested and 5595136 written

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 17, 2023

I used transformers library to load model and save it as pytorch_model.bin.
Something like

from transformers import AutoModelForCausalLM

token = "<your token>"
model = AutoModelForCausalLM.from_pretrained(
  "stabilityai/stablelm-3b-4e1t",
  trust_remote_code=True,
  torch_dtype="auto",
  token=token
)
model.save_pretrained("output")

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 18, 2023

@Green-Sky I changed conversion script to use safetensors. It works for me, could you give it a try?
As for the error from yesterday, it looks like you're running out of disk space ;)

@Green-Sky
Copy link
Collaborator

@Green-Sky I changed conversion script to use safetensors. It works for me, could you give it a try? As for the error from yesterday, it looks like you're running out of disk space ;)

thanks for the update. I still have the same issue with the safetensors file. I also compared the sha256 of model.safetensors to the one on huggingface and they are a match. And no, i have >300gig on free space on the disk :)

blk.24.ffn_norm.bias, n_dims = 1, torch.bfloat16 --> float32
blk.24.ffn_norm.weight, n_dims = 1, torch.bfloat16 --> float32
blk.24.attn_k.weight, n_dims = 2, torch.bfloat16 --> float16
blk.24.attn_output.weight, n_dims = 2, torch.bfloat16 --> float16
blk.24.attn_q.weight, n_dims = 2, torch.bfloat16 --> float16
Traceback (most recent call last):
  File "/home/green/workspace/llama.cpp/./convert-stablelm-hf-to-gguf.py", line 182, in <module>
    gguf_writer.add_tensor(new_name, data)
  File "/home/green/workspace/llama.cpp/gguf-py/gguf/gguf.py", line 825, in add_tensor
    tensor.tofile(self.temp_file)
OSError: 6553600 requested and 5609472 written

@Green-Sky
Copy link
Collaborator

Green-Sky commented Oct 18, 2023

running in f32 mode the error is somewhat different and more explicit:

blk.15.ffn_down.weight, n_dims = 2, torch.bfloat16 --> float32
blk.15.ffn_gate.weight, n_dims = 2, torch.bfloat16 --> float32
Traceback (most recent call last):
  File "/home/green/workspace/llama.cpp/models/stablelm-3b-4e1t/../../convert-stablelm-hf-to-gguf.py", line 182, in <module>
    gguf_writer.add_tensor(new_name, data)
  File "/home/green/workspace/llama.cpp/models/stablelm-3b-4e1t/../../gguf-py/gguf/gguf.py", line 825, in add_tensor
    tensor.tofile(self.temp_file)
OSError: Not enough free space to write 70778880 bytes

how. where is this temp file located?

$ df -h
Filesystem      Size  Used Avail Use% Mounted on
devtmpfs        1.6G     0  1.6G   0% /dev
tmpfs            16G  140K   16G   1% /dev/shm
tmpfs           7.9G  7.0M  7.9G   1% /run
tmpfs            16G  460K   16G   1% /run/wrappers
/dev/nvme1n1p2  916G  514G  356G  60% /
/dev/nvme1n1p1  511M   85M  427M  17% /boot
tmpfs           3.2G   96K  3.2G   1% /run/user/1000

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Oct 18, 2023

how. where is this temp file located?

It's putting the temp file on tmpfs. You need to run the command with TMPDIR=/var/tmp (or any other folder on disk) before it so you can convert models that are larger than 50% of your system RAM. See #3433

@Green-Sky
Copy link
Collaborator

how. where is this temp file located?

It's putting the temp file on tmpfs. You need to run the command with TMPDIR=/var/tmp (or any other folder on disk) before it so you can convert models that are larger than 50% of your system RAM. See #3433

thanks, that worked.

here is the current state:

system_info: n_threads = 12 / 24 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = -1, n_keep = 0


The meaning of life is to be of Life for every
of life for me or existence in life
your life and life – life.ит life life life: life.
meaning Life life is the
life. world me being an meaning life in you life of life to your one life_ the of " life at. m of. (and. The. life to the see what This and life is' is" [. … of this a it/o her family as is it can life which is here the life.. life Mean me is Lifee death in my ors
 meaning is. earth finding person of p which you.[ m. life of. if search_? " " is by meaning is means. The is meaning ," but . said is mean
, me for meet it you after. to that the l as the women in he him
 a and w s.. you..:. Meaning You A. this people event " find an is men life. who which see having This. is you Of .. meaning place of its at. Me. being mess // The mean off someone is if or .
 “ " on mif is If" be has of the[ in and [| •[ , b me ; mina he there itself â find this / \\ ;. e will s � ] is f dr l set a . . I it . his for | you its that those is as. see if meaning get pm ” us // on ! .” the in sex but* // okot .orying  why .. what. & internet 1 is ,[ p and// c has? your please away the mâ\\\\& ;\\\\\\\\\\ym \\ her%1.|12ines%oiseg[ thep as never find/+ meet thi ] s the \ you f c she \terminate called after throwing an instance of 'std::out_of_range'
  what():  unordered_map::at
Aborted (core dumped)

looks like vocabulary is missing.

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 19, 2023

I think it's more of a model architecture being incorrectly implemented and model going nuts and "hallucinating" non existent tokens. I mostly copied the implementation from existing llama/gptneox ones and stitched them together, but I've seen that stability had their own modifications I haven't checked yet, so it'll probably be a pain. On a different note I verified that gpt2 tokenizer seems to be working as expected here (stablelm uses gptneox, but under the hood they seem identical, I'll have verify this). Initial results using GPTNeoxFast from transformers and llama.cpp's gpt2 produce the same output, so we should be good at least as a PoC

mmnga pushed a commit to mmnga/llama.cpp that referenced this pull request Oct 21, 2023
@Galunid
Copy link
Collaborator Author

Galunid commented Oct 21, 2023

I think I've found where the problem is:

Normalization: LayerNorm (Ba et al., 2016) with learned bias terms as opposed to RMSNorm (Zhang & Sennrich, 2019).

We use RMSNorm here, I'll look into this more tomorrow

@Galunid
Copy link
Collaborator Author

Galunid commented Oct 22, 2023

The author of Hamlet is The author of Hamlet is unknown. Although Shakespeare was known to have written the play, it wasn't until after his death in 1616 that his identity as its author became generally accepted by scholars; a theory advanced in the 17th century by Francis Meres, a playwright and poet who had been born near Stratford-upon-Avon, England. This was because Hamlet is so full of allusions to Shakespeare's plays and life that it would be very unusual for any other writer to have made them. Shakespeare wrote many different types of plays over his long career -- tragedies like Macbeth and Romeo and Juliet; comedies like Much Ado About Nothing and A Midsummer Night's Dream; and historical dramas such as Henry V, Richard III and Julius Caesar. The most famous of all is Hamlet, which Shakespeare probably wrote between 1600 and 1602, though critics debate this point. In the play Hamlet, Prince of Denmark, a Danish prince who was heir to the throne but was not allowed to marry his father's widow because she had been married before. After his uncle killed Hamlet’s father, Hamlet swore revenge on Claudius, Hamlet’s new stepfather. In the end, it is Hamlet, however, who kills himself in despair over his failure to avenge his father and restore order to Denmark. Shakespeare wrote Hamlet when he was around 30 years old, but it wasn't performed until more than 200 years after its composition because of censorship rules that prohibited performance of plays with religious or political themes.<|endoftext|> [end of text]
The best music is The best music is often found in the most unlikely places. This week on The Sound of Young America, we're bringing you some classic tracks that are all about finding love and romance...in the strangest of ways! We'll start with a song from one of our favorite bands - Death Cab For Cutie - who has an album out this month (which is also being released in a limited-ed
On this day humanity received a grim reminder trigger warning (mass shooting) On this day humanity received a grim reminder that it is not, and never will be, invincible. The world suffered the loss of thousands of lives due to one man’s hatred for people who don’t look like him or think like he does. It was an act so atrocious that it would make even the strongest person question their faith in humanity—and that is exactly what happened. After taking down a building, this sick individual set fire to himself with his own lighter and ended up burning alive on live television during the attack. What makes this all that more disturbing? The fact that it was supposed to be a terroristic attack, but due to the suspect’s poor aim, no one else died in the incident. This terrorist attack is now known as the worst act of terrorism ever committed by an individual on U.S soil and has been ranked the deadliest for a single attacker in world history—ever since it occurred. The man who started this tragedy was named Omar Mateen. He was born in New York to Afghan parents, but moved to Florida after his birth. His father is reportedly a former member of Afghanistan’s Communist Party and was considered an enemy to the U.S government when he came to America. This terrorist attack occurred at Pulse Nightclub, which is located in Orlando, Florida. The nightclub is very popular with people who identify as LGBTQ+—a group that is often discriminated against by religious extremists like Omar Mateen. The nightclub was hosting a Latin Night on the night of June 12 when this tragedy took place. It was open to both men and women, but all patrons were required to purchase admission tickets in order to enter. Mateen entered Pulse with an AR-15 semi-automatic rifle which he had purchased legally at a gun store less than a week prior. He also brought two handguns to the club during this attack—all of these weapons came from his father, who was unaware that his son would use them for such a horrific purpose. After entering Pulse, Omar Mateen began shooting people inside the nightclub and eventually set fire to the main entrance. The police arrived at the scene within minutes after the first shot rang out and confronted Mateen outside of the building—at which point he pulled out his own weapon and opened fire on them too! The man killed in this terrorist attack was named Eddie Justice, who worked as a security guard at Pulse nightclub during its opening hours. He was able to save lives by wrestling away Omar Mateen’s gun from him so that other people could escape without being harmed. During the attack, many patrons of the club were forced outside into a nearby parking lot where they hid under cars or stood on top of them in order not to be shot at from above—all while watching helplessly as their friends and family members inside continued to get killed one by one! The police used explosives to enter Pulse nightclub after Mateen’s gunfire had stopped. It was then that they found the bodies of 49 people who were murdered during this attack, which makes it second only in number to 9/11 when it comes down to deadliest terrorist attacks ever committed on U.S soil! After he died from his injuries, Omar Mateen left a note behind which contained quotes from Islamic holy texts as well as some personal messages directed towards America’s LGBTQ community… but most importantly—he mentioned that Allah had instructed him to commit this act of terror against them all in order for God’s will be done! Omar Mateen was only 29 years old when he committed suicide after being fatally shot by police officers during the Orlando shooting. He is reported to have been married with two children, and his wife still lives in Florida where they were previously living together before this tragedy occurred on June 12th 2016.<|endoftext|> [end of text]

I'd say not bad, it's still crashing sometimes and I'm not sure why.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still hitting a cuda assert with -ngl 35

I think you need to ggml_cont either qrot or krot or both since CUDA does not yet support non-contiguous rope. Not ideal, but maybe we can fix this later

Let's merge and support this from master

convert-hf-to-gguf.py Outdated Show resolved Hide resolved
@daaain
Copy link

daaain commented Nov 11, 2023

using the aforementioned preconverted model files and this branch, we can now use an even smaller vision model :)

Any idea why it doesn't stop output? Is it misconfigured control tokens in the (converted) model metadata?

@Galunid
Copy link
Collaborator Author

Galunid commented Nov 12, 2023

It looks like there's some missing ggml cuda functions (I think it's CONCAT), offloading <35 layers works fine. It fails on
GGML_ASSERT: ggml.c:14322: tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU. Guess we should just leave it as is and hopefully it'll be implemented in the future.

I'll merge master tomorrow and we can merge this then.

@Galunid
Copy link
Collaborator Author

Galunid commented Nov 13, 2023

@Green-Sky Would you mind giving it one final look? If all is good feel free to merge (if CI runs green) :)

Copy link
Collaborator

@Green-Sky Green-Sky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reconverted the model, re-ran some perplexities - values slightly differ, bit within variance. Looks pretty good.

the ci oracle gives the green light :)

@Galunid Galunid merged commit 36eed0c into ggerganov:master Nov 14, 2023
32 checks passed
@Galunid Galunid deleted the stablelm-support branch November 14, 2023 10:17
@Green-Sky
Copy link
Collaborator

I was toying around with it a bit more and realized its very slow now.

llama-bench on master(36eed0c)

stablelm-3b-4e1t

$ llama-bench -ngl 34 -m models/stablelm-3b-4e1t/ggml-model-f16.gguf -m models/stablelm-3b-4e1t/ggml-model-Q8_0.gguf -m models/stablelm-3b-4e1t/ggml-model-Q4_0.gguf -m models/stablelm-3b-4e1t/ggml-model-Q6_K.gguf -m models/stablelm-3b-4e1t/ggml-model-Q3_K_M.gguf
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5
model size params backend ngl test master t/s be2ac38 t/s
stablelm 3B mostly F16 5.21 GiB 2.80 B CUDA 34 pp 512 341.39 ± 7.78 457.85 ± 3.54
stablelm 3B mostly F16 5.21 GiB 2.80 B CUDA 34 tg 128 41.62 ± 1.02 42.91 ± 0.35
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 pp 512 204.19 ± 3.34 289.87 ± 14.80
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 tg 128 60.15 ± 1.14 62.53 ± 0.18
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 pp 512 322.43 ± 9.71 303.46 ± 7.25
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 tg 128 84.48 ± 0.91 85.05 ± 0.92
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 pp 512 332.80 ± 6.27 382.22 ± 8.95
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 tg 128 63.43 ± 0.30 64.63 ± 0.21
stablelm 3B mostly Q3_K_M 1.29 GiB 2.80 B CUDA 34 pp 512 101.37 ± 1.20 107.47 ± 1.52
stablelm 3B mostly Q3_K_M 1.29 GiB 2.80 B CUDA 34 tg 128 70.83 ± 0.62 71.78 ± 0.62

TinyLlama-1.1B

$ llama-bench -ngl 34 -m models/TinyLlama-1.1B-intermediate-step-240k-503b/ggml-model-F16.gguf -m models/TinyLlama-1.1B-intermediate-step-240k-503b/ggml-model-Q8_0.gguf -m models/TinyLlama-1.1B-intermediate-step-240k-503b/ggml-model-Q4_0.gguf -m models/TinyLlama-1.1B-intermediate-step-240k-503b/ggml-model-Q6_K.gguf -m models/TinyLlama-1.1B-intermediate-step-240k-503b/ggml-model-Q3_K_M.gguf
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5
model size params backend ngl test t/s
llama ?B mostly F16 1.98 GiB 1.10 B CUDA 34 pp 512 6114.76 ± 385.59
llama ?B mostly F16 1.98 GiB 1.10 B CUDA 34 tg 128 124.62 ± 0.73
llama ?B mostly Q8_0 1.09 GiB 1.10 B CUDA 34 pp 512 4907.85 ± 115.38
llama ?B mostly Q8_0 1.09 GiB 1.10 B CUDA 34 tg 128 181.02 ± 1.81
llama ?B mostly Q4_0 606.53 MiB 1.10 B CUDA 34 pp 512 5022.45 ± 81.78
llama ?B mostly Q4_0 606.53 MiB 1.10 B CUDA 34 tg 128 257.43 ± 1.02
llama ?B mostly Q6_K 860.86 MiB 1.10 B CUDA 34 pp 512 5523.69 ± 137.63
llama ?B mostly Q6_K 860.86 MiB 1.10 B CUDA 34 tg 128 182.86 ± 0.76
llama ?B mostly Q3_K - Medium 523.67 MiB 1.10 B CUDA 34 pp 512 5471.36 ± 90.28
llama ?B mostly Q3_K - Medium 523.67 MiB 1.10 B CUDA 34 tg 128 199.72 ± 1.25

I know comparing 1.1B to 2.8B is not very fair, but the token generation speed seems to scale properly. Or am I just looking at the difference in architecture (mostly GQA)?

next problem

I re-ran the benchmark on master.... and now the prompt processing speeds are vastly inferior...
So I cant trust my own numbers...

$ llama-bench -ngl 34 -m models/stablelm-3b-4e1t/ggml-model-f16.gguf -m models/stablelm-3b-4e1t/ggml-model-Q8_0.gguf -m models/stablelm-3b-4e1t/ggml-model-Q4_0.gguf -m models/stablelm-3b-4e1t/ggml-model-Q6_K.gguf -m models/stablelm-3b-4e1t/ggml-model-Q3_K_M.gguf
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5
model size params backend ngl test t/s
stablelm 3B mostly F16 5.21 GiB 2.80 B CUDA 34 pp 512 100.38 ± 0.41
stablelm 3B mostly F16 5.21 GiB 2.80 B CUDA 34 tg 128 43.44 ± 0.14
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 pp 512 145.05 ± 2.77
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 tg 128 61.39 ± 1.48
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 pp 512 147.17 ± 2.62
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 tg 128 85.65 ± 0.58
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 pp 512 241.34 ± 2.15
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 tg 128 64.47 ± 0.39
stablelm 3B mostly Q3_K - Medium 1.29 GiB 2.80 B CUDA 34 pp 512 153.08 ± 1.14
stablelm 3B mostly Q3_K - Medium 1.29 GiB 2.80 B CUDA 34 tg 128 71.97 ± 0.41

... now I have no idea what is going on q.q

@Galunid
Copy link
Collaborator Author

Galunid commented Nov 14, 2023

Could you try 6be3356, it should be faster. Right now there are more matrix operations we need to compute (for rope). Previously we just computed rope and on a whole tensor, right now we need to split tensor into two halves (the one roped, and the one not roped). Calculate rope for the first one, concatenate both (which also requires permuting both tensors) at the end. We also use ggml_cont, which should also add a bit of an overhead.

@Green-Sky
Copy link
Collaborator

did some more testing on master, and turns out that watching youtube and moving windows around has a positive correlation to both performance AND variance ....

model size params backend ngl test t/s
stablelm 3B mostly F16 (guessed) 5.21 GiB 2.80 B CUDA 34 pp 512 344.08 ± 8.68
stablelm 3B mostly F16 (guessed) 5.21 GiB 2.80 B CUDA 34 tg 128 40.90 ± 0.20
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 pp 512 332.28 ± 5.05
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 tg 128 59.24 ± 0.25
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 pp 512 103.72 ± 1.37
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 tg 128 80.40 ± 0.91
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 pp 512 104.29 ± 0.51
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 tg 128 60.75 ± 0.56
stablelm 3B mostly Q3_K - Medium 1.29 GiB 2.80 B CUDA 34 pp 512 98.88 ± 1.62
stablelm 3B mostly Q3_K - Medium 1.29 GiB 2.80 B CUDA 34 tg 128 68.08 ± 0.68

sooo... my system seems to be unsuited for llama.cpp / benchmarking.

@Green-Sky
Copy link
Collaborator

Oh yea, 6be3356 is significantly faster:

model size params backend ngl test t/s
stablelm 3B mostly F16 (guessed) 5.21 GiB 2.80 B CUDA 34 pp 512 935.57 ± 14.04
stablelm 3B mostly F16 (guessed) 5.21 GiB 2.80 B CUDA 34 tg 128 44.55 ± 0.26
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 pp 512 856.02 ± 7.24
stablelm 3B mostly Q8_0 2.77 GiB 2.80 B CUDA 34 tg 128 67.80 ± 0.72
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 pp 512 844.57 ± 15.42
stablelm 3B mostly Q4_0 1.50 GiB 2.80 B CUDA 34 tg 128 95.70 ± 1.11
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 pp 512 879.32 ± 23.63
stablelm 3B mostly Q6_K 2.14 GiB 2.80 B CUDA 34 tg 128 69.02 ± 0.25
stablelm 3B mostly Q3_K - Medium 1.29 GiB 2.80 B CUDA 34 pp 512 893.89 ± 10.10
stablelm 3B mostly Q3_K - Medium 1.29 GiB 2.80 B CUDA 34 tg 128 77.37 ± 0.29

@maddes8cht
Copy link
Contributor

maddes8cht commented Nov 15, 2023

Thanks for making this happen!

There is now a growing list of many of the mentioned finetuned Models converted to gguf format on this Huggingface Collection

Some of the mentioned Models will not convert at all, or produce incorrect files. Do you want a list of them?

@Galunid
Copy link
Collaborator Author

Galunid commented Nov 15, 2023

Yes, I'll take a look if you could list them

Comment on lines +4702 to +4786
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(tmpq, "tmpq", il);

struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(tmpk, "tmpk", il);

struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);

// RoPE the first n_rot of q/k, pass the other half, and concat.
struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d(
ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head,
0
));
cb(qrot, "qrot", il);

struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d(
ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head_kv,
0
));
cb(krot, "krot", il);

// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
struct ggml_tensor * qpass = ggml_view_3d(
ctx0, tmpq, (n_embd_head - hparams.n_rot), n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head,
ggml_element_size(tmpq) * hparams.n_rot
);
cb(qpass, "qpass", il);

struct ggml_tensor * kpass = ggml_view_3d(
ctx0, tmpk, (n_embd_head - hparams.n_rot), n_head_kv, n_tokens,
ggml_element_size(tmpk) * (n_embd_head),
ggml_element_size(tmpk) * (n_embd_head) * n_head_kv,
ggml_element_size(tmpk) * hparams.n_rot
);
cb(kpass, "kpass", il);

struct ggml_tensor * qrotated = ggml_rope_custom(
ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(qrotated, "qrotated", il);

struct ggml_tensor * krotated = ggml_rope_custom(
ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(krotated, "krotated", il);

// ggml currently only supports concatenation on dim=2
// so we need to permute qrot, qpass, concat, then permute back.
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
cb(qrotated, "qrotated", il);

krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
cb(krotated, "krotated", il);

qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
cb(qpass, "qpass", il);

kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
cb(kpass, "kpass", il);

struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
cb(Qcur, "Qcur", il);

struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
cb(Kcur, "Kcur", il);

struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3));
cb(Q, "Q", il);

Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Persimmon, we should look into simplifying this.
Either we introduce some custom operation, or extend rope to support this kind of cases. Or if necessary, we can prepare the model data upon conversion to be more friendly to ggml ops.

Copy link
Collaborator Author

@Galunid Galunid Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could revert a371a8b. I think that's the most sensible, since it requires just one extra rope implementation, compared to multiple operations we need to implement now (similar to persimmon). This should also allow for simplifying persimmon.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if necessary, we can prepare the model data upon conversion to be more friendly to ggml ops.

The disadvantage of doing that is that it makes harder to convert LoRAs from HF, since the tensors no longer match, and we would need to apply the same conversions to the LoRAs (if that's possible at all). Related: #3519

KerfuffleV2 pushed a commit to KerfuffleV2/llama.cpp that referenced this pull request Nov 17, 2023
* Add support for stablelm-3b-4e1t
* Supports GPU offloading of (n-1) layers
olexiyb pushed a commit to Sanctum-AI/llama.cpp that referenced this pull request Nov 23, 2023
* Add support for stablelm-3b-4e1t
* Supports GPU offloading of (n-1) layers
@husnoo
Copy link

husnoo commented Dec 20, 2023

using the aforementioned preconverted model files and this branch, we can now use an even smaller vision model :)

Any idea why it doesn't stop output? Is it misconfigured control tokens in the (converted) model metadata?

Did you make any progress with this? It blabbers on.

@daaain
Copy link

daaain commented Dec 22, 2023

Any idea why it doesn't stop output? Is it misconfigured control tokens in the (converted) model metadata?

Did you make any progress with this? It blabbers on.

Seems to have stopped doing that now, using nisten/obsidian-3b-multimodal-q6-gguf/obsidian-q6.gguf

If anything, it's very terse now, can't get it to respond with more than a sentence 🤷

image

@husnoo
Copy link

husnoo commented Dec 22, 2023 via email

@husnoo
Copy link

husnoo commented Dec 22, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[User] How to convert Stability 3B model to ggml/ggul