[Paper Explain] MetaFormer: Khi Attention is NOT all you need cho bài toán phân loại ảnh
Yêu cầu nhỏ
Hiểu về các lớp Norm khác nhau hoạt động như nào: BatchNorm (BN), GroupNorm (GN), LayerNorm (LN) và biết cách sử dụng Pytorch
Mở đầu
Từ khi Transformer được áp dụng cho bài toán phân loại ảnh qua ViT, đã có rất nhiều models mới tập trung vào cải thiện phép Self-Attention. Một khối encoder của Transformer có kiến trúc chung như ở hình 1, bao gồm 2 thành phần. Phần đầu là nơi chứa Attention module dùng để "mix" thông tin của tokens với nhau, do đó ta đặt nó là "Token Mixer". Phần còn lại là các module nhỏ kiểu như Channel MLP và Skip-connection. Với sự thành công của Transformer, người ta cho rằng đó là do sự vượt bậc của phép Self-Attention. Do đó, có rất nhiều models mới tập trung vào cải thiện phép Self-Attention. Tuy nhiên, một số nghiên cứu lại chỉ đơn thuần sử dụng phép Spatial MLP làm Token Mixer và đạt được hiệu năng tương đương với phép Self-Attention. Hay thậm chí, có một nghiên cứu khác sử dụng Fourier Transform thay thế cho phép Attention và đạt được hiệu năng tới 97% một model ViT thuần.
Tác giả của paper MetaFormer nhận ra một điểm chung giữa các models được phát triển bây giờ, đó là các khối của chúng đều có kiến trúc 2 phần tương tự nhau. Phần đầu chứa phép Token Mixer, phần sau là Channel MLP và các skip-connection. Và tác giả cho rằng, thay vì tập trung quá nhiều vài cải tiến Token Mixer, kiến trúc tổng thể của một khối, cũng góp phần vô cùng lớn trong việc tạo ra một model có hiệu năng mạnh mẽ. Tác giả đóng gói kiến trúc tổng thể của một khối này lại, và đặt tên cho nó là MetaFormer (Hình 1). Tác giả đã tạo ra một hướng nghiên cứu mới, thay vì tập trung vào cải thiện Token Mixer, thì ta sẽ cải thiện kiến trúc tổng thể của nó, tức cải thiện MetaFormer.
Chứng minh
Ừ thì tác giả đề ra rằng kiến trúc tổng thể cũng vô cùng quan trọng, nếu không có gì để backup cho lời nói đó thì cũng chỉ là kết luận xàm xí thôi. Đầu tiên, tác giả sẽ chứng minh cho câu nói đó.
Một khối MetaFormer được chia ra làm 2 phần (Hình 1), phần đầu có Token Mixer và phần sau cho Channel MLP:
Token Mixer
Để chứng minh rằng MetaFormer là thứ cực kì quan trọng, thì tất nhiên ta phải sử dụng một Token Mixer cực yếu để xem rằng hiệu năng của toàn bộ model với Token Mixer cực yếu này có còn tốt?
Vì vậy, Token Mixer được chọn sẽ là phép Pooling. cụ thể là Average Pooling. Cụ thể, phép Pooling sẽ được định nghĩa như sau:
Modified Layer Norm
Thông thường trong kiến trúc họ Transformer, Norm được chọn để sử dụng là LayerNorm (LN). Tuy nhiên, tác giả nhận ra Norm này có vẻ không phù hợp cho lắm.
Trước tiên, ta phải hiểu về cách hoạt động của một số lớp Norm thường sử dụng và Norm của Pytorch API đã. Khi sử dụng kiến trúc họ Transformer, Feature maps sẽ có dạng với là . Khi thực hiện LN, ta sẽ norm ở trên chiều , và weight của LN sẽ có chiều luôn. Tuy nhiên, việc không sử dụng phép Self-Attention, tức feature maps sẽ có dạng . Và nếu lúc này thực hiện LN thông qua nn.LayerNorm
của Pytorch thì ta sẽ thực hiện norm ở trên chiều thay vì chỉ , và weight của LN cũng sẽ có chiều là . Nếu ta muốn áp dụng được LN chuẩn mực, tức là chỉ Norm trên chiều kể cả là feature maps có dạng , thì vẫn hoàn toàn có thể, chỉ là giờ chúng ta sẽ phải tự viết code chứ không còn áp dụng được nn.LayerNorm
của Pytorch nữa mà thôi.
Tóm cái váy lại thì:
- LayerNorm chỉ trên chiều channel như Transformer: Phải tự viết code lại, chỉ thực hiện norm trên chiều channel và weight cũng có chiều channel
nn.LayerNorm
: thực hiện norm trên chiều và weight có chiều
Oke thế giờ thì 2 Norm kể trên LayerNorm và nn.LayerNorm
có gì chưa tốt nếu sử dụng với feature maps có dạng . Nếu ta sử dụng nn.LayerNorm
, ta sẽ phải khai báo một lớp LN như sau:
self.layer_norm = nn.LayerNorm([C, H, W])
Điều này yêu cầu nn.LayerNorm
của ta phải có và cố định thì mới có thể khai báo và forward được. Điều này khiến cho mạng chỉ nhận đầu vào là một kích cỡ ảnh cố định không khả thi cho các downstream task.
Còn sử dụng LayerNorm thì hiệu năng yếu.
Vì vậy, tác giả đã tạo ra Modified Layer Norm (MLN). MLN sẽ thực hiện Norm trên chiều , tuy nhiên, weight của MLN sẽ có chiều . Nó là sự kết hợp giữa nn.LayerNorm
và LayerNorm. Khi khai báo, ta sẽ chỉ phải khai báo như sau:
self.modified_layer_norm = ModifiedLayerNorm(C)
Tức là giờ ta không còn phải cố định chiều và nữa.
Dưới đây là bảng kết quả so sánh MLN với LN và BN:
Kiến trúc của mạng
Bảng kết quả
Phần đáng mong đợi nhất đây. Đây là điều khiến cho câu nói của tác giả không trở thành câu chém gió vô căn cứ.
Mình có một câu hỏi thú vị muốn đặt ra cho các bạn đọc bài này. Phía trên CNN cùi bắp thì không nói, nhưng cùng là MetaFormer, tại sao PoolFormer lại mạnh hơn PVT, ViT hay Swin với việc chỉ sử dụng phép Pooling làm Token Mixer?
Cải thiện MetaFormer
Như đã nói ở trên, thay vì tập trung sáng tạo ra một Token Mixer mới, thì ta sẽ sử dụng những Token Mixer đơn giản, và cải thiện những thứ còn lại của một MetaFormer Block.
StarReLU
Trong paper Transformer, ReLU được chọn làm activation function:
Activation function này có độ nặng tính toán là 1 FLOP. Sau đó, GELU được sử dụng làm activation function chính cho các model họ Transformer:
Phép tính tốn mất 14 FLOPs, lớn hơn gấp 14 lần ReLU. Có một paper đã tìm ra phép thay thế gần đúng của GELU, gọi là SquaredReLU như sau:
SquaredReLU chỉ tốn có 2 FLOPs, tuy nhiên, hiệu năng của SquaredReLU vẫn không thể sánh ngang với GELU trên bài toán phân loại ảnh. Nhóm tác giả của MetaFormer cho rằng, việc hiệu năng tụt giảm có thể là do distribution shift (sự dịch chuyển phân phối) trên output của phép tính. Giả dụ tuân theo phân phối chuẩn với mean 0 và variance 1, ~, ta có:
với và lần lượt là expectation và variance. Do đó, nhóm tác giả tạo ra StarReLU để có thể giải quyết distribution shift như sau:
Tuy nhiên, đấy là giả định khi sử dụng input là normal distribution. Để StarReLU phù hợp với nhiều distribution hơn thì, thay vì sử dụng 2 hệ số cố định và , ta sẽ để cho nó tự học. StarReLU bản mở rộng như sau:
với và là learnable params (như BN). StarReLU lúc này chỉ tốn có 4 FLOPs
Các thay đổi khác
Output của một phần trong một khối MetaFormer được tính như sau:
với là input, là output, là lớp Norm và là Token Mixer hoặc channel MLP Scaling branch output. Giống như implicit knowledge dạng nhân, ta có 3 cách thêm scaling branch như sau:
- Residual scale: thêm hệ số scale vào phần Residual:
- Layer Scale: thêm hệ số scale vào phần tính toán
- Branch Scale: Kết hợp cả Layer Scale và Residual Scale
Hiệu năng của Residual Scale, theo những thử nghiệm của tác giả cho bài toán phân loại ảnh, đang là tốt nhất. Tuy nhiên, mọi người có thể thử các cách scale khác cho phù hợp với bài toán
IdentityFormer và RandFormer
Để chứng minh hiệu năng của việc cải thiện MetaFormer thay vì Token Mixer, tác giả của MetaFormer đã tạo ra IdentityFormer và RandFormer
IdentityFormer. Loại bỏ hoàn toàn Token Mixer ra khỏi model, hay nói cách khác là sử dụng Identity Mapping làm Token Mixer, ta có:
hay:
RandFormer. Sử dụng một ma trận ngẫu nhiên (Random Matrix) làm Token Mixer, và Random Matrix này sẽ không được cập nhật trong quá trình backprop. Tức là ma trận được khởi tạo ngẫu nhiên ra sao thì giữ nguyên như thế đến cuối
có là số token (), là số channels;
ConvFormer và CAFormer
Với việc đã chứng minh được cải tiến MetaFormer thay vì Token Mixer có thể đem lại hiệu năng không tưởng, ta sẽ sử dụng Token Mixer đơn giản và có sẵn để tạo ra một model mạnh thay vì phải đau đầu suy nghĩ ra một Token Mixer mới và phức tạp ConvFormer. Sử dụng Depthwise Separable Convolution trong MobileNetV2 làm TokenMixer:
với là DWConv (hay còn gọi là Point wise Conv), là DWConv, trong paper , và là một hàm phi tuyến tính.
CAFormer. Sử dụng nửa đầu giống như ConvFormer, còn nửa sau của model sử dụng Self-Attention.
Các bạn có thể thử suy nghĩ tại sao nửa sau của CAFormer lại sử dụng Self-Attention mà không sử dụng từ đầu nhé
TL;DR
[W]hat
- Một hướng đi mới trong việc nghiên cứu model cho bài toán phân loại ảnh, có thể mở rộng thành các backbone chung chung
[W]hy
- Mình cũng không có đánh giá gì về phần này vì nó không phải cải tiến những cái gì chưa tốt từ những model cũ mà mở ra một hướng nghiên cứu mới khá là thú vị
Ho[W]
- Tập trung vào cải tiến khối tổng thể (MetaFormer) thay vì chỉ tập trung vào cải tiến một module (Token Mixer)
Reference
PoolFormer: https://arxiv.org/abs/2111.11418
MetaFormer: https://arxiv.org/abs/2210.13452
All Rights Reserved