Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
Blog

Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)

2025.04.20
Β·Hugging FaceΒ·by Anonymous
#Multimodal Model#SFT#Fine-tuning#VLM#LLM

Key Points

  • 1This guide details the process of fine-tuning multimodal language models, specifically Gemma 3, using Supervised Fine-Tuning (SFT) within the TRL library.
  • 2It covers two scenarios: single-image + text datasets (like LLaVA Instruct Mix) and multi-image + text datasets (like MMIU-Benchmark), emphasizing the necessary data preprocessing and conversational formatting for multi-image inputs.
  • 3The fine-tuning procedure involves setting up the environment, loading datasets, preparing the model with BitsAndBytes and QLoRA, configuring training arguments with SFTConfig, and implementing a custom `collate_fn` to handle and mask multimodal inputs.

This document details the process of fine-tuning multimodal language models, specifically exemplified with Gemma 3, using Supervised Fine-Tuning (SFT) within the TRL library. The guide addresses two distinct use cases: fine-tuning with single image and text data, and fine-tuning with multi-image and text data (interleaving).

For the single image + text scenario, the HuggingFaceH4/llava-instruct-mix-vsft dataset is utilized. This dataset comprises conversations where a user provides a single image and text, and the assistant responds based on both modalities. For multi-image + text, the FanqingM/MMIU-Benchmark dataset is employed, which features a context, a question, a series of related images, and an expected answer, requiring the model to reason over multiple visual inputs.

The fine-tuning methodology involves several key steps:

  1. Environment Setup: Dependencies such as trl, bitsandbytes, peft, hf_xet, and tensorboard are installed. Access to gated models like Gemma 3 requires logging into the Hugging Face Hub.
  1. Data Loading and Preprocessing:
    • Single Image + Text: The dataset is directly loaded using datasets.load_dataset as it's already formatted for SFT.
    • Multi-Image + Text: The raw FanqingM/MMIU-Benchmark dataset is loaded. It then undergoes a preprocessing step to convert raw data (including zipped image files) into a conversational format suitable for multimodal models. This involves extracting images from ZIP files, converting them to PIL.Image.Image objects, and structuring samples as a list of message dictionaries. Each message typically includes a role (e.g., "system", "user", "assistant") and content, where content can be a list containing objects of type: "text" or type: "image" with the actual image data. An example conversational structure is provided:
json{ "role": "system", "content": [{"type": "text", "text": "..."}] }, { "role": "user", "content": [<images_list>, {"type": "text", "text": "..."}] }, { "role": "assistant", "content": [{"type": "text", "text": "..."}] }

The prepare_dataset function handles downloading and extracting image archives, while format_data structures the text and image inputs into the required conversational format.

  1. Model and Processor Preparation: The google/gemma-3-4b-it model is loaded using transformers.AutoModelForImageTextToText and its corresponding AutoProcessor. To optimize memory usage, BitsAndBytesConfig is employed for 4-bit quantization (loadin4bit=Trueload_in_4bit=True, bnb4bitusedoublequant=Truebnb_4bit_use_double_quant=True, bnb4bitquanttype="nf4"bnb_4bit_quant_type="nf4", bnb4bitcomputedtype=torch.bfloat16bnb_4bit_compute_dtype=torch.bfloat16). The attn_implementation is set to "eager". The tokenizer's padding_side is set to "right".
  1. QLoRA Configuration: Quantized Low-Rank Adaptation (QLoRA) is configured using peft.LoraConfig. Key hyperparameters include loraalpha=16lora_alpha=16, loradropout=0.05lora_dropout=0.05, r=16r=16, bias="none"bias="none", targetmodules="allβˆ’linear"target_modules="all-linear", and tasktype="CAUSALLM"task_type="CAUSAL_LM". Additionally, modulestosave=["lmhead","embedtokens"]modules_to_save=["lm_head", "embed_tokens"] are specified to ensure these essential parts are not quantized or are trained fully.
  1. Training Arguments: trl.SFTConfig is used to define training parameters such as output_dir (e.g., "gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft"), numtrainepochs=1num_train_epochs=1, per_device_train_batch_size (8 for single-image, 1 for multi-image), gradient_accumulation_steps (4 for single-image, 1 for multi-image), optim="adamwtorchfused"optim="adamw_torch_fused", savestrategy="epoch"save_strategy="epoch", learningrate=2eβˆ’05learning_rate=2e-05, bf16=Truebf16=True, pushtohub=Truepush_to_hub=True, reportto="tensorboard"report_to="tensorboard", gradientcheckpointing=Truegradient_checkpointing=True (with usereentrant=Falseuse_reentrant=False), datasetkwargs="skippreparedataset":Truedataset_kwargs={"skip_prepare_dataset": True}, and removeunusedcolumns=Falseremove_unused_columns=False.
  1. Data Collator (collate_fn): This crucial function prepares batches for training.
    • It applies the chat template to convert message lists into raw text strings.
    • It extracts images, converting them to RGB format. The process_vision_info helper function handles the extraction of PIL.Image.Image objects from the potentially nested messages structure for multi-image cases.
    • It tokenizes both texts and images using the processor, returning PyTorch tensors with padding.
    • It sets labels to a clone of input_ids.
    • Crucially, it masks specific tokens by setting their corresponding label entries to -100, ensuring they are ignored during loss computation:
      • processor.tokenizer.pad_token_id
      • processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) (beginning of image token ID)
      • The hardcoded ID 262144 (corresponding to the <imagesofttoken><image_soft_token>).
  1. Training Execution: The trl.SFTTrainer is instantiated with the model, training arguments, data collator, training dataset (dataset["train"] for single-image, dataset["test"] for multi-image), processor, and PEFT configuration. The trainer.train() method then initiates the fine-tuning process. After training, trainer.save_model() saves the fine-tuned model.

The process leverages Hugging Face ecosystem tools for efficient multimodal model fine-tuning, including transformers for model and processor handling, bitsandbytes for quantization, peft for LoRA, and trl for the SFT training loop. Results are automatically logged to tools like TensorBoard or Weights & Biases. The document also notes existing limitations specific to Gemma fine-tuning.