An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
Photo by Ryan Stone / Unsplash

오늘은 2020년 10월에 발표된 Image Classification SOTA급 논문에 대해서 리뷰해보도록 하겠습니다. 이 논문에서 제시한 모델인 ViT(Vision Transformer) 는 ImageNet dataset에 대해서 top-1 error 88.55%를 기록하였습니다. 현재 2번째 순위이며 SOTA는 EfficientNet의 변형버전인 SAM이 88.61%로 차지하고 있습니다.

RELATED WORK

Vision Transformer는 기존의 classification domain model들과 다르게 아예 새로운 접근 방법을 선택하였습니다. NLP task에서 많이 쓰던 Transformer라는 개념을 Vision domain에 도입시킨 것입니다.

Transformer개념이 등장한 것은 2016년이고 현재로부터 4년 동안 많은 시도들이 있었습니다. 먼저 주변 픽셀끼리만 self-attention을 하는 Parmar et al.(2018), self-attention block으로 convolutions를 대체한 Ramachandran et al.(2019) 같은 접근이 있었습니다. 그 뒤 image 자체에 attention을 하는 것이 아닌 feature maps에서 self-attention을 한 Beelo et al.(2019), Carion et al.(2020) 등이 있었습니다. 지금까지는 image를 따로따로 보거나 feature map으로 보는 경우였다면 이미지 전체를 한번에 attention을 적용하는 경우가 iGPT(Chen et al. 2020)입니다. 이 모델은 ImageNet에 대해 top-1 error를 28%를 달성했습니다. SOTA를 따라잡기엔 아쉬운 결과였고 저자는 이 아이디어를 더 발전시켜 12% 이하로 만들게 됩니다.

METHOD

기본적으로 Transformer의 구조를 그대로 따릅니다. 다만 이미지 자체를 여러개로 나눠서 Sequence Data로 만들어야 하기 때문에 P*P 크기의 patch로 잘라 데이터를 생성합니다. 이 과정이 Transformer에서 Embedding이라고 생각할 수 있습니다.

이후 Positional Encoding을 해주는데 이미지 데이터의 위치는 1d가 아닌 2d이므로 2d Positional Encoding을 합니다. 2d로 할 때에는 vector의 반은 x축 Encoding, 나머지 반은 y축 Encoding을 합니다. Encoding 자체는 Transformer와 동일하게 진행합니다.

한가지 신기한 점은 맨 처음 셀에는 BERT의 class token처럼 이 이미지의 class를 나타내주는 벡터가 들어갑니다. 이 벡터는 ViT의 최종 feature maps에서 MLP를 통해 사이즈를 맞춘 벡터입니다. 즉 전체를 다 보고 하나하나씩 따로 보고 이 그림이 어떤 class인지 맞추게 되는 것입니다.

이렇게 Embedding을 하게 되면 이후 Transformer Encoder로 들어가게 됩니다.

이 Encoder는 Transformer의 Encoder와 거의 비슷하지만 다른 점이 몇개 있습니다. 먼저 after-norm이 아닌 pre-norm을 한다는 것입니다. 또한 MLP에서 activation function으로 ReLU가 아닌 GeLU를 사용하고 있습니다. 이것을 제외하고는 전부 같은 방식이며 Encoder만 쓰는 것이 BERT와 비슷하다고 볼 수 있습니다.

RESULT

기존 모델보다 더 좋은 성능을 확인하였습니다.

Attention 또한 잘 잡는 모습을 볼 수 있었습니다.

Question

class token이 정확히 어떻게 동작할까?
고양이가 패치들간의 사이에서 정확히 잘리면 그것을 어떻게 판단할까?
pre-norm 하는 이유?
GeLU 사용이유?
attention distance 계산방법?