본문 바로가기

자연어처리/LM

Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models

■ recurrent neural network (RNN)은 inference 속도가 빠르고 long sequence에서도 효율적으로 확장할 수 있지만, training이 어렵고(예: 기울기 소실/폭주 문제 등) scale을 확장하기 힘든 단점이 있다.  

■ 논문에서는 "gated linear recurrences"를 사용하는 RNN인 Hawk와, "gated linear recurrences"와 "local attention"을 혼합한 hybrid model인 Griffin을 제안한다. 

■ Hawk는 downstream tasks에서 Mamba의 reported된 성능을 능가하며, Griffin은 Llama-2보다 6배 이상 적은 tokens로 학습되었음에도 불구하고 대등한 성능을 보인다.  

■ 그리고 Griffin은 학습 중에 본 것보다 훨씬 긴 sequence에 대해서도 extrapolate할 수 있음을 보여준다. 

[2402.19427] Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models

 

Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models

Recurrent neural networks (RNNs) have fast inference and scale efficiently on long sequences, but they are difficult to train and hard to scale. We propose Hawk, an RNN with gated linear recurrences, and Griffin, a hybrid model that mixes gated linear recu

arxiv.org

 

1. Introduction

■ RNN은 병렬 처리가 어렵고 기울기 소실 문제 등으로 scale을 확장시켜 학습시키기 어렵다는 고질적인 문제가 있어, MLP와 MHA를 교차로 배치하는 Transformer 아키텍처가 딥러닝과 NLP 분야를 지배해 왔다.  

■ Transformer는 RNN보다 더 나은 성능을 달성하며, 병렬 처리가 가능하기 때문에 현대 하드웨어를 활용하는 데에도 매우 효율적이다. web에서 수집된 massive datasets으로 학습된 Transformer-based LLM들은 놀라운 성능을 보여줬다.  

■ 그러나, Transformer는 global attention의 quadratic complexity때문에 long sequences로 확장하기 어렵다는 한계가 있다. 게다가, sequence length에 따라 선형적으로 증가하는 KV cache는 inference time에서 Transformer를 더 느리게 만든다. 

■ Multi-Query Attention (MQA)이 cache 크기를 constant factor만큼 줄여 이 문제를 완화하지만, cache는 여전히 sequence length에 따라 선형적으로 증가한다.  

■ 반면, RNN은 sequence를 반복적으로 업데이트되는 고정된 크기의 hidden state로 압축하기 때문에, sequence length가 길어져도 메모리 사용량이나 inference 비용이 증가하지 않는다.  
- 이런 장점이 있지만, sequence length에 상관없이 고정된 크기의 hidden state로 압축하기 때문에, long sequence에서는 정보 손실 문제가 있다.  
- RNN은 시간에 따라 정보를 전달하는 구조이기 때문에, 특히 long sequence를 다룰수록 backpropagation 과정에서 기울기 소실이 발생하기 쉽고, 그 결과 long-term dependency를 충분히 포착하지 못하는 한계가 있다.  

■ 그래서 Transformer를 RNN으로 대체하기 위해서는, 대규모에서도 대등한 성능을 보여줄 뿐만 아니라 유사한 하드웨어 효율성을 달성해야 한다.  

■ 저자들은 새로운 gated linear recurrent layer인 "RG-LRU layer"를 제안하고, 이를 중심으로 MQA를 대체할 새로운 recurrent block을 설계한다.  

■ 이 recurrent block을 사용하여 두 가지 새로운 모델을 구축한다: MLPs과 recurrent blocks을 교차 배치한 모델인 Hawk, MLPs과 recurrent blocks 및 local attention(Longformer)의 결합을 교차 배치한 hybrid model인 Griffin 

■ 실험을 통해 다음과 같은 결과들을 보여준다.
- (1) Hawk와 Griffin은 7B 파라미터 이상까지도 held-out loss와 training FLOPs 사이에 power law scaling을 보이며, 이는 OpenAI scaling laws에서 관찰된 것과 동일하다. 
- (2) Griffin은 모든 모델 scales에서 강력한 Transformer baselines보다 약간 더 낮은 held-out loss를 달성한다. 
- (3) 다양한 모델 scales에서 300B tokens으로 Hawk와 Griffin을 overtrain시킨다. Hawk-3B는 절반가량의 tokens로 학습되었음에도 불구하고 downstream tasks에서 Mamba-3B의 성능을 능가한다.  
- Griffin-7B와 Griffin-14B는 약 7배 더 적은 tokens(300B tokens)로 학습되었음에도 2T tokens로 학습된 Llama-2와 대등한 성능을 보인다.  
- (4) diagonal RNN layers은 memory bound이므로, Pallas로 구현된 RG-LRU layer용 kernel을 통해 메모리 전송을 최소화했다. 이를 통해 Hawk와 Griffin 모두 TPU-v3에서 Transformer와 견줄 말한 training efficiency를 달성한다. 
- (5) inference 시, Hawk와 Griffin 모두 MQA Transformer보다 훨씬 높은 throughput을 달성(Fig 1 (b))하며, long sequences을 샘플링할 때 더 낮은 latency를 달성한다. 
- (6) Griffin은 training에서 본 것보다 더 긴 sequences에서도 Transformer보다 더 나은 성능을 보이며, training data로부터 copying 및 retrieval tasks을 효율적으로 학습할 수도 있다.  
- 그러나 fine-tuning 없이 pre-trained models을 copying 및 exact-retrieval tasks에 대해 평가할 때는 Hawk와 Griffin은 Transformer보다 성능이 떨어진다.   

[출처] https://arxiv.org/abs/2402.19427



2. Model Architecture

■ 논문에서 제안하는 모든 models은 다음 구성 요소를 포함한다: (1) 잔차 블록(residual block) (2) MLP block (3) temporal-mixing block 

■ (1)과 (2)는 모든 models에서 동일하지만, (3)의 경우 세 가지의 temporal-mixing blocks을 고려한다: global Multi-Query Attention, local (sliding-window) MQA, 그리고 저자들이 제안하는 recurrent block 

■ 그리고 recurrent block의 일부로서, Linear Recurrent Unit에서 영감을 받은 새로운 recurrent layer인  Real-Gated Linear Recurrent Unit (RG-LRU)를 사용한다.  

■ Fig 2 (a)에서 볼 수 있듯이, residual block은 models의 전체적인 구조를 정의하는데 사용된다.

■ input sequence를 embedding한 후, 이를 \( N \)개의 blocks(여기서 \( N \)은 model의 depth)에 통과시키고, final activations 값을 생성하기 위해 RMSNorm을 사용한다.   

■ 그리고 token probabilities을 계산하기 위해 linear layer와 softmax를 사용한다. output layer의 가중치는 input embedding layer와 공유된다 (weight tying).  

[출처] https://arxiv.org/abs/2402.19427


2.1 Residual block

■ residual block에서는 두 가지 구성 요소(sub-blocks)가 순서대로 적용된다. (Fig 2 (a)) 
- (1) 첫 번째 구성 요소: 먼저 hidden state \( x \)를 받아 RMSNorm을 적용한 뒤, temporal-mixing block을 통과시킨다. 그런 다음, 그 output을 \( x \)로부터 온 skip connection과 덧셈을 통해 병합한다. 
- \( x = x + \text{Temporal-mixingBlock}(\text{RMSNorm}(x)) \) 
- (2) 두 번째 구성 요소: 첫 번째 구성 요소와 비슷하게 RMSNorm을 적용한 뒤 MLP block에 통과시키고, 그 output을 RMSNorm의 입력으로부터 온 skip connection과 덧셈을 통해 병합한다.  
- \( x = x + \text{MLPBlock}(\text{RMSNorm}(x)) \) 


2.2 MLP block

■ gated MLP block을 사용하며(Fig 2 (b)), 이는 dimension \( D \)인 input으로부터 두 개의 branch를 생성한다. 

■ 각 branch에 output dimension \( MD \)를 가진 linear layer를 적용하는데, 여기서 \( M \)은 expansion factor이다. 이 \( M \)을 통해 내부 차원을 키운다. 논문에서는 \( M = 3 \)을 사용한다. 

■ GeGeLU와 유사하게, 두 가지를 element-wise multiplication으로 병합하기 전에, 한쪽 branch에 GeLU로 non-linearity를 적용한다.  

■ dimension을 다시 원래 크기 \( D \)로 줄이기 위해, GeGeLU layer의 otputs에 dimension \( D \)를 가진 linear layer를 적용한다.  


2.3 Temporal-mixing blocks

■ temporal-mixing block은 sequence 내의 서로 다른 시간(위치)에 있는 정보들(hidden layer의 activations)을 합치는 역할을 한다.  

■ 논문에서는 세 가지 temporal-mixing blocks을 고려한다: global MQA, local MQA, 그리고 저자들이 제안하는 Recurrent block 

Global multi-query attention

■ 별도의 언급이 없는 한, 논문의 실험에서 Transformer baselines의 inference speed를 높이기 위해 MHA 대신 MQA를 사용한다.  

■ 고정된 head dimension \( D_{head}=128 \)을 사용하며, \( HD_{head} = D \)가 되도록 attention heads의 수 \( H \)를 고정한다.  
- 이를 위해 model dimension \( D \)는 128의 배수여야 한다. 

■ absolute positional embeddings을 사용하지 않고, 대신 Rotary Position Embedding (RoPE)을 relative positional embedding으로 사용한다.  

Local sliding window attention

■ global attention의 주요 단점 중 하나는 computational complexity가 sequence length에 따라 quadratically하게 증가한다는 것이다.  
- sequence length가 \( n \)이라고 할 때, standard self-attention으로 attention score를 계산하기 위해 \( n \times n \)을 해야 한다.  
- 그래서 sequence length가 증가하면 computational complexity가 \( n^2 \)에 비례하여 늘어난다. 예를 들어, sequence length를 2배 늘리면(\( 2n \)), 계산량은 4배(\( 4n^2 \))가 된다.  

■ 이를 해결하기 위해 여러 연구들에서 sliding window attention으로도 알려진 local attention(Longformer)을 채택하기 시작했다.  

■ 이는 각 위치가 고정된 수(window size)의 과거 시점 tokens에만 attend하도록 한다. 이는 computational FLOPs를 줄일 뿐만 아니라, KV cache의 크기를 window size로 제한하여, sequence length에 대해 더 이상 quadratic으로 증가하지 않게 만든다. 

■ 다른 details은 위의 global MQA와 동일하다.  

Recurrent block

■ recurrent block (Fig 2 (c))은 GSS block 및 Mamba에서 사용된 block과 유사하다. 

■ dimension \( D \)인 input을 받아 병렬로 output dimension \( D_{RNN} \)을 가진 두 개의 linear layer를 적용하여, 두 개의 branch를 생성한다. 

■ 첫 번째 branch에서는 temporal filter dimension이 4인 small separable Conv1D layer를 적용한다. 이 Conv1D layer는 단지 \( 4D_{RNN} \)개의 파라미터를 가진 매우 작은 layer이다.  

■ Conv1D layer 다음에 RG-LRU layer를 배치한다. 

■ 두 번째 branch에서는 GeLU nonlinearity를 적용한 후, element-wise multiplication으로 두 branch를 다시 병합한다. 그런 다음 output dimension \( D \)를 가진 linear layer를 적용한다.  


2.4 Real-Gated Linear Recurrent Unit (RG-LRU)

■ RG-LRU layer는 재귀 구조를 가지고 있으며, LSTM과 GRU의 gating mechanism을 통합한 layer이다. 이 layer를 설명하는 식은 다음과 같다.  

■ layer의 output은 \( y_t = h_t \)이며, 방정식 내의 non-linearity function \( \sigma \)는 sigmoid function이다. 

■ 첫 번째 식은 input \( x_t \)를 얼마나 오래 유지할지(즉, 얼마만큼 recurrence할지)를 조절하며, 두 번째 식은 LSTM/GRU의 input gate처럼 input \( x_t \)를 얼마나 받아들일지 조절한다. 식 (3)은 recurrent weight이다. \( r_t \) 값으로 조절하는 것을 볼 수 있다.   

■ 식 (4)의 recurrent weight \( a \)는 diagonal이다. 따라서 모든 연산은 element-wise로 이루어진다. 

■ 식 (3)의 \( a \)는 \( a = \sigma(\Lambda) \)로 파라미터화하며, \( 0 \leq a \leq 1 \) 범위를 유지하여 recurrence를 안정적이게 만든다. 여기서 \( \Lambda \)는 학습 가능한 파라미터이다.  

■ 변수 \( c \)는 8로 설정된 스칼라 상수이다. 수치적 안정성을 위해, 실제로는 \( a^{cr_t} \)를 log-space에서 계산한다. (Appendix A) 

■ LSTM이나 GRU는 게이트를 계산할 때 이전 시점의 hidden state \( h_{t-1} \)이 필요하지만, 이 레이어는 어떤 게이트에서도 \( h_{t-1} \)을 사용하지 않기 때문에 계산을 효율적으로 실행할 수 있다.  

■ \( W_a \)와 \( W_x \)는 LeCun 초기화 방식을 사용하여 초기화한다: 평균이 0이고 분산이 \( \dfrac{1}{n_{in} \)인 분포에서 가중치를 초기화한다. \( w \sim N(0, \dfrac{1}{n_{in}}) \), 여기서 \( n_{in} \)은 입력 뉴런의 개수이다.  

■ \( \Lambda \)는 training 시작 시, \( a^c \)가 0.9와 0.999 사이에 균일하게 분포되도록 초기화한다. 이렇게 \( a \)를 1에 가깝게 초기화하는 이유는, 식 (4)를 보면 알 수 있듯이 모델이 과거 정보를 아주 오랫동안 기억하도록 유도하기 위함이다.  

■ RG-LRU는 LRU layer에서 영감을 받은 것이지만, original LRU layer와 달리 recurrence에서 complex algebra를 사용하지 않는다.  

■ complex recurrences을 사용하면 layer의 표현력(주어진 데이터 분포를 정확하고 복잡하게 모델링)이 더 높아질 수 있지만, 실제 language modelling에서는 complex recurrences가 이점을 주지 않는다는 것을 발견했으며, 이는 이전 연구에서도 관찰된 바이다.  

Gate behaviour

■ input gate \( i_t \)는 input \( x_t \)를 필터링할 수 있다는 점에서 LSTM의 input gate와 유사하다.  그러나 recurrence gate \( r_t \)는 다른 연구들의 gating mechanisms과는 다르다. 

■ 예를 들어, Mamba에서 제안한 selection mechanism은 previous state와 현재의 관측값 \( x_t \) 사이를 보간한다는 점에서 GRU의 update gate와 비슷하다. hidden state에 미치는 이 효과는 LSTM의 forget gate와 유사하게 state를 reset하고 과거로부터 유지하고 있는 정보를 잊어버리게 할 수 있게 한다. 
- \( t-1 \) 시점과 \( t \) 시점에서 받아들일 정보의 양을 조절한다. 현재 \( t \) 시점의 정보를 많이 받아들이려면, 필연적으로 과거 \( t-1 \) 시점의 정보를 그만큼 잊어버려야 한다.   

■ 대조적으로, 논문의 recurrence gate는 이전 기록(history)의 모든 정보를 보존할 수 있게 해준다. (Appendix A)


A. RG-LRU Recurrence Gate

Implementation

■ 수치적 안정성을 위해 섹션 2.4에 정의된 recurrence gate를, 수학적으로 동일하지만 약간 다른 형태로 구현한다. 

■ \( a_t \)는 \( a^{cr_t} \)로 정의되지만, 실제 구현에서는 수치적 안정성을 위해, 다음과 같이 로그로 계산한 후, 지수를 취하는 방식을 사용한다.  

■ 이론적인 수식은 \( a_t = a^{cr_t} = \sigma(\Lambda)^{cr_t} \)이다. \( \sigma(\Lambda) \)는 0과 1 사이의 값이고, \( c \)는 고정된 상수(논문에서는 8 사용), 게이트 \( r_t \)는 0과 1 사이의 값을 갖기 때문에, \( a=\sigma(\Lambda) \)가 0에 아주 가깝거나 지수를 계산할 때, 값이 너무 작아져서 언더플로우가 발생할 위험이 있다.  

■ 그래서 저자들은 식 (6)처럼, 로그를 취해 \( a_t \)의 로그를 계산한 다음 지수화하는 방식을 사용한다.  

Gate behaviour

■ 다른 연구들의 gating mechanisms과 달리, 과거의 정보를 유지하는 쪽으로 편향되어 있으며 \( h_{t-1} \)을 완전히 버리는 것을 허용하지 않는다. 단, 이는 \( \Lambda \) 값에 따라 달라진다.  

■ 논문에서는 일반적인 recurrence에 대해 식 (7)과 같이 정의하여, output \( y_t \)에서 \( h_{t-1} \) 대비 \( x_t \)의 상대적 가중치를 분석한다.  

■ Fig 7의 \( x \)축 \( r_t \)는 게이트의 값(0~1)이다. \( r_t \)가 1에 가까울수록 \( x_t \)를 더 많이 반영하게 된다. 

■ \( y \)축은 \( \alpha / \alpha + \beta \)로, 새로운 입력 \( x_t \)와 과거 정보 \( h_{t-1} \)의 반영 비율을 나타낸 것이다.  

[출처] https://arxiv.org/abs/2402.19427

■ Mamba와 GRU는 선형적인 곡선 형태를 띈다. 즉, \( h_{t-1} \)과 \( x_t \) 사이를 완전히 보간한다. 

■ 반면, RG-LRU은 비선형적인 형태를 띈다. 즉, 정보를 보존하려는 성향이 강하게 나타난다. 게이트가 작동해도 \( h_{t-1} \)의 기여도를 완전히 0으로 만들지 않는다. (특히 \( a \) 값이 클 때) 



3. Recurrent Models Scale as Efficiently as Transformers

■ 모델의 크기를 키울 때(scaling) 성능이 어떻게 변하는지 비교하여, 제안하는 모델이 LLM로서 적합한지 검증한다. 비교 모델들은 baseline인 MQA-Transformer와  순수 RNN model인 Hawk, 그리고 hybrid model인 Griffin이다.  

■ 모델 scales에 따른 하이퍼파라미터는 다음과 같다. 

[출처] https://arxiv.org/abs/2402.19427

MQA Transformer baseline

■ 섹션 2에서 설명된 residual pattern과 gated MLP block을 사용하며, MQA 및 RoPE를 사용한다. 

Hawk

■ RG-LRU의 경우 \( \alpha(r_t) = a_t = a^{cr_t} \)이고, \( \beta (r_t) = \sqrt{1-\alpha(r_t)^2} \)이다. standard GRU 스타일의 gating은 \( \alpha(r_t) = 1-r_t \)이고, \( \beta(r_t) = r_t \)이다.  

■ Hawk의 아키텍처는 baseline과 동일한 residual pattern 및 MLP block을 사용하지만, MQA 대신 RG-LRU layer가 포함된 recurrent block을 temporal mixing block으로 사용한다. 

■ 일반적으로 RNN layer는 attention layer보다 파라미터 수가 적다. 그래서 성능 비교가 불공정할 수 있다. 이에 저자들은 Hawk의 \( D_{RNN} \)을 약 1.33배 키워서, Transformer와 동일한 파라미터 수를 갖도록 조정했다. (\( D_{RNN} \approx 4D / 3 \)) 

Griffin

■ recurrent blocks과 local attention을 혼합하여 사용하는 모델이다. sequence를 고정된 크기의 state로 압축하고, 최근의 window만큼만 KV cache를 유지하므로 메모리 사용량이 고정된다.  

■ 이러한 local attention으로 가까운 과거를 정확하게 모델링하고, recurrent layers로 long sequences에 걸쳐 정보를 전단할 수 있다.  

■ Griffin은 baseline과 동일한 residual pattern 및 MLP block을 사용한다. 그러나 baseline 및 Hawk와는 달리, recurrent blocks과 MQA blocks의 혼합을 사용한다.  

■ 구체적으로, recurrent block을 가진 두 개의 residual blocks 뒤에, 섹션 2.3에서 설명한 local (MQA) attention block을 사용하는 하나의 residual block이 온다. 그리고 이 순서로 반복해서 계층을 쌓는다.  
- 즉, 전체 layers의 2/3은 RNN(RG-LRU), 1/3은 local attention이다.  

■ 달리 명시되지 않는 한, local attention의 window size는 1024 tokens으로 고정한다.  


3.1 Scaling curves

■ Fig 1 (a)는 scaling 결과들을 나타낸 것이다. 세 가지 모델 모두 100M에서 7B 범위의 파라미터로 학습되었으며, 14B Griffin model도 포함되어 있다. 

■ Chinchilla scaling laws에서 규정한 대로, training tokens의 수를 model parameter 수에 대략 비례하도록 증가시켰다.  

■ models은 MassiveText dataset을 사용하하여 학습되었으며, 이때 2048 tokens의 sequence length가 사용되었다. (더 긴 sequence에 대한 결과는 섹션 6) 

■ small models(예: 100M)의 실행 결과를 바탕으로 전반적인 경향성을 파악한 뒤, 이를 바탕으로 7B 및 14B 모델의 하이퍼파라미터를 설정했다.  

■ 세 가지 모델 모두, valid loss와 training FLOPs 사이에 선형적인 스케일링 관계를 보여준다. (Fig 1 (a))

■ 주목할 점은, Griffin이 global attention layers을 전혀 사용하지 않음에도 불구하고, 모든 FLOPs 범위에서 Transformer baseline보다 더 낮은 valid loss를 달성했다는 것이다.  


3.2 Evaluation on downstream tasks

[출처] https://arxiv.org/abs/2402.19427

■ 다른 모델들(Mamba, Llama-2)과 비교하기 위해, downstream tasks에서 평가하기 전, 모든 모델을 300B tokens로 학습시켰다.  

■ Mamba는 300B tokens의 두 배인 600B tokens으로, Llama-2는 거의 7배인 2T tokens로 학습되었다. 

■ Mamba와 Llama-2는 저자들이 사용한 데이터셋과 다른 데이터로 학습되었고, 다른 하이퍼파라미터 튜닝 전략으로 학습되었다. 그래서 성능 차이가 있다면 아키텍처 때문인지 데이터 때문인지 불분명할 수 있다. 

■ 이를 해소하기 위해, 저자들은 Hawk 및 Griffin과 동일한 데이터 및 하이퍼파라미터 튜닝 예산을 들여서 MQA transformer baseline을 학습시켰다.  
- 만약, Griffin이 Mamba나 Llama-2와 성능 차이가 난다면, 어떤 요인에 의한 차이인지 불분명할 수 있지만, 데이터 조건과 튜닝 예산이 동일한 MQA transformer baseline을 이긴다면, 이는 Griffin의 아키텍처가 더 우월하기 때문이라고 해석할 수 있다. 

■ Table 1은 downstream tasks에 대한 평가 결과이다. Hawk와 Griffin 모두 강력한 성능을 달성한 것을 볼 수 있다. 

■ Hawk의 성능은 모델 크기를 키울수록 크게 향상되며, Hawk-3B는 절반의 tokens로 학습되었음에도 불구하고 Mamba-3B보다 우수한 성능을 보인다. 

■ Griffin-3B는 Mamba-3B를 능가하며, Griffin-7B와 Griffin-14B는 거의 7배 더 적은 tokens로 학습되었음에도 Llama-2와 경쟁력 있는 성능을 보인다.  

■ 또한, Hawk는 자체 MQA Transformer baseline과 경쟁력 있는 성능을 보이며, Griffin은 이 baseline을 능가한다.  

■ 더 적은 토큰으로 학습된 RNN(RG-LRU) + local attention 구조(Griffin)가 더 많은 토큰으로 학습된 global attention 구조보다 높은 성능을 보였다. 이는 Griffin의 학습 효율이 월등히 높다는 결과이다.


3.3 Training speed on longer sequences

[출처] https://arxiv.org/abs/2402.19427

■ 추가로, Griffin의 계산적 이점을 확인하기 위해 다양한 모델 크기와 sequence lengths에 걸쳐 training speed를 비교한다.

■ 각 모델 크기에 대해, 배치당 총 토큰 수를 고정한다. 즉, sequence length를 늘리면 sequence의 수(즉, batch size)를 그에 비례하여 줄인다. 

■ Fig 3은 2048 sequence length에서의 MQA baseline 대비 Griffin의 상대적 실행 시간을 나타낸 것이다.  

■ 가장 짧은 sequence length에서는 두 모델이 비슷한 training time을 가지지만, sequence length를 늘리면 Transformer는 느려지는 반면, Griffin의 실행 시간은 유지되는 것을 볼 수 있다. 

■ Transformer의 속도 저하는 모델 크기가 작을 때 더 두드러지며, 모델 크기가 커질수록 감소한다. 7B에서는 Griffin과 거의 비슷한 실행 시간을 보인다. 

■ 이는 모델 크기가 커질수록 여러 개의 linear layers을 포함하고 있기 때문이다.

■ \( T \)가 sequence length, \( D \)가 model의 width라고 할 때, linear layer의 계산은 \( O(TD^2) \)이며, RG-LRU는 \( O(TD) \), global attention은 \( O(T^2D) \)이다.  

■ 그래서 sequence length \( T \)에 비해 model width \( D \)를 늘릴수록, linear layer가 차지하는 비중 \( O(TD^2) \)이 더 커지므로, global attention \( O(T^2D) \)을 RNN(RG-LRU) \( O(TD) \)로 바꿔서 얻는 이득이 전체 시간에서 차지하는 비중이 미미해진다.   
- 모든 모델에는 수많은 linear layers이 존재하는데, sequence length \( T \)에 비해 model width \( D \)가 크다면, linear layer가 주된 계산 병목이 되므로 RNN block으로 인한 효율성이 최소화되는 것이다.  

■ 그러므로 Transformer를 Hawk나 Griffin으로 대체했을 때, 가장 큰 효율성 이득을 볼 수 있는 경우는 attention이 계산 시간의 대부분을 차지할 만큼 model width \( D \)에 비해 sequence length \( T \)가 충분히 길 때이다.  


 


4. Inference Speed

■ LLM의 inference는 prefill stage와 decode stage, 두 단계로 구성된다. 

■ prefill stage에서는 prompt를 받고 처리한다. 이 단계는 사실상 모델의 forward pass를 수행하는 단계이다. 

■ forward pass를 수행하는 단계이므로 병렬로 계산할 수 있기 때문에, 이 단계에서 대부분의 모델 연산은 compute bound이다.  

■ prefill stage의 다음 단계는 decode stage이다. 여기서는 모델에서 tokens을 auto-regressively하게 샘플링한다. 

■ decode stage에서 Transformer는 KV cache를 사용할 경우, sequence 길이에 비례하여 KV cache가 선형적으로 커진다. attention 연산을 위해 모델 파라미터와 KV cache를 device의 계산 유닛으로 이동시켜야 하는데, batch size가 최적의 batch size보다 작을 경우 memory bound가 되어 latency가 높아지고 throughput이 낮아진다.  

■ 그러나 recurrent models은 고정된 크기의 벡터에 sequence를 압축하기 때문에, sequence가 아무리 길어져도 메모리 사용량과 계산량이 늘어나지 않아, decode stage에서 Transformer보다 더 낮은 latency와 더 높은 throughput이 가능하다. 특히 attention에 사용되는 KV cache가 커질 수 있는 long sequence에서 더욱 그렇다.  

■ inference speed를 평가할 때 일반적으로 고려되는 두 가지 주요 지표는 latency와 throughput이다. 

■ 보통 latency 측정은, 특정 batch size에서 지정된 수의 tokens을 생성하는 데 걸리는 시간을 측정한다. throughput의 경우, 지정된 수의 tokens을 샘플링할 때 single device에서 초당 생성할 수 있는 최대 토큰 수를 측정한다.  

■ throughput은 "샘플링된 토큰 수 \( \times \) 배치 크기 / latency"로 계산되므로, latency를 줄이거나 메모리 사용량을 줄여 device에 더 큰 배치 크기를 사용할 수 있게 함으로써 throughput을 향상시킬 수 있다.  

■ latency 측정은 빠른 응답 시간을 필요로 하는 실시간 애플리케이션에서 유용한 지표로 사용되며, throughput은 주어진 시간 내에 모델에서 샘플링할 수 있는 최대 토큰 수를 알려주므로 역시 유용한 지표이다.  


4.1 A simple model of the decode step

■ batch size가 너무 크지 않은 한, decoding 중 언어 모델은 메모리 대역폭의 제한을 받는다. (memory bound)
- 이 섹션의 나머지 부분에서는 memory bound를 가정한다. 

■ Transformer의 가장 큰 메모리 오버헤드는 일반적으로 모델 파라미터와 KV cache에서 발생한다. 그러므로, decoding 중 batch \( B \) 내의 각 sequence에 대해 하나의 token을 생성하는 데 필요한 시간은 메모리에서 이 두 가지(모델 파라미터와 KV cache)를 로드하는 데 필요한 시간으로 다음과 같이 근사할 수 있다.  

- 즉, sequence의 next token 하나를 생성하기 위해서는 모델 파라미터와 KV cache가 연산 유닛으로 이동해야 하며, batch size가 작을 경우 next token을 연산하는 시간보다 이 시간이 지배적이므로 next token을 생성하는 시간은 위와 같이 근사될 수 있다는 것이다.   
- 식 (5)에서 cache size는 batch size 1일 때의 KV cache(Transformer의 경우), 또는 batch size 1일 때의 recurrent state의 size(RNN의 경우)를 의미한다.  

Cache sizes

■ recurrent blocks 및 local attention blocks에서는 cache size가 상당히 작기 때문에 모델 파라미터 로딩이 주된 병목이다. 

■ 대조적으로, global attention의 KV cache는 sequence length \( T \)에 비례하여 커지며, 모델 파라미터의 크기와 비슷하거나 심지어 이를 초과할 수 있다. 이는 sequence length \( T \)가 충분히 클 때 상당한 오버헤드를 발생시킨다.  

■ 이러한 차이로, recurrent model은 동일한 크기의 Transformer model보다 상당히 낮은 latency를 보일 수 있다.  

■ 그러나 모델 크기가 커질수록, 모델 파라미터 크기 > KV cache인 상황(\( T \)가 적당한 값인 상황)에서는 병목의 주범은 모델 파라미터가 된다. 
- KV Cache가 주된 병목이 되려면, \( T \)가 매우 커서 KV cache 크기가 모델 파라미터 크기만큼 커져야 한다.   

■ 반면, recurrent model은 작은 recurrent state를 가지기 때문에 메모리를 적게 차지한다. 메모리 공간이 남으면 batch size를 증가시킬 수 있으며, batch를 키우면 한 번에 처리하는 양이 늘어나므로, 결과적으로 더 높은 throughput으로 이어진다.  


4.2 Results

■ 여기서는 1B 크기의 모델들에 대한 latency와 throughput 결과들을 확인한다.  

■ baseline으로는 standard MHA Transformer보다 inference speed가 훨씬 빠른 MQA Transformer를 사용한다. 이 baseline과 Hawk, Griffin의 결과를 비교한다.  

Latency

[출처] https://arxiv.org/abs/2402.19427

■ Hawk와 Griffin은 긴 시퀀스에 대해 MQA Transformer보다 더 빠른 샘플링 속도(즉, 더 낮은 latency)를 달성한다. 이는 시퀀스 길이와 prefill 길이(KV cache 크기에 영향을 미침)가 증가할수록 더 두드러진다.  

■ 그리고 Griffin은 Hawk와 비슷한 수준의 latency를 달성한다. 이는 linear recurrences와 local attention의 결합이 속도 저하 없이 매우 효율적으로 작동함을 증명하는 결과이다.  

■ 정리하면, batch size가 고정된 상태에서(또는 최적의 batch size보다 더 낮은 batch size를 사용하는 상황에서) long context를 처리할 때 Transformer는 KV cache가 커져서 느려지지만, Hawk와 Griffin은 그 영향을 거의 받지 않아 훨씬 빠르게 반응하며, Hawk와 Griffin 간의 속도 차이도 거의 없다.  

Throughput

■ 512, 1024, 2048, 4196 tokens을 샘플링할 때 동일한 모델들의 최대 throughput (tokens/s)을 비교한다 (Fig 1 (b))

■ Griffin과 Hawk 모두 MQA Transformer보다 상당힌 높은 throughput을 달성한다. 이는 recurrent model의 latency가 더 낮기 때문이기도 하지만, 주된 이유는 Griffin과 Hawk의 cache size가 더 작아서, 남는 메모리 공간으로 MQA Transformer보다 더 큰 batch size를 사용할 수 있기 때문이다.  

■ batch size가 클 때는, local attention cache의 크기가 결국 모델 파라미터 크기와 비슷해지기 때문에, Hawk가 Griffin보다 더 높은 throughput을 달성한다.  
- Griffin은 local attention을 사용하기 때문에 Hawk보다는 메모리를 조금 더 쓰게 된다. 그래서 throughput이 Hawk보다는 약간 낮다.  
- 그럼에도 여전히 Transformer보다는 throughput이 월등히 높다.  



5. Long Context Modeling

■ 이 섹션에서는 Hawk와 Griffin이 longer contexts을 사용하여 next token prediction을 개선하는 효과를 확인하고, inference에서의 extrapolation capabilities을 확인한다.  
- 즉, Hawk와 Griffin이 RNN의 고질적인 약점인 long-term dependency 문제를 여전히 겪는지, 그리고 training에서 보지 못한 긴 길이를 inference에서 잘 처리할 수 있는지 확인한다.  


5.1 Improving next token prediction with longer contexts

[출처] https://arxiv.org/abs/2402.19427

■ 여기서는 다양한 시퀀스 길이에 걸쳐 held-out dataset에 대한 loss를 측정하여 학습된 모델을 평가한다. 

■ long documents을 사용함으로써 모델의 extrapolation 능력, 즉 training 중에 보았던 것보다 더 긴 contexts이 주어졌을 때에도 next token을 정확하게 예측하는지 평가할 수 있다.  

■ Transformer에서 extrapolation 능력은 주로 attention layer에 사용된 positional encoding에 의해 결정된다. 

■ recurrent models의 경우, state를 계속 업데이트하는 방식이므로 이론적으로는 무한한 길이를 처리할 수 있다. 그러므로, context가 길어짐에 따라 state에 저장된 표현을 계속 업데이트하는 모델의 용량(정확하게는 context를 담을 hidden state vector의 차원의 크기)에 의해 좌우된다.  

■ Fig 5의 left는 2048 tokens의 시퀀스로 학습된 모델들을 대상으로 평가한 결과이다.  

■ Hawk와 Griffin 모두 더 긴 contex가 주어져도 next token prediction의 loss가 꽤 유지되는 것을 볼 수 있다. 두 모델 모두 학습된 길이보다 훨씬 더 긴 시퀀스(최소 4배 더 긴)에 대해 extrapolate할 수 있다. 특히, local attention layers에 RoPE를 사용한 Griffin은 더 뛰어난 extrapolation을 보인다.  
- 2048 tokens의 길이로 학습된 Hawk는, 그 4배인 8K까지 next token prediction loss가 잘 유지되고, Griffin은 32K까지 계속 loss가 잘 유지되는 것을 볼 수 있다.  

■ 저자들은 2048 tokens보다 더 긴 길이로 학습한 모델이, 더 효과적으로 extrapolation을 가지는지 평가하기 위해, MassiveText dataset에서 8192(8K) tokens의 시퀀스로 1B 모델을 학습시키고, 이를 동일한 데이터셋에서 2048(2K) tokens로 학습된 모델과 비교하였다.  

■ 그리고 비교의 공정성을 위해, 8K 길이로 학습된 모델에 대해선 batch size를 4배 줄임으로써(단, training steps는 고정), 모델 간의 총 학습 토큰 수를 동일하게 유지했다. 

■ Fig 5 right에서, Hawk-8K와 Griffin-8K가 8192 이상의 시퀀스 길이에 대해 Hawk-2K 및 Griffin-2K보다 더 낮은 evaluation loss(더 좋은 성능)을 달성한 것을 볼 수 있다.  

■ 그리나 짧은 시퀀스 길이에서는 2K로 학습된 모델들이 오히려 더 나은 성능을 보인다. 

■ 이는 training sequence length는 downstream task에 따라(즉, 모델이 사용될 downstream이 긴 글을 다루는 task인지, 짧은 task인지에 따라) 신중하게 선택되어야 함을 시사한다.  


5.2 Copy and retrieval capabilities

■ Transformer가 RNN인 state space model (SSM)보다 context를 복사하거나 context에서 관련 토큰들을 검색하는 것과 같은 synthetic tasks를 학습하는 데 있어 훨씬 더 효율적일 수 있음을 보여준 연구가 있다.  그리고 pre-trained Transformer가 pre-trained SSM에 비해 해당 tasks에서 훨씬 더 우수함을 보여준 연구도 있다. 

■ 이 섹션에서는 Griffin과 Hawk가 context에서 tokens을 복사하고 검색하는 방법을 학습하는 효율성(즉, 얼마나 더 빨리 학습할 수 있는지, Transformer만큼 빨리 학습할 수 있는지)을 확인한다.  

복사 및 검색 능력 모두를 테스트하기 위해 설계된 phone number lookup task에서 사전학습된 Hawk 및 Griffin 모델을 평가한다.  

Training on synthetic tasks

■ context에서 relevant tokens을 복사하고 검색하는 방법을 학습하는 효율성을 조사하기 위해, 두 가지 synthetic tasks인 Selective Copying과 Induction Heads에 대해 학습을 진행한다.  

■ Transformer를 Hawk 및 Griffin과 비교할 수 있도록, model dimension 64를 가진 5 bolck 깊이의 네트워크(모델 크기는 약 250K 정도)를 고려하며, 여기서 Griffin은 세 번째 block에서만 local attention을 사용한다. (즉, 5개 block 중 RG-LRU 4개, local attention 1개) 

- (1) Selective copying task
- 이 task에서 모델은 noise tokens은 무시하면서 sequence 내의 data tokens을 복사하는 법을 배워야 한다. 
- 여기서는 16의 vocabulary size를 사용하고, vocabulary에서 무작위로 샘플링되고 무작위 위치에 배치된 16개의 data tokens을 포함하며, 나머지 tokens은 noise tokens으로 설정된, 1024의 sequences을 학습한다.  
- 즉, 1024개 tokens 중 복사해야 하는 tokens은 단 16개이다. 모델은 1000개가 넘은 노이즈를 무시하고 이 16개를 정확히 기억해내야 한다.  
- Griffin은 512의 local attention window size를 사용한다. 
- (2) Induction heads
- 이 task에서 모델은 special token 바로 뒤에 오는 token을 기억해야 한다.  
- 이를 위해서는 모델이 special token을 학습하고, context 내에서 그 뒤에 오는 token을 검색해야 한다. 그러므로 모델은, 학습된 것보다 훨씬 긴 sequence로 extrapolate할 수 있어야 한다.  
- vocabulary size 16을 사용하고, tokens이 무작위로 샘플링된 길이 256의 sequences을 학습하며, sequence 내 special token의 위치를 랜덤 샘플링한다.  
- Griffin은 128의 local attention window size를 사용한다.  

[출처] https://arxiv.org/abs/2402.19427

■ Fig 6에서, Selective Copying task의 경우 세 가지 모델 모두 해당 taks를 완벽히 해결할 수 있음을 볼 수 있다. 단, 이 task에 대한 학습 속도를 비교했을 때, Hawk가 Transformer보다 상당히 느리다.   

■ 반면, Griffin은 단 하나의 local attention layer를 사용함에도 불구하고 Transformer의 학습 속도와 필적한다. 

■ Induction Heads task의 경우, 세 가지 모델 모두 training sequence length까지는 task를 완벽하게 해결할 수 있지만, Transformer baseline은 그 길이를 넘어서는 순간부터 성능이 저하되는 것을 볼 수 있다. 이 baseline은 RoPE를 사용했으므로, RoPE의 한계로 볼 수 있다. 

■ 반면, Hawk는 training sequence length보다 몇 자릿수 더 긴 evaluation sequences에 대해서도 이 task를 완벽하게 수행한다. Griffin 또한 local attention을 가지고 있음에도 우수한 성능을 보인다. Induction Heads task에서 Hawk와 Griffin의 결과는 RNN의 장점 덕분으로 보인다.  

Evaluating pre-trained models

■ 다음으로, pre-trained models에서도 copying 및 retrieval 능력이 발현되는지 확인한다. 

■ MassiveText dataset에서 300B tokens로 학습된 7B Hawk 및 Griffin과 6B MQA Transformer를 비교한다.  

■ 이전 연구와 동일한 phonebook lookup task를 사용한다. 이 task에서는 이름과 전화번호가 포함된 전화번호부를 모델에게 제공하고, 특정 이름이 주어졌을 때 올바른 전화번호를 검색하도록 요청한다.  

■ 모델에 주어지는 prompt는 랜덤 샘플링된 특정 길이의 이름과 번호 목록으로 구성된 전화번호부, 그 뒤에는 랜덤 샘플링된 두 개의 task example, 그리고 마지막으로 모델이 전화번호를 찾아내야 할 타겟 이름으로 구성된다.  

■ Hawk는 작은 크기의 고정된 state vector를 사용한다. 그래서 전화번호부 길이가 매우 짧을 때는 잘 수행하지만, 길이가 길어지면 올바른 전화번호를 기억하고 검색하는 데 실패하는 것을 Fig 6 (c)에서 볼 수 있다. 

■ Transformer baseline은 training sequence length까지는 이 task를 거의 완벽하게 수행하지만, 결국 training sequence length보다 더 긴 길이의 context에 대해서는 올바른 전화번호를 검색하는 데 실패한다.  

■ Griffin은 단 하나의 local attention layer만 사용함에도 불구하고, local attention window size인 1024와 일치하는 길이까지는 이 task를 완벽하게 해결할 수 있다. context 길이가 window size로 커버할 수 없을 만큼 길어지면, 성능이 저하되기 시작한다.