Implementing the BGE-M3 Model
Key Points
- 1This paper presents a hands-on guide for implementing the BGE-M3 multilingual embedding model from scratch using TensorFlow-Keras, focusing on its core architecture composed mainly of Dense and LayerNormalization layers.
- 2It details the step-by-step construction of the model's components, including word, position, and token type embeddings, as well as the Transformer Block's Multi-Head Attention and Feed-Forward Network, adhering to the Roberta XL base structure.
- 3The complete TensorFlow implementation enables versatile deployment for inference across various platforms, highlighting its utility for tasks like Retrieval-Augmented Generation (RAG) and efficient multilingual search.
The paper details the implementation of the BGE-M3 (BAAI General Embedding - Multi-lingual, Multi-task, Multi-vector) model using TensorFlow-Keras, emphasizing its lean architecture that primarily relies on Dense layers and LayerNormalization for inference. BGE-M3 is a multilingual embedding model supporting over 70 languages, noted for its strong performance in Korean MTEB benchmarks and its utility in Retrieval-Augmented Generation (RAG) tasks.
The model's core architecture is based on the XLMRobertaModel, which is characterized by its simplicity, avoiding more recent complexities like Rotary Position Embedding (RoPE), Pre-Normalization, or linear bias removal. The paper highlights that the inference structure can be realized with just nine fundamental linear layers (Dense, Linear, MLP) and three LayerNormalizations per block, repeated across 24 Transformer blocks.
The implementation consists of three main parts:
- Embedding Layer:
- Word Embedding: Maps 250,002 tokens to 1024-dimensional vectors. Implemented via
tf.keras.layers.Embedding. - Position Embedding: Encodes positional information for up to 8,194 positions into 1024-dimensional vectors. Implemented via
tf.keras.layers.Embedding. Position IDs are generated dynamically usingcreate_position_ids_from_input_ids, which computes cumulative sums of a mask derived frominput_ids(where padding tokens are ignored). - Token Type Embedding: A single 1024-dimensional constant vector applied to all tokens, as BGE-M3 uses a single token type. Implemented via .
- All three embeddings are summed: .
- Finally, the combined embedding output undergoes LayerNormalization with : .
- Word Embedding: Maps 250,002 tokens to 1024-dimensional vectors. Implemented via
- Transformer Block: Each block is composed of Multi-Head Self-Attention (MHA) and a Feed-Forward Neural Network (FFNN), with residual connections and LayerNormalization applied after each sub-layer. The model uses 24 such blocks.
- Multi-Head Attention (MHA):
- Input
inputs(from embedding or previous block) is linearly transformed into Query (Q), Key (K), and Value (V) tensors using distincttf.keras.layers.Denselayers, each with . - Q, K, V are then split into 16 heads (), where each head has a
depthof . This involves reshaping and transposing: . - Scaled Dot-Product Attention: Attention scores are computed as . Here, is the dimension of the keys (64), so the scores are divided by .
- Attention
maskis applied by adding a large negative value () to padding token positions before softmax, effectively zeroing out their attention scores. Theextended_attention_maskis derived fromattention_mask_originby . - The output from all heads is concatenated and passed through a final
tf.keras.layers.Denselayer (). - A residual connection adds the initial
inputsto the attention output, followed by LayerNormalization: .
- Input
- Feed-Forward Neural Network (FFNN):
- The output from the attention sub-layer (
attention_output) is passed through anintermediatetf.keras.layers.Denselayer that expands the dimension to . - A GELU (Gaussian Error Linear Unit) approximation activation function is applied: .
- The result then passes through an
output_densetf.keras.layers.Denselayer, projecting back to . - A second residual connection adds
attention_outputto the FFNN output, followed by LayerNormalization: .
- The output from the attention sub-layer (
- Multi-Head Attention (MHA):
- Model Forward Flow and Output:
- Input
input_idsandattention_maskare processed through the embedding layer and then iteratively through 24 Transformer blocks. The output of the last block ishidden_states. - Dense Retrieval Output: The pooled output for dense retrieval is extracted as the first token's hidden state (CLS token): .
- Multi-Vector Retrieval Output: An additional
colbert_lineartf.keras.layers.Denselayer is applied to the non-CLS tokens ofhidden_states(i.e.,hidden_states[:, 1:]). This output is then masked by the original attention mask (excluding the CLS token) to zero out padding tokens: . - The model returns a dictionary containing
dense_vecsandcolbert_vecs. - The paper also provides instructions for loading the
colbert_linearweights, which are typically found as separate PyTorch files.
- Input
The paper concludes by demonstrating how to save the implemented TensorFlow-Keras model with a serving signature for deployment, enabling its use across various platforms and applications, from large-scale Hadoop/Spark jobs to mobile inference via TensorFlow Lite.