본문 바로가기

자연어처리/LM

RecurrentGemma: Moving Past Transformers for Efficient Open Language Models

■ RecurrentGemma는 Griffin 아키텍처를 사용한다. Griffin 기반이므로 고정된 크기의 state를 가져 메모리 사용량을 줄이고, long sequences에 대해서도 efficient inference가 가능하다. 

■ pre-trained 및 instruction tuned된 2B 및 9B 모델은 더 적은 tokens로 학습되었음에도 불구하고, 비슷한 크기의 Gemma models과 대등한 성능을 달성한다.  

[2404.07839] RecurrentGemma: Moving Past Transformers for Efficient Open Language Models

 

RecurrentGemma: Moving Past Transformers for Efficient Open Language Models

We introduce RecurrentGemma, a family of open language models which uses Google's novel Griffin architecture. Griffin combines linear recurrences with local attention to achieve excellent performance on language. It has a fixed-sized state, which reduces m

arxiv.org

 

1. Introduction

■ Griffin 아키텍처는 global attention 대신, linear recurrences와 local attention의 결합을 통해 sequence를 모델링한다. 

■ 2B 및 9B의 RecurrentGemma는 모두 2T tokens로 학습되었으며, 다양한 downstream tasks에서 Gemini를 기반으로 한 Gemma models과 경쟁할 만한 수준의 성능을 달성한다.   

■ transformer는 inference 시점에서 KV cache를 device memory에 로드해야 한다. 이 KV cache는 sequence length에 따라 선형적으로 증가한다.  

■ local attention을 사용하여(정확하게는 window size를 적절히 선택하여) cache size를 줄일 수는 있지만, 과거 문맥의 전체를 보지 못하므로 성능이 떨어지는 문제가 있다.  

■ 대조적으로, Griffin 아키텍처 기반의 RecurrentGemma는 성능 저하 없이 input sequences을 고정된 크기의 state로 압축하기 때문에, 메모리 사용량을 줄이고 input sequence에서 효율적인 추론을 가능하게 한다.  

■ 그리고 TPU에서 linear recurrence을 수행하기 위한 specialized Pallas kernel을 사용한다.  



2. Model architecture

■ 저자들은 Griffin 아키텍처에 단 하나의 수정 사항만을 적용했다. 그것은 input embeddings에 model width(즉, model dimension d_model)제곱근과 같은 상수를 곱하는 것이다.  
- 보통 embedding은 초기화 시 아주 작은 값(예: torch.nn.Embedding, 평균 0, 분산 1)으로 설정된다.  
- 여기에 positional encoding의 숫자가 input embedding보다 훨씬 크면, 단어의 의미는 무시하고 위치 정보에만 집중하게 될 위험이 있다.  
- 그래서 임베딩 벡터에 model width의 제곱근과 같은 상수를 곱해서 스케일링한다. 그러면 embedding 값들의 크기가 커져서 위치 인코딩 값들과 비슷한 수준의 영향력을 갖게 된다.  
- Attention Is All You Need에서도 embedding에 d_model의 제곱근을 곱한 것을 사용한다.  

■ RecurrentGemma에서는 input embeddings과 output embeddings이 서로 tied되어 가중치를 공유하지만, 이 상수는 output embeddings에는 적용하지 않았다.  

■ RNN 구조는 Transformer보다 학습이 불안정하다. 그래서 저자들은 recurrent (RG-LRU) layers의 파라미터에는 weight decay를 적용하지 않았다. 

■ 또한, recurrent layers 내부의 제곱근 연산을 backpropagation할 때, 학습 안정성을 위해 항상 미분값(derivative)을 최댓값 1000으로 클리핑한다.  
- RG-LRU 내부 연산에는 제곱근(\( \sqrt{x} \)) 연산이 있다. \( \sqrt{x} \)는 미분하면 \( \dfrac{1}{2 \sqrt{x}} \)가 된다. 그래서 \( x \)가 0에 가까워질수록 미분값이 커지며, 0이 되면 발산한다. 그래서 수치적 안정성을 위해 미분값이 1000을 넘지 않도록 클리핑한 것이다.  

■ RecurrentGemma-2B와 RecurrentGemma-9B의 하이퍼파라미터는 다음과 같다. 

[출처] https://arxiv.org/abs/2404.07839



3. Training details

Pre-training

■ 8192 tokens의 sequences로 모델을 학습시킨다. Gemma와 동일한 pre-training data를 사용하며, 이는 주로 web documents, mathematics, code로 구성된 English data이다.  

■ RecurrentGemma-2B와 RecurrentGemma-9B 모두 2T tokens로 pre-training한다. Gemma-2B는 3T tokens, Gemma-7B는 6T tokens로 pre-trained되었다.  

■ Gemma와 동일하게, 먼저 대규모의 일반적인 데이터의 mixture로 학습시킨 후, 더 작지만 더 높은 품질의 dataset으로 continuing training한다.  

■ 그리고 Gemma와 마찬가지로, 256K tokens로 구성된 vocabulary를 가진 SentencePiece tokenizer의 subset을 사용한다.  

Instruction tuning and RLHF

■ 모델이 높은 reward를 받는 responses을 출력하도록 fie-tune시키기 위해 새로운 RLHF algorithm을 포함하여, Gemma와 유사한 instruction tuning approach를 따른다.  

■ instruction tuned된 model은 Gemma처럼 Table 3의 format을 따르도록 학습되었다.  

[출처] https://arxiv.org/abs/2404.07839



4. Evaluation

Automated Benchmarks

[출처] https://arxiv.org/abs/2404.07839

■ RecurrentGemma-2B는 Gemma-2B보다 50% 더 적은 tokens로 학습되었음에도 불구하고, 비슷한 성능을 보인다.  

■ RecurrentGemma-9B도 Gemma-7B보다 3배 더 적은 tokens로 학습되었음에도 대등한 성능을 보인다. embedding layers을 고려하면 RecurrentGemma-9B는 Gemma-7B의 총 파라미터 수는 비슷하다. 즉, 모델의 실질적인 체급은 비슷하다.  

Human Evaluation

[출처] https://arxiv.org/abs/2404.07839

■ 최종적으로 instruction tuned된 두 개의 RecurrentGemma models (2B IT, 9B IT)을 Mistral 7B v0.2 Instruct model과 human evaluation으로 비교한다.  

■ Table 5에서 볼 수 있듯이 creative writing, coding tasks 전반에 걸쳐 모델의 지시 수행 능력을 묻는 약 1000개의 held-out prompts에 대해, RecurrentGemma-2B IT는 Mistral 7B를 상대로 43.7%의 승률을 달성했으며, RecurrentGemma-9B IT는 59.3%의 승률을 달성했다.  

■ 기본적인 safety protocols을 테스트하는 데 중점을 둔 약 400개의 held-out prompts에서, RecurrentGemma-2B IT는 Mistral 7B를 상대로 59.8%의 승률을 달성했으며, RecurrentGemma-9B IT는 59.9%의 승률을 달성했다.  

Inference Speed Benchmarks

[출처] https://arxiv.org/abs/2404.07839

■ RecurrentGemma의 장점은 long sequences에서 transformer보다 state의 size가 훨씬 작다는 것이다. 

■ transformer 기반인 Gemma의 KV cache는 sequence length에 비례하여 증가하는 반면, RecurrentGemma는 local attention window size 2K tokens보다 긴 sequences에서도 증가하지 않는다. state size는 제한되어 있기 때문이다.  
- local attention의 window size를 2k tokens로 설정했다면, 2K tokens의 cache만 있으면 된다.  

■ 언어 모델의 inference는 일반적으로 memory-bound이다. inference speed를 높이는 가장 좋은 방법은 batch size를 최적에 맞게 키우는 것이다.  
- 모델 파라미터 및 KV cache를 이동시키는 비용을 batch size로 나누어 내는 셈이므로, batch size를 최적의 크기에 맞게 키운다면 생성 비용이 낮아진다.  

■ Gemma의 경우, autoregressively하게 생성할 수 있는 가장 긴 샘플은 가용 메모리에 의해 제한된다. 반면, RecurrentGemma는 임의의 길이의 sequences을 생성할 수 있다. 

■ Griffin 기반 RecurrentGemma는, transformer 기반 Gemma보다 inference에서 샘플 생성 시 메모리 요구량이 더 작기 때문에, 훨씬 더 큰 batch size를 설정하여 inference를 수행할 수 있다.   

■ Fig 1 (a)와 1 (b)는 RecurrentGemma 2B 및 9B 모델이 달성한 inference throughput을 비슷한 크기의 Gemma 모델들과 비교한 결과이다.  

■ 2K tokens의 prompt로부터 다양한 생성 길이에 대해 샘플링할 때 달성된 throughput이다. 

■ RecurrentGemma-2B의 경우 single TPUv5e, RecurrentGemma-9B의 경우 single TPUv4에서 초당 샘플링할 수 있는 최대 토큰 수로 throughput을 계산하였다.  

■ 단, 이 결과(Fig 1 (a)와 1 (b)의 각 left)에는 prompt를 처리하는 데 걸리는 시간이나, 출력 시퀀스를 token id들의 리스트에서 최종 텍스트 문자열로 변환하는 데 소요되는 시간은 포함되어 있지 않다. 

■ RecurrentGemma는 모든 시퀀스 길이에서 Gemma보다 더 높은 throughput을 달성하였다. 

■ RecurrentGemma가 달성한 throughput은 시퀀스 길이가 증가해도 줄어들지 않는 반면, Gemma가 달성한 throughput은 cache가 커짐에 따라(즉, 시퀀스 길이가 증가함에 따라) 떨어지는 것을 볼 수 있다.  

■ 그리고 2B 모델끼리의 비교보다, Gemma-7B와 RecurrentGemma-9B의 비교에서 throughput 차이가 더 두드러지게 나타나는데, 이는 Gemma-7B가 MHA를 사용하는 반면, Gemma-2B는 MQA를 사용하기 때문이다.  

■ MQA는 key/value vectors을 공유하므로 그나마 효율적인 것이다. 

■ 추가로, 서로 다른 길이의 input prompts을 처리할 때 달성되는 throughput도 측정하였다. (Fig 1 (a)와 1 (b)의 각 right)

■ auto-regressive phase와 달리, promt phase에서는 tokens이 병렬로 처리된다. 
- auto-regressive phase는 주어진 tokens로 next token을 생성하는 단계이고, prompt phase에서는 input tokens을 단순히 계산하는 단계이기 때문에 autoregressively하게 생성할 필요 없이 병렬 처리한다.  

■ Gemma와 RecurrentGemma는 비슷한 속도로 input prompts을 처리한다. Gemma와 RecurrentGemma 모두 2B 모델의 경우 초당 약 40K tokens, 9B 모델의 경우 초당 약 12K tokens을 처리한다.  

■ autoregressive sampling에서는, RecurrentGemma는 초당 6K tokens의 throughput을 달성하는 반면, Gemma는 그보다 더 느리다. 그러므로, prompt가 생성하고자 하는 샘플보다 훨씬 길지 않은 이상, 샘플링 과정이 전체 소요 시간을 좌우할 것이다.  
- 즉, 샘플링 속도가 빠른 RecurrentGemma가 전체적인 서비스 속도 경쟁력에서 Gemma보다 훨씬 우위에 있다.  

■ Fig 1 (a)와 1 (b)는 RecurrentGemma의 Flax implementation을 사용한 결과이다. 여기에는 TPU에서의 실행을 위해 특수 제작된 Pallas kernel이 포함되어 있다.  

■ 그래서 PyTorch implementation을 사용하거나 GPU를 사용할 경우, 논문의 수치보다 더 낮은 throughput을 기록할 수 있다.  

Responsible Deployment

[출처] https://arxiv.org/abs/2404.07839

■ RecurrentGemma는 Gemma release에 기술된 것과 동일한 safety mitigations을 따른다.  

■ standard academic safety benchmarks에서 RecurrentGemma를 평가했으며, 그 결과는 Table 6에 제시되어 있다. 다만 이러한 벤치마크 평가는 모든 가능한 use cases를 포괄하지는 못한다.