learningAI
article thumbnail

Summary

최근 연구된 language modeling 논문에서는 학습되는 모델의 크기가 커짐에 따라 NLP tasks에 적용했을 때의 성능이 나아질 것이라고 말한다. 하지만, 큰 모델은 메모리 한계와 같은 문제로 학습되기 쉽지 않다. Megatron-LM은 Billions of parameters(몇 십억개의 파라미터)를 학습시키기 위해 모델 병렬화 기법을 소개한다. 논문에서 소개한 병렬화 기법은 새로운 컴파일러를 필요로 하거나 라이브러리를 수정할 필요가 없으며, 간단한 계산 수식을 추가하여 구현할 수 있다. Megatron-LM은 새로 소개한 병렬화 기법을 활용하여 8.3B(83억)개의 파라미터를 가진 transformer 모델을 512개의 GPU로 구현한다. Baseline으로 사용된 모델은 단일 GPU로 학습된 모델이며, 최대 FLOPs의 30%인 39 TeraFLOPs를 가지며, 8.3B개라는 엄청난 파라미터 수를 가진 모델은 그보다 76% 높은 확장 효율성을 가진 15.1 PetaFLOPs를 유지한다. 

논문은 큰 language model이 SOTA를 달성할 수 있다는 것을 증명하기 위해 GPT-2와 비슷한 구조를 가진 8.3B개의 파라미터 수를 가진 모델과 BERT와 비슷한 3.9B개의 파라미터 수를 가진 모델을 학습시킨다. BERT 기반의 모델은 확장함에 따라 layer normalization의 위치가 중요했으며 이렇게 확장된 모델들은 WikiText103, LAMBADA, RACE에서 SOTA를 달성한다.

 

Introduction

NLP 분야는 가능한 컴퓨팅 자원과 데이터셋 크기의 증가를 통해 빠르게 발전하고 있다. 컴퓨팅 자원과 데이터가 풍부함에 따라 BERT와 GPT-2와 같은 더 큰 모델의  unsupervised pretraining을 수행할 수 있게 된다. 이전 연구에 따라 language model은 클수록 여러 task에 대한 성능이 향상된다는 것을 알 수 있으며, pretrained 모델은 finetuning함으로써 downstream tasks에서 SOTA를 찍을 수 있었다. 

모델의 크기가 커지면서 필요한 메모리는 점점 커졌고, 이를 해결하기 위해 activation checkpointing와 같은 기법이 추가로 사용되었다. 큰 모델의 또 다른 문제점은 널리 사용하는 ADAM optimizer은 파라미터 개수에 따른 추가 메모리가 필요하기 때문에 실질적으로 학습할 수 있는 모델의 사이즈를 축소시킨다. 기존의 여러 모델 병렬화 기법들은 가중치와 관련된 optimizer state를 분할함으로써 이러한 문제를 해소했다. 그 예로 GPipe, Mesh-Tensorflow가 있으며 이러한 기법은 모델을 수정하고, 라이브러리의 c++을 건드리는 귀찮은 작업을 포함한다.

따라서 이 논문은 intra-layer 모델 병렬화를 활용하여 간단한 병렬 처리 방법을 제시한다. Transformer의 내부적인 구조를 활용한 병렬화를 구현했으며, c++ 코드나 추가 컴파일러를 건드릴 필요 없이 Pytorch로 간단히 구현할 수 있다는 장점이 있다!!

모델의 확장성을 보여주기 위해 하나의 NVDIA V100 32GB GPU에서 학습된 1.2B개의 baseline 모델을 학습시키며, 이는 39 TeraFLOPs를 유지한다. 이는 이론적으로 DGX-2H 서버의 단일 GPU에 대한 최대 FLOPs의 30%이기 때문에 강력한 baseline 모델이라고 할 수 있다. 512개의 GPU에서 8-way 모델 병렬화를 사용하여 모델을 8.3 billion parameter까지 확장하면, 전체에서 15.1 PetaFLOPS/s를 지속적으로 달성한다. 이는 baseline 모델의 76%에 달하는 확장 효율성을 가진다.

모델 확장에 따른 성능을 분석하기 위해 GPT-2와 BERT의 큰 모델을 구현한다. BERT는 단순히 scaling만 했을 경우 성능이 감소했는데, layer normalization와 residual connection의 배치를 수정함으로써 모델 확장에 따른 성능이 증가했다. 결과적으로 확장된 모델은 여러 task에서 SOTA 성능을 보인다.

 

Parallel methods

그림 1. Transformer layer

 

Megatron-LM은 transformer 네트워크를 내부적으로 병렬 계산을 함으로써 간단한 parallelism을 구현한다. 위의 그림 1을 보면 transformer network를 attention layer와 MLP(Multi-Layer Perceptron)으로 구성된다. 우선 MLP 병렬화를 먼저 살펴보자. 아래의 그림 2처럼 MLP를 두 파트로 나눴을 때, 첫번째 부분은 GeLU 활성화 함수를 사용한 GEMM(GEneral Matrix Multiplication)이다:

이때, GEMM을 병렬화 할 수 있는 선택지는 두가지이다.

 

1. $X$를 열을 따라 나누고 가중치 $A$를 행을 기준으로 나눈다. 이렇게 나누게 된다면 계산은 $Y = GeLU(X_1 A_1 + X_2 A_2)$가 되며, 이는 $GeLU(X_1 A_1) + GeLU(X_2 A_2)$로 변환할 수 없으며 GeLU 이전의 동기화 포인트가 필요하게 된다.

2. 또 다른 방법은 가중치 $A$를 열을 기준으로 $A = [A_1, A_2]$ 나누는 것이다. 이렇게 되면 GeLU 함수를 독립적으로 적용시킬 수 있게 되면서 나뉘어진 GEMM을 수행할 수 있다.

이렇게 하면 1번 선택지와 달리 동기화 포인트를 제거할 수 있으며, GeLU를 따로 적용하기 위한 추가적인 communication이 필요하지 않다. 다음으로 두번째 GEMM의 출력은 dropout layer을 통과하기 전에 GPU에서 reduced된다. MLP의 두 개의 블럭 각각에서 순전파 과정에서 all-reduce 계산을 수행하는 $g$ 연산자와 역전파 과정에서 all-reduce 계산을 수행하는 $f$ 연산자가 있으며 이는 PyTorch 코드로 간단하게 구현할 수 있다.

 

그림 2. 연산자 $f$ 구현 코드

 

다음으로 attention block에 대한 병렬화는 Q, K, V에서 각각 이뤄진다. 이는 MLP 병렬화와 동일하게 열에 따라서 나누는 작업을 진행하며, $Q = [Q_1, Q_2]$, $K = [K_1, K_2]$, $V = [V_1, V_2]$와 같이 나뉜다. Attention block을 이와 같이 나눔으로써 attention head을 여러 GPU에서 작업할 수 있게 되며, self-attention에 대한 추가 communication을 필요로 하지 않는다. Self-attention 다음의 GEMM은 행에 따라 병렬화하여 이전의 attention 출력에 계산된다.

 

그림 3. 모델 병렬화

 

위 그림3에서와 같이 MLP에서도 동일하게 두번째 GEMM에서는 가중치를 행을 기준으로 나눔으로써 이전 GEMM의 출력과 곧바로 계산할 수 있게 한다. 최종적으로 정리하자면, 이러한 병렬화 기법은 syncronization point(동기화 포인트)를 필요로 하지 않으며, 순전파와 역전파에서의 all-reduce만으로 간단히 구현할 수 있게 해준다. 이때 all-reduce란, 각 GPU에서 연산된 결과를 다른 GPU에게 전달하여 결과를 동기화하고 다음 연산을 위해 준비하는 것이다.

Transformer 모델은 [H, V] 크기의 embedding을 사용한다. Embedding은  큰 사이즈를 가지기 때문에, 이 또한 병렬화를 해준다. Embedding 가중치는 output과 input에서 공유하기 때문에 두 embedding 가중치에 대한 modification 작업이 진행된다. $E_{H \times v}$은 $E = [E_1, E_2]$로 앞에서 설명한 병렬화와 동일하게 embedding matrix를 열을 기준으로 나뉜다. 나뉘어진 matrix는 embedding matrix의 일부만을 포함하고 있기 때문에 all-reduce($g$ 연산)는 input embedding 뒤에 위치한다. Output embedding의 경우, all-gather $Y = all-gather([Y_1, Y_2])$를 추가하여 logits를 도출한다. All-gather는 $b \times s \times v$를 communicate하는데, 이는 vocabulary size가 크기 때문에 아주 많은 계산을 요구한다. Communication size을 줄이기 위해 $GEMM[Y_1, Y_2]$의 출력에 대한 cross entropy loss를 계산함으로써 크기를 $b \times s$로 줄일 수 있게 된다. 

논문에서 제시한 병렬화 기법은 communcation을 줄이기 위한 method라고 할 수 있다. 단일 GPU가 dropout, 정규화, residual을 계산하여 다른 GPU로 broadcast하는 대신, 여러 GPU에서 계산을 하는 방법을 사용한다. 최적화하기 위해서 병렬화된 worker가 각각의 파라미터를 업데이트할 수 있도록 하며, 각 worker에 포함된 파라미터는 고유한 특성을 가지거나 GPU에 대하여 중복되기 때문에 이 또한 특별한 communication이 필요하지 않다.

 

Model settings

모델 pre-training은 NLP에서 아주 중요한 task이다. 대표적인 pretrained 모델인 GPT-2와 BERT를 통해 논문에서 제시한 병렬화 method를 적용한다. BERT와 GPT-2 간단한 리뷰는 제 지난 피드 참고해주세요! (BERT, GPT-2)

 

Datasets

Long-term dependency를 가진 학습 데이터를 수집하기 위해 Wikipedia, CC-stories, RealNews, 그리고 OpenWebtext와 같이 큰 모델들을 조합한다. Downstream task에 학습 데이터가 영향을 주지 않기 위해 WikiText103과 겹치는 Wikipedia 부분을 제거하고, CC-stories에서 불필요한 부분을 제거한다. BERT 학습에서는 BooksCorpus가 사용되지만, GPT-2에선 LAMBADA benchmark와 겹치기에 사용되지 않는다.

학습을 위해 모든 데이터셋을 결합한 다음, 통합된 데이터셋에서 콘텐츠 길이가 128 토큰 미만인 모든 문서를 필터링한다. 통합된 데이터셋에서 비슷한 문맥이 여러번 추출될 수 있기에 LSH(Locality-Sensitive Hashing)을 사용한다. LSH는 데이터에서 jaccard similarity가 0.7 이상인 문맥을 제거한다. 결과적으로 174GB의 통합 corpus 데이터가 된다.

 

Optimization and Hyperparameters

효율적으로 모델을 학습시키기 위해 mixed precision 학습 기법과 dynamic loss scaling을 사용하여 V100's Tensor Cores의 장점을 최대한 활용하고자 한다. 가장 먼저 가중치 $W$를 간단한 normal distribution $W$ ~ $N (0, 0.02)$로 초기화한다. 그 다음, residual connection 직전에 $\frac{1}{\sqrt{2N}}$으로 scaling된다. $N$은 self-attention, MLP 블럭으로 구성된 transformer layer의 개수이다. Optimizer는 Adam을 사용하고 weight decay를 $\lambda = 0.01$으로 설정된다. 추가적으로 큰 모델을 안정적으로 학습시키기 위해서 global gradient norm clipping를 1.0으로 사용한다. Dropout은 0.1으로 설정된다 (모델 전체에서). 최종적으로 메모리 추적의 보다 나은 관리를 위해 모든 transformer layer 이후에 activation checkpointing이 사용된다. 

GPT-2 모델은 학습률 1.5e-4으로 시작되어 300k iterations 동안 512 사이즈의 배치로 구성된 1024 sub-word sequence로 학습된다. 학습률은 3k iteration 동안 warmup period을 거친 후에 나머지 297k iteration 동안 single cycle cosine decay가 일어난다. Decay는 학습률이 1e-5가 되면 멈춘다. 

BERT는 기존의 vocab size인 30,522가 사용되었고, whole word n-gram  마스킹이 적용되었다. 학습 전체에서 배치 사이즈는 1024로 통일되었고, 학습률은 1.0e-4로 시작되어 10k iteration 동안 warmup session을 거쳐 나머지 2M iteration 동안 선형적으로 감소한다.

 

Experiments

모든 실험 과정은 32개의 DGX-2H server(512개의 Tesla V100 SXM3 32GB GPU)를 사용하여 진행된다. 모델의 확장성을 평가하기 위해서 4 가지 파라미터를 설정한다.

 

표 1. 모델 파라미터 세팅

 

Self-attention layer의 GEMM 크기를 유지하기 위해 (Attention heads) * 96 = (Hidden size)가 유지되며, 헤드와 레이어의 수를 변화시켜 모델은 1B에서 8B(10억에서 80억)까지의 파라미터 수를 가진다. 1.2B개의 파라미터 수를 가지는 모델은 하나의 GPU에서 구동되는 반면, 8B의 파라미터 수를 가진 모델은 8개의 GPU가 사용된다(8-way model parallelism). 기존의 vocabulary 크기는 50,257이었지만 GEMM의 계산 효율성을 최대로 끌어올리기 위해서 각 GPU당 vocab size를 128의 배수로 만들어야 한다. 따라서 8개의 GPU가 구동되므로 $128 \times 8 = 1024$이며 따라서 padded vocab size는 51,200이 된다. Model과 mode-data 병렬화 방법 모두 학습되었으며 batch size는 8로 고정된다. 한편, model+data 실험에서는 배치 크기를 512로 고정한다(GPU도 훨씬 많이 사용됨).

 

Weak scaling

본 논문은 model에 대한 병렬화와 model+data에 대한 병렬화의 weak scaling을 평가한다. Weak scaling은 컴퓨팅 자원이 늘어남에 따라 작업에 할당되는 자원의 비율을 말하며, 스케일링 수치가 높을수록 병렬 처리 시스템의 효율성이 높아진다. 

 

그림 4. Weak scaling. Linear scaling에 대한 percentage이다.

 

그림 4는 각 병렬화 기법에 대한 스케일링 값이다. 두 가지 세팅 모두 훌륭한 스케일링 수치를 나타내는 것을 볼 수 있는데, 8개의 GPU를 사용하는 8.3B 모델은 linear scaling의 77%를 달성하고 512개의 GPU를 사용하는 모델 또한 가장 큰 컴퓨팅 자원을 사용함에도 불구하고 baseline모델에 비해 74%의 스케일링 수치를 나타낸다.

 

GPT-2 scaling

표 2. GPT-2 모델 파라미터 세팅

GPT-2 based 모델은 표 2와 같이 스케일링 되었다. 355M 모델은 BERT-Large 모델과 같은 크기이며, 2.5B 모델은 기존 GPT-2 모델보다도 크며, 8.3B 모델은 가장 큰 모델이다. 이전에 설명된 학습 세팅을 따라 학습되었으며, 8.3B 크기의 모델은 한 epoch당 무려 2일 이상이 소요된다.

 

표 3. GPT-2 SOTA 달성

당연하게도 모델 크기가 커질수록 정확도와 perplexity 수치가 개선된다. Wikitext103과 LAMBADA에서 모두 SOTA를 달성한다.

 

BERT scaling

BERT를 스케일링하는 과정에서 336M개의 파라미터 수가 넘어가는 모델은 예기치 못한 성능 저하가 발생하였다. 이 문제는 레이어 내의 배치를 변화시켜 해결했다.

 

그림 5. BERT 스케일링

 

그림 5를 보면 (a)구조를 사용한 752M 크기의 모델의 loss가 발산하는 것을 확인할 수 있다. (b) 구조를 사용한 모델은 문제없이 loss가 잘 줄어드는 것을 볼 수 있다. 결과적으로 다음과 같이 SOTA를 달성한다.

 

표 4. BERT SOTA 달성

 

Discussions

해당 논문이 발행되기 이전에 진행되었던 연구들이 언급한 큰 사이즈의 모델은 더 나은 성능을 보일 것이라는 말을 직접 실행에 옮긴 연구가 Megatron-LM이라고 할 수 있겠다. 몇 십억개의 파라미터를 가진 모델을 훈련시킨 것도 충격적이었지만 가장 큰 모델인 GPT-2 8.3B 모델은 한 epoch당 이틀이 소요한다는 부분이 특히나 놀라웠다. 더 큰 모델 스케일링에 대해 다룬 논문들도 추후 읽어보고 어떻게 GPT-4까지 가게 된 것인지 알아봐야 할 것 같다.

논문은 아래 링크 참고!

https://arxiv.org/pdf/1909.08053.pdf

profile

learningAI

@YyunS

인공지능 공부하는 학생입니다!