News
Models
Products
keyboard_arrow_down
Reader
Convert any URL to Markdown for better grounding LLMs.
Embeddings
World-class multimodal multilingual embeddings.
Reranker
World-class reranker for maximizing search relevancy.
DeepSearch
Search, read and reason until best answer found.
More
keyboard_arrow_down
Classifier
Zero-shot and few-shot classification for image and text.
Segmenter
Cut long text into chunks and do tokenization.

MCP Server
Add mcp.jina.ai as your MCP server to access our API in LLMs
open_in_new
API Docs
Auto codegen for your copilot IDE or LLM
open_in_new


Company
keyboard_arrow_down
About us
Contact sales
Intern program
Join us
open_in_new
Download logo
open_in_new
Terms & Conditions


Log in
login
Base and Quantized GGUFs
Usage and Caveats
Efficient Embedding via llama-embedding
Benchmark
Conclusion
Tech blog
August 13, 2025

Optimizing GGUFs for Decoder-Only Embedding Models

4000 tokens/sec for a 3B-parameter embedding model on L4 GPU is probably as fast as you'll get with llama.cpp. Or is it?
Han Xiao
Han Xiao • 15 minutes read

Two weeks ago, we released GGUF formats of jina-embeddings-v4 - a universal embedding model for multimodal multilingual retrieval - with various quantized versions. Our motivation was simple: as a 3.75B parameter model, the vanilla transformer version of jina-embeddings-v4 doesn't scale well on our GCP G2 (L4 GPU) API instances, so we wanted to speed up inference using these smaller, faster GGUF versions. During our experiments, we discovered some interesting findings while converting and running GGUF embedding models. Since most of the llama.cpp community focuses on LLMs, we thought it'd be valuable to share this from an embedding provider's perspective.

What's particularly relevant is that today's embedding models are almost identical to LLMs - for example, jina-embeddings-v4 is based on Qwen2.5-VL-3B-instruct and jina-reranker-m0 is based on Qwen2-VL-2B. The only real difference is the output: LLMs are generative, the embeddings and rerankers are discriminative. This creates both opportunities and challenges: on one hand, we can leverage llama.cpp's efficient implementation (e.g. ubatch_size) to serve embedding/reranker models; on the other hand, llama.cpp's embedding implementations were mostly developed for older encoder-only architectures (like RoBERTa-based models) and haven't fully caught up with modern decoder-only embedding/reranker models. This article shares what we learned while adapting modern embedding models to work with GGUF format and llama.cpp tooling, e.g. llama-embedding and llama-serving.

tagBase and Quantized GGUFs

jina-embeddings-v4 is based on Qwen2.5-VL-3B-instruct with three LoRA adapters: retrieval (optimized for retrieval tasks), text-matching (optimized for sentence similarity tasks), and code (optimized for code retrieval tasks). It is also heavily trained for visual document retrieval and late-interaction style multi-vector output. So the idea here is to leverage llama.cpp existing graph implementation of Qwen2.5-VL-3B and use llama-embedding for the inference.

However, the first thing we noticed was buggy behavior in the mmproj or vision transformer implementation in llama.cpp, which yields different embeddings w.r.t. the torch implementation of Qwen2.5-VL-3B given the same image input. While we're fixing this issue in our fork, we decided to exclude vision tower from the GGUF versions for now. You can find more details about this discussion here.

Multi-modal embeddings for jinaai/jina-embeddings-v4 · ggml-org llama.cpp · Discussion #14851
Hey folks! I’m working on getting multimodal embeddings working with jina-embeddings-v4 (based on Qwen 2.5 VL) through llama.cpp server. I’ve hit an issue with mtmd inconsistencies and was hoping s…
GitHubggml-org

Multi-vector embedding output isn't supported out of the box either, but it's not as big issue as the vision transformers. Multi-vector output comes from a trained MLP at the last transformer block, so at worst we can always export this MLP separately to numpy and apply it after getting the token-level embeddings from llama.cpp - which is what we did for jina-reranker-m0-GGUF. Sure, it's not very efficient, but it works without having to modify and recompile llama.cpp.

0:00
/0:04

We stripped out the vision transformer and the multi-vector projector and got three base GGUF models in F16.

So to fully comply with llama.cpp's existing Qwen2.5-VL-3B graph implementation, we stripped out the vision transformer and the multi-vector projector at the last transformer block and merged all LoRA adapters back into the base language model. This gave us three task-specific v4 models at 3.09B parameters each - down from the original v4's 3.75B parameters:

HuggingFace Repo Task
jinaai/jina-embeddings-v4-text-retrieval-GGUF Text retrieval
jinaai/jina-embeddings-v4-text-code-GGUF Code retrieval
jinaai/jina-embeddings-v4-text-matching-GGUF Sentence similarity

Then we used calibration_data_v5_rc.txt (which can be found here and is recommended by Unsloth) to calibrated all three base GGUF models and got three imatrix files, then used llama-quantize with imatrix to quantize the models from float16 as follows:

# build imatrix
llama-imatrix -m jina-embeddings-v4-text-retrieval-F16.gguf -f calibration_data_v5_rc.txt -ngl 99 --no-ppl -o imatrix-retrieval-512.dat

# quantize
./quantize.sh jina-embeddings-v4-text-retrieval-F16.gguf retrieval-i3 imatrix-retrieval-512.dat jinaai/jina-embeddings-v4-text-retrieval-GGUF

The quantize.sh script is shown below:

#!/bin/bash

F16_MODEL_FILE="$1"
OUTPUT_DIR="$2"
IMATRIX="$3"
HF_REPO="$4"

FILENAME="$(basename "$F16_MODEL_FILE")"
BASE_NAME="${FILENAME%-F16.gguf}"
BASE_NAME="${BASE_NAME%.gguf}"

mkdir -p "$OUTPUT_DIR"

# Array of quantization types
QUANT_TYPES=("IQ1_S" "IQ1_M" "IQ2_XXS" "IQ2_M" "Q2_K" "IQ4_NL" "IQ4_XS"  "IQ3_XXS" "IQ3_S" "IQ3_M" "IQ3_XS" "Q3_K_M" "Q4_K_M" "Q5_K_S" "Q5_K_M" "Q6_K" "Q8_0")

for quant_type in "${QUANT_TYPES[@]}"; do
    llama-quantize --imatrix "${IMATRIX}" "$F16_MODEL_FILE" "${OUTPUT_DIR}/${BASE_NAME}-${quant_type}.gguf" $quant_type 8
done

Eventually, we uploaded all quantizations to HuggingFace.

Quantization BPW File Size (GB)
IQ1_S 2.04 0.73
IQ1_M 2.19 0.79
IQ2_XXS 2.44 0.88
IQ2_M 2.94 1.06
Q2_K 3.29 1.18
IQ3_XXS 3.31 1.19
IQ3_XS 3.59 1.29
IQ3_S 3.76 1.35
IQ3_M 3.84 1.38
Q3_K_M 4.11 1.48
IQ4_NL 4.72 1.69
IQ4_XS 4.49 1.61
Q4_K_M 4.99 1.79
Q5_K_S 5.61 2.02
Q5_K_M 5.75 2.07
Q6_K 6.56 2.36
Q8_0 8.50 3.05
F16 16.00 5.75
v3 (Transformers) 16.00 1.10
v4 (Transformers) 16.00 7.40

tagUsage and Caveats

We can now use llama-server and llama-embedding to serve GGUFs for embedding. Unlike transformer libraries where we have the flexibility to write custom input preprocessing code, we have to handle this part manually (unless we want to recompile llama-server and llama-embedding). Specifically, to get results that are fully consistent with using AutoModel.from_pretrained("jinaai/jina-embeddings-v4")..., you need to be very careful about prefixes and manually add them to your GGUF model inputs. Here's a reference table:

Task prompt_name in Transformer implementation Actual input to the model
retrieval query (default) Query: {original_text}
retrieval passage Passage: {original_text}
text-matching query (default) Query: {original_text}
text-matching passage Query: {original_text} ⚠️
code query (default) Query: {original_text}
code passage Passage: {original_text}

Some users might find ⚠️ surprising that prompt_name='passage' gets overridden to "Query: " when using text-matching in the original AutoModel.from_pretrained("jinaai/jina-embeddings-v4").... But this actually makes sense since text-matching is a sentence similarity task with no left/right roles—the inputs are symmetric.

tagVia llama-server

After installing llama.cpp, you can run llama-server to host the embedding model as an OpenAI API-compatible HTTP server. For example, to use text-matching with F16, you can do:

llama-server -hf jinaai/jina-embeddings-v4-text-matching-GGUF:F16 --embedding --pooling mean -ub 8192

--pooling mean is required as v4 is mean-pooling embeddings.

Then send request via:

curl -X POST "http://127.0.0.1:8080/v1/embeddings" \
  -H "Content-Type: application/json" \
  -d '{
    "input": [
      "Query: A beautiful sunset over the beach",
      "Query: Un beau coucher de soleil sur la plage",
      "Query: 海滩上美丽的日落",
      "Query: 浜辺に沈む美しい夕日"
    ]
  }'

When using retrieval and code models, add Query: or Passage: in front of your input, like this:

curl -X POST "http://127.0.0.1:8080/v1/embeddings" \
  -H "Content-Type: application/json" \
  -d '{
    "input": [
      "Query: A beautiful sunset over the beach",
      "Query: Un beau coucher de soleil sur la plage",
      "Passage: 海滩上美丽的日落",
      "Passage: 浜辺に沈む美しい夕日"
    ]
  }'

tagVia llama-embedding

For a quick sanity check, you can also use the precompiled llama-embedding for one-shot embedding. We don't recommend using it for bulk embedding since it has some performance issues that we'll discuss in the next section:

llama-embedding -hf jinaai/jina-embeddings-v4-text-matching-GGUF:F16 --pooling mean -p "Query: jina is awesome" --embd-output-format json  2>/dev/null

Read the next section for more performant bulk embedding with our build of llama-embedding with some fix and improvements.

tagSummary of Caveats

Before moving to a more performant implementation, let's summarize the caveats of GGUF models:

  • You must manually add Query: or Passage: in front of the text inputs.
  • They can't handle image input right now because we removed the vision transformers from the GGUF model. We had to remove them due to bugs in llama.cpp's vision transformer/mmproj implementation of Qwen2.5-vl-3b, which we're working to fix with upstream.
  • They can't output multi-vector embeddings since it's not part of llama.cpp's Qwen2.5-vl-3b graph implementation. The easiest workaround without recompiling llama.cpp is to export and run the MLP separately after getting token-level embeddings by setting --pooling none in llama-embedding.
  • v4 is trained with Matryoshka representation learning, and converting to GGUF preserves this feature. If you get embeddings with shape NxD, you can simply use embeddings[:, :truncate_dim] to get smaller truncated embeddings. However, not every dimension is trained. For v4, we trained truncate_dim for these specific values: [128, 256, 512, 1024, 2048]. This means embeddings[:, :131] quality won't be some interpolation between the quality of embeddings[:, :128] and embeddings[:, :256], but will be significantly worse than either 128-dim or 256-dim embeddings because 131-dim is not trained.
  • Late chunking can still work as part of post-processing after getting token-level embeddings via --pooling none. Just like what we did with separating the MLP from the llama.cpp graph, this isn't super efficient but doesn't require recompiling. However, there's another caveat: since v4 is a causal model, late chunking won't be bidirectional anymore - earlier chunk embeddings won't contain contextual information from subsequent chunks. Remember that in v3, every chunk embedding had global context information because we used bidirectional attention masks in the transformer blocks. Internally, we discussed whether causality makes late chunking obsolete: some argue that "context is also causal" - meaning a reader processes text from left to right, so the context needed to interpret a sentence should come from preceding text. Others say restricting late chunking to be unidirectional blocks context sharing between chunks. Either way, the effectiveness of late chunking in v4 remains questionable and needs further study.

tagEfficient Embedding via llama-embedding

llama-embedding is a relatively simple C++ wrapper on top of llama.cpp for embedding text with very clean I/O: stdin, stdout. We're focusing on improving this rather than llama-server right now because there are tons of other problems like network queuing, load balancing, multi-tenancy, and serialization that we believe are out of scope at this point. Our question is straightforward: how much speed can we get from an L4 24GB GPU, and what's the peak VRAM usage for embedding long documents?

But why L4? Mainly because GCP offers pretty convenient Cloud Run functions on top of it, and it's the most widely available and economical GPU type you can get for serveless inference APIs. GCP does offer A100 and H100 on Cloud Run at request, and we do get pitched by GCP team from time to time to use better GPUs. But our philosophy is simple: if we need A100/H100 to serve a 3B model, that's clearly a skill issue on our part.

For some background, in llama.cpp, logical batch size (-b) represents the maximum number of tokens submitted to the model in a single evaluation call. When processing long inputs, they are split into chunks up to this size. The physical batch size (-ub) is the actual number of tokens processed simultaneously in one forward pass through the hardware, constrained by available memory. The KV-cache updates after each physical batch completes. The Context Window (-c) is the hard limit for how many tokens the model can "see" at once - for v4 models this is 32,000 tokens, representing the model's maximum attention span. All tokens must fit within this context window to maintain coherent attention across the entire sequence. The following figure illustrates their relationships.

tagOur Fixes

GitHub - hanxiao/llama.cpp: LLM inference in C/C++
LLM inference in C/C++. Contribute to hanxiao/llama.cpp development by creating an account on GitHub.
GitHubhanxiao

In our fork above, we made several optimizations to make llama-embedding more efficient:

  • Simplified batch handling: We automatically set -b equal to -c, effectively making this parameter obsolete. Users no longer need to specify -b since we always leverage the model's full context length for logical batching.
  • Flexible memory control: Unlike the original implementation where -ub was forced to equal -b (since they assumed embedding models can't be causal), we allow users to independently set -ub. This gives fine-grained control over peak VRAM usage when encoding long contexts - you can process a 32K context with a small 512-token physical batch to stay within VRAM limits thanks to the KV cache implementation. Note that this change is only correct for causal embedding models like jina-embeddings-v4 - for encoder-only architectures like v3, this would be the wrong implementation.
  • Fixed mean pooling: We corrected the mean pooling calculation for embeddings when ub < b, which was previously broken in the original implementation.

This change makes it much easier to work with long-context decoder-only embedding models while managing memory constraints effectively. Users now only need to configure two parameters:

    • -c: The maximum context length (how many tokens the embedding model can process)
    • -ub: The physical batch size (how many tokens the GPU processes at once)

So the exact code for running our fork on L4 is as follows:

# Compile
git clone https://github.com/hanxiao/llama.cpp.git
cd llama.cpp
cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release -j 8

# Run
INPUT_PREFIX="Query: "  # or "Passage: "

cat big_input.txt | sed "s/^/${INPUT_PREFIX}/" | \
./llama.cpp/build/bin/llama-embedding -f /dev/stdin \
    -hf "jinaai/jina-embeddings-v4-text-retrieval-GGUF:FP16" \
    --pooling mean \
    --no-escape \
    --embd-output-format array \
    --ubatch-size 512 \
    --ctx-size 8192 \
    --flash-attn \
    -ngl 99 \
    > "embeddings.txt" 2> "error.log"

Each line in big_input.txt is a sentence to be embedded. --no-escape should be set to prevent in-sentence \n from being interpreted as separators. --flash-attn and -ngl 99 should be set for best performance on L4 GPU.

tagBenchmark

We want to understand the following questions through benchmarking:

  • How good our quantization is compared to the original v4 Float16? At what point does it degrade so much that we'd be better off just using v3 embeddings?
  • How fast can each quantization run on L4, and what's the peak VRAM usage?
  • How do -ub physical batch size and -c context length affect speed and peak VRAM?

Datasets we used in benchmarking are:

Task Documents Queries Relevant Pairs Avg Doc Length Max Doc Length Avg Query Length
NanoHotpotQA 5,090 50 100 57.3 345 14.9
NanoSciFact 2,919 50 56 205.8 1524 13.5
NanoArguAna 3,635 50 50 164.5 1058 193.0
NanoNFCorpus 2,953 50 2,518 223.3 1460 3.3
NanoFiQA2018 4,598 50 123 159.1 1882 10.2

We used our custom build of llama-embedding for benchmarking.

tagQuality of Quantizations

The best performing quantized version is IQ3_M (3.84 BPW) - quantizations below 2 bits perform worse than v3, so there's little point in using them.

Quantization NanoHotpotQA NanoFiQA2018 NanoArguAna NanoNFCorpus NanoSciFact
IQ1_S 0.6369 0.3178 0.3798 0.2933 0.5934
IQ1_M 0.6316 0.3313 0.5167 0.3256 0.6114
IQ2_XXS 0.7236 0.4582 0.4584 0.4067 0.7392
IQ2_M 0.7427 0.5869 0.5090 0.4468 0.7880
Q2_K 0.7683 0.5744 0.5168 0.4183 0.7546
IQ3_XXS 0.7780 0.5991 0.4811 0.4267 0.7610
IQ3_XS 0.7727 0.5615 0.5195 0.4439 0.7726
IQ3_S 0.8002 0.5505 0.4886 0.4381 0.7690
IQ3_M 0.8106 0.5387 0.5091 0.4462 0.7760
Q3_K_M 0.7567 0.5267 0.4486 0.4092 0.7775
IQ4_NL 0.7930 0.5598 0.4911 0.4285 0.7794
IQ4_XS 0.7979 0.5627 0.4947 0.4258 0.7789
Q4_K_M 0.8029 0.5569 0.4883 0.4226 0.7877
Q5_K_S 0.7969 0.5581 0.4721 0.4288 0.7842
Q5_K_M 0.7927 0.5601 0.4745 0.4247 0.7873
Q6_K 0.7951 0.5636 0.4822 0.4337 0.7846
Q8_0 0.7938 0.5687 0.4784 0.4335 0.7851
F16 0.7940 0.5610 0.4931 0.4343 0.7963
v3 (Transformers) 0.7393 0.5144 0.4600 0.4068 0.7820
v4 (Transformers) 0.7977 0.5571 0.4844 0.4351 0.7963

tagSpeed and VRAM

We now fix the benchmark dataset to NanoHotpotQA and plot all quantizations by their bits per weight versus speed (measured in tokens per second) and VRAM consumption. We found that GGUF versions are slightly faster than the vanilla version at FP16 (2023 vs 1865 tokens/sec). Most quantizations cluster around 2000-2100 tokens/sec. With flash attention enabled, we get ~77% speedup across all quantizations (3000+ vs 2000+ tokens/sec). However, the best performing quantization Q8_0 at around 3700 tokens per second is still far behind vanilla v3 (572M parameters), which hits 16000 tokens/sec. Quantized versions save considerable VRAM and nearly approach the level of the v3 FP16 model with IQ3.

Quantization BPW File Size (GB) Peak VRAM (GB) Token/s w FA Token/s w/o FA
IQ1_S 2.04 0.73 4.04 3625 2050
IQ1_M 2.19 0.79 4.09 3349 1997
IQ2_XXS 2.44 0.88 4.19 3701 2071
IQ2_M 2.94 1.06 4.37 3407 1989
Q2_K 3.29 1.18 4.49 3173 1905
IQ3_XXS 3.31 1.19 4.50 3668 2067
IQ3_XS 3.59 1.29 4.60 3604 2053
IQ3_S 3.76 1.35 4.66 3599 2049
IQ3_M 3.84 1.38 4.69 3603 2053
Q3_K_M 4.11 1.48 4.78 3450 2008
IQ4_NL 4.72 1.69 5.00 3571 2039
IQ4_XS 4.49 1.61 4.92 3585 2046
Q4_K_M 4.99 1.79 5.10 3558 2045
Q5_K_S 5.61 2.02 5.32 3567 2044
Q5_K_M 5.75 2.07 5.38 3528 2034
Q6_K 6.56 2.36 5.66 3334 1981
Q8_0 8.50 3.05 6.36 3767 2101
F16 16.00 5.75 9.70 3399 2023
v3 (Transformers) 16.00 1.10 2.82 16505
v4 (Transformers) 16.00 7.40 14.45 1865

Click to expand system info

load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 36 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 37/37 layers to GPU
load_tensors: CUDA0 model buffer size = 3127.61 MiB
load_tensors: CPU_Mapped model buffer size = 315.30 MiB
...................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch = 4096
llama_context: n_ubatch = 4096
llama_context: causal_attn = 1
llama_context: flash_attn = 1 // 1 for w/ FA in the table; 0 for w/o FA
llama_context: kv_unified = true
llama_context: freq_base = 1000000.0
llama_context: freq_scale = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (128000) -- the full capacity of the model will not be utilized
llama_context: CUDA_Host output buffer size = 0.59 MiB
llama_kv_cache_unified: CUDA0 KV buffer size = 144.00 MiB
llama_kv_cache_unified: size = 144.00 MiB ( 4096 cells, 36 layers, 1/1 seqs), K (f16): 72.00 MiB, V (f16): 72.00 MiB
llama_context: CUDA0 compute buffer size = 2470.16 MiB
llama_context: CUDA_Host compute buffer size = 96.17 MiB
llama_context: graph nodes = 1234
llama_context: graph splits = 2
common_init_from_params: added <|endoftext|> logit bias = -inf
common_init_from_params: added <|im_end|> logit bias = -inf
common_init_from_params: added <|fim_pad|> logit bias = -inf
common_init_from_params: added <|repo_name|> logit bias = -inf
common_init_from_params: added <|file_sep|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

system_info: n_threads = 4 (n_threads_batch = 4) / 8 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VNNI = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |
main: n_tokens in batch = 0
main: number of embeddings = 5090

tagOptimal Physical Batch and Context Size

We now fix the quantization type to IQ3_S and examine how physical batch size (-ub) and context size (-c) affect speed and VRAM. The results on L4 GPU show that -ub=512 with -c=2048 provides the optimal configuration, delivering 4,143 tokens/sec while using 2,025MB VRAM. The takeaway is intuitive: when you know the maximum length of a single document in your input, use a smaller context size that's just enough to cover it. For physical batch size, 512 tokens seems to be the sweet spot on L4 GPU.

Tokens per Second Performance

ubatch_size ctx_size=64 ctx_size=128 ctx_size=256 ctx_size=512
64 2233 2093 2128 2125
128 N/A 2866 2821 2877
256 N/A N/A 3287 3349
512 N/A N/A N/A 3469
ubatch_size ctx_size=2048 ctx_size=4096 ctx_size=8192 ctx_size=16384
256 3971 3630 3593 2766
512 4143 3797 3758 2852
1024 4059 3742 3707 2822
2048 3957 3631 3603 2762
4096 N/A 3450 3410 2625

Peak VRAM Usage (MB)

ubatch_size ctx_size=64 ctx_size=128 ctx_size=256 ctx_size=512
64 1691 1689 1689 1697
128 N/A 1729 1727 1737
256 N/A N/A 1803 1811
512 N/A N/A N/A 1963
ubatch_size ctx_size=2048 ctx_size=4096 ctx_size=8192 ctx_size=16384
256 1885 1947 2099 2409
512 2025 2101 2257 2577
1024 2329 2407 2571 2917
2048 2933 3025 3203 3597
4096 N/A 4285 4497 4985

tagConclusion

For v4 users who want to run quantized GGUF efficiently on budget GPUs, choose IQ3_S or IQ3_M with our custom build of llama-embedding - this should give you 4000 tokens/sec on regular datasets (where sentence length is <2048 tokens). For embedding longer documents, increase the context size -c and control the physical batch size -ub to reduce VRAM footprint. With our custom build, you can encode super long documents (>32K tokens) using only 3GB VRAM by setting -ub to a small number like 1024 - something that wasn't possible with the original implementation or vanilla transformers.

The quest for speed optimization never ends. There's always room for faster, leaner implementations with higher throughput. 4000 tokens/sec probably isn't our ceiling - there's plenty more work to be done. Beyond fixing the qwen2.5-vl-3b mmproj/vision transformer implementation in llama.cpp, we're also exploring deeper llama.graph and KV-cache level optimizations, improving llama-serving batching logic, and adding streaming options to embedding APIs. Our goal is to make llama.cpp natively support modern decoder-only multimodal embeddings for our current and future reranker releases.

Categories:
Tech blog
rss_feed

Read more
July 31, 2025 • 12 minutes read
How Image Resolution Impacts Visual Document Retrieval
Maximilian Werk
Michael Günther
Scott Martens
Abstract composition with a dark background featuring a flower-like design, radiant eye-like feature, rainbow-colored curved
July 14, 2025 • 11 minutes read
Submodular Optimization for Text Selection, Passage Reranking & Context Engineering
Han Xiao
Network illustration of interconnected hexagons, some solid and some hollow blue, connected by red lines indicating paths or
July 04, 2025 • 13 minutes read
Submodular Optimization for Diverse Query Generation in DeepResearch
Han Xiao
Black and white typographic design of "1993" with a 3D effect, minimalistic black border, and a sense of depth on a white bac
Offices
location_on
Sunnyvale, CA
710 Lakeway Dr, Ste 200, Sunnyvale, CA 94085, USA
location_on
Berlin, Germany (HQ)
Prinzessinnenstraße 19-20, 10969 Berlin, Germany
location_on
Beijing, China
Level 5, Building 6, No.48 Haidian West St. Beijing, China
location_on
Shenzhen, China
402 Floor 4, Fu'an Technology Building, Shenzhen, China
Search Foundation
Reader
Embeddings
Reranker
DeepSearch
Classifier
Segmenter
API Documentation
Get Jina API key
Rate Limit
API Status
Company
About us
Contact sales
Newsroom
Intern program
Join us
open_in_new
Download logo
open_in_new
Terms
Security
Terms & Conditions
Privacy
Manage Cookies
email
Jina AI © 2020-2025.