Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
Key Points
- 1This guide details the process of fine-tuning multimodal language models, specifically Gemma 3, using Supervised Fine-Tuning (SFT) within the TRL library.
- 2It covers two scenarios: single-image + text datasets (like LLaVA Instruct Mix) and multi-image + text datasets (like MMIU-Benchmark), emphasizing the necessary data preprocessing and conversational formatting for multi-image inputs.
- 3The fine-tuning procedure involves setting up the environment, loading datasets, preparing the model with BitsAndBytes and QLoRA, configuring training arguments with SFTConfig, and implementing a custom `collate_fn` to handle and mask multimodal inputs.
This document details the process of fine-tuning multimodal language models, specifically exemplified with Gemma 3, using Supervised Fine-Tuning (SFT) within the TRL library. The guide addresses two distinct use cases: fine-tuning with single image and text data, and fine-tuning with multi-image and text data (interleaving).
For the single image + text scenario, the HuggingFaceH4/llava-instruct-mix-vsft dataset is utilized. This dataset comprises conversations where a user provides a single image and text, and the assistant responds based on both modalities. For multi-image + text, the FanqingM/MMIU-Benchmark dataset is employed, which features a context, a question, a series of related images, and an expected answer, requiring the model to reason over multiple visual inputs.
The fine-tuning methodology involves several key steps:
- Environment Setup: Dependencies such as
trl,bitsandbytes,peft,hf_xet, andtensorboardare installed. Access to gated models like Gemma 3 requires logging into the Hugging Face Hub.
- Data Loading and Preprocessing:
- Single Image + Text: The dataset is directly loaded using
datasets.load_datasetas it's already formatted for SFT. - Multi-Image + Text: The raw
FanqingM/MMIU-Benchmarkdataset is loaded. It then undergoes a preprocessing step to convert raw data (including zipped image files) into a conversational format suitable for multimodal models. This involves extracting images from ZIP files, converting them toPIL.Image.Imageobjects, and structuring samples as a list of message dictionaries. Each message typically includes arole(e.g., "system", "user", "assistant") andcontent, where content can be a list containing objects oftype: "text"ortype: "image"with the actual image data. An example conversational structure is provided:
- Single Image + Text: The dataset is directly loaded using
{
"role": "system",
"content": [{"type": "text", "text": "..."}]
},
{
"role": "user",
"content": [<images_list>, {"type": "text", "text": "..."}]
},
{
"role": "assistant",
"content": [{"type": "text", "text": "..."}]
}The
prepare_dataset function handles downloading and extracting image archives, while format_data structures the text and image inputs into the required conversational format.- Model and Processor Preparation: The
google/gemma-3-4b-itmodel is loaded usingtransformers.AutoModelForImageTextToTextand its correspondingAutoProcessor. To optimize memory usage,BitsAndBytesConfigis employed for 4-bit quantization (, , , ). Theattn_implementationis set to"eager". The tokenizer'spadding_sideis set to"right".
- QLoRA Configuration: Quantized Low-Rank Adaptation (QLoRA) is configured using
peft.LoraConfig. Key hyperparameters include , , , , , and . Additionally, are specified to ensure these essential parts are not quantized or are trained fully.
- Training Arguments:
trl.SFTConfigis used to define training parameters such asoutput_dir(e.g., "gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft"), ,per_device_train_batch_size(8 for single-image, 1 for multi-image),gradient_accumulation_steps(4 for single-image, 1 for multi-image), , , , , , , (with ), , and .
- Data Collator (
collate_fn): This crucial function prepares batches for training.- It applies the chat template to convert message lists into raw text strings.
- It extracts images, converting them to RGB format. The
process_vision_infohelper function handles the extraction ofPIL.Image.Imageobjects from the potentially nestedmessagesstructure for multi-image cases. - It tokenizes both texts and images using the
processor, returning PyTorch tensors with padding. - It sets
labelsto a clone ofinput_ids. - Crucially, it masks specific tokens by setting their corresponding label entries to -100, ensuring they are ignored during loss computation:
processor.tokenizer.pad_token_idprocessor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])(beginning of image token ID)- The hardcoded ID
262144(corresponding to the ).
- Training Execution: The
trl.SFTTraineris instantiated with the model, training arguments, data collator, training dataset (dataset["train"]for single-image,dataset["test"]for multi-image), processor, and PEFT configuration. Thetrainer.train()method then initiates the fine-tuning process. After training,trainer.save_model()saves the fine-tuned model.
The process leverages Hugging Face ecosystem tools for efficient multimodal model fine-tuning, including transformers for model and processor handling, bitsandbytes for quantization, peft for LoRA, and trl for the SFT training loop. Results are automatically logged to tools like TensorBoard or Weights & Biases. The document also notes existing limitations specific to Gemma fine-tuning.