We have introduced multiple fixes in our fork of llama.cpp, so that it works with jina-embeddings-v4 on multimodal embeddings.
jina-embeddings-v4 introduces state-of-the-art multimodal embeddings that can process text, images, and complex visual documents for vector search. A few weeks ago, we released v4's GGUFs and dynamic quantizations for text-only tasks, which offer a smaller VRAM footprint and improved performance. However, multimodal embedding support on GGUF was still missing. To complete the picture, we've now figured out how to generate multimodal embeddings with llama.cpp and GGUF. Check out this README file for the full walkthrough.
To be fair, llama.cpp upstream does have some support for multimodal input, but since most of the llama.cpp community focuses on LLMs and text generation, support for multimodal embedding output is completely missing. In this article, we'll explain how we implemented multimodal embeddings in llama.cpp and examine how it performs (along with two quantized versions) compared to the PyTorch version of jina-embeddings-v4. Throughout this article, we'll refer to the PyTorch version as our "reference model."
tagUnderstand Image Input in Llama.cpp
Let’s first recap how multimodal embeddings is done with our reference model. First, you pair each image input with a special prompt:
<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n
Then the model preprocesses the image, encodes it (via its ViT), and then processes the entire interleaved sequence in a single forward pass.
When it comes to llama.cpp, however, things are trickier. While it supports image inputs for chat completion, it doesn’t support multimodal inputs—namely, inputs (like the above) that combine both a text and an image. This is exactly why we forked llama.cpp, changing the embedding handler to accept base64-encoded images, thus letting us process multimodal content in a similar way to the chat completion handler.
So now, to work with multimodal inputs in llama.cpp, we can start with a prompt similar to the one used by the reference model:
<|im_start|>user\n<__image__>Describe the image.<|im_end|>\n
The process works as follows:
- llama.cpp surrounds the
<__image__>
token with vision markers<|vision_start|>
and<|vision_end|>
, giving us something like this:<|im_start|>user\n<|vision_start|><__image__><|vision_end|>Describe the image.<|im_end|>\n
- The tokenizer replaces the special
<__image__>
token with-1
when tokenizing the prompt, internally signaling that the sequence contains an image that should be encoded before (being processed later). - The text tokens before the
<__image__>
marker (namely,<|im_start|>user\n<|vision_start|>
) are decoded via the LLM and injected into theKVCache
. - The image is encoded via the ViT component, outputting a series of image tokens that are decoded via the LLM. Although these tokens are processed separately from the tokens in step one, the attention layers can still attend to those text tokens via the
KVCache
. However, at this point, the attention layers can’t attend to any later text tokens (<|vision_end|>Describe the image.<|im_end|>\n
). - The LLM decodes any remaining text tokens (
<|end_vision|>Describe the image.<|im_end|>\n
). Now the attention layers can attend to all the earlier tokens (both text and image) via theKVCache
.
The embedding inference process (image encoding and text/image token decoding) is shown in the figure below:

tagAttention on Image Tokens
Because of the attention mechanism, this multi-step process can be problematic for some models. Let’s quickly recap the different types of attention used in models:
- Causal attention - the attention mechanism for one token at position
k
attends only to previous tokens, in positions[0:k-1]
. - Non-causal attention - the attention mechanism for one token at position
k
attends to all tokens in the sequence[0:n]
.
The figure below shows the tokens that the attention mechanism would attend to when processing img_tok_n
in the second step:

When processing img_tok_n
, the state of the model is as follows:
- All previous text tokens (
<|im_start|>
,user
,\n
,<|vision_start|>
) have already been processed and saved in theKVCache
. - All image tokens (
img_tok_1
toimg_tok_n
) are processed at this point, as part of the same sequence. - All following text tokens (
<|vision_end|>
,Describe
, etc. ) will be processed later.
In the case of causal attention, only previous tokens are considered when calculating attention scores, with past tokens being retrieved via the KVCache
.
In the case of non-causal attention, all tokens should be considered. However, future text tokens (<|vision_end|>
, Describe
, etc.) have not yet been processed. They will be processed in a future step, so things break fast.
Since jina-embeddings-v4 uses causal attention, the multi-step process works without problems, but for other models, this might not be the case.
In terms of embeddings, the hidden states of each token are captured exactly when processed and combined into a single sequence at the end. Currently, normalization and pooling are handled in Python, but (given some extra work) this could also be done on the llama.cpp side.
tagOur Fixes
After enabling image inputs for the embeddings endpoint in the llama.cpp server, we started testing the implementation with benchmarks and saw surprisingly large differences compared to our reference model. We suspected there had to be something wrong with llama.cpp’s implementation of the ViT used by Qwen2.5-VL to encode images into image patch embeddings (dense vector representation of image squares) that Qwen2.5 LLM can process.
Here’s an example of how the ViT outputs differ between the reference model and llama.cpp implementation:
=== vit_out reference === Shape: [1008, 1280]
Logging patch 0, dimensions 0-9
Patch 0: -0.375000 -0.250000 -4.281250 -5.968750 2.953125 -8.125000 8.625000 -9.250000 8.937500 -0.332031 ... (dims 10-1279)
... (patches 1-1007 not shown)
=== vit_out llama.cpp === Shape: [1280, 1008, 1, 1]
Logging patch 0, dimensions 0-9
Patch 0: -2.998136 -2.226554 0.233671 -7.486460 0.596918 -12.889042 8.904849 -8.6
... (patches 1-1007 not shown)
As you can see, the differences are quite noticeable. To confirm this was the only problem, we precomputed the image tokens in Python, then decoded them using llama.cpp’s implementation of Qwen2.5 (using only the LLM), hoping the resulting embeddings would match the reference model’s values much more closely—this was not the case, however.
tagFix #1: Causal Attention Mask for Attention Layers
We continued debugging by looking at the attention layers—the most probable cause for the number differences. We noticed that the attention mask used by the attention layers wasn’t calculated properly for image tokens. To see this, we can go back to our example sequence:
<|im_start|>user\n<|vision_start|><__image__><|vision_end|>Describe the image.<|im_end|>\n
When processing image tokens, the <__image__>
marker gets unpacked into something like img_tok_1 img_tok_2 .... img_tok_last
. So the full sequence would be:
<|im_start|>user\n<|vision_start|> img_tok_1 img_tok_2 ... img_tok_last <|vision_end|>Describe the image.<|im_end|>\n
<|im_start|>
, user
, etc.), and image tokens ( img_tok_1
, img_tok_2
, etc.), are dense vectors and not literal text tokens. We use this form of the sequence to make the explanation simpler.When decoding img_tok_2
, the attention mechanism should attend to all previous tokens, namely:
<|im_start|>user\n<|vision_start|> img_tok_1
However, a bug in the attention mask was instead causing the mechanism to attend to the entire image sequence, like so:
<|im_start|>user\n<|vision_start|> img_tok_1 img_tok_2 ... img_tok_last
After we fixed this bug, our llama.cpp model’s embeddings (using the pre-computed image tokens from the ViT of the Torch model) finally matched the reference model embeddings (within a small margin of error).
tagFix #2: Image Processing and Patch Embeddings
llama.cpp’s ViT encoder also produced different image embeddings from the reference model, with numbers diverging immediately after pre-processing. This was particularly evident during the initial patch-creation step, where both our reference model and llama.cpp split the raw image (pixel values) into patches that are encoded via convolutional layers. The differences between the raw patches (before ViT processing) can be seen below:
=== raw_patches reference === Shape: [1008, 1176]
Logging patches 0-4, dimensions 0-9
Patch 0: 0.484375 0.484375 0.500000 0.500000 0.470703 0.470703 0.470703 0.484375 0.470703 0.484375 ... (dims 10-1175)
... (patches 1-1007 not shown)
=== raw_patches llama.cpp === Shape: [1176, 1008, 1, 1]
Logging patches 0-4, dimensions 0-9
Patch 0: 0.455895 0.455895 0.455895 0.455895 0.455895 0.455895 0.470494 0.470494 0.470494 0.470494 ... (dims 10-1175)
... (patches 1-1007 not shown)
Our reference model and llama.cpp process these patches in different ways:
- The reference model groups the pixel values using reshape operations and then uses a single conv3d layer to encode the pre-grouped pixel patches.
- The llama.cpp model creates and encodes these patches with two conv2d layers
To bring the llama.cpp model’s embeddings closer to those of the reference model, we thought it would be simpler to use the reference model’s exact operations rather than debug llama.cpp’s approach.
Our reference model generates pixel patches using complex reshape and transpose operations that require 9-dimensional tensors. The low-level tensor processing library used in llama.cpp — ggml
—cannot support them, so to get around this, we generated the patches using a separate Python service that calls the llama.cpp server via HTTP.
ggml
also lacks support for conv3d layers. In our reference model, the conv3d layer configuration looks like this:
kernel_size = [
2, # temporal_patch_size,
14, # patch_size
14 # patch_size
]
proj = nn.Conv3d(
3, # in_channels
1152, # embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False
)
You can see that stride
and kernel_size
are the same, meaning we can simply flatten the inputs and the weights of the conv3d layer and perform a simple matrix-multiplication operation instead. To do this, we modified the conversion script in llama.cpp (convert_hf_to_gguf.py
) to export a flattened version of the conv3d weights for the patch projection layer:
if 'patch_embed.proj.weight' in name:
c1, c2, kt, kh, kw = data_torch.shape
# Note: this part of the script also exports other versions of this layer
# Only showing the relevant parts
# Flat matmul weight: row-major [out, in*kT*kH*kW] = [embed_dim, 1176]
W_flat = data_torch.contiguous().view(c1, -1)
outputs.append(("v.patch_embd.weight_flat", W_flat))
To apply the matmul operation instead of the two conv2d layers in llama.cpp, we modified the code that builds the graph of the Qwen2.5-VL ViT:
ggml_tensor * build_inp_raw_precomputed() {
ggml_tensor * inp_raw = ggml_new_tensor_2d(
ctx0,
GGML_TYPE_F32,
img.p_dim,
img.npx * img.npy
);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
return inp_raw;
}
ggml_cgraph * build_qwen2vl() {
// NOTE: only showing the code we've added for using pre-arranged image patches
const bool uses_precomputed_image = img.is_precomputed;
ggml_tensor * inp = nullptr;
if (uses_precomputed_image) {
ggml_tensor * inp_raw = build_inp_raw_precomputed();
cb(inp_raw, "inp_raw", -1);
inp = ggml_mul_mat(ctx0, model.patch_embeddings_flat, inp_raw);
} else {
// Usual 2x conv2d path
}
// rest of the code
}
With these final changes, the final image embeddings were within a 2% margin of error when compared to the reference model (as can be seen in the evaluation table in the next section).
tagEvaluation
After making these changes, we evaluated the llama.cpp model against our reference model on ViDoRe tasks using the MTEB benchmark. You can see the script and instructions to replicate these results in our llama.cpp fork as well as two quantized versions.
Task | Reference model | llama.cpp (F16) | llama.cpp (Q4_K_M) | llama.cpp (IQ4_XS) |
---|---|---|---|---|
VidoreArxivQARetrieval | 83.55 | 85.00 | 84.38 | 84.34 |
VidoreDocVQARetrieval | 50.53 | 52.02 | 51.93 | 51.57 |
VidoreInfoVQARetrieval | 87.77 | 87.31 | 87.61 | 87.28 |
VidoreShiftProjectRetrieval | 84.07 | 82.25 | 82.56 | 81.73 |
VidoreSyntheticDocQAAIRetrieval | 97.52 | 96.71 | 97.28 | 97.15 |
VidoreSyntheticDocQAEnergyRetrieval | 91.22 | 90.34 | 90.47 | 90.30 |
VidoreSyntheticDocQAGovernmentReportsRetrieval | 91.61 | 93.84 | 93.47 | 94.47 |
VidoreSyntheticDocQAHealthcareIndustryRetrieval | 95.42 | 96.08 | 95.67 | 96.05 |
VidoreTabfquadRetrieval | 94.52 | 94.94 | 94.83 | 94.72 |
VidoreTatdqaRetrieval | 65.52 | 64.85 | 64.63 | 64.76 |
Average | 84.17 | 84.33 | 84.28 | 84.23 |
Looking at the results table, on average, the llama.cpp model and its quantized variants don’t diverge much from the reference model.
To compare the models in more depth, we used images from different domains and with different resolutions, plotting the distance between image patch embeddings (before pooling/normalization). The redder the patch, the greater the cosine distance between the vectors of the reference model and llama.cpp model for that particular patch.


Figure 3: Page from jina-embeddings-v4 technical report, 372 × 526 resolution (left), 2481 × 3508 resolution (right)


Figure 4: Screenshot from Jina AI website, 594 × 428 resolution (left), 1982 × 1428 resolution (right)


Figure 5: Tokyo, Shibuya by S K on Unsplash, 383 × 255 resolution (left), 5472 × 3649 resolution (right)
We aimed to spot any patterns that went beyond numerical precision differences—patterns that may reveal further bugs or differences between our models. However, no particular patterns were visible, except that the number of diverging patches increases with image resolution. These differences most likely appear due to backend differences, and not because of any particular bugs in the implementation of Qwen2.5-VL (the backbone model of jina-embeddings-v4).
Nonetheless, we need to reiterate that these differences are minimal, and the benchmark results also reflect this fact. Overall, the llama.cpp models perform as well as the reference model, albeit while producing slightly different embedding vectors.
tagRemaining Issues
There are several potential areas of improvement for multimodal embeddings in llama.cpp:
- Quantizing the vision encoder. Currently, llama.cpp only supports quantization for LLMs, but to achieve better scaling we would also like to quantize the vision encoder.
llama-llava-quantize-cli
, but the relevant resources have been removed since the introduction of the mtmd
library.- Separating the vision encoder into a dedicated service. Vision encoders typically use non-causal masking, meaning any given image needs to be encoded within a single forward call. Therefore, we can’t make use of continuous batching. However, we could look at separating the vision encoder into a separate service, which would batch together multiple images (even from separate sources) and encode them all in a single forward pass. This would mean higher vRAM requirements, but would also be much faster than encoding each image one-by-one. This separation would also mean we could auto-scale the vision encoder independently of the language model.
- Enable multi-vector embeddings. In this article, we’ve only worked with single-vector embeddings. But to make full use of jina-embeddings-v4, we’d also like to enable multi-vector embeddings to achieve higher accuracy on complex images. This would be an easy addition since these embeddings are generated with a single linear layer on top of the base model.
tagConclusion
Despite the initial bugs and setbacks, integrating multimodal embeddings into llama.cpp now yields results closely matching our reference PyTorch model, including on a range of benchmark tasks. Fixes to the attention mask and image patch processing removed the main sources of divergence, and even the quantized variants achieve similar accuracy while using far fewer resources. The remaining differences at higher image resolutions appear minor and are likely due to backend variations rather than the core model implementation.
Looking ahead, extending quantization to the vision encoder, enabling batched processing through a separate service, and supporting multi-vector embeddings would further improve both efficiency and accuracy. These additions would make multimodal embeddings in llama.cpp more scalable and better suited for real-world use cases.