목록으로
Paper2026.01.10

mHC: Manifold-Constrained Hyper-Connections

요약

최근 Hyper-Connections(HC)는 성능 향상을 제공하지만, identity mapping 속성을 손상시켜 훈련 불안정성과 확장성 제약을 야기합니다.
이러한 문제를 해결하기 위해, Manifold-Constrained Hyper-Connections(mHC)는 HC의 residual connection 공간을 특정 manifold에 투영하여 identity mapping 속성을 복원합니다.
mHC는 kernel fusion 및 recomputing과 같은 엄격한 인프라 최적화를 통해 효율성을 확보하여, 대규모 훈련에서 뛰어난 성능과 안정성, 그리고 우수한 확장성을 입증했습니다.

상세 내용

"mHC: Manifold-Constrained Hyper-Connections" 논문은 Hyper-Connections (HC)의 장점을 유지하면서도 불안정성과 확장성 문제를 해결하기 위해 Manifold-Constrained Hyper-Connections (mHC)라는 새로운 프레임워크를 제안합니다.

1. HC의 한계점:
기존의 잔차 연결(residual connection)은 xl+1=xl+F(xl,Wl)x_{l+1} = x_l + F(x_l, W_l)과 같이 표현되며, x_l이라는 항을 통해 identity mapping 속성을 유지하여 학습 안정성과 효율성을 제공했습니다. 하지만 최근 HC(Zhu et al., 2024)는 잔차 스트림(residual stream)의 width를 확장하고 연결 패턴을 다양화하여 성능 향상을 이뤘습니다. HC에서 단일 레이어 전파는 xl+1=Hlresxl+(Hlpost)TF(Hlprexl,Wl)x_{l+1} = H^{res}_l x_l + (H^{post}_l)^T F(H^{pre}_l x_l, W_l)로 정의됩니다. 여기서 x_lxl+1x_{l+1}의 특징 차원은 C에서 n x C로 확장되며, HlresRn×nH^{res}_l \in \mathbb{R}^{n \times n}는 잔차 스트림 내 특징을 혼합하는 학습 가능한 매핑, HlpreR1×nH^{pre}_l \in \mathbb{R}^{1 \times n}nC-dim 스트림을 C-dim 레이어 입력으로 집계하는 매핑, HlpostR1×nH^{post}_l \in \mathbb{R}^{1 \times n}는 레이어 출력을 스트림으로 다시 매핑하는 역할을 합니다.

하지만 HC는 다음과 같은 문제를 야기합니다:
* 수치적 불안정성(Numerical Instability): HlresH^{res}_l이 제약 없이 학습되면서, 여러 레이어를 거칠 때(i=lL1Hires\prod_{i=l}^{L-1} H^{res}_i) identity mapping 속성이 손상됩니다. 이는 신호의 무한 증폭 또는 감쇠를 초래하여 학습 불안정성(unbounded signal amplification or attenuation)을 유발합니다. 특히 HlresH^{res}_l의 연속적인 적용은 gradient explosion 문제를 야기할 수 있습니다.
* 시스템 오버헤드(System Overhead): nC로 확장된 잔차 스트림은 상당한 memory access (I/O) 비용을 발생시켜 training throughput을 저하시키고 GPU memory footprint를 증가시킵니다. 특히 Hlpre,Hlpost,HlresH^{pre}_l, H^{post}_l, H^{res}_l의 중간 활성화값(intermediate activations)을 저장해야 하므로 gradient checkpointing이 필요합니다.

2. mHC의 핵심 방법론:
mHC는 HC의 문제점을 해결하기 위해 다음 두 가지 핵심 아이디어를 제안합니다:
* Manifold Constraint: 잔차 매핑 HlresH^{res}_l을 특정 manifold에 제약하여 identity mapping 속성을 복원하고 안정성을 확보합니다. 구체적으로, HlresH^{res}_ldoubly stochastic matrix (이중 확률 행렬) 집합인 Birkhoff polytope 상에 투영합니다. 이는 모든 엔트리가 비음수이며, 각 행과 열의 합이 1인 행렬을 의미합니다.
* Norm Preservation: doubly stochastic matrixspectral norm은 1로 제한됩니다(Hlres21||H^{res}_l||_2 \le 1). 이는 학습 가능한 매핑이 non-expansive하며 gradient explosion 문제를 효과적으로 완화합니다.
* Compositional Closure: doubly stochastic matrix의 집합은 행렬 곱셈에 대해 닫혀 있습니다. 즉, 여러 레이어를 통한 복합 매핑(Hres\prod H^{res}) 또한 doubly stochastic하게 유지되어 모델의 전체 깊이에 걸쳐 안정성이 보존됩니다.
* Geometric Interpretation: Birkhoff polytope는 순열 행렬(permutation matrices) 집합의 convex hull을 형성합니다. 이는 잔차 매핑이 순열의 convex combination으로 작동하며, 스트림 간 정보 혼합(mixing of information)을 단조롭게 증가시켜 강력한 특징 융합(feature fusion) 메커니즘으로 기능함을 의미합니다.
* Non-negativity Constraints: 입력 매핑 HlpreH^{pre}_l 및 출력 매핑 HlpostH^{post}_l에 대해서도 비음수 제약을 부과합니다. 이는 양수 및 음수 계수의 조합으로 인해 발생할 수 있는 신호 상쇄(signal cancellation)를 방지합니다.

3. 매개변수화 및 Manifold 투영 (Parameterization and Manifold Projection):
mHC는 입력 xlRn×Cx_l \in \mathbb{R}^{n \times C}를 벡터 x^l=vec(xl)R1×nC\hat{x}_l = \text{vec}(x_l) \in \mathbb{R}^{1 \times nC}로 평탄화(flatten)한 후, 동적 매핑(dynamic mappings)과 정적 매핑(static mappings)을 다음과 같이 계산합니다:
x^l=RMSNorm(x^l)\hat{x}'_l = \text{RMSNorm}(\hat{x}_l)
H~lpre=αlpre(x^lφlpre)+blpre\tilde{H}^{pre}_l = \alpha^{pre}_l \cdot (\hat{x}'_l \varphi^{pre}_l) + b^{pre}_l
H~lpost=αlpost(x^lφlpost)+blpost\tilde{H}^{post}_l = \alpha^{post}_l \cdot (\hat{x}'_l \varphi^{post}_l) + b^{post}_l
H~lres=αlresmat(x^lφlres)+blres\tilde{H}^{res}_l = \alpha^{res}_l \cdot \text{mat}(\hat{x}'_l \varphi^{res}_l) + b^{res}_l
여기서 α\alpha는 학습 가능한 스칼라 gating factor이며, φ\varphi는 동적 매핑을 위한 선형 투영, b는 정적 매핑을 위한 학습 가능한 bias입니다. mat()mat(\cdot)R1×n2\mathbb{R}^{1 \times n^2}에서 Rn×n\mathbb{R}^{n \times n}으로 재구성하는 함수입니다.

최종 제약된 매핑은 다음을 통해 얻어집니다:
Hlpre=σ(H~lpre)H^{pre}_l = \sigma(\tilde{H}^{pre}_l)
Hlpost=2σ(H~lpost)H^{post}_l = 2\sigma(\tilde{H}^{post}_l)
Hlres=Sinkhorn-Knopp(exp(H~lres))H^{res}_l = \text{Sinkhorn-Knopp}(\text{exp}(\tilde{H}^{res}_l))
여기서 σ()\sigma(\cdot)는 Sigmoid 함수입니다. SinkhornKnopp()Sinkhorn-Knopp(\cdot) 연산자는 먼저 H~lres\tilde{H}^{res}_l의 모든 요소를 exponent operator를 통해 양수로 만들고, 그 다음 행과 열의 합이 1이 되도록 교대로 rescale하는 반복 정규화 프로세스를 수행합니다. 이 과정은 doubly stochastic matrix로 수렴합니다.

4. 효율적인 인프라 설계 (Efficient Infrastructure Design):
mHC는 효율성을 위해 다음과 같은 인프라 최적화를 포함합니다:
* Kernel Fusion: 고차원 hiddenstatex^lhidden state \hat{x}_l에 대한 RMSNormlatency를 줄이기 위해 dividing-by-norm 연산의 순서를 matrix multiplication 뒤로 재배치합니다. 또한, mixed-precision 전략을 사용하고, 여러 연산들을 공유된 memory access를 가진 단일 compute kernel로 융합하여 memory bandwidth bottleneck을 줄입니다. 예를 들어, x^l\hat{x}_l에 대한 두 번의 scan을 융합하고, backward pass의 두 matrix multiplication을 단일 kernel로 통합합니다. 가벼운 계수 연산(lightweight coefficient operations)도 단일 kernel로 융합하여 kernel launch overhead를 줄입니다. Sinkhorn-Knopp 반복 또한 단일 kernel 내에서 구현됩니다. 이러한 kernel 구현에는 TileLang 프레임워크가 활용되었습니다.
* Recomputing: n-stream residual design으로 인한 상당한 memory overhead를 완화하기 위해, forward passmHC kernel의 중간 활성화값(intermediate activations)을 버리고, backward pass에서 필요할 때 on-the-fly로 다시 계산(recompute)합니다. 이는 heavy layer function F를 제외한 mHC kernel을 재실행함으로써 이루어집니다.
* Overlapping Communication in DualPipe: pipeline parallelism (Qi et al., 2024)에서 HC로 인해 발생하는 더 큰 bubbles 및 저하된 training throughput을 개선하기 위해, DualPipe schedule (Liu et al., 2024b) 내에서 communication을 신중하게 overlap합니다.

5. 실험 결과:
언어 모델 사전 학습(language model pretraining)에 대한 광범위한 실험은 mHC가 HC의 성능 이점을 유지하면서 뛰어난 안정성(exceptional stability)과 확장성(scalability)을 보여줍니다. n=4n=4의 확장률(expansion rate)을 가질 때, mHC는 대규모 학습을 지원하며 추가 training overhead는 단 6.7%에 불과했습니다.

원본 보기
Arxiv
Shared by Anonymous