Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
요약
collate_fn 구현 방법을 상세히 설명합니다.상세 내용
1. 개요 및 목적
이 가이드의 주요 목적은 멀티모달 언어 모델을 파인튜닝하는 과정을 안내하는 것입니다. 기존의 VLM SFT 스크립트를 보완하며, 다른 Vision-Language Models (VLMs) 및 데이터셋에도 적용될 수 있는 일반적인 원칙들을 제시합니다.
2. 데이터셋 이해
두 가지 시나리오를 위해 각각 적합한 데이터셋이 활용됩니다:
* HuggingFaceH4/llava-instruct-mix-vsft (Single Image + Text): LLaVA Instruct Mix 데이터셋을 재구성한 것으로, 사용자 대화에 단일 이미지와 텍스트가 포함됩니다. 모델(assistant)은 시각 및 텍스트 정보를 기반으로 응답을 생성하도록 훈련됩니다. SFT를 위해 이미 포맷되어 있어 직접 로드하여 사용합니다.
* FanqingM/MMIU-Benchmark (Multi-Image + Text): 이 데이터셋은 Context (시스템 프롬프트), Question (사용자 입력), 여러 장의 Series of Images, 그리고 모델의 예상 응답인 Answer로 구성됩니다. 모델이 여러 이미지를 추론하여 정보를 기반으로 응답을 생성하는 데 적합하며, 로드 후 추가적인 전처리가 필요합니다.
3. 멀티모달 SFT를 위한 파인튜닝 스크립트 개발
파인튜닝을 위한 환경 설정, 데이터 로딩, 모델 및 학습 준비 과정은 다음과 같습니다.
* 환경 설정 (Setting Up the Environment):
* 필수 라이브러리(trl, bitsandbytes, peft, hf_xet, tensorboard)를 pip을 통해 설치합니다.
* Gemma 3와 같은 gated 모델에 접근하기 위해 Hugging Face Hub에 로그인(huggingface-cli login)하고, 접근 토큰이 필요합니다.
* 데이터 로딩 (Loading the Data):
* Single Image + Text: datasets 라이브러리의 load_dataset 함수를 사용하여 HuggingFaceH4/llava-instruct-mix-vsft를 직접 로드합니다.
* Multi-Image + Text: FanqingM/MMIU-Benchmark 데이터셋을 로드한 후 전처리가 필요합니다. 이 과정은 데이터셋 내의 압축된 이미지 파일을 추출하고, format_data 함수를 사용하여 데이터를 시스템, 사용자, 어시스턴트 역할이 명확하게 구분된 대화형 구조로 변환합니다. 변환된 데이터 구조는 텍스트와 이미지 요소가 번갈아 나타나는 리스트 형식의 messages 필드를 가집니다 (예: [{ "role": "system", "content": [{ "type": "text", ... }] }, { "role": "user", "content": [images_list + { "type": "text", ... }] }, { "role": "assistant", "content": [{ "type": "text", ... }] }]).
* 학습 준비 (Preparing for Training):
* 모델 및 프로세서 로딩: google/gemma-3-4b-it 모델과 AutoProcessor를 로드합니다.
* 메모리 사용 최적화를 위해 BitsAndBytesConfig를 사용하여 4-bit 양자화(quantization)를 설정합니다. 주요 설정은 , , , , 입니다.
* 모델은 AutoModelForImageTextToText.from_pretrained를 통해 , , 설정과 함께 로드됩니다.
* 프로세서의 토크나이저 padding_side는 "right"로 설정됩니다.
* QLoRA 설정: peft 라이브러리의 LoraConfig를 사용하여 QLoRA(Quantized Low-Rank Adaptation)를 구성합니다. 주요 파라미터는 , , , , , , 입니다.
* 학습 인자 (Training Arguments): trl 라이브러리의 SFTConfig를 사용하여 학습 인자를 정의합니다. 주요 인자는 다음과 같습니다:
* output_dir: 모델 저장 및 허브 업로드 경로.
* num_train_epochs: 학습 에포크 수 (예: 1).
* per_device_train_batch_size: 디바이스당 배치 크기 (단일 이미지: 8, 다중 이미지: 1).
* gradient_accumulation_steps: 그래디언트 누적 단계 (단일 이미지: 4, 다중 이미지: 1).
* : 메모리 절약을 위한 그래디언트 체크포인팅 활성화 ( 설정).
* : AdamW 옵티마이저 사용.
* : 에포크마다 체크포인트 저장.
* .
* : bfloat16 정밀도 사용.
* : 파인튜닝된 모델을 Hugging Face Hub에 자동 푸시.
* : TensorBoard로 메트릭 보고.
* : 데이터셋 수동 전처리 활성화.
* : 콜레이터에서 사용되지 않는 컬럼 제거 방지.
* collate_fn 정의: collate_fn은 배치 처리를 위해 개별 샘플을 준비하는 역할을 합니다.
* 텍스트에 chat_template을 적용합니다.
* 프로세서는 텍스트와 이미지를 토크나이징하고 텐서로 인코딩합니다.
* 학습을 위한 labels는 input_ids로 설정됩니다.
* 손실 계산 시 특정 스페셜 토큰(예: pad_token_id, , ID 262144에 해당하는 )은 -100으로 마스킹되어 손실 계산에서 제외됩니다.
* 단일 이미지의 경우 이미지 리스트를 직접 처리하고, 다중 이미지의 경우 process_vision_info 함수를 사용하여 이미지 리스트의 리스트를 처리합니다.
* 모델 학습 (Training the Model):
* SFTTrainer를 인스턴스화하고, 이전에 정의된 model, args, data_collator, train_dataset, processing_class, peft_config를 전달합니다.
* trainer.train()을 호출하여 학습을 시작하고, 학습 완료 후 trainer.save_model()을 통해 파인튜닝된 모델을 저장합니다. TRL은 학습 결과를 Weights & Biases(Wandb) 또는 TensorBoard에 자동으로 로깅합니다.
4. 결과 및 한계
학습 중 및 학습 후의 결과는 Wandb 또는 TensorBoard를 통해 확인할 수 있습니다. 현재 Gemma 모델의 파인튜닝에는 알려진 몇 가지 한계가 있으므로, 이 가이드에 명시된 절차를 따르는 것이 최상의 결과를 보장합니다.