
mHC: Manifold-Constrained Hyper-Connections
핵심 포인트
- 1최근 Hyper-Connections(HC)는 성능 향상을 제공하지만, identity mapping 속성을 손상시켜 훈련 불안정성과 확장성 제약을 야기합니다.
- 2이러한 문제를 해결하기 위해, Manifold-Constrained Hyper-Connections(mHC)는 HC의 residual connection 공간을 특정 manifold에 투영하여 identity mapping 속성을 복원합니다.
- 3mHC는 kernel fusion 및 recomputing과 같은 엄격한 인프라 최적화를 통해 효율성을 확보하여, 대규모 훈련에서 뛰어난 성능과 안정성, 그리고 우수한 확장성을 입증했습니다.
"mHC: Manifold-Constrained Hyper-Connections" 논문은 Hyper-Connections (HC)의 장점을 유지하면서도 불안정성과 확장성 문제를 해결하기 위해 Manifold-Constrained Hyper-Connections (mHC)라는 새로운 프레임워크를 제안합니다.
1. HC의 한계점:
기존의 잔차 연결(residual connection)은 과 같이 표현되며, x_l이라는 항을 통해 identity mapping 속성을 유지하여 학습 안정성과 효율성을 제공했습니다. 하지만 최근 HC(Zhu et al., 2024)는 잔차 스트림(residual stream)의 width를 확장하고 연결 패턴을 다양화하여 성능 향상을 이뤘습니다. HC에서 단일 레이어 전파는 로 정의됩니다. 여기서 x_l과 의 특징 차원은 C에서 n x C로 확장되며, 는 잔차 스트림 내 특징을 혼합하는 학습 가능한 매핑, 는 nC-dim 스트림을 C-dim 레이어 입력으로 집계하는 매핑, 는 레이어 출력을 스트림으로 다시 매핑하는 역할을 합니다.
하지만 HC는 다음과 같은 문제를 야기합니다:
- 수치적 불안정성(Numerical Instability): 이 제약 없이 학습되면서, 여러 레이어를 거칠 때()
identity mapping속성이 손상됩니다. 이는 신호의 무한 증폭 또는 감쇠를 초래하여 학습 불안정성(unbounded signal amplification or attenuation)을 유발합니다. 특히 의 연속적인 적용은gradient explosion문제를 야기할 수 있습니다. - 시스템 오버헤드(System Overhead):
nC로 확장된 잔차 스트림은 상당한memory access(I/O) 비용을 발생시켜training throughput을 저하시키고GPU memory footprint를 증가시킵니다. 특히 의 중간 활성화값(intermediate activations)을 저장해야 하므로gradient checkpointing이 필요합니다.
2. mHC의 핵심 방법론:
mHC는 HC의 문제점을 해결하기 위해 다음 두 가지 핵심 아이디어를 제안합니다:
- Manifold Constraint: 잔차 매핑 을 특정 manifold에 제약하여
identity mapping속성을 복원하고 안정성을 확보합니다. 구체적으로, 을doubly stochastic matrix(이중 확률 행렬) 집합인Birkhoff polytope상에 투영합니다. 이는 모든 엔트리가 비음수이며, 각 행과 열의 합이 1인 행렬을 의미합니다.- Norm Preservation:
doubly stochastic matrix의spectral norm은 1로 제한됩니다(). 이는 학습 가능한 매핑이non-expansive하며gradient explosion문제를 효과적으로 완화합니다. - Compositional Closure:
doubly stochastic matrix의 집합은 행렬 곱셈에 대해 닫혀 있습니다. 즉, 여러 레이어를 통한 복합 매핑() 또한doubly stochastic하게 유지되어 모델의 전체 깊이에 걸쳐 안정성이 보존됩니다. - Geometric Interpretation:
Birkhoff polytope는 순열 행렬(permutation matrices) 집합의convex hull을 형성합니다. 이는 잔차 매핑이 순열의convex combination으로 작동하며, 스트림 간 정보 혼합(mixing of information)을 단조롭게 증가시켜 강력한 특징 융합(feature fusion) 메커니즘으로 기능함을 의미합니다.
- Norm Preservation:
- Non-negativity Constraints: 입력 매핑 및 출력 매핑 에 대해서도 비음수 제약을 부과합니다. 이는 양수 및 음수 계수의 조합으로 인해 발생할 수 있는 신호 상쇄(signal cancellation)를 방지합니다.
3. 매개변수화 및 Manifold 투영 (Parameterization and Manifold Projection):
mHC는 입력 를 벡터 로 평탄화(flatten)한 후, 동적 매핑(dynamic mappings)과 정적 매핑(static mappings)을 다음과 같이 계산합니다:
여기서 는 학습 가능한 스칼라 gating factor이며, 는 동적 매핑을 위한 선형 투영, b는 정적 매핑을 위한 학습 가능한 bias입니다. 은 에서 으로 재구성하는 함수입니다.
최종 제약된 매핑은 다음을 통해 얻어집니다:
여기서 는 Sigmoid 함수입니다. 연산자는 먼저 의 모든 요소를 exponent operator를 통해 양수로 만들고, 그 다음 행과 열의 합이 1이 되도록 교대로 rescale하는 반복 정규화 프로세스를 수행합니다. 이 과정은 doubly stochastic matrix로 수렴합니다.
4. 효율적인 인프라 설계 (Efficient Infrastructure Design):
mHC는 효율성을 위해 다음과 같은 인프라 최적화를 포함합니다:
- Kernel Fusion: 고차원 에 대한
RMSNorm의latency를 줄이기 위해dividing-by-norm연산의 순서를matrix multiplication뒤로 재배치합니다. 또한,mixed-precision전략을 사용하고, 여러 연산들을 공유된memory access를 가진 단일compute kernel로 융합하여memory bandwidth bottleneck을 줄입니다. 예를 들어, 에 대한 두 번의scan을 융합하고,backward pass의 두matrix multiplication을 단일kernel로 통합합니다. 가벼운 계수 연산(lightweight coefficient operations)도 단일kernel로 융합하여kernel launch overhead를 줄입니다.Sinkhorn-Knopp반복 또한 단일kernel내에서 구현됩니다. 이러한kernel구현에는TileLang프레임워크가 활용되었습니다. - Recomputing:
n-stream residual design으로 인한 상당한memory overhead를 완화하기 위해,forward pass후mHC kernel의 중간 활성화값(intermediate activations)을 버리고,backward pass에서 필요할 때on-the-fly로 다시 계산(recompute)합니다. 이는heavy layer function F를 제외한mHC kernel을 재실행함으로써 이루어집니다. - Overlapping Communication in DualPipe:
pipeline parallelism(Qi et al., 2024)에서 HC로 인해 발생하는 더 큰bubbles및 저하된training throughput을 개선하기 위해,DualPipe schedule(Liu et al., 2024b) 내에서communication을 신중하게overlap합니다.
5. 실험 결과:
언어 모델 사전 학습(language model pretraining)에 대한 광범위한 실험은 mHC가 HC의 성능 이점을 유지하면서 뛰어난 안정성(exceptional stability)과 확장성(scalability)을 보여줍니다. 의 확장률(expansion rate)을 가질 때, mHC는 대규모 학습을 지원하며 추가 training overhead는 단 6.7%에 불과했습니다.