목록으로
Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
Blog2025.04.20

Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)

요약

이 문서는 TRL 라이브러리를 사용하여 Gemma 3와 같은 멀티모달 언어 모델을 단일 또는 다중 이미지 데이터셋으로 SFT(Supervised Fine-Tuning)하는 과정을 안내합니다.
️ QLoRA 및 BitsAndBytes를 활용하여 메모리 효율적인 학습 설정을 다루며, 특히 멀티모달 입력을 처리하기 위한 사용자 정의 collate_fn 구현 방법을 상세히 설명합니다.
HuggingFaceH4/llava-instruct-mix-vsft와 FanqingM/MMIU-Benchmark 데이터셋을 예시로 들어, 환경 설정, 데이터 로딩 및 전처리, 모델 및 학습 인자 준비, 그리고 모델 학습 및 저장까지의 전체 워크플로우를 제시합니다.

상세 내용

이 문서는 TRL 라이브러리의 SFT(Supervised Fine-Tuning) 기능을 사용하여 멀티모달 모델, 특히 Gemma 3 모델을 파인튜닝하는 상세 가이드를 제공합니다. 이 가이드는 단일 이미지와 텍스트(Single Image + Text) 조합 및 다중 이미지와 텍스트(Multi-Image + Text) 조합의 두 가지 시나리오를 다룹니다.

1. 개요 및 목적
이 가이드의 주요 목적은 멀티모달 언어 모델을 파인튜닝하는 과정을 안내하는 것입니다. 기존의 VLM SFT 스크립트를 보완하며, 다른 Vision-Language Models (VLMs) 및 데이터셋에도 적용될 수 있는 일반적인 원칙들을 제시합니다.

2. 데이터셋 이해
두 가지 시나리오를 위해 각각 적합한 데이터셋이 활용됩니다:
* HuggingFaceH4/llava-instruct-mix-vsft (Single Image + Text): LLaVA Instruct Mix 데이터셋을 재구성한 것으로, 사용자 대화에 단일 이미지와 텍스트가 포함됩니다. 모델(assistant)은 시각 및 텍스트 정보를 기반으로 응답을 생성하도록 훈련됩니다. SFT를 위해 이미 포맷되어 있어 직접 로드하여 사용합니다.
* FanqingM/MMIU-Benchmark (Multi-Image + Text): 이 데이터셋은 Context (시스템 프롬프트), Question (사용자 입력), 여러 장의 Series of Images, 그리고 모델의 예상 응답인 Answer로 구성됩니다. 모델이 여러 이미지를 추론하여 정보를 기반으로 응답을 생성하는 데 적합하며, 로드 후 추가적인 전처리가 필요합니다.

3. 멀티모달 SFT를 위한 파인튜닝 스크립트 개발
파인튜닝을 위한 환경 설정, 데이터 로딩, 모델 및 학습 준비 과정은 다음과 같습니다.

* 환경 설정 (Setting Up the Environment):
* 필수 라이브러리(trl, bitsandbytes, peft, hf_xet, tensorboard)를 pip을 통해 설치합니다.
* Gemma 3와 같은 gated 모델에 접근하기 위해 Hugging Face Hub에 로그인(huggingface-cli login)하고, 접근 토큰이 필요합니다.

* 데이터 로딩 (Loading the Data):
* Single Image + Text: datasets 라이브러리의 load_dataset 함수를 사용하여 HuggingFaceH4/llava-instruct-mix-vsft를 직접 로드합니다.
* Multi-Image + Text: FanqingM/MMIU-Benchmark 데이터셋을 로드한 후 전처리가 필요합니다. 이 과정은 데이터셋 내의 압축된 이미지 파일을 추출하고, format_data 함수를 사용하여 데이터를 시스템, 사용자, 어시스턴트 역할이 명확하게 구분된 대화형 구조로 변환합니다. 변환된 데이터 구조는 텍스트와 이미지 요소가 번갈아 나타나는 리스트 형식의 messages 필드를 가집니다 (예: [{ "role": "system", "content": [{ "type": "text", ... }] }, { "role": "user", "content": [images_list + { "type": "text", ... }] }, { "role": "assistant", "content": [{ "type": "text", ... }] }]).

* 학습 준비 (Preparing for Training):
* 모델 및 프로세서 로딩: google/gemma-3-4b-it 모델과 AutoProcessor를 로드합니다.
* 메모리 사용 최적화를 위해 BitsAndBytesConfig를 사용하여 4-bit 양자화(quantization)를 설정합니다. 주요 설정은 loadin4bit=Trueload_in_4bit=True, bnb4bitusedoublequant=Truebnb_4bit_use_double_quant=True, bnb4bitquanttype="nf4"bnb_4bit_quant_type="nf4", bnb4bitcomputedtype=torch.bfloat16bnb_4bit_compute_dtype=torch.bfloat16, bnb4bitquantstorage=torch.bfloat16bnb_4bit_quant_storage=torch.bfloat16 입니다.
* 모델은 AutoModelForImageTextToText.from_pretrained를 통해 devicemap="auto"device_map="auto", torchdtype=torch.bfloat16torch_dtype=torch.bfloat16, attnimplementation="eager"attn_implementation="eager" 설정과 함께 로드됩니다.
* 프로세서의 토크나이저 padding_side"right"로 설정됩니다.
* QLoRA 설정: peft 라이브러리의 LoraConfig를 사용하여 QLoRA(Quantized Low-Rank Adaptation)를 구성합니다. 주요 파라미터는 loraalpha=16lora_alpha=16, loradropout=0.05lora_dropout=0.05, r=16r=16, bias="none"bias="none", targetmodules="alllinear"target_modules="all-linear", tasktype="CAUSALLM"task_type="CAUSAL_LM", modulestosave=["lmhead","embedtokens"]modules_to_save=["lm_head", "embed_tokens"] 입니다.
* 학습 인자 (Training Arguments): trl 라이브러리의 SFTConfig를 사용하여 학습 인자를 정의합니다. 주요 인자는 다음과 같습니다:
* output_dir: 모델 저장 및 허브 업로드 경로.
* num_train_epochs: 학습 에포크 수 (예: 1).
* per_device_train_batch_size: 디바이스당 배치 크기 (단일 이미지: 8, 다중 이미지: 1).
* gradient_accumulation_steps: 그래디언트 누적 단계 (단일 이미지: 4, 다중 이미지: 1).
* gradientcheckpointing=Truegradient_checkpointing=True: 메모리 절약을 위한 그래디언트 체크포인팅 활성화 (usereentrant=Falseuse_reentrant=False 설정).
* optim="adamwtorchfused"optim="adamw_torch_fused": AdamW 옵티마이저 사용.
* savestrategy="epoch"save_strategy="epoch": 에포크마다 체크포인트 저장.
* learningrate=2e05learning_rate=2e-05.
* bf16=Truebf16=True: bfloat16 정밀도 사용.
* pushtohub=Truepush_to_hub=True: 파인튜닝된 모델을 Hugging Face Hub에 자동 푸시.
* reportto="tensorboard"report_to="tensorboard": TensorBoard로 메트릭 보고.
* datasetkwargs="skippreparedataset":Truedataset_kwargs={"skip_prepare_dataset": True}: 데이터셋 수동 전처리 활성화.
* removeunusedcolumns=Falseremove_unused_columns=False: 콜레이터에서 사용되지 않는 컬럼 제거 방지.
* collate_fn 정의: collate_fn은 배치 처리를 위해 개별 샘플을 준비하는 역할을 합니다.
* 텍스트에 chat_template을 적용합니다.
* 프로세서는 텍스트와 이미지를 토크나이징하고 텐서로 인코딩합니다.
* 학습을 위한 labelsinput_ids로 설정됩니다.
* 손실 계산 시 특정 스페셜 토큰(예: pad_token_id, <imagetokenid><image_token_id>, ID 262144에 해당하는 <imagesofttoken><image_soft_token>)은 -100으로 마스킹되어 손실 계산에서 제외됩니다.
* 단일 이미지의 경우 이미지 리스트를 직접 처리하고, 다중 이미지의 경우 process_vision_info 함수를 사용하여 이미지 리스트의 리스트를 처리합니다.

* 모델 학습 (Training the Model):
* SFTTrainer를 인스턴스화하고, 이전에 정의된 model, args, data_collator, train_dataset, processing_class, peft_config를 전달합니다.
* trainer.train()을 호출하여 학습을 시작하고, 학습 완료 후 trainer.save_model()을 통해 파인튜닝된 모델을 저장합니다. TRL은 학습 결과를 Weights & Biases(Wandb) 또는 TensorBoard에 자동으로 로깅합니다.

4. 결과 및 한계
학습 중 및 학습 후의 결과는 Wandb 또는 TensorBoard를 통해 확인할 수 있습니다. 현재 Gemma 모델의 파인튜닝에는 알려진 몇 가지 한계가 있으므로, 이 가이드에 명시된 절차를 따르는 것이 최상의 결과를 보장합니다.

원본 보기
Hugging Face
Shared by Anonymous