Two weeks ago, we released GGUF formats of jina-embeddings-v4 - a universal embedding model for multimodal multilingual retrieval - with various quantized versions. Our motivation was simple: as a 3.75B parameter model, the vanilla transformer version of jina-embeddings-v4 doesn't scale well on our GCP G2 (L4 GPU) API instances, so we wanted to speed up inference using these smaller, faster GGUF versions. During our experiments, we discovered some interesting findings while converting and running GGUF embedding models. Since most of the llama.cpp
community focuses on LLMs, we thought it'd be valuable to share this from an embedding provider's perspective.
What's particularly relevant is that today's embedding models are almost identical to LLMs - for example, jina-embeddings-v4 is based on Qwen2.5-VL-3B-instruct
and jina-reranker-m0 is based on Qwen2-VL-2B
. The only real difference is the output: LLMs are generative, the embeddings and rerankers are discriminative. This creates both opportunities and challenges: on one hand, we can leverage llama.cpp
's efficient implementation (e.g. ubatch_size
) to serve embedding/reranker models; on the other hand, llama.cpp
's embedding implementations were mostly developed for older encoder-only architectures (like RoBERTa-based models) and haven't fully caught up with modern decoder-only embedding/reranker models. This article shares what we learned while adapting modern embedding models to work with GGUF format and llama.cpp
tooling, e.g. llama-embedding
and llama-serving
.
tagBase and Quantized GGUFs
jina-embeddings-v4 is based on Qwen2.5-VL-3B-instruct
with three LoRA adapters: retrieval
(optimized for retrieval tasks), text-matching
(optimized for sentence similarity tasks), and code
(optimized for code retrieval tasks). It is also heavily trained for visual document retrieval and late-interaction style multi-vector output. So the idea here is to leverage llama.cpp
existing graph implementation of Qwen2.5-VL-3B
and use llama-embedding
for the inference.
However, the first thing we noticed was buggy behavior in the mmproj
or vision transformer implementation in llama.cpp, which yields different embeddings w.r.t. the torch implementation of Qwen2.5-VL-3B
given the same image input. While we're fixing this issue in our fork, we decided to exclude vision tower from the GGUF versions for now. You can find more details about this discussion here.
Multi-vector embedding output isn't supported out of the box either, but it's not as big issue as the vision transformers. Multi-vector output comes from a trained MLP at the last transformer block, so at worst we can always export this MLP separately to numpy and apply it after getting the token-level embeddings from llama.cpp
- which is what we did for jina-reranker-m0-GGUF
. Sure, it's not very efficient, but it works without having to modify and recompile llama.cpp
.
We stripped out the vision transformer and the multi-vector projector and got three base GGUF models in F16.
So to fully comply with llama.cpp
's existing Qwen2.5-VL-3B
graph implementation, we stripped out the vision transformer and the multi-vector projector at the last transformer block and merged all LoRA adapters back into the base language model. This gave us three task-specific v4 models at 3.09B parameters each - down from the original v4's 3.75B parameters:
HuggingFace Repo | Task |
---|---|
jinaai/jina-embeddings-v4-text-retrieval-GGUF |
Text retrieval |
jinaai/jina-embeddings-v4-text-code-GGUF |
Code retrieval |
jinaai/jina-embeddings-v4-text-matching-GGUF |
Sentence similarity |
Then we used calibration_data_v5_rc.txt
(which can be found here and is recommended by Unsloth) to calibrated all three base GGUF models and got three imatrix
files, then used llama-quantize
with imatrix
to quantize the models from float16 as follows:
# build imatrix
llama-imatrix -m jina-embeddings-v4-text-retrieval-F16.gguf -f calibration_data_v5_rc.txt -ngl 99 --no-ppl -o imatrix-retrieval-512.dat
# quantize
./quantize.sh jina-embeddings-v4-text-retrieval-F16.gguf retrieval-i3 imatrix-retrieval-512.dat jinaai/jina-embeddings-v4-text-retrieval-GGUF
The quantize.sh
script is shown below:
#!/bin/bash
F16_MODEL_FILE="$1"
OUTPUT_DIR="$2"
IMATRIX="$3"
HF_REPO="$4"
FILENAME="$(basename "$F16_MODEL_FILE")"
BASE_NAME="${FILENAME%-F16.gguf}"
BASE_NAME="${BASE_NAME%.gguf}"
mkdir -p "$OUTPUT_DIR"
# Array of quantization types
QUANT_TYPES=("IQ1_S" "IQ1_M" "IQ2_XXS" "IQ2_M" "Q2_K" "IQ4_NL" "IQ4_XS" "IQ3_XXS" "IQ3_S" "IQ3_M" "IQ3_XS" "Q3_K_M" "Q4_K_M" "Q5_K_S" "Q5_K_M" "Q6_K" "Q8_0")
for quant_type in "${QUANT_TYPES[@]}"; do
llama-quantize --imatrix "${IMATRIX}" "$F16_MODEL_FILE" "${OUTPUT_DIR}/${BASE_NAME}-${quant_type}.gguf" $quant_type 8
done
Eventually, we uploaded all quantizations to HuggingFace.
Quantization | BPW | File Size (GB) |
---|---|---|
IQ1_S | 2.04 | 0.73 |
IQ1_M | 2.19 | 0.79 |
IQ2_XXS | 2.44 | 0.88 |
IQ2_M | 2.94 | 1.06 |
Q2_K | 3.29 | 1.18 |
IQ3_XXS | 3.31 | 1.19 |
IQ3_XS | 3.59 | 1.29 |
IQ3_S | 3.76 | 1.35 |
IQ3_M | 3.84 | 1.38 |
Q3_K_M | 4.11 | 1.48 |
IQ4_NL | 4.72 | 1.69 |
IQ4_XS | 4.49 | 1.61 |
Q4_K_M | 4.99 | 1.79 |
Q5_K_S | 5.61 | 2.02 |
Q5_K_M | 5.75 | 2.07 |
Q6_K | 6.56 | 2.36 |
Q8_0 | 8.50 | 3.05 |
F16 | 16.00 | 5.75 |
v3 (Transformers) | 16.00 | 1.10 |
v4 (Transformers) | 16.00 | 7.40 |

tagUsage and Caveats
We can now use llama-server
and llama-embedding
to serve GGUFs for embedding. Unlike transformer libraries where we have the flexibility to write custom input preprocessing code, we have to handle this part manually (unless we want to recompile llama-server
and llama-embedding
). Specifically, to get results that are fully consistent with using AutoModel.from_pretrained("jinaai/jina-embeddings-v4")...
, you need to be very careful about prefixes and manually add them to your GGUF model inputs. Here's a reference table:
Task | prompt_name in Transformer implementation |
Actual input to the model |
---|---|---|
retrieval |
query (default) |
Query: {original_text} |
retrieval |
passage |
Passage: {original_text} |
text-matching |
query (default) |
Query: {original_text} |
text-matching |
passage |
Query: {original_text} ⚠️ |
code |
query (default) |
Query: {original_text} |
code |
passage |
Passage: {original_text} |
Some users might find ⚠️ surprising that prompt_name='passage'
gets overridden to "Query: "
when using text-matching
in the original AutoModel.from_pretrained("jinaai/jina-embeddings-v4")....
But this actually makes sense since text-matching
is a sentence similarity task with no left/right roles—the inputs are symmetric.
tagVia llama-server
After installing llama.cpp
, you can run llama-server
to host the embedding model as an OpenAI API-compatible HTTP server. For example, to use text-matching
with F16
, you can do:
llama-server -hf jinaai/jina-embeddings-v4-text-matching-GGUF:F16 --embedding --pooling mean -ub 8192
--pooling mean
is required as v4 is mean-pooling embeddings.
Then send request via:
curl -X POST "http://127.0.0.1:8080/v1/embeddings" \
-H "Content-Type: application/json" \
-d '{
"input": [
"Query: A beautiful sunset over the beach",
"Query: Un beau coucher de soleil sur la plage",
"Query: 海滩上美丽的日落",
"Query: 浜辺に沈む美しい夕日"
]
}'
When using retrieval
and code
models, add Query:
or Passage:
in front of your input, like this:
curl -X POST "http://127.0.0.1:8080/v1/embeddings" \
-H "Content-Type: application/json" \
-d '{
"input": [
"Query: A beautiful sunset over the beach",
"Query: Un beau coucher de soleil sur la plage",
"Passage: 海滩上美丽的日落",
"Passage: 浜辺に沈む美しい夕日"
]
}'
tagVia llama-embedding
For a quick sanity check, you can also use the precompiled llama-embedding
for one-shot embedding. We don't recommend using it for bulk embedding since it has some performance issues that we'll discuss in the next section:
llama-embedding -hf jinaai/jina-embeddings-v4-text-matching-GGUF:F16 --pooling mean -p "Query: jina is awesome" --embd-output-format json 2>/dev/null
Read the next section for more performant bulk embedding with our build of llama-embedding
with some fix and improvements.
tagSummary of Caveats
Before moving to a more performant implementation, let's summarize the caveats of GGUF models:
- You must manually add
Query:
orPassage:
in front of the text inputs. - They can't handle image input right now because we removed the vision transformers from the GGUF model. We had to remove them due to bugs in llama.cpp's vision transformer/
mmproj
implementation ofQwen2.5-vl-3b
, which we're working to fix with upstream. - They can't output multi-vector embeddings since it's not part of
llama.cpp
'sQwen2.5-vl-3b
graph implementation. The easiest workaround without recompilingllama.cpp
is to export and run the MLP separately after getting token-level embeddings by setting--pooling none
inllama-embedding
. - v4 is trained with Matryoshka representation learning, and converting to GGUF preserves this feature. If you get embeddings with shape
NxD
, you can simply useembeddings[:, :truncate_dim]
to get smaller truncated embeddings. However, not every dimension is trained. For v4, we trainedtruncate_dim
for these specific values:[128, 256, 512, 1024, 2048]
. This meansembeddings[:, :131]
quality won't be some interpolation between the quality ofembeddings[:, :128]
andembeddings[:, :256]
, but will be significantly worse than either 128-dim or 256-dim embeddings because 131-dim is not trained. - Late chunking can still work as part of post-processing after getting token-level embeddings via
--pooling none
. Just like what we did with separating the MLP from thellama.cpp
graph, this isn't super efficient but doesn't require recompiling. However, there's another caveat: since v4 is a causal model, late chunking won't be bidirectional anymore - earlier chunk embeddings won't contain contextual information from subsequent chunks. Remember that in v3, every chunk embedding had global context information because we used bidirectional attention masks in the transformer blocks. Internally, we discussed whether causality makes late chunking obsolete: some argue that "context is also causal" - meaning a reader processes text from left to right, so the context needed to interpret a sentence should come from preceding text. Others say restricting late chunking to be unidirectional blocks context sharing between chunks. Either way, the effectiveness of late chunking in v4 remains questionable and needs further study.
tagEfficient Embedding via llama-embedding
llama-embedding
is a relatively simple C++ wrapper on top of llama.cpp
for embedding text with very clean I/O: stdin, stdout. We're focusing on improving this rather than llama-server
right now because there are tons of other problems like network queuing, load balancing, multi-tenancy, and serialization that we believe are out of scope at this point. Our question is straightforward: how much speed can we get from an L4 24GB GPU, and what's the peak VRAM usage for embedding long documents?
But why L4? Mainly because GCP offers pretty convenient Cloud Run functions on top of it, and it's the most widely available and economical GPU type you can get for serveless inference APIs. GCP does offer A100 and H100 on Cloud Run at request, and we do get pitched by GCP team from time to time to use better GPUs. But our philosophy is simple: if we need A100/H100 to serve a 3B model, that's clearly a skill issue on our part.
For some background, in llama.cpp, logical batch size (-b
) represents the maximum number of tokens submitted to the model in a single evaluation call. When processing long inputs, they are split into chunks up to this size. The physical batch size (-ub
) is the actual number of tokens processed simultaneously in one forward pass through the hardware, constrained by available memory. The KV-cache updates after each physical batch completes. The Context Window (-c
) is the hard limit for how many tokens the model can "see" at once - for v4 models this is 32,000 tokens, representing the model's maximum attention span. All tokens must fit within this context window to maintain coherent attention across the entire sequence. The following figure illustrates their relationships.

tagOur Fixes
In our fork above, we made several optimizations to make llama-embedding
more efficient:
- Simplified batch handling: We automatically set
-b
equal to-c
, effectively making this parameter obsolete. Users no longer need to specify-b
since we always leverage the model's full context length for logical batching. - Flexible memory control: Unlike the original implementation where
-ub
was forced to equal-b
(since they assumed embedding models can't be causal), we allow users to independently set-ub
. This gives fine-grained control over peak VRAM usage when encoding long contexts - you can process a 32K context with a small 512-token physical batch to stay within VRAM limits thanks to the KV cache implementation. Note that this change is only correct for causal embedding models like jina-embeddings-v4 - for encoder-only architectures like v3, this would be the wrong implementation. - Fixed mean pooling: We corrected the mean pooling calculation for embeddings when
ub < b
, which was previously broken in the original implementation.
This change makes it much easier to work with long-context decoder-only embedding models while managing memory constraints effectively. Users now only need to configure two parameters:
-c
: The maximum context length (how many tokens the embedding model can process)-ub
: The physical batch size (how many tokens the GPU processes at once)
So the exact code for running our fork on L4 is as follows:
# Compile
git clone https://github.com/hanxiao/llama.cpp.git
cd llama.cpp
cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release -j 8
# Run
INPUT_PREFIX="Query: " # or "Passage: "
cat big_input.txt | sed "s/^/${INPUT_PREFIX}/" | \
./llama.cpp/build/bin/llama-embedding -f /dev/stdin \
-hf "jinaai/jina-embeddings-v4-text-retrieval-GGUF:FP16" \
--pooling mean \
--no-escape \
--embd-output-format array \
--ubatch-size 512 \
--ctx-size 8192 \
--flash-attn \
-ngl 99 \
> "embeddings.txt" 2> "error.log"
Each line in big_input.txt
is a sentence to be embedded. --no-escape
should be set to prevent in-sentence \n
from being interpreted as separators. --flash-attn
and -ngl 99
should be set for best performance on L4 GPU.
tagBenchmark
We want to understand the following questions through benchmarking:
- How good our quantization is compared to the original v4 Float16? At what point does it degrade so much that we'd be better off just using v3 embeddings?
- How fast can each quantization run on L4, and what's the peak VRAM usage?
- How do
-ub
physical batch size and-c
context length affect speed and peak VRAM?
Datasets we used in benchmarking are:
Task | Documents | Queries | Relevant Pairs | Avg Doc Length | Max Doc Length | Avg Query Length |
---|---|---|---|---|---|---|
NanoHotpotQA | 5,090 | 50 | 100 | 57.3 | 345 | 14.9 |
NanoSciFact | 2,919 | 50 | 56 | 205.8 | 1524 | 13.5 |
NanoArguAna | 3,635 | 50 | 50 | 164.5 | 1058 | 193.0 |
NanoNFCorpus | 2,953 | 50 | 2,518 | 223.3 | 1460 | 3.3 |
NanoFiQA2018 | 4,598 | 50 | 123 | 159.1 | 1882 | 10.2 |
We used our custom build of llama-embedding
for benchmarking.
tagQuality of Quantizations
The best performing quantized version is IQ3_M
(3.84 BPW) - quantizations below 2 bits perform worse than v3, so there's little point in using them.

Quantization | NanoHotpotQA | NanoFiQA2018 | NanoArguAna | NanoNFCorpus | NanoSciFact |
---|---|---|---|---|---|
IQ1_S | 0.6369 | 0.3178 | 0.3798 | 0.2933 | 0.5934 |
IQ1_M | 0.6316 | 0.3313 | 0.5167 | 0.3256 | 0.6114 |
IQ2_XXS | 0.7236 | 0.4582 | 0.4584 | 0.4067 | 0.7392 |
IQ2_M | 0.7427 | 0.5869 | 0.5090 | 0.4468 | 0.7880 |
Q2_K | 0.7683 | 0.5744 | 0.5168 | 0.4183 | 0.7546 |
IQ3_XXS | 0.7780 | 0.5991 | 0.4811 | 0.4267 | 0.7610 |
IQ3_XS | 0.7727 | 0.5615 | 0.5195 | 0.4439 | 0.7726 |
IQ3_S | 0.8002 | 0.5505 | 0.4886 | 0.4381 | 0.7690 |
IQ3_M | 0.8106 | 0.5387 | 0.5091 | 0.4462 | 0.7760 |
Q3_K_M | 0.7567 | 0.5267 | 0.4486 | 0.4092 | 0.7775 |
IQ4_NL | 0.7930 | 0.5598 | 0.4911 | 0.4285 | 0.7794 |
IQ4_XS | 0.7979 | 0.5627 | 0.4947 | 0.4258 | 0.7789 |
Q4_K_M | 0.8029 | 0.5569 | 0.4883 | 0.4226 | 0.7877 |
Q5_K_S | 0.7969 | 0.5581 | 0.4721 | 0.4288 | 0.7842 |
Q5_K_M | 0.7927 | 0.5601 | 0.4745 | 0.4247 | 0.7873 |
Q6_K | 0.7951 | 0.5636 | 0.4822 | 0.4337 | 0.7846 |
Q8_0 | 0.7938 | 0.5687 | 0.4784 | 0.4335 | 0.7851 |
F16 | 0.7940 | 0.5610 | 0.4931 | 0.4343 | 0.7963 |
v3 (Transformers) | 0.7393 | 0.5144 | 0.4600 | 0.4068 | 0.7820 |
v4 (Transformers) | 0.7977 | 0.5571 | 0.4844 | 0.4351 | 0.7963 |
tagSpeed and VRAM
We now fix the benchmark dataset to NanoHotpotQA and plot all quantizations by their bits per weight versus speed (measured in tokens per second) and VRAM consumption. We found that GGUF versions are slightly faster than the vanilla version at FP16 (2023 vs 1865 tokens/sec). Most quantizations cluster around 2000-2100 tokens/sec. With flash attention enabled, we get ~77% speedup across all quantizations (3000+ vs 2000+ tokens/sec). However, the best performing quantization Q8_0
at around 3700 tokens per second is still far behind vanilla v3 (572M parameters), which hits 16000 tokens/sec. Quantized versions save considerable VRAM and nearly approach the level of the v3 FP16 model with IQ3.

Quantization | BPW | File Size (GB) | Peak VRAM (GB) | Token/s w FA | Token/s w/o FA |
---|---|---|---|---|---|
IQ1_S | 2.04 | 0.73 | 4.04 | 3625 | 2050 |
IQ1_M | 2.19 | 0.79 | 4.09 | 3349 | 1997 |
IQ2_XXS | 2.44 | 0.88 | 4.19 | 3701 | 2071 |
IQ2_M | 2.94 | 1.06 | 4.37 | 3407 | 1989 |
Q2_K | 3.29 | 1.18 | 4.49 | 3173 | 1905 |
IQ3_XXS | 3.31 | 1.19 | 4.50 | 3668 | 2067 |
IQ3_XS | 3.59 | 1.29 | 4.60 | 3604 | 2053 |
IQ3_S | 3.76 | 1.35 | 4.66 | 3599 | 2049 |
IQ3_M | 3.84 | 1.38 | 4.69 | 3603 | 2053 |
Q3_K_M | 4.11 | 1.48 | 4.78 | 3450 | 2008 |
IQ4_NL | 4.72 | 1.69 | 5.00 | 3571 | 2039 |
IQ4_XS | 4.49 | 1.61 | 4.92 | 3585 | 2046 |
Q4_K_M | 4.99 | 1.79 | 5.10 | 3558 | 2045 |
Q5_K_S | 5.61 | 2.02 | 5.32 | 3567 | 2044 |
Q5_K_M | 5.75 | 2.07 | 5.38 | 3528 | 2034 |
Q6_K | 6.56 | 2.36 | 5.66 | 3334 | 1981 |
Q8_0 | 8.50 | 3.05 | 6.36 | 3767 | 2101 |
F16 | 16.00 | 5.75 | 9.70 | 3399 | 2023 |
v3 (Transformers) | 16.00 | 1.10 | 2.82 | 16505 | |
v4 (Transformers) | 16.00 | 7.40 | 14.45 | 1865 |
Click to expand system info
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 36 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 37/37 layers to GPU
load_tensors: CUDA0 model buffer size = 3127.61 MiB
load_tensors: CPU_Mapped model buffer size = 315.30 MiB
...................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch = 4096
llama_context: n_ubatch = 4096
llama_context: causal_attn = 1
llama_context: flash_attn = 1 // 1 for w/ FA in the table; 0 for w/o FA
llama_context: kv_unified = true
llama_context: freq_base = 1000000.0
llama_context: freq_scale = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (128000) -- the full capacity of the model will not be utilized
llama_context: CUDA_Host output buffer size = 0.59 MiB
llama_kv_cache_unified: CUDA0 KV buffer size = 144.00 MiB
llama_kv_cache_unified: size = 144.00 MiB ( 4096 cells, 36 layers, 1/1 seqs), K (f16): 72.00 MiB, V (f16): 72.00 MiB
llama_context: CUDA0 compute buffer size = 2470.16 MiB
llama_context: CUDA_Host compute buffer size = 96.17 MiB
llama_context: graph nodes = 1234
llama_context: graph splits = 2
common_init_from_params: added <|endoftext|> logit bias = -inf
common_init_from_params: added <|im_end|> logit bias = -inf
common_init_from_params: added <|fim_pad|> logit bias = -inf
common_init_from_params: added <|repo_name|> logit bias = -inf
common_init_from_params: added <|file_sep|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
system_info: n_threads = 4 (n_threads_batch = 4) / 8 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VNNI = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |
main: n_tokens in batch = 0
main: number of embeddings = 5090
tagOptimal Physical Batch and Context Size
We now fix the quantization type to IQ3_S
and examine how physical batch size (-ub
) and context size (-c
) affect speed and VRAM. The results on L4 GPU show that -ub=512
with -c=2048
provides the optimal configuration, delivering 4,143 tokens/sec while using 2,025MB VRAM. The takeaway is intuitive: when you know the maximum length of a single document in your input, use a smaller context size that's just enough to cover it. For physical batch size, 512 tokens seems to be the sweet spot on L4 GPU.

Tokens per Second Performance
ubatch_size | ctx_size=64 | ctx_size=128 | ctx_size=256 | ctx_size=512 |
---|---|---|---|---|
64 | 2233 | 2093 | 2128 | 2125 |
128 | N/A | 2866 | 2821 | 2877 |
256 | N/A | N/A | 3287 | 3349 |
512 | N/A | N/A | N/A | 3469 |
ubatch_size | ctx_size=2048 | ctx_size=4096 | ctx_size=8192 | ctx_size=16384 |
---|---|---|---|---|
256 | 3971 | 3630 | 3593 | 2766 |
512 | 4143 | 3797 | 3758 | 2852 |
1024 | 4059 | 3742 | 3707 | 2822 |
2048 | 3957 | 3631 | 3603 | 2762 |
4096 | N/A | 3450 | 3410 | 2625 |
Peak VRAM Usage (MB)
ubatch_size | ctx_size=64 | ctx_size=128 | ctx_size=256 | ctx_size=512 |
---|---|---|---|---|
64 | 1691 | 1689 | 1689 | 1697 |
128 | N/A | 1729 | 1727 | 1737 |
256 | N/A | N/A | 1803 | 1811 |
512 | N/A | N/A | N/A | 1963 |
ubatch_size | ctx_size=2048 | ctx_size=4096 | ctx_size=8192 | ctx_size=16384 |
---|---|---|---|---|
256 | 1885 | 1947 | 2099 | 2409 |
512 | 2025 | 2101 | 2257 | 2577 |
1024 | 2329 | 2407 | 2571 | 2917 |
2048 | 2933 | 3025 | 3203 | 3597 |
4096 | N/A | 4285 | 4497 | 4985 |
tagConclusion
For v4 users who want to run quantized GGUF efficiently on budget GPUs, choose IQ3_S
or IQ3_M
with our custom build of llama-embedding
- this should give you 4000 tokens/sec on regular datasets (where sentence length is <2048 tokens). For embedding longer documents, increase the context size -c
and control the physical batch size -ub
to reduce VRAM footprint. With our custom build, you can encode super long documents (>32K tokens) using only 3GB VRAM by setting -ub
to a small number like 1024 - something that wasn't possible with the original implementation or vanilla transformers.
The quest for speed optimization never ends. There's always room for faster, leaner implementations with higher throughput. 4000 tokens/sec probably isn't our ceiling - there's plenty more work to be done. Beyond fixing the qwen2.5-vl-3b
mmproj/vision transformer implementation in llama.cpp
, we're also exploring deeper llama.graph and KV-cache level optimizations, improving llama-serving
batching logic, and adding streaming options to embedding APIs. Our goal is to make llama.cpp
natively support modern decoder-only multimodal embeddings for our current and future reranker releases.