본문 바로가기

자연어처리

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

■ 트랜스포머의 self-attention은 입력 길이에 대한 2차 복잡도(quadratic complexity) 때문에 매우 긴 시퀀스에 대해서는 속도가 엄청 느리다는 한계를 가진다.

■ 논문에서는 이 한계를 해결하기 위해 self-attention을 커널 특징 맵(kernel feature maps)의 선형 내적(linear dot-product)으로 표현하고, 행렬 곱의 결합 법칙(associativity property)을 이용하여 복잡도를 \( O(N^2) \)에서 \( O(N) \)으로 줄이는 방법을 제안한다. 

[2006.16236] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

 

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

Transformers achieve remarkable performance in several tasks but due to their quadratic complexity, with respect to the input's length, they are prohibitively slow for very long sequences. To address this limitation, we express the self-attention as a line

arxiv.org

 

1. INTRODUCTION

■ 트랜스포머 모델은 자연어, 오디오, 이미지를 다루는 다양한 태스크에서 인상적인 결과를 보여주고 있다. 
■ 충분한 지도 데이터가 있는 태스크 외에도 트랜스포머는 autoregressive 또는 masked language modeling 목적 함수로 사전학습되었을 때, 지도 데이터가 제한적이거나 없는 태스로 사전학습에서 얻은 지식을 전달(파인튜닝)하는 데에도 효과적이다. 

■ 그러나 이러한 이점들은 매우 높은 계산 및 메모리 비용을 대가로 한다. 

■ 병목 현상은 주로 self-attention에 의해 야기되는데, \( N \)개의 입력 컨텍스트에 대해 \( N \times N \) 크기의 어텐션 행렬을 계산하고 저장해야 하기 때문이다. 이 때문에 시간과 메모리 복잡도가 모두 \( O(N^2) \)이 된다. 

■ 이에 대한 해결책 중 하나로 Transformer-XL에서는 이전 세그먼트의 계산 결과를 저장하여 더 긴 컨텍스트를 보려는 시도를 하였지만, 이 역시 추가적인 계산 비용을 유발하는 한계가 있었다. 

■ 이 논문에서는 메모리 사용량을 줄이고 컨텍스트 길이에 대해 선형적으로 확장되는 "linear transformer" 모델을 제안한다.

■ 이를 위해 커널 기반의 self-attention 공식과 행렬 곱의 결합 법칙을 사용하여, self-attention의 가중치(attention weight) 계산하는 방식을 도입하였다. 



2. Linear Transformers

■ 전통적인 softmax attention에서 softmax는 attention weight를 얻기 위해 사용한다. \( N \times N \) 크기의 attention score matrix의 행별로 softmax가 적용되므로, 한 행에 대한 계산량은 \( O(N) \)이라고 할 수 있다. 총 \( N \)개의 행이 있으므로, softmax 연산 자체의 총 계산량은 \( O(N^2) \)이 된다. 

softmax를 사용하는 이유는 attention score를 확률 분포와 유사한 형태로 변환(attention probability)하기 위함이지만, 반드시 softmax를 사용해야만 attention probability(또는 attention weight)를 계산할 수 있는 것은 아니다.
저자들은 이러한 점에 착안하여, 전통적인 softmax attention에서 특징 맵(feature map) 기반의 내적 어텐션으로 어텐션을 대체한다. 


2.1 Transformers

■ 입력을 \( x \in \mathbb{R}^{N \times F} \)라고 하자. 트랜스포머는 \( L \)개의 트랜스포머 레이어 \( T_1 (\cdot), \cdots, T_L(\cdot) \)의 합성으로 정의되는 함수 \( T : \mathbb{R}^{N \times F} \rightarrow \mathbb{R}^{N \times F} \)이며, 다음과 같다.

- 함수 \( f_l (\cdot) \)은 피드포워드 네트워크, \( A_l (\cdot) \)은 self-attention 함수

■ 입력 시퀀스 \( x \)는 세 개의 행렬 \( W_Q \in \mathbb{R}^{F \times D}, W_K \in \mathbb{R}^{F \times D}  , W_V \in \mathbb{R}^{F \times M}  \)에 의해 \( Q, K, V \)로 투영된다. 모든 위치에 대한 출력 \( A_l(x) = V' \)는 다음과 같이 계산된다. 

■ softmax 함수는 \( QK^T \)에 행별로 적용된다. 보통 \( Q, K, V \)를 각각 쿼리, 키, 벨류라고 부른다. 

■ 식 (2)는 쿼리와 키 사이의 내적의 지수 함수를 유사도 점수로 사용하는, softmax attention이라고 불리는 self-attention의 특정 형태를 표현한 것이다. 

행렬에 아래 첨자 \( i \)를 붙이는 것이 \( i \)번째 행을 벡터로 반환한다고 했을 때, 임의의 유사도 함수에 대한 일반화된 어텐션 식을 다음과 같이 쓸 수 있다. 

■ 식 (3)에서 유사도 함수(similarity function)를 \( \text{sim}(q, k) = \text{exp} \left( \dfrac{q^T k}{\sqrt{D}} \right) \)로 치환한다면, 식 (3)은 식 (2)와 동일한 식이다. 

- 식 (3)의 분자는 \( i \)번째 쿼리와 모든 키 \( K_j \) 의 유사도를 계산하고, 그 유사도를 벨류 \( V_j \)에 곱한 값들을 모두 더한다. 

- 식 (3)의 분모는 \( i \)번째 쿼리와 모든 키의 유사도의 총합이다.

- 식 (3)의 분자를 식 (3)의 분모로 나누는 것은, 유사도 총합이 1이 되도록 정규화하여 가중 평균을 구하는 과정이다. 


2.2 Linearized Attention

■ 커널의 관점에서 유사도 함수(sim)을 다른 함수로 얼마든지 바꿀 수 있다. 단, similarity function(kernel function)의 값 \( \text{sim} (\cdot) \)이 음수가 아니어야 한다는 제약 조건이 필요하다. 

- 즉, 간단히 말해 일반화된 어텐션(식 (3))이 잘 작동하려면, sim 값은 음수만 아니면 된다. 

■ 이 제약 조건은 확률의 공리와 간접적으로 깊은 관련이 있다고 할 수 있다. 이는 일반화된 어텐션 메커니즘(식 (3))이 기존  softmax attention 메커니즘처럼 합리적으로 작동하기 위한 필수 조건이기 때문이다. 

- 어텐션의 최종 목표는 중요도를 나타내는 attention weight를 각 벨류 벡터에 곱해서 모두 더하는 것이다. 

- 기존 방식의  attention weight는 softmax를 통과한 값이기 때문에 확률처럼 해석할 수 있다. 즉, attention weight는 확률처럼 0과 1 사이의 값을 가지며, 총합은 1이고, 각 가중치는 항상 0 또는 양수이다. 

- " \( \text{sim} (\cdot) \)이 음수가 아니다"라는 제약 조건이 없으면, 확률의 공리 중 \( P(A) \geq 0 \)(확률은 항상 0 또는 양수여야 함)를 만족하지 않는다. 또한, 모든 \( Q_i \)와 \( K_j \)의 사잇각이 90도라면, 식 (3)의 분모가 0이 될 수 있다. 

- 즉, 이 제약 조건은 어텐션 메커니즘이 '가중 평균'으로 작동하고, \( P(A) \geq 0 \)처럼 '중요도'가 항상 0 이상의 값을 가지도록 보장하며, 계산 과정에서 분모가 0이 되어 '0으로 나누기'와 같은 오류를 방지하기 위해 필요하다. 

feature representation map이 \( \phi \)로 주어졌을 때, 식 (2)를 다음과 같이 다시 쓸 수 있다.

■ 여기에 행렬 곱셈의 결합 법칙을 이용하여 다음과 같이 더 단순화시킬 수 있다. 

■ 아래 식 (6)은 결합 법칙을 전체 행렬(\( Q, K, V \))의 관점에서 보여주는 것이다. 

- 좌변은 기존 방식으로 쿼리와 키를 먼저 곱하기 때문에 \( N \times N \) 크기의 행렬이 먼저 생성된다.

- 우변은 결합 법칙을 적용했을 때를 나타내며, 단순히 계산 순서를 바꿈으로 인해 \( N \times N \) 행렬이 먼저 생성되지 않는다. 

- feature map은 행렬 \( Q \)와 \( K \)에 행별로 적용된다. 

식 (2)로부터 softmax attention의 계산 비용은 \( O(N^2) \)이 발생하며, \( Q, K, V \)에 대한 그래디언트를 계산하기 위해 \( N \times N \) 크기의 어텐션 행렬이 저장되어야 하므로, 메모리 복잡도 역시 \( O(N^2) \)이다. 

■ 반면, 식 (5)는 시간 및 메모리 복잡도가 \( O(N) \)이다. 

- 식 (5)의 분자 \( \displaystyle\sum_{j=1}^{N} \phi(K_j) V_j^{T}\)와 분모 \( \displaystyle\sum_{j=1}^{N} \phi(K_j)\)를 한 번 계산하고 모든 쿼리 \( Q_i \)에 대해 재사용할 수 있기 때문이다. 

-  식 (5)의 분자는 쿼리 \( Q_i \)와 무관하다. 따라서 단 한 번만 계산해두면 모든 \( i \)에 대해 재사용할 수 있다. 

■ 좀 더 구체적으로 기존 방식과 linear attention을 비교하면,

- (1) 기존 방식

- 먼저, 식 (4)의 \( \displaystyle\sum_{j=1}^{N} \phi(Q_i)^{T} \phi(K_j) V_j \)를 보자. \( \phi \)는 행렬의 크기와 상관없으니 무시하겠다. 

- \( i = 1 \)일 때, \( V'_1 \)을 계산하기 위해서는, \( Q_1 \)과 \( K_1 \)의 내적, \( Q_1 \)과 \( K_2 \)의 내적, \( \cdots \), \( Q_1 \)과 \( K_N \)의 내적, 여기까지 \( N \) 번의 내적 연산이 필요하다. 

- 계산된 \( N \)개의 스칼라 가중치를 사용하여 \( V \) 벡터들의 가중 평균을 계산한다.

- \( i = 2 \)일 때에도, \( Q_2 \)과 \( K_1 \)의 내적, \( Q_2 \)과 \( K_2 \)의 내적, \( \cdots \), \( Q_2 \)과 \( K_N \)의 내적, \( N \) 번의 내적 연산이 필요하다. 

- 이 과정을 \( N \)개의 모든 출력에 대해 반복한다. 

- 각 출력(\( V'_i \))마다. \( N \)번의 연산이 필요하고, 이런 출력이 \( N \)개 있으므로, 총 연산 횟수는 \( N \times N \)에 비례한다. 

- \( N \times N \)의 원인은 \( i \) 번째 출력을 계산할 때 했던 \( K_j \)와 \( V_j \)에 대한 계산을 \( i + 1 \) 번째 출력을 계산할 때 처음부터 다시 하기 때문이다. 

- (2) linear attention

- \( \displaystyle\sum_{j=1}^{N} \phi(K_j) V_j^{T}\)를 미리 계산한 결과가 \( S \)라고 하자. 이 계산은 \( j=1 \)부터 \( N \)까지, \( K_j  V_j \)를 계산하므로 복잡도는 \( O(N) \)이다. 이 과정에는 쿼리 \( Q_i \)가 전혀 관여하지 않기 때문이다. 

- \( i = 1 \)일 때, 미리 계산한 결과 \( S \)가 있다고 하자. \( V'_1 \)을 계산하기 위해서는 \( Q_1 \)과 \( S \)를 곱하기만 하면 된다. 

- \( Q \)가 \( N \times D \)이므로 각 \( Q_i \)는 \( 1 \times D \)이다. \( S \)는 \( D \times M \)이므로 

- \( V'_1 \)을 계산하기 위해 필요한 것은 \( D \)와 \( M \)이다. 즉, 시퀀스 길이 \( N \)에 전혀 영향을 받지 않으므로 \( N \)에 대한 복잡도는 상수이다.

- 따라서 각 쿼리 \( Q_i \)에 대한 어텐션 계산은 \( O(1) \)이 된다. 

- 이 \( O(1) \)에 대한 연산을 \( N \)번 진행하니 \( N \times O(1) = O(N) \)이 된다. 

- 그러므로, 총 시간 복잡도는 \( O(N) + N \times O(1) = O(N) \)이 된다. 

■ 이렇게 linear attention은 행렬 곱의 순서를 바꾸는 트릭을 통해 \( N \times N \) 행렬의 생성 자체를 회피한다.

대신, \( N \)에 선형적으로 비례하는 계산을 통해 모든 쿼리 \( Q_i \)가 공유할 수 있는 중간값을 한 번만 계산하고 재사용하므로, 시간과 메모리 복잡도가 모두 \( O(N) \)으로 감소한다. 

2.2.1 FEATURE MAPS AND COMPUTATIONAL COST

■ softmax attention의 경우, 곱셈과 덧셈 측면에서 총비용은 \( O \left( N^2 \text{max}(D, M) \right) \)로 확장된다.

- \( D \)는 쿼리와 키의 차원, \( M \)은 벨류의 차원

■ 반대로 linear attention의 경우, 먼저 \( C \) 차원의 feature map을 계산한다. 새로운 벨류를 계산하는 데는 \( O(NCM) \)의 덧셈과 곱셈이 필요하다. 

- \( C \)는 feature map \( \phi \)가 출력하는 벡터의 차원

■ 논문에 따르면, 다항식 커널을 사용했을 때 유한 차원의 feature map을 가지고 있으며, 지수(exp) 또는 RBF 커널과 동등하게 잘 동작한다고 한다. 2차 다항식 트랜스포머를 선형화했을 때의 계산 비용은 \( O(ND^2M) \)이며, \( N > D^2 \)일 때, 기존 방식보다 계산 복잡도가 유리하다고 한다.

- 대부분의 긴 시퀀스 문제(예: \( N = 4096, D = 512 \)에서 이 조건은 만족된다. 

이처럼, 어떤 feature map을 사용할지에 대해서는 여러 가지 선택지가 존재한다. 저자들은 더 짧은 시퀀스를 다루는 실험을 하기 위해, 다음과 같은 feature map을 사용하였다. 

\( \phi(x) = \text{elu}(x) + 1 = \begin{cases} x+1 & x>0 \\ e^x & x \le 0 \end{cases} \)

- \( \text{sim} (\cdot) \) 값이 양수가 되는 feature map이다. 

- elu(exponential linear unit)는 최소 -1의 값을 가지므로, elu(x) + 1의 값은 항상 0 이상이다. 이는 attention weight가 음수가 되지 않도록 하는 제약 조건을 만족한다. 

- relu() 대신 elu()를 사용한 이유는 \( x \)가 음수일 때, 그래디언트가 0이 되는 것을 방지하기 위해서라고 한다. 

  feature map을 사용할 경우 \( x \)의 차원 \( D \)를 바꾸지 않으므로(\( C = D \)), \( O(NDM) \)의 덧셈과 곱셈이 필요하다.

■ 논문에 따르면, 식 (7)의 feature map이 계산 및 메모리 요구사항을 상당히 줄이면서도, 기존 방식과 대등한 성능을 보인다고 한다. 


2.3 Causal Masking

■ 트랜스포머 아키텍처는 \( i \) 번째 위치가 오직 \( j \leq i \)인 위치 \( j \)에 의해서만 영향을 받을 수 있도록, 즉 미래 위치들의 영향을 받을 수 없도록 어텐션 계산을 마스킹함으로써 autoregressive 모델을 효율적으로 훈련시키는 데 사용될 수 있다. 

■ 인과적 마스킹(causal masking)은 식 (3)을 다음과 같이 변형시킨다.

Linearized Attention 섹션의 논리에 따라 식 (8)을 다음과 같이 나타낼 수 있다. 

마찬가지로 다음과 같이 분자의 합 \( S_i \)와 부모의 합 \( Z_i \)를 정의했을 때, 

식 (9)를 다음과 같이 단순화할 수 있다.

- 식 (10) \( S_i \)는 다음과 같이 \( S_{i-1} \)과 \( i \) 번째의 합으로 나타낼 수 있다.

- \( S_i = S_{i-1} + \phi(K_i)^T V_i \)

- 식 (11)도 마찬가지이다. \( Z_i = Z_{i-1} + \phi(K_i) \)

- 즉, \( i \) 번째 시점의 상태(\( S_i \)와 \( Z_i \))는 이전 시점의 상태(\( S_{i-1} \)와 \( Z_{i-1} \))에 현재 시점의 입력(각각 \( \phi(K_i)^T V_i \)와 \( \phi(K_i) \))만 더해주면 되므로 \( O(1) \)에 계산할 수 있다. 

- 각 \( i \) 번째 출력 \( V'_i \)를 계산하는 데 필요한 \( S_i, Z_i \)를 \( O(1) \)에 계산할 수 있으며, \( V'_i \)를 계산하는 것도 \( N \)과 무관한 연산이다.

- 따라서 \( i = 1 \)부터 \( N \)까지 전체 시퀀스를 처리하는 총 시간 복잡도는 \( O(N) \)이 된다. 

2.3.1 GRADIENT COMPUTATION

■ 어떤 딥러닝 프레임워크에서든, 그래디언트를 계산하기 위해서는 모든 중간값을 저장해야 한다. 

■ 즉, 식 (12)의 구현을 사용할 경우, 그래디언트를 계산하기 위해 모든 중간값 \( S_i \)를 저장해야 한다. 이것은 메모리 소비를 \( \text{max}(D, M) \)배만큼 증가시켜 메모리 복잡도가 매우 커질 수 있다. 

■ 이를 해결하기 위해, 저자들은 다음과 같은 방식으로 역전파를 수행한다. 인과적 선형 어텐션의 순전파와 역전파를 모두 선형 시간(linear time)과 상수 메모리(constant memory)로 계산할 수 있게 한다.

[출처] https://arxiv.org/abs/2006.16236

- (1) forward function

- forward는 앞서 본 것처럼, \( i-1 \) 시점의 상태 \( S \)에 현재 \( i \) 시점의 정보 \( \phi(K_i)V_i^T \)를 더하여 \( i \) 시점의 상태를 업데이트한다.

- 업데이트된 \( i \) 시점의 상태 \( S \)와 현재 쿼리 \( Q_i \)를 사용하여 \( i \)번째 출력 \( \bar{V}_i \)를 계산한다.

- 루프를 도는 동안  \( S_1, S_2, \cdots, \)를 모두 저장할 필요 없이, 오직 하나의 변수 \( S \)만 유지하는 것을 볼 수 있다.

- (2) backward function

- return 부분을 보면, 손실 \( \mathcal{L} \)에 대한 \( \phi(Q), \phi(K), V \)의 그래디언트 \( \nabla_{\phi(Q)} \mathcal{L}, \nabla_{\phi(K)} \mathcal{L}, \nabla_V \mathcal{L} \)을 계산하는 것이 목표임을 확인할 수 있다.

- \( G \)는 forward의 output(forward function의 return)에 대한 손실의 그래디언트이다. 

- backward function은 두 부분으로 나뉜다.

- 각 \( i \) 시점에서 계산된 상태 \( S \)와 주어진 그래디언트 \( G_i \)를 사용하여, \( \nabla_{\phi(Q_i)} \mathcal{L} \)을 계산한다. 식 (13)

- \( i = 1 \)부터 \( N \)까지 진행하기 때문에 시간 복잡도는 \( O(N) \)이며, 중간값을 저장할 필요 없이 상태 변수 \( S \) 하나만 유지하는데, 변수 \( S \)는 \( N \)과 무관하기 때문에 \( O(1) \) 메모리가 요구된다고 할 수 있다.

- \( \nabla_{\phi(K_i)} \mathcal{L} \)와 \( \nabla_{V_i} \mathcal{L} \)을 계산하는 과정에서도 동일하게 \( i \) 시점의 \( S \)를 사용한다. 즉, 이 역방향 루프 역시 \( O(N) \) 시간 복잡도와 \( O(1) \)의 메모리가 요구된다. 

2.3.2 TRAINING AND INFERENCE

■ 자기회귀적(autoregressive) 트랜스포머 모델을 학습할 때는, 전체 시퀀스가 주어지므로 모든 계산을 병렬로 처리할 수 있다. 

■ 그러나 추론 시에는, \( i \) 번째 timestep의 출력을 \( i+1 \)번째 timestep의 입력으로 사용한다. 즉, 한 토큰씩 생성해야 하므로 병렬화가 불가능하다.

■ 또한, 트랜스포머의 추론 과정에서 timestep당 비용은 상수가 아니다: 이전의 모든 timestep에 대해 어텐션이 계산되어야 하기 때문에, 현재 시퀀스 길이의 제곱에 따라 확장된다. 

- 예를 들어, 추론 과정에서 \( i \) 번째 토큰을 생성한다면 모든 이전 토큰들과의 어텐션을 처음부터 다시 계산해야 한다. 

- 따라서 timestep당 비용은 현재 생성하려는 \( i \) 길이의 제곱이 된다. 

■ 저자들이 제안한 선형 트랜스포머 모델의 경우, 하나의 예측에 대한 시간 및 메모리 비용이 상수(constant)이다. 이는 \( \phi(K_j)V_j^T \) 행렬(\( S \) 행렬)을 내부 상태로 간단히 저장하고, RNN처럼 매 timestep마다 업데이트할 수 있기 때문이다.

■ 이러한 RNN과 같은 추론 방식 덕분에, 추론 단계에서 기존 트랜스포머가 시퀀스가 길어질수록 점점 느려지는 것과 달리, 선형 트랜스포머는 일정한 속도로 토큰을 생성할 수 있다. 



3. Experiments

■ 선형 트랜스포머의 성능을 확인하기 위해 합성 데이터(synthetic data)에 대해 계산 비용, 메모리 소비, 수렴성 측면에서 선형화된 어텐션을 평가한다. 

■ 그리고 선형 트랜스포머의 효과를 더 보여주기 위해, 이미지 생성(image generation)과 자동 음성 인식(automatic speech recognition)에서 모델을 평가한다. 

■ 실험 전반에 걸쳐, 선형 트랜스포머와 두 가지 베이스라인(softmax attention을 사용하는 트랜스포머, Reformer)을 비교한다. 각 실험에 대한 결과는 다음과 같다. 

■ Fig 1은 시퀀스 길이가 길어짐에 따라 시간과 GPU 메모리가 어떻게 변화하는지 측정한 결과이다. 시퀀스가 길어짐에 따라 전통적인 방식인 softmax attention의 경우 시간과 GPU 메모리 모두 가파르게 증가하는 것을 볼 수 있다. 

■ 반면 Reformer(lsh-X)와 linear(선형 트랜스포머)는 시퀀스 길이가 길어져도 시간과 메모리가 거의 선형적으로 증가하는 것을 볼 수 있다. 즉, linear attention은 softmax attention에 비해 긴 시퀀스에 대한 확장성이 압도적으로 뛰어나다. 

■ Fig 2는 수렴 속도를 비교한 결과이다. \( x \)축은 gradient steps, \( y \)축은 cross entropy loss이다. softmax attention과 linear attention을 비교하면, 학습의 안정성이나 최종 성능 면에서 저하가 거의 없는 것을 확인할 수 있다. 

[출처] https://arxiv.org/abs/2006.16236

■ Table 1과 2는 MNIST, CIFAR-10에서 이미지 생성 성능을 비교한 것으로, softmax attention과 linear attention은 거의 동일한 수준의 Bits/dim을 달성했지만, 추론 속도는 linear attention이 142.8 Images/sec로, softmax attention보다 약 317배 빠른 것을 볼 수 있다. 

■ CIFAR-10에 대해서도 softmax attention과 linear attention의 Bits/dim은 큰 차이가 없지만, linear attention의 추론 속도가 softmax attention보다 약 4,462배 더 빠른 것을 확인할 수 있다.

■ Table 3는 자동 음성 인식에 대한 비교 결과로, softmax attention이 가장 좋은 성능을 보이지만, time/epoch가 2711로 가장 느리다.

linear attention의 경우 softmax attention과 비교했을 때, 약간의 성능 저하가 있지만 time/epoch가 824로 softmax attention보다 3배 이상 빠른 것을 볼 수 있다. 

[출처] https://arxiv.org/abs/2006.16236

■ 즉,  linear attention은 큰 성능 저하 없이 속도 면에서 압도적인 이점을 제공한다.