learningAI
article thumbnail

BERT에서 제시한 Masked LM 기법을 XLNet에서는 AE(Auto Encoding)이라고 하고, GPT에서 사용하는 모델링 기법을 AR(Auto Regressive)라고 한다. AE는 양방향 문맥에 대한 학습이 가능하다는 이점이 있지만, 마스킹된 토큰에 대한 연관성은 알 수 없다는 단점이 존재한다. AR은 양방향 문맥에 대한 판단이 필요한 문제에는 성능이 떨어진다. 따라서 XLNet은 AR과 AE의 장점만을 추출하여 AR이 양방향 문맥 학습을 가능하게 만든 Permutation Language Modeling을 제시한다. XLNet은 BERT보다 20개의 tasks에서 뛰어난 성능을 보인다.

 

Introduction

Unsupervised representation 학습은 자연어 처리 모델의 성능을 높이는데 아주 큰 역할을 한다. 그 중에서도 AR(Auto Regressive)와 AE(Auto Encoding)은 가장 효과적인 pre-training objective이다. 

AR

AR은 대표적으로 GPT에서 사용된 pre-training objective이며, 입력 x가 주어질 때, 다음 토큰을 예측하는 language modeling objective이다. 이때, 입력이 $x = (x_1, \cdots, x_{<t})$에 따라 다음 확률 분포를 예측한다:
$$p(x) = \Pi_{t=1}^{T} p(x_t | x_{<t})$$
$$p(x) = \Pi_{t=T}^{1} p(x_t | x_{>t})$$
위에 위치한 수식은 문장 토큰의 순서 그대로 다음 토큰을 예측하는 확률 분포이고, 아래 수식은 반대로 역순의 토큰들에 대한 $x_t$를 학습하는 분포이다. AR은 다음 factorization을 최대화하도록 학습된다:
$$\max_{\theta} log p_{\theta}(x) = \sum_{t=1}^T log p_{\theta}(x_t | x_{<t}) = \sum_{t=1}^T log \frac{exp(h_{\theta}(x_{1:t-1})^{\top} e(x_t))}{\sum_{x’} exp (h_{\theta} (x_{1:t-1}^{\top} e(x’)))}$$
AR은 uni-directional representation을 학습하도록 설계되어 있어 deep bidirectional context를 학습하기에 비효율적이다. 따라서 Downstream tasks에서 양방향 문맥에 대한 정보를 요구할 때, AR은 최적의 pre-training objective라고 할 수는 없다.

AE

AE는 AR의 약점인 양방향 문맥 학습을 가능하게 만든 pre-training objective이다. AE를 사용한 대표적인 예시가 BERT이다.
AE는 마스킹 과정을 통해 손상된 입력 시퀀스를 복구시키는 pre-training 방식이다. 즉, 입력 시퀀스의 일정 부분을 [MASK] 토큰으로 대체하고, 마스킹된 토큰을 예측하는 objective를 사용함으로써 양방향 문맥을 학습할 수 있다.
원본 데이터가 $\bar{x}$이고 마스킹된 데이터를 $\hat{x}$로 나타낼 때, AE objective는 다음과 같다.
$$\max_{\theta} log p_{\theta}(\bar{x} | \hat{x}) \approx \sum_{t=1}^{T} m_{t} log p_{\theta}(x_{t} | \hat{x}) = \sum_{t=1}^{T} m_{t} log \frac{exp(H_{\theta}(\hat{x})_t^{\top} e(x_t))}{\sum_{x'} exp(H_{\theta}(\hat{x})_{t}^{\top} e(x'))}$$
위 수식에서 $m_t = 1$는 $x_t$가 마스킹되었음을 나타낸다. 그리고 $H_{\theta}$는 length-T인 입력 $x$를 hidden vectors $H_{\theta}(x) = [H_{\theta}(x)_1, H_{\theta}(x)_2, \cdots, H_{\theta}(x)_T]$와 맵핑한다.

AE와 AR의 다음 세 가지의 차이점을 가진다.

  • Independence Assumption: 위의 수식 표현에서 $\approx$로 강조되었듯이, BERT는 결합 조건부 확률 $p(\bar{x} | \hat{x})$를 인수분해한다. AE는 마스킹된 토큰에 따라 독립적으로 예측되는 특징이 있으며, 반대로 AR은 입력 텍스트를 그대로 사용할 수 있기 때문에 independence assumption(독립적 추정)이 일어나지 않는다.
  • Input noise: BERT에서와 같이, AE는 [MASK]와 같은 추가적인 토큰을 필요로 하는데, 이는 pre-training 과정에서 input 텍스트를 손상시키고 복구시키는 과정에서 fine-tuning과의 괴리감을 발생시켜 pretrain-finetune discrepancy와 같은 문제가 생기게 된다. 그와 반대로 AR은 이와 같은 문제가 발생하지 않는다.
  • Context dependency: AR representation인 $h_{\theta}(x_{1:t-1})$은 position t까지의 단방향에 대한 문맥 정보를 활용하고, BERT representation $H_{\theta}(x)_{t}$는 양쪽 문맥의 접근을 가능하게 한다. 따라서, BERT에서 사용된 AE objective가 양방향 문맥을 capturing하는 것에 더 효과적이다.

 

해당 논문은 AR과 AE의 장점만을 추출한 permutation language modeling을 소개하며, 이를 구현하기 위해 target-aware representation과 two-stream self-attention에 대해 설명한다.

 

Permutation LM

AR과 AE는 서로만의 강점을 가지고 있다. XLNet은 둘의 강점만을 사용할 수 있는 pre-training objective를 사용하기 위해 permutation language modeling objective를 제시한다. 

AR은 AE가 입력에 노이즈를 발생시키면서 생기는 불일치와 독립적 추정이 일어나지 않지만 uni-directional representation에서 발생하는 비효율성이 존재한다. AR이 bidirectional context 학습을 가능하게 하기 위해 입력 시퀀스 x의 가능한 모든 순열에 대한 pre-training을 수행하며, 이를 permutation language modeling이라고 칭한다. 입력 텍스트 x의 길이가 T라고 했을 때, 가능한 모든 순열의 조합은 T! 이다. 모든 순열 조합에 대한 AR은 모델이 bidirectional context를 학습할 수 있게 한다. 그림과 예시는 아래와 같다.

  • $x = [1, 2, 3, 4]$
  • $Z_{t} = [[1, 2, 3, 4], [1, 2, 4, 3], [2, 1, 3, 4], \cdots, [4, 3, 2, 1]]$
  • $Z_{t}$의 길이 = 4! = 24

Factorization 순서가 $z = [1, 4, 2, 3]$일 때, 3을 예측하기 위해 1, 4, 2번째 토큰들이 사용되며 오른쪽 그림은 $z = [4, 3, 1, 2]$ 이므로 4번째 토큰이 입력으로 사용된다.

$z_t$와 $z_{<t}$는 각각 예측 토큰과 $(t-1)$까지의 요소를 나타내며, $z \in Z_{t}$이다. 이때, permutation language modeling objective의 수식 표현은 다음과 같다:

$$\max_{\theta} E_{z \sim Z_{t}}[\sum_{t=1}^{T} log p_{\theta}(x_{z_t} | x_{z_{<t}})]$$

※ 이때, 모델링 과정에서 입력 시퀀스 자체의 순서를 바꾸는 것이 아닌 factorization 순서만 바꿔야 한다. 따라서 기존의 입력 순서는 그대로 가지고 있으며, 섞인 순서에 대응되는 positional encoding을 사용해야 한다. Fine-tuning 과정에서는 원래 형태의 입력만 적용시킨다.

 

Architecture

Target-aware representations

$$p_{\theta}(X_{z_t} = x | x_{z_{<t}}) = \sum_{t=1}^T log \frac{exp(e(x)^{\top} h_{\theta}(x_{z_{<t}})}{\sum_{x’} exp(e(x)^{\top} h_{\theta}(x_{z_{<t}})}$$

기존의 Transformer의 구조를 그대로 사용할 때, 앞에서 설명한 permutation language modeling objective가 학습을 제대로 이끌어내지 못할 수 있다. 기본 Softmax formulation을 활용하여 next-token 분포 $p_{\theta}$를 예측하는 수식은 위와 같다.

  • $h_{\theta}(x_{z_{<t}})$은 target position을 신경쓰지 않는다.
  • Problem: $[x_t | x_{<t}]$가 $[6 | 1, 3, 2]$일때와 $[5 | 1, 3, 2]$일때를 구분하지 못한다.

 

문제를 해결하기 위해서 타겟 포지션 $z_t$를 추가로 입력받는 $g_{\theta}(x_{z_{<t}}, z_t)$ query representation을 사용한다.

$$p_{\theta}(X_{z_t} = x | x_{z_{<t}}) = \sum_{t=1}^T log \frac{exp(e(x)^{\top} g_{\theta}(x_{z_{<t}}, z_t)}{\sum_{x'} exp(e(x)^{\top} g_{\theta}(x_{z_{<t}}, z_t)}$$

 

Two-Stream Self-Attention

앞서 설명한 target-aware representation은 타깃 예측에 대한 모호함을 해결하지만, $g_{\theta}(x_{z_{<t}}, z_t)$를 어떻게 구현할지에 대한 문제가 아직 남아있다. 이것을 구현하기 위해 타깃 위치 $z_t$를 기준으로 attention과 context $x_{z_{<t}}$을 통해 정보를 축적하는 방식을 사용하게 된다. 하지만 이 방법을 적용하기 전에 두 가지 모순을 먼저 해결해야 한다.

 

  1. 타깃 토큰 $x_{z_t}$, $g_{\theta}(x_{z_{<t}}, z_t)$를 예측하기 위해서는 $x_{z_t}$의 내용을 포함하지 않는 위치 정보인 $z_t$와 이전의 토큰 정보 $x_{z_{<t}}$만을 사용해야 한다.
  2. 이후의 토큰 $x_{z_j} (j > t)$을 예측하기 위해서는  $g_{\theta}(x_{z_{<t}}, z_t)$가 $x_{z_t}$를 포함한 전체 정보를 가지고 있어야 한다.

$x_{z_t}$의 context를 포함하지 않는 hidden layer와 다음 토큰의 예측을 위해 포함하는 layer를 구분하기 위해 다음 두 가지 종류(Context representation, Query representation)의 hidden representation을 사용한다.

  • Context representation: $h_{\theta}(x_{z_{<t}})$는 기존의 Transformer의 hidden state와 유사하며, $x_{z_t}$를 포함한 context를 인코딩한다.

$$h_{z_t}^{(m)} \leftarrow Attention(Q = h_{z_t}^{(m-1)}, KV = h_{z_{\leq t}}; \theta)$$

  • Query representation: $g_{\theta}(x_{z_{<t}}, z_t)$는 $x_{z_t}$를 제외한 타깃 토큰 이전의 정보인 $x_{z_{<t}}$와 위치 정보 $z_t$에 접근할 수 있다.

$$g_{z_t}^{(m)} \leftarrow Attention(Q = g_{z_t}^{(m-1)}, KV = h_{z_{<t}}^{(m-1)}; \theta)$$

 

rTwo-stream self-attention의 전반적인 구조는 위 그림과 같다. (a)는 context representation, (b)는 query representation, 그리고 (c)는 전반적인 구조를 나타낸다. 

 

Partial Prediction

Permutation language modeling은 모든 factorization order에 대한 AR을 다루면서 bidirectional representation을 학습할 수 있지만, 그만큼 늘어나는 연산량을 어떻게 감당할 것인지에 대해 고민해야 한다. XLNet은 sequence $z (\in Z_t)$를 기준점 c($|z| / (|z| - c) \approx K$)를 기준으로 non-target subsequence $z_{\leq c}$와 target subsequence $z_{c>}$로 나누어 target subsequence만 예측하는 방식으로 연산량을 감소시켰다.

 

Transformer-XL

Permutation LM이 AR을 기반으로 하기 때문에, AR에서 좋은 성능을 보인 Transformer-XL 모델을 pre-training 구조에 사용했으며, 모델의 이름을 따서 XLNet이라는 이름을 지었다. Transformer-XL 중에서 두 가지 중요한 기술인 relative positional encoding과 segment recurrence mechanism을 추출하여 pre-training 모델에 적용한다. 

  • Segment recurrence mechanism: 이전 segment를 처리한 hidden state를 현재 hidden state와 결합하여 사용 가능한 정보의 범위를 넓힐 수 있는 방법이다.
  • Relative positional encoding: 이 방법은 segment recurrence mechanism을 사용하기 때문에, 기존의 절대적인 위치를 표현했던 positional embedding을 사용할 수 없기 때문에 상대적인 위치를 나타내는 positional encoding을 사용한다.

 

Long sequence에서 추출한 두 세그먼트 $\tilde{x} = s_{1:T}$와 $x = s_{T+1:2T}$가 있을 때, $\tilde{z} = [1, \cdots, T], z = [T + 1, \cdots, 2T]$이다. 이 때, 두번째 segment $x$에서 attention update로 활용되는 memory는 다음과 같이 표현한다. 

$$h_{z_t}^{(m)} \leftarrow Attention(Q = h_{z_t}^{(m-1)}, KV = [\tilde{h}^{(m-1)}, h_{z_{\leq t}}^{(m-1)}; \theta]$$

[] 로 표현된 sequence는 연결한다는 뜻이다. 결론적으로, 모델은 이전 segment의 모든 factorization order에 대한 메모리를 학습에 사용할 수 있게 된다.

 

Discussion

BERT는 [MASK] 토큰에 대한 예측, 그리고 XLNet은 슬라이싱된 target token에 대한 예측을 통해 partial prediction을 사용하게 된다. BERT는 모든 토큰이 마스킹되었을 경우, 유의미한 예측이 불가하기 때문에 필수적으로 partial prediction을 하게 된다. 또한, BERT와 XLNet에서 partial prediction은 모델이 최적화하기 더 쉽게 만들어준다. 하지만 BERT에서 사용하는 AE는 독립적 예측(independence assumption) 문제가 발생하게 되는데, 이는 모델 학습에서 비효율적이라고 할 수 있다.

  • x = [That, ball, is, red], log p(That ball | is red)
  • BERT = log p(That | is red) + log p(ball | is red)
  • XLNet = log p(That | is red) + log p(ball | That is red)

 

위의 objective 연산에 따라 XLNet은 [That, ball] 간의 연관성을 학습할 수 있게 된다. BERT는 두 토큰 간의 연관성을 판단할 수 없어 independence assumtion이 발생한다. 따라서 XLNet이 항상 더 많은 연관성을 판단할 수 있게 된다.

profile

learningAI

@YyunS

인공지능 공부하는 학생입니다!