Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
핵심 포인트
- 1이 문서는 TRL 라이브러리를 사용하여 Gemma 3와 같은 멀티모달 언어 모델을 단일 또는 다중 이미지 데이터셋으로 SFT(Supervised Fine-Tuning)하는 과정을 안내합니다.
- 2QLoRA 및 BitsAndBytes를 활용하여 메모리 효율적인 학습 설정을 다루며, 특히 멀티모달 입력을 처리하기 위한 사용자 정의 `collate_fn` 구현 방법을 상세히 설명합니다.
- 3HuggingFaceH4/llava-instruct-mix-vsft와 FanqingM/MMIU-Benchmark 데이터셋을 예시로 들어, 환경 설정, 데이터 로딩 및 전처리, 모델 및 학습 인자 준비, 그리고 모델 학습 및 저장까지의 전체 워크플로우를 제시합니다.
이 문서는 TRL 라이브러리의 SFT(Supervised Fine-Tuning) 기능을 사용하여 멀티모달 모델, 특히 Gemma 3 모델을 파인튜닝하는 상세 가이드를 제공합니다. 이 가이드는 단일 이미지와 텍스트(Single Image + Text) 조합 및 다중 이미지와 텍스트(Multi-Image + Text) 조합의 두 가지 시나리오를 다룹니다.
1. 개요 및 목적
이 가이드의 주요 목적은 멀티모달 언어 모델을 파인튜닝하는 과정을 안내하는 것입니다. 기존의 VLM SFT 스크립트를 보완하며, 다른 Vision-Language Models (VLMs) 및 데이터셋에도 적용될 수 있는 일반적인 원칙들을 제시합니다.
2. 데이터셋 이해
두 가지 시나리오를 위해 각각 적합한 데이터셋이 활용됩니다:
HuggingFaceH4/llava-instruct-mix-vsft(Single Image + Text): LLaVA Instruct Mix 데이터셋을 재구성한 것으로, 사용자 대화에 단일 이미지와 텍스트가 포함됩니다. 모델(assistant)은 시각 및 텍스트 정보를 기반으로 응답을 생성하도록 훈련됩니다. SFT를 위해 이미 포맷되어 있어 직접 로드하여 사용합니다.FanqingM/MMIU-Benchmark(Multi-Image + Text): 이 데이터셋은Context(시스템 프롬프트),Question(사용자 입력), 여러 장의Series of Images, 그리고 모델의 예상 응답인Answer로 구성됩니다. 모델이 여러 이미지를 추론하여 정보를 기반으로 응답을 생성하는 데 적합하며, 로드 후 추가적인 전처리가 필요합니다.
3. 멀티모달 SFT를 위한 파인튜닝 스크립트 개발
파인튜닝을 위한 환경 설정, 데이터 로딩, 모델 및 학습 준비 과정은 다음과 같습니다.
- 환경 설정 (Setting Up the Environment):
- 필수 라이브러리(
trl,bitsandbytes,peft,hf_xet,tensorboard)를pip을 통해 설치합니다. - Gemma 3와 같은 gated 모델에 접근하기 위해 Hugging Face Hub에 로그인(
huggingface-cli login)하고, 접근 토큰이 필요합니다.
- 필수 라이브러리(
- 데이터 로딩 (Loading the Data):
- Single Image + Text:
datasets라이브러리의load_dataset함수를 사용하여HuggingFaceH4/llava-instruct-mix-vsft를 직접 로드합니다. - Multi-Image + Text:
FanqingM/MMIU-Benchmark데이터셋을 로드한 후 전처리가 필요합니다. 이 과정은 데이터셋 내의 압축된 이미지 파일을 추출하고,format_data함수를 사용하여 데이터를 시스템, 사용자, 어시스턴트 역할이 명확하게 구분된 대화형 구조로 변환합니다. 변환된 데이터 구조는 텍스트와 이미지 요소가 번갈아 나타나는 리스트 형식의messages필드를 가집니다 (예:[{ "role": "system", "content": [{ "type": "text", ... }] }, { "role": "user", "content": [images_list + { "type": "text", ... }] }, { "role": "assistant", "content": [{ "type": "text", ... }] }]).
- Single Image + Text:
- 학습 준비 (Preparing for Training):
- 모델 및 프로세서 로딩:
google/gemma-3-4b-it모델과AutoProcessor를 로드합니다.- 메모리 사용 최적화를 위해
BitsAndBytesConfig를 사용하여 4-bit 양자화(quantization)를 설정합니다. 주요 설정은 , , , , 입니다. - 모델은
AutoModelForImageTextToText.from_pretrained를 통해 , , 설정과 함께 로드됩니다. - 프로세서의 토크나이저
padding_side는"right"로 설정됩니다.
- 메모리 사용 최적화를 위해
- QLoRA 설정:
peft라이브러리의LoraConfig를 사용하여 QLoRA(Quantized Low-Rank Adaptation)를 구성합니다. 주요 파라미터는 , , , , , , 입니다. - 학습 인자 (Training Arguments):
trl라이브러리의SFTConfig를 사용하여 학습 인자를 정의합니다. 주요 인자는 다음과 같습니다:output_dir: 모델 저장 및 허브 업로드 경로.num_train_epochs: 학습 에포크 수 (예: 1).per_device_train_batch_size: 디바이스당 배치 크기 (단일 이미지: 8, 다중 이미지: 1).gradient_accumulation_steps: 그래디언트 누적 단계 (단일 이미지: 4, 다중 이미지: 1).- : 메모리 절약을 위한 그래디언트 체크포인팅 활성화 ( 설정).
- : AdamW 옵티마이저 사용.
- : 에포크마다 체크포인트 저장.
- .
- : bfloat16 정밀도 사용.
- : 파인튜닝된 모델을 Hugging Face Hub에 자동 푸시.
- : TensorBoard로 메트릭 보고.
- : 데이터셋 수동 전처리 활성화.
- : 콜레이터에서 사용되지 않는 컬럼 제거 방지.
collate_fn정의:collate_fn은 배치 처리를 위해 개별 샘플을 준비하는 역할을 합니다.- 텍스트에
chat_template을 적용합니다. - 프로세서는 텍스트와 이미지를 토크나이징하고 텐서로 인코딩합니다.
- 학습을 위한
labels는input_ids로 설정됩니다. - 손실 계산 시 특정 스페셜 토큰(예:
pad_token_id, , ID 262144에 해당하는 )은-100으로 마스킹되어 손실 계산에서 제외됩니다. - 단일 이미지의 경우 이미지 리스트를 직접 처리하고, 다중 이미지의 경우
process_vision_info함수를 사용하여 이미지 리스트의 리스트를 처리합니다.
- 텍스트에
- 모델 및 프로세서 로딩:
- 모델 학습 (Training the Model):
SFTTrainer를 인스턴스화하고, 이전에 정의된model,args,data_collator,train_dataset,processing_class,peft_config를 전달합니다.trainer.train()을 호출하여 학습을 시작하고, 학습 완료 후trainer.save_model()을 통해 파인튜닝된 모델을 저장합니다. TRL은 학습 결과를 Weights & Biases(Wandb) 또는 TensorBoard에 자동으로 로깅합니다.
4. 결과 및 한계
학습 중 및 학습 후의 결과는 Wandb 또는 TensorBoard를 통해 확인할 수 있습니다. 현재 Gemma 모델의 파인튜닝에는 알려진 몇 가지 한계가 있으므로, 이 가이드에 명시된 절차를 따르는 것이 최상의 결과를 보장합니다.