모델이 vision-and-language task를 학습하기 위해서는 이미지와 언어 입력을 이해할 줄 알아야 하며, 무엇보다 가장 중요한 것은 두 modality(vision feature & language context)를 align시키는 것이다. 본 논문에서는 vision-and-language alignment를 학습할 수 있는 framework인 LXMERT를 제시한다. LXMERT는 3개의 encoder를 포함하고 있으며, 5개의 pre-training task를 정의하여 모델이 alignment를 더욱 잘 이해하도록 설계했다.
Fine-tuning 과정을 거친 후 LXMERT는 VQA와 GQA 데이터셋에서 SOTA(State-of-the-art)를 달성할 수 있었다. 또한 도전적인 visual-reasoning task인 NLVR 데이터셋에서도 22% 성능 개선을 보여줬다.
Model Architecture
LXMERT는 Transformer 모델을 기반으로 한 self-attention과 cross-attention 레이어로 cross-modality model을 구현했다. 위의 그림처럼 입력은 이미지와 그와 관련된 문장으로 구성된다. 각 이미지는 sequence of objects로 표현되고, 각 문장은 sequence of words로 나타낸다. Self-attention과 cross-attention layer를 세밀하게 디자인하고 조합함으로써 LXMERT는 입력에 따라 (Language, image, cross-modality) representations을 생성할 수 있다.
Input Embeddings
LXMERT의 input embedding layer는 입력을 word-level sentence embeddings와 object-level image embeddings로 변환한다.
Word-Level Sentence Embeddings
가장 먼저 sentence input은 WordPiece tokenizer에 의해 길이 $n$인 sequence ${w_1, \cdots, w_n}$으로 나눠진다. 그 다음으로 위의 그림과 같이 단어 $w_i$와 인덱스 $i$는 임베딩 레이어를 지나 벡터로 변형되며, index-aware word embedding으로 더해진다.
Object-Level Image Embeddings
CNN 네트워크의 출력인 feature map을 사용하는 대신, 해당 논문에서는 탐색된 object의 feature를 이미지 임베딩으로 사용한다. Object detector는 이미지에서 m개의 object ${o_1, \cdots, o_m}$를 추출하며, 각 object $o_j$는 positional feature $p_j$와 region-of-interest(RoI)로 나타낸다. 이를 직접적으로 사용하는 대신 full-connected layer를 연결하여 더한 position-aware embedding로 계산된다.
Encoders
LXMERT는 세 가지 encoder(language, object-relationship, cross-modality)를 사용하며, self-attention layer와 cross-attention layer를 포함하고 있다.
Attention layer는 query vector $x$와 context vector(key, value) ${y_j}$ 간의 연관성을 검색하는 것을 겨냥한다. 가장 먼저 $x$와 $y_i$ 간의 연관성 벡터인 $a_j$를 계산하고, softmax 함수를 통해 정규화된다.
Attention layer의 출력은 소프트맥스 정규화된 score에 따른 context vector의 가중합이다: $Attn_{X \rightarrow Y} = \sum_j\alpha_jy_j$ 이때, $x$와 $y_j$가 같을 경우, 이 레이어는 self-attention layer이다.
Single-Modality Encoders
Embedding layer 이후 image embeddings, word embeddings에 각각 object-relationship encoder와 language encoder를 적용한다. 해당 encoder는 single-modality를 인코딩하는 것에 집중한다. BERT와 다른 점은 word input에만 적용하는 것이 아니라 vision input에도 따로 적용된다는 것이다. 인코더의 각 레이어는 self-attention과 feed-forward sub-layer가 포함되고 있으며, language encoder의 레이어 수는 $N_L$으로 나타내며, object-relationship encoder의 레이어 수는 $N_R$로 나타낸다.
Cross-Modality Encoder
[그림 1]를 보면 cross-modality encoder는 "Cross"로 나타난 cross-attention과 self-attention, 그리고 feed-forward sub-layers로 이뤄져 있다. k-번째 레이어에서 계산을 위해 사용되는 query와 context vector는 이전 (k-1)번째 레이어의 출력이며, language feature $\{h_i^{k-1}\}$와 vision feature $\{v_j^{k-1}\}$이 입력으로 사용된다.
위는 cross-attention을 계산하는 수식으로, context vector를 교차해서 입력함으로써 two-modality connection을 학습하게 되는 것이다. Cross-attention 계산 이후는 이전의 single-modality encoder와 동일하게 계산되어 출력된다.
Output Representations
[그림 1]의 가장 오른쪽 부분을 보면 LXMERT는 language, vision, 그리고 cross-modality의 세 가지 종류의 출력을 하고 있는 것을 확인할 수 있다. Language와 vision output은 cross-modality에 의해 계산된 feature sequence이다. Cross-modality output의 경우는 BERT에서처럼 word stream의 입력 첫부분에 [CLS] 토큰을 삽입하여 계산된 representation을 사용한다.
Pre-training Strategies
LXMERT가 보다 나은 초기 상태와 vision-and-language connection을 잘 이해할 수 있도록 본 논문에서는 큰 데이터셋에서 각 modality에서 총 5 가지 pre-training을 진행한다. Pre-training task의 대략적인 구조는 아래 그림에서 확인할 수 있다.
Language Task: Masked Cross-Modality LM
Language modality에서는 1가지 task에 대해 pre-training을 수행한다. [그림 2]에서 오른쪽 아래에 명시된 것처럼 BERT와 동일하게 마스킹된 입력을 reconstructing하는 method를 사용한다. 하지만 BERT와 차이점이라고 할 수 있는 것은 cross-modality encoder를 통해 visual feature를 활용할 수 있도록 학습했기 때문에 language representation에서 부족한 정보를 visual feature를 활용하여 보충할 수 있다. 예를 들어 language feature만 사용한다면 위의 그림에서 "carrot"이라는 단어를 유추하기에 모호한 부분이 있지만 이미지 정보를 활용함으로써 보다 정확하게 예측할 수 있게 된다.
Vision Task: Masked Object Prediction
[그림 2]의 오른쪽 윗부분을 보면 object masking을 통해 두 개의 vision task를 수행하는 것을 볼 수 있다. Language task에서처럼 모델은 vision/language feature를 모두 활용하여 예측을 수행한다. Object feature를 활용하는 것은 object-relationship을 파악하는데 도움을 줄 것이고, language feature를 활용함으로써 모델은 cross-modality alignment를 학습할 수 있게 된다. Vision task는 RoI-feature Regression과 Detected Label Classification sub-tasks로 이뤄져 있다.
RoI-Feature Regression은 L2 loss를 통해 RoI feature $f_j$를 예측하는 vision sub-task이다.
Detected Label Classification은 마스킹된 object의 class를 예측하는 sub-task이며, cross-entropy loss를 사용한다. 이때, 데이터셋에 따라 라벨의 종류 및 명시된 이름이 다를 수 있다. 이런 이유로 해당 논문에서는 Faster R-CNN에서 탐색된 label의 출력을 사용한다.
Cross-Modality Tasks
[그림 2]의 오른쪽 가운데에서 볼 수 있는 것처럼 cross-modality task 또한 language/vision modality의 정보를 모두 필요로 하는 두 개의 pre-training task를 포함하고 있다.
Cross-Modality Matching은 이미지와 문장이 매칭되는지를 예측하는 task이다. 50%의 확률로 이미지와 align되는 문장이 입력될 수도 있고 랜덤한 문장이 입력될 수 있으며, 이는 BERT의 NSP(Next Sentence Prediction)와 비슷하다.
Image QA는 이미지에 대한 질문에 대답하는 task이다. [그림 2]에서는 "Who is eating the carrot?"에 대해 이미지에서 토끼가 당근을 먹고 있는 정보를 capture하여 "Rabbit"이라는 출력을 했다. 이때, 이미지와 문장이 매칭되지 않는 경우에는 해당 task를 스킵한다. Image QA task를 학습함으로써 모델은 더 나은 cross-modality representation을 학습할 수 있다.
Pre-training Data
위의 그림은 pre-training에서 사용된 데이터이다. 데이터는 MS사의 COCO 혹은 Visual Genome 이미지를 사용하며, VQA v2.0, GQA, VG-QA의 train/dev split이 사용된다. 최종적으로 9.18M개의 image-sentence pair가 수집되었다.
Pre-training procedure
- Input sentence는 WordPiece tokenizer로 나뉘고, object는 Faster R-CNN에 의해 탐색된다.
- Faster R-CNN은 pre-trained된 그대로 사용하며, fine-tuning을 하지 않는다.
- Padding을 사용하지 않음으로써 최적화를 위해 object 개수를 36개로 유지한다.
- $N_L = 9$, $N_R = 5$, $N_X = 5$, $H = 768$
- Image QA task를 위해 9500 답변에 대한 joint answer table를 형성한다.
Experiments
LXMERT는 세 가지 dataset에 대해 평가된다: VQA v2.0, GQA, NLVR. 논문에서 소개한 5개의 pre-training task에서 학습된 모델은 세 데이터셋에 대해 모두 좋은 성능을 보이며, SOTA 성능을 뛰어넘는 것을 확인할 수 있다.
Pre-train + BERT라고 되어있는 모델은 BERT의 pre-trained 파라미터를 LXMERT에 불러와 pre-training을 진행한 방식으로 오히려 악영향을 미치는 것을 볼 수 있으며, 이에 따라 LXMERT는 BERT의 파라미터를 불러오지 않고 처음부터 pre-training을 진행한다.
위의 표는 Image QA와 fine-tuning 영향력에 대한 ablation study이다. P10+QA10은 (10 epochs w/o QA + 10 epochs w/ QA)를 의미하고, P20은 Image QA task 없이 20 epochs 동안 학습하는 것이다. DA와 FT는 각각 data augmentation과 fine-tuning을 의미하며, data augmentation은 이전 논문들에서 VQA 데이터셋을 처리할 수 있는 모델을 구현하기 위해 종종 사용했던 방법이다. 결과적으로 위의 표에서 2와 4를 보면 Image QA task의 영향력이 큰 것을 알 수 있고, 또한 DA보다 FT가 효율적임을 알 수 있다.
Conclusion
LXMERT은 cross-modality connection을 학습하기 위한 3개의 encoder 구조와 5개의 pre-training task를 제시했다. Image QA task에 대해 학습함으로써 모델은 alignment 정보를 더 잘 학습할 수 있었고, experiment에서 결과로 확인할 수 있었다. 또한 다른 BERT 기반의 multi-modality 모델들과 다르게 BERT의 파라미터를 불러오는 것이 오히려 안 좋은 결과로 이어진다는 흥미로운 결과를 보여줬다.
논문