본문 바로가기

자연어처리/Speculative Decoding

Fast Inference from Transformers via Speculative Decoding

■ 이 논문에서는 autoregressive model의 추론 속도를 향상시키기 위한 방법으로 speculative decoding을 소개한다.

■ 이 방법은 re-training이나 아키텍처 변경 없이 기존의 off-the shelf models에도 적용하여 추론 속도를 가속화할 수 있다.

[2211.17192] Fast Inference from Transformers via Speculative Decoding

 

Fast Inference from Transformers via Speculative Decoding

Inference from large autoregressive models like Transformers is slow - decoding K tokens takes K serial runs of the model. In this work we introduce speculative decoding - an algorithm to sample from autoregressive models faster without any changes to the

arxiv.org

 

1. Introduction

■ 대표적인 대규모 autoregressive 모델로는 GPT-3, LaMDA, PaLM 등 large model은 small model보다 훨씬 더 나은 성능을 보여주고 있다. 그러나 large autoregressive model의 디코딩 속도는 small model보다 훨씬 느리다. 

autoregressive는 한 단어씩 순차적으로 생성하기 때문에, \( K \)개의 tokens을 생성하기 위해 \( K \)번의 순차적인 실행이 필요하다.  

■ large autoregressive model의 추론을 더 빠르게 만들기 위해 여러 접근법이 등장하였다. 
- (1) 모든 input에 대해 동일하게 추론 비용을 동등하게 줄이는 것
- (2) 모든 추론 단계가 동등하게 만들어지지 않았다는 관찰, 모든 예측 단계가 똑같이 어렵지는 않다. 어떤 단계는 매우 큰 모델을 요구하는 반면, 다른 단계들은 더 효율적인 모델로도 잘 근사될 수 있다.  
- 예를 들어, "나는 어제 도서관에" 다음에는 "갔다"가 나올 확률이 매우 높다. 즉, 쉬운 예측이 가능하다.
- 이렇게 쉬운 예측 단계에서는 더 작고 효율적인 모델을 사용하여 해결하는 것이다. 

■ 이러한 적응적 계산 방법(adaptive computation methods)은 더 쉬운 추론 단계를 위해 더 적은 계산 자원을 사용하는 것을 목표로 한다. 

■ 이러한 해결책 중 다수가 실제로 매우 효과적임이 입증되었지만 모델 아키텍처 변경, training procedures 변경 및 retraining을 요구하며, 동일한 output을 유지하지 못한다(원래 결과와 output이 달라진다).  

■ 위에서 일부 추론 단계는 "harder", 일부는 "easier"는 관찰이 이 논문의 핵심 동기이다. 

추가적으로 대규모 모델의 추론의 병목 현상이 종종 연산 과정이 아닌 메모리 대역폭과 communication에 의해 발생한다. 

■ 이는 GPU의 계산 유닛은 놀고 있으며, 데이터가 도착하기를 기다리는 시간이 길다는 의미이다. 즉, 남아도는 계산 자원이 있을 수 있다. 저자들은 이러한 남는 계산 자원을 활용하여 병렬성을 높이는 방향으로 추론을 가속화하고자 하였다. 

■ 구체적으로, 모델 아키텍처나 training procedure를 변경하거나, 모델을 retraining하지 않고도, 모델이 출력하는 output distribution을 유지하면서 추론 속도를 높이고자 하였다. 이는 추측 실행(speculative execution)을 통해 달성할 수 있다. 

speculative execution은 최적화 기술로, 어떤 작업이 실제로 필요한지 확인하는 것과 병렬로 그 작업을 수행하는 것이다.
- speculative execution의 대표적인 예로 branch prediction이 있다.

- 예를 들어 다음과 같은 코드가 있다고 하자.

if (x > 0):
    실행문1
    실행문2
    ...
    ...

- 예를 들어 if (x > 0) 문을 만났을 때, 프로그램은 일단 조건이 참(true) 이라고 가정(speculate) 하고 이후의 코드 블록(예: 실행문 1, 2, …)을 미리 실행한다. 이후 이 가정이 맞았으면 실행 결과를 그대로 사용하고, 틀렸다면 실행 결과를 폐기한다. 
- 단순한 조건인 x > 0이 아니라, 병목이 발생할 수 있는 조건문—예를 들어 x가 1천만 개의 원소를 가진 리스트이고, 그 안에 100이 존재하는지를 검사하는 경우를 생각해 보자. 이러한 조건문은 평가에 긴 시간이 소요될 수 있다. 
- 이때 speculative execution은 자원(예: cpu)가 단순이 대기하는 대신 미리 실행문들을 수행하는, 즉 자원이 유휴 상태로 남지 않게 하는 방법이다.  

■ 저자들은 speculative execution을 기반으로 한 speculative decoding을 제안하며, 이를 트랜스포머와 같은 autoregressive model의 디코딩에 적용하였다. 이에 대한 예시는 아래의 Fig 1에서 확인할 수 있다.  

[출처] https://arxiv.org/abs/2211.17192

■더 효율적인 모델(6M parameters)을 사용하여 더 큰 target model(97M parameters)에서 단 9번의 순차적 실행으로 38개의 토큰으로 이루어진 문장을 생성하는 것을 볼 수 있다.  
- 논문에서는 작기 때문에 더 빠른(즉, 효율적인) 작은 모델을 approximation model, 더 큰 모델(그렇기 때문에 추론 속도가 더 느린)을  target model이라고 표기하고 있다.  
기존 방식의 경우, autoregressive model에서 38개의 토큰을 생성하기 위해서는 매 토큰마다 순차적으로 38번의 실행이 필요하다. 
반면, 저자들이 제안한 "speculative decoding"을 사용하면, small model(6M)이 후보 토큰을 추측하고, large model(97M)이 이를 검증하는 방식을 통해, 결과적으로 large model을 단 9번만 순차적으로 실행하면 된다. (Fig 1) 
- 초록색 토큰은 approximation model이 생성한 토큰으로, 이 결과를 traget model에게 제안한다. 
- target model은 초록색 토큰을 보고 자신의 판단과 일치한다면, 이를 받아들여 사용할 수도 있고 사용하지 않을 수도 있다. Fig. 1에서 빨간색 토큰과 파란색 토큰은 각각 거부된 제안과 그에 대한 수정 사항을 나타낸다.
Fig 1 예시의 경우 생성 수행 횟수가 38 / 9 \( \approx \) 4.2배 정도 줄어듦에도 불구하고, 생성된 최종 문장의 확률분포는 원본과 동일하게 유지된다.  



2. Speculative Decoding


2.1 Overview

■ \( M_p \)를 추론 속도를 가속하려는 target model이라 하고, prefix \( x_{<t} \)에 대해 모델로부터 얻는 분포를 \( p(x_t \mid x_{x<t}) \)라고 하자.  

■그리고 \( M_q \)를 동일한 task/dataset에 사용되는 approximation model이라 하고, prefix \( x_{<t} \)에 대해 모델로부터 얻는 분포를 \( q(x_t \mid x_{<t}) \)라고 하자.  
■ 핵심 아이디어는 다음과 같다.
- (1) 더 효율적인 모델인 approximation model \( M_q \)를 사용하여 \( \gamma \in \mathbb{Z}^+ \)개의 completions을 생성하고, (\( \gamma \)는 \( M_q \)가 생성한 토큰의 개수)
- (2) target model \( M_p \)를 사용하여 \( M_q \)로부터 나온 모든 추측(guess)들과 각각의 확률들을 parallel하게 평가하여, 동일한 분포로 이어질 수 있는 것들을 accept하고, 
- (3) reject된 토큰이 있다면 해당 토큰을 수정하기 위해 조정된 분포에서 토큰을 샘플링한다. 만약, \( M_p \)가 생성한 모든 토큰들(\( \gamma \)개의 토큰들)이 accept되었다면, \( \gamma + 1 \)번째 위치에 해당하는 토큰 하나를 추가로 생성한다. 

■ 이렇게 하면 target model \( M_p \)의 각 병렬 실행은 최소 하나의 새로운 토큰을 생성할 것이기 때문에, target model의 순차적인 실행 횟수는 최악의 경우에도 단순한 autoregressive model보다 결코 많을 수 없다.  

■ 운이 좋으면(\( M_q \)의 어림짐작(speculative)이 모두 맞으면), \( M_p \)를 한 번 실행하고도 최대 \( \gamma + 1 \)개의 토큰을 한꺼번에 얻을 수 있어 속도가 크게 향상된다.   


2.2 Standardized Sampling

■ 언어 모델이 여러 후보 중에서 next token을 선택(샘플링)하기 위해 사용되는 argmax, top-k, nucleus, 그리고 temperature 설정과 같은 대중적인 방법들은 보통 로짓(logits) 수준에서 output을 조정한다.  

이 방법들은 원래의 확률분포를 어떤 규칙에 따라 조정된 확률분포로 변환한뒤, 그 새로운 분포에서 표준적인 샘플링을 수행하는 것과 동일하다.  

■ 예를 들어, argmax 샘플링은 가장 높은 값을 가진 값의 인덱스를 100% 확률로 선택하는 것이다. 이는 최댓값은 1로, 최댓값을 제외한 나머지 모든 요소들을 0으로 만들고 정규화하여, 정규화된(조정된) 새로운 분포에서 샘플링하는 것과 같다고 볼 수 있다.  

■ top-k도 마찬가지이다. top \( k \)개의 토큰을 제외한 모든 토큰의 확률을 0으로 만들고, 남아있는 \( k \)개의 토큰의 확률값들이 합이 1이 되도록 정규화한 분포에서 샘플링하는 것으로 볼 수 있다.  

■ 저자들은 이처럼 샘플링을 수행하는 방식이 다르더라도, 확률분포를 조정하여 샘플링하는 방식으로 통일할 수 있다고 주장한다. 그래서 \( p(x) \)와 \( q(x) \)가 각각 \( M_p \)와 \( M_q \)로부터 나온, 샘플링 방법에 맞게 조정된 분포라고 가정하였다.  

- 즉, \( M_p \)와 \( M_q \)의 샘플링 방식이 달라서 확률분포가 서로 다르더라도, 샘플링을 표준화하는 방식으로 동일하게 확률분포를 맞출 수 있다는 의미이다.

- 여기서 \( p(x) \)는 \( p(x_t \mid x_{x<t}) \)이며, \( q(x) \)는 \( q(x_t \mid x_{<t}) \)이다. 


2.3 Speculative Sampling

■ \( x \sim p(x) \)를 샘플링하기 위해, \( x \sim q(x) \)를 샘플링한 다음, \( q(x) \leq p(x) \)이면 그 샘플을 유지하고, \( q(x) > p(x) \)이면 \( 1 - \dfrac{p(x)}{q(x)} \)의 확률로 샘플을 거절하고,

조정된 분포 \( p' (x) = norm(max(0, p(x)-q(x))) \)에서 \( x \)를 다시 샘플링한다.   

- 샘플링 규칙을 정리하면 다음과 같다.

- (1) 먼저, \( M_q \)로부터 토큰 \( x \)를 하나 샘플링한다.

- (2) \( q(x) \leq p(x) \)라면, \( M_q \)가 \( M_p \)보다 \( x \)가 나올 확률이 더 높다고 동의(또는 확신)하는 것으로 볼 수 있다. 그러므로, 이 경우에는 \( M_q \)의 추측을 수락한다.

- (3) \( q(x) > p(x) \)는, target model \( M_p \)가 \( x \)가 나올 확률이 그렇게 높지 않다고 판단한 것으로 볼 수 있다. 이 경우에는 \( M_q \)의 추측을 거절할 수 있다. 

- 예를 들어, 만약 \( q(x) = 0.8, p(x) = 0.2 \)이라면, \( 1 - \dfrac{0.2}{0.8} = 0.75 = 75 \)%의 확률로 거절한다.

- \( q(x) > p(x) \)인 상황에서 \( p(x) \)와 \( q(x) \)의 차이가 클수록 거절될 확률이 높아진다. 

- (4) 추측이 거부되면, \( p' (x) = norm(max(0, p(x)-q(x))) \)에서 \( x \)를 다시 샘플링한다.  

■ 임의의 분포 \( p(x) \)와 \( q(x) \)에 대해, 이런 방식으로 샘플링된 \( x \)는 실제로 \( x \sim p(x) \)를 따른다는 것을 다음과 같이 증명할 수 있다. (Appendix A.1) 

- 임의의 분포 \( p(x) \)와 \( q(x) \)에 대해, \( p(x) \)와 \( q(x) \)로부터 speculative sampling을 통해 토큰 \( x' \)가 샘플링될 확률 \( P(x=x') \)가 처음부터 target model로 뽑았을 때의 확률 \( p(x') \)와 동일하다는 것을 증명하는 것이 목표이다.

- 여기서 조정된 확률분포 \( p'(x) = norm(max(0, p(x)-q(x))) = \dfrac{p(x)-\min(q(x),p(x))}{\sum_{x'} (p(x')-\min(q(x'),p(x')))} = \dfrac{p(x)-\min(q(x),p(x))}{1-\beta} \)이며, \( p'(x) \)의 정규화 상수는 \( 1 - \beta \)이다. 

- \( \beta \)는 speculative sampling에 의해 \( q(x) \)를 수락할 확률이다. (섹션 3.1)

- 최종적으로 토큰 \( x' \)가 선택되는 경우는 ① approximation model의 guess가 \( x' \)이며 이것이 수락되는 경우 ② approximation model의 guess가 거절되고, 재샘플링 과정에서 \( x' \)가 뽑히는 경우이다. 

- 그러므로, 토큰 \( x' \)가 선택될 확률은 다음과 같이 나타낼 수 있다.  

- 여기서 \( P(guess accepted, x = x' ) \), 이 사건이 일어나려면 (1) 먼저, \( M_q \)가 \( x' \)를 제안해야 한다. 이 확률은 \( q(x') \)이다. (2) 그리고 제안된 \( x' \)가 수락되어야 한다. 섹션 2.3의 규칙에 따라 수락될 확률을 \( \min\left(1, \dfrac{p(x')}{q(x')}\right) \)로 표현할 수 있다. 

- 두 사건이 동시에 일어날 확률은 두 확률의 곱 \( q(x') \times \min\left(1, \dfrac{p(x')}{q(x')}\right) \)가 된다. 

- 여기서 min 함수를 2.3의 규칙에 따라 경우를 나눠보면,

- (1) \( q(x') \leq p(x') \), 즉 \( \dfrac{p(x')}{q(x')} \geq 1 \)이면 target model이 확신하는 경우이며 min = 1이되므로, 확률은 \( q(x') \times 1 = q(x') \)가 된다. 

- (2) \( q(x') > p(x') \), 즉 \( \dfrac{p(x')}{q(x')} < 1 \)이면 target model이 \( M_q \)의 결과를 확신하지 않는 경우이며 \( \min = \dfrac{p(x')}{q(x')} \)이므로, 확률은 \( q(x') \times \dfrac{p(x')}{q(x')} = p(x') \)가 된다. 

- \( P(guess rejected, x = x' ) \)는 (1) 먼저, \( M_q \)의 guess가 거절되어야 한다. (2) 그리고, 거절된 후 진행되는 샘플링 과정에서 \( x' \)가 선택되어야 한다.

-  (1)은 \( M_q \)가 어떤 토큰을 제안하든, 그 제안이 거부될 확률의 총합이다. 이는 전체 확률 1에서 \( \beta \)를 뺀 것과 같다. 즉, \( 1 - \beta \)로 나타낼 수 있다.

- 그리고 (2)는 거절되었을 때 사용하는 조정된 분포는 \( p'(x) \)이다. 따라서 \( p'(x) \)에서 \( x' \)가 뽑힐 확률은 \( p'(x') \)이다. 

- 그러므로 \( P(guess rejected, x = x' ) = (1-\beta)p'(x') \)이다.

- 이때 \( p'(x) = \dfrac{p(x)-\min(q(x),p(x))}{1-\beta} \)이므로 \( (1-\beta)p'(x') = (1- \beta) \times \dfrac{p(x')-\min(q(x'),p(x'))}{1-\beta} = p(x') - \min(q(x'),p(x')) \)가 된다. 

- 두 경우의 확률을 더하면 \( P(x=x') = \min(q(x'),p(x')) + p(x') - \min(q(x'),p(x')) = p(x') \)가 된다.  

- \( P(x=x') = p(x') \)가 성립하므로, speculative sampling을 통해 \( x' \)가 선택될 확률은, 처음부터 target model \( M_p \)를 사용했을 때 \( x' \)가 뽑힐 확률과 동일하다고 할 수 있다.  

- 이 증명으로 저자들 speculative decoding의 기반인 speculative sampling이 단순한 근사(approximation)가 아니라, 원본 모델의 출력 분포를 전혀 훼손하지 않는 방법임을 보여주었다.  

■ \( p^{'} (x) = norm(max(0, p(x)-q(x))) \)의 효과를 더 자세히 보기 위해, 예를 들어 토큰 a, b, ,c ,d가 있고 각 토큰에 대한 확률이 다음과 같다고 하자. 

token \( q(x) \) \( p(x) \)
a 0.3 0.5
b 0.4 0.2
c 0.1 0.1
d 0.2 0.3

■ 토큰 b는 \( q(x) > p(x) \)인 경우이다. \( M_q \)의 제안인 b를 \( M_p \)가 거절했다고 하자. 거절되면 조정된 확률분포 \( p^{'} (x) = norm(max(0, p(x)-q(x))) \)에서 다시 샘플링해야 한다. 

■ 확률분포를 조정하기 위해 \( max(0, p(x)-q(x)) \)를 적용하면 토큰 b가 나올 확률은 다음과 같이 조정된다. 

token \( p(x) - q(x) \) \( max(0, p(x) - q(x)) \)
a 0.2 0.2
b -0.2 \( max(0, -0.2) = 0 \)
c 0 0
d 0.1 0.1

■ 이렇게 조정을 거치면, 토큰 b가 조정된 확률분포에서 샘플링될 확률이 0이 되는 것을 볼 수 있다. 이미 거절한 토큰 b가 등장할 확률을 0%로 만들어 다시 등장하지 못하도록 하는 것이다. 

■ conditioning prefix에 \( M_q \)를 실행하여 얻은 분포 \( q(x) \)가 주어지면, 토큰 \( x_1 \sim q(x) \)를 샘플링할 수 있다. 그런 다음 prefix에 \( M_p \)를 실행하여 분포 \( p(x) \)를 계산하는 동시에, prefix \( + [x_1] \)에 \( M_p \)를 실행하여 다음 토큰인 \( x_2 \)의 분포를 병렬로 추측적으로 계산한다. 

■ 두 계산이 모두 완료되면 위와 같이 진행한다: 만약 \( x_1 \)이 거절되면 계산된 \( x_2 \)를 버리고 조정된 분포에서 \( x_1 \)을 다시 샘플링하며, 만약 \( x_1 \)이 수락되면, 두 토큰 \( x_1 \)과 \( x_2 \)를 모두 유지한다. 

■ 아래의 Algorithm 1은 이 아이디어를 일반화한 것으로, 한 번에 1개에서 \( \gamma+1 \)개의 토큰을 샘플링한다.  

[출처] https://arxiv.org/abs/2211.17192

- \( M_q \)를 자기회귀적으로 \( \gamma \)번 실행하여, \( \gamma \)개의 추측 토큰 \( x_1, \cdots, x_{\gamma} \)를 샘플링한다. 

- \( M_p \)를 병렬로 실행하여 prefix, prefix \( + [x_1] \), ..., prefix \( + [x_1, \cdots, x_{\gamma}] \)에 대한, \( \gamma + 1 \)개의 확률분포 \( p_1, \cdots, p_{\gamma +1} \)를 얻는다. 

- 그런 다음, sepeculative sampling을 수행한다. 

- 예를 들어, \( M_q \)가 5개의 토큰 \( x_1, \cdots, x_5 \)를 추측했다고 하자. \( \gamma = 5 \)

- target model \( M_p \)는 병렬로 생성된 토큰들을 prefix와 더하여 각각 6개의 토큰에 대한 확률분포 \( p_1, \cdots, p_6 \)를 계산한다. 예를 들어 \( p_6 \)은 \( M_p (prefix + [x_1, \cdots , x_5] ) \)의 결과이다.  

- 그런 다음, 각 추측 토큰 \( x_i \)에 대해, 해당 추측을 수락할지 확률적으로 결정한다. 

- 이를 위해 [0, 1] 구간의 균등분포에서 난수 \( r_i \)를 샘플링하고, 다음 식을 통해 수락할 추측의 개수 \( n \)을 결정한다.  

\( n \leftarrow \min \left(\{i - 1 \mid 1 \leq i \leq \gamma, r_i > \frac{p_i(x)}{q_i(x)}\} \cup \{\gamma\}\right) \) 

- \( n \)은 \( r_i > \dfrac{p_i(x)}{q_i(x)} \), 즉 추측이 거절된 위치(\( i \)) 중 1을 뺀 값에서의 최솟값이다. 다시 말해 "처음으로" 거절된 위치의 바로 이전 위치를 찾는 것이다. 

- 예를 들어 \( i = 1, 2, 4 \)에서 \( r_i \leq \dfrac{p_i(x)}{q_i(x)} \)로 추측이 수락되고, \( i = 3, 5 \)에서는 \( r_i > \dfrac{p_i(x)}{q_i(x)} \)로 추측이 거절되었다고 하자. 

- 이 예시의 경우, 추측이 거절된 위치는 \( i = 3 \)과 \( i = 5 \)이다. 그러므로 \( n = min(3-1, 5-1) = 2 \)가 된다.  

- \( n = 2 \)이므로 \( x_1, x_2 \)는 그대로 확정(수락)된다. 수정할 위치는 거절이 발생한 \( n + 1 = 3 \)번째 위치이다.  

- 이 3번째 위치를 채우기 위해 \( M_p \)가 계산해 둔 \( p_3 \)를 그대로 사용하는 대신, \( p_3 \)에서 \( M_q \)가 계산한 \( q_3 \)를 뺀 조정된 분포 \( p' \)를 만든다.  

- 그런 다음, 이 \( p' \)에서 토큰 \( t \)를 샘플링한다.

- 최종적으로 생성된 시퀀스는 \( [x_1, x_2, t] \)가 된다. 한 번의 병렬 연산으로 3개의 토큰이 생성된 것이다. 

- 만약 모든 위치 \( i \)에서 추측 토큰들이 수락되었다면, 한 번의 병렬 연산으로 \( \gamma + 1 = 6 \)개의 토큰이 생성되었을 것이다.  



3. Analysis


3.1 Number of Generated Tokens

■ 저자들은 알고리즘 1의 1회 실행에 의해 생성되는 expected number of tokens를 분석하였다. 

■ 이를 위해 acceptance rate \( \beta_{x < t } \)를 정의하였다. prefix \( x_{<t} \)가 주어졌을 때, \( \beta_{x < t } \)는 섹션 2.3에 따라, speculative sampling에 의해 \( x_t \sim q(x_t \mid x_{<t}) \)를 수락할 확률이다.  

■ 저자들은 \( \beta \)에 대한 기댓값 \( E(\beta) \)를 \( M_q \)가 \( M_p \)를 얼마나 잘 근사(모방)하는지에 대한 척도로 사용하였다. 

■ 만약 \( \beta \)들이 i.i.d.라는 가정을 하고(\( beta \)값은 매 토큰을 생성할 때마다 문맥에 따라 계속 변하므로), \( \alpha = E(\beta) \)라고 하면,  

■ 알고리즘 1의 1회 실행에 의해 생성되는 토큰의 수는 성공 확률이 \( 1-\alpha \)이고, 상한이 \( \gamma + 1 \)인 제한된 기하분포를 따르며, 알고리즘 1에 의해 생성되는 기대 토큰 수는 아래의 방정식 (1)을 만족한다.  

- \( \alpha = E(\beta) \)이므로, \( M_q \)가 target model \( M_p \)를 평균적으로 얼마나 잘 모방하는가를 나타내는 척도가 된다.  
- \( \alpha \)가 1에 가까울수록 \( M_q \)가 \( M_p \)를 더 잘 근사하여, 두 모델이 매우 유사하다는 의미가 된다. 
- i.i.d. 가정 하에서 첫 번째 추측을 시도했을 때, \( \alpha \) 확률로 성공(수락), \( 1- \alpha \) 확률로 실패(거부)이며, 두 번째 추측도 마찬가지로 \( \alpha \) 확률로 성공(수락), \( 1- \alpha \) 확률로 실패(거부)가 된다. 
- speculative decoding은 이 과정을 \( \gamma \)번 반복하거나, 중간에 실패하면 멈추는 것으로 볼 수 있다. 
- 이것은 성공 확률이 \( \alpha \)인 동전 던지기를 계속하다가, 처음으로 앞면(\( 1 - \alpha \) 확률의 실패)이 나오면 멈추는 기하분포와 같다. 단, 최대 \( \gamma \)번 시도한다는 상한이 존재한다.  
- 관심 있는 것은 "생성된 토큰의 수"이며, 모든 토큰이 거절 없이 수락되면 최대 \( \gamma + 1 \)개의 토큰을 생성, 거절이 발생하면 \( n + 1 \)개의 토큰이 생성된다.  
- 즉, \( \gamma + 1 \)개 토큰이 생성될 확률은, 성공 확률이 \( \alpha \)일 때 \( \gamma \)번 모두 성공하는 것으로 볼 수 있다.  
- 식 (1)은 평균 수락률 \( \alpha \)와 최대 추측 개수 \( \gamma \)라는 두 개의 파라미터를 알면, speculative decoding 한 번에 평균적으로 몇 개의 토큰이 생성될지 기댓값을 예측하는 공식이다. 

- 만약 \( \alpha \)가 0에 가깝다면, \( E( \text{#} generated tokens) \)는 1에 가까워진다. 즉, 평균적으로 한 번의 speculative decoding 실행에 1개의 토큰만 생성하게 되어, 기존 자기회귀 decoding 방식과 동일하게 동작하므로 속도 이점이 없다. 

- 만약 \( \alpha \)가 1에 가깝다면, 분모가 0에 가까워져 기댓값이 \( \gamma + 1 \)에 근접한다. 즉, 거의 항상 \( \gamma + 1 \)개의 토큰이 생성되어 속도가 크게 향상될 수 있다.  

- 여기서 \( \gamma \)는 \( M_q \)가 생성하는 추측 토큰의 수를 의미하므로, \( \alpha \) 값과 \( \gamma \) 값이 클수록 기대되는 토큰 생성 수도 증가한다. 이는 더 빠른 속도 향상으로 이어진다.  

[출처] https://arxiv.org/abs/2211.17192

■ Fig 2는 \( x \)축이 \( \alpha \), \( y \)축은 평균 생성 토큰 수로, 추측 토큰 개수 \( \gamma \)를 1, 3, 5, 7, 그리고 무한대로 바꿨을 때의 변화를 나타낸 것이다.

■ \( \alpha \)가 1에 가까울수록, 그리고 \( \gamma \)가 클수록 평균 생성 토큰 수(즉, 평균 속도 향상)이 증가하는 것을 볼 수 있다. 


3.2 Calculating \( \alpha \)

■ 저자들은 주어진 prefix와 두 모델 \( M_P \), \( M_q \)에 대해 \( \alpha \)를 계산하기 위한 공식을 유도하였다. 이는 divergence \( D_{LK} \)를 정의하는 것으로 시작한다. 

- \( D_{L, K} \)는 \( p \)와 \( q \)의 평균 분포(\( M \))로부터, 각 분포가 얼마나 떨어져 있는지를 나타낸다. 

- \( D_{LK}(p, q) = \sum_{x} |p(x) - M(x)| = \sum_{x} |q(x) - M(x)| \), 여기서 \( M(x) = \dfrac{p(x)+q(x)}{2} \) 

- 이 \( D_{LK} \)는 \( p \)와 \( q \)의 평균 분포(\( M \))로부터, 각 분포가 얼마나 떨어져 있는지를 나타낸다. 

- \( M(x) = \dfrac{p(x)+q(x)}{2} \)임을 이용하면, \( D_{LK}(p, q) = \sum_{x} |p(x) - M(x)| = \sum_{x} | \dfrac{2p(x)-p(x)-q(x)}{2} | = \dfrac{1}{2} \sum_{x} |p(x) - q(x)| \) 

- 여기서 \( \sum_{x} |p(x) - q(x)| = \sum_{x} \max(p(x), q(x)) - \sum_{x} \min(p(x), q(x)) \)로 나타낼 수 있다. \( |a-b| = \max(a, b) - \min(a, b) \)임을 이용한 것이다. 

- 예를 들어, \( a = 5, b= 2 \)이면, \( | a - b | = \max(5, 2) - \min(5, 2) = 3 \)이고 \( a = 2, b = 5 \)여도 \( |a - b| = \max(2, 5) - \min(2, 5) = 3 \)이 된다.  

- \( \max(p(x), q(x)) - \min(p(x), q(x)) \)를 모든 \( x \)에 대해 합산하면 \( \sum_{x} |p(x) - q(x)| = \sum_{x} \max(p(x), q(x)) - \sum_{x} \min(p(x), q(x)) \) 

- 이때 \( p(x) \)와 \( q(x) \)는 확률분포이므로, 모든 \( x \)에 대해 합하면 1이 된다. 즉 \( \sum_{x} p(x) = 1 \)이고 \( \sum_{x} q(x) = 1 \)이다.  

- 그러므로, \( \sum_{x} \max(p(x), q(x)) + \sum_{x} \min(p(x), q(x)) = 2 \)라고 할 수 있다. 

- \( \sum_{x} |p(x) - q(x)| = \sum_{x} \max(p(x), q(x)) - \sum_{x} \min(p(x), q(x)) \)와 \( \sum_{x} \max(p(x), q(x)) + \sum_{x} \min(p(x), q(x)) = 2 \)라는 두 식을 연립하여 \( \sum_{x} |p(x) - q(x)| \)를 \( \min \)으로 나타낼 수 있다.  

- \( \sum_{x} \max = 2 - \sum_{x} \min \)로 정리하여 첫 번째 식에 대입하면, \( \sum | p-q | = 2 - \sum \min - \sum \min = 2 - 2 \sum \min \)이 된다.  

- 따라서 \( D_{LK} \)를, \( D_{\text{LK}}(p, q) = \dfrac{1}{2} \sum_{x} | p(x) - q(x) | = \dfrac{1}{2} \left( 2- 2 \sum_{x} \min(p(x), q(x)) \right) = 1 - \sum_{x} \min(p(x), q(x)) \)로 나타낼 수 있다.  

- \( \sum_{x} \min(p(x), q(x)) \)는 0에서 1의 값을 갖으므로, \( D_{LK}(p, q) \)의 값도 0에서 1의 값을 갖게 된다.

- 두 분포가 완전히 같으면(\( p = q \)) \( D_{LK} = 0 \)이 되고, 두 분포가 전혀 겹치지 않으면 \( D_{LK} = 1 \)이 된다.

(예: \( p(x) = [1.0, 0.0], \; q(x) = [0.0, 1.0] \)이면, \( \frac{1}{2} [ |1.0-0.0| + |0.0-1.0| ] = 1 \))

- 섹션 2.3의 규칙을 따라 \( \beta = \sum_{x} \min(p(x), q(x)) \)로 나타낼 수 있다.

- Lemma 3.3의 \( D_{LK} (p, q) = 1 - \sum_{x} \min(p(x), q(x)) \)에 대입하면, \( \beta = 1 - D_{LK} (p, q) \)가 유도된다. \( \alpha = E(\beta) = E(1-D_{LK}(p, q) = 1 - E(D_{LK}(p,q)) \)

- 즉, \( \alpha \)는 두 분포의 평균적인 차이 \( D_{LK} \)가 클수록/작을수록 작아진다/커진다. 

- 또한, Theorem 3.5의 \( \beta = \sum_{x} \min(p(x), q(x)) \)를 이용하면, \( \alpha = E(\beta) = E(\sum_{x} \min(p(x), q(x))) \)가 된다. 

- Lemma 3.3에 의해 \( \sum_x \min(p, q) = 1 - D_{LK}(p, q) \)이며, 두 값은 모두 \( \beta \)와 같다. 

- 전체 생성 과정에서 평균 수락률 \( \alpha = E(\beta) \)이므로, \( \alpha \)에 대해서도 동일한 관계가 성립된다. 이것이 바로 Corollary 3.6이다. 

■ 아래의 Table 3은 저자들이 실험에서 경험적으로 측정하여 관찰한 \( \alpha \) 값들이다. 

[출처] https://arxiv.org/abs/2211.17192


3.3 Walltime Improvement

■ 저자들은 i.i.d 가정 하에 알고리즘 1이 target model에 대한 호출 횟수를 \( \dfrac{1-\alpha^{\gamma+1}}{1-\alpha} \) 배만큼 줄일 수 있다는 것을 보여주었다. 

■ 이 섹션에서는 speculative decoding을 사용할 때와 사용하지 않을 때의 전체 시간 개선(speculative decoding을 적용했을 때 실제 시간이 얼마나 단축되는지)을 확인한다.   

■ 이를 위해 \( M_q \)의 1회 실행 시간과 \( M_p \)의 1회 실행 시간 사이의 비율을 cost efficient \( c \)로 정의하였다.  

■ 사용하는 모델과 해결하고자 하는 task에 의해 달라지는 \( \alpha \)와 달리, \( c \)의 값은 하드웨어와 소프트웨어(구현 세부사항)에 따라 달라진다.  

■ 논문에 따르면, 저자들 실험(일반적으로 \( M_q \)가 \( M_p \)보다 수십 배 작은 상황)에서 \( c \)는 항상 0.05 미만이었고, 종종 0에 가까운 값을 가졌다고 한다.  

■ 생성되는 토큰 수에 \( c \) 값을 결합한, 알고리즘 1에 의한 총 실제 시간의 기대 개선율은 \( \dfrac{1-\alpha^{\gamma+1}}{(1-\alpha)(\gamma c +1)} \)로, 섹션 3.1에서 유도한 이론적으로 생성되는 평균 토큰 수와 speculative decoding 한 스텝에 드는 비용으로 구성되어 있다. 

■ \( gamma c + 1 \)은 \( M_q \)를 \( \gamma \)번 실행하는 비용 \( \gamma c \)와 \( M_p \)를 병렬로 한 번 실행하는 비용 1의 합이다. 

■ \( \gamma \)를 키우면(즉, 더 많은 토큰들을 추측하면), 생성되는 평균 토큰 수가 증가하지만, 분모(비용)도 함께 증가한다. 따라서 \( \alpha \)와 \( c \) 값에 따라, 이 개선율을 최대화하는 최적의 \( \gamma \)를 찾을 수 있다.

■ Theorem 3.8을 따르면, \( \alpha \)가 \( c \)보다 커야만, speculative decoding을 사용하는 것이 실행 시간에 있어 이득이다. 

■ \( \alpha > c \) 조건이 만족될 때, \( \gamma = 1 \)로 설정하면, 최소 \( \dfrac{1+\alpha}{1+c} \) 배의 속도 향상을 보장할 수 있다. 


3.4 Number of Arithmetic Operations

■ 알고리즘 1은 \( M_p \)를 \( \gamma + 1 \)번 병렬로 실행하므로, concurrent 산술 연산의 수는 \( \gamma + 1 \)배 증가한다. 

■ 알고리즘 1이 실행당 최대 \( \gamma + 1 \)개의 토큰을 생성하므로, 총 계산량은 표준 디코딩 알고리즘보다 더 높을 수 있다. 병렬 처리를 한다고 해서 총 계산량이 줄어드는 것이 아니기 때문이다.  

■ speculative decoding은 \( M_q \)가 제안한 샘플을 수락할 경우, 총 계산량은 표준 디코딩보다 증가하지 않지만, 거부할 경우 더 많은 계산 비용이 발생할 수 있다.   

■ \( M_q \)와 \( M_p \)의 토큰당 산술 연산의 비율을 \( \hat{c} \)라고 했을 때, 알고리즘 1의 총 연산 수의 기대 증가율은 \( \dfrac{(1-\alpha)(\gamma \hat{c} + \gamma + 1)}{1 - \alpha^{\gamma + 1}} \)이 된다. 

■ 이 증가율은 한 스텝의 총 계산량과 평균 생성 토큰 수(식 (1))를 나눈 것이다. 
- 분자 \( \gamma \hat{c} + \gamma + 1 \)은 speculative decoding 한 스텝에 필요한 계산량으로, \( M_q \)를 \( \gamma \)번 실행하는 계산량 \( \hat{c} \gamma \)와 \( M_p \)를 \( \gamma + 1 \)번 병렬로 실행하는 계산량 \( \gamma + 1 \)의 합이다. 
- 분모는 섹션 3.1에서 유도한 한 스텝에서 생성되는 평균 토큰 수이다.  
- 직관적으로, \( \alpha \)가 낮으면(즉, \( M_q \)의 추측이 자주 틀리면), 분모는 1에 가까워지고 분자는 커지게 된다. 즉, speculative decoding을 사용함으로써 연산량이 증가한 것이다.  
- 반대로 \( \alpha \)가 높으면, 분모가 커지고 분자는 작아진다. speculative decoding을 사용함으로써 연산량이 줄어들었다는 의미이며, 계산량 낭비 없이 속도 향상을 기대할 수 있다.  

■ 총 산술 연산 수와 달리, 총 메모리 접근 횟수는 줄어들 수 있다. 

■ 구체적으로, target model의 가중치와 KV 캐시는 알고리즘 1의 실행당 한 번만 읽어오면 된다. 

■ 그리고 이 한 번의 읽기로 평균 \( \dfrac{1-\alpha^{\gamma+1}}{1-\alpha} \) 개의 토큰을 생성할 수 있다. 결과적으로, 토큰을 생성하기 위해 필요한 평균 메모리 접근 횟수가 \( \dfrac{1-\alpha^{\gamma+1}}{1-\alpha} \) 배만큼 감소한다.  


3.5 Choosing \( \gamma \)

■ \( c \)와 \( \alpha \)가 주어지고 충분한 계산 자원이 있다고 가정했을 때, 최적의 \( \gamma \)는 Theorem 3.8의 \( \dfrac{1-\alpha^{\gamma+1}}{(1-\alpha)(\gamma c + 1)} \)을 최대화하는 값이다.  

■ \( \gamma \)는 정수이므로, Fig 3에서 보듯이 수치적으로 쉽게 찾을 수 있다.  

- \( \gamma \)를 늘리면(즉, 더 많이 추측하면), 분자의 '토큰 이득'은 증가하지만, 동시에 분모의 '실행 비용'도 증가한다. 즉, 무작정 \( \gamma \)를 키운다고 해서 효율이 증가하는 것이 아니며, 비용 대비 편익이 가장 커지는 최적의 \( \gamma \)를 찾아야 한다. 

- \( \gamma \)는 정수이므로 Fig 3에서 볼 수 있듯이, \( \gamma = 1, 2, 3, \cdots \)을 순서대로 대입하여, Theorem 3.8의 식을 최대화하는 \( \gamma \)를 찾으면 된다.

■ Table 1과 Fig 4는 \( c = \hat{c} = 0 \)이라고 가정할 때, 추론 속도와 총 산술 연산 수 사이의 trade-off를 나타낸 것이다.  

[출처] https://arxiv.org/abs/2211.17192
[출처] https://arxiv.org/abs/2211.17192


 


4. Experiments


4.1 Empirical Walltime Improvement

Setup

■ 저자들은 T5 논문의 두 가지 태스크에서 T5 version 1.1 모델을 테스트하였다: (1) WMT EnDe로 파인튜닝된 영어-독일어 번역 (2) CCN/DM으로 파인튜닝된 텍스트 요약 

■ 두 태스크 모두에 대해 \( M_p \)로 11B T5-XXL을 사용하고, \( M_q \)로는 800M T5-large, 250M T5-base, 77M T5-small과 같은 기존의 off-the-shelf 모델을 사용하였다.  

■ argmax sampling(temp=0)과 standard sampling(temp=1) 모두에 대해 하나의 TPU-v4에서 배치 크기 1로 실제 시간 개선을 측정하였다.  

Results

[출처] https://arxiv.org/abs/2211.17192

■ Table 2에서 \( c \)와 \( \alpha \)의 균형이 좋은 T5-small이 테스트된 approximation model들 중에서 가장 높은 속도 향상을 달성한 것을 확인할 수 있다.  

■ approximation model의 크기가 커짐에 따라 \( \alpha \)가 증가하고 속도 향상이 줄어드는 것을 볼 수 있다.

approximation model의 크기가 커질수록 \( c \)도 함께 커져서 \( c \gamma + 1 \) 항 때문에 전체 속도 향상이 줄어든다.  

■ 또한, \( \alpha \)와 실제 시간 개선은standard sampling(temp=1)보다 argmax sampling(temp=0)에서 더 높은 것을 볼 수 있다.