0

Paper reading | Expanding Language-Image Pretrained Models for General Video Recognition

Giới thiệu chung

Video recognition là một lĩnh vực trong trí tuệ nhân tạo (AI) và thị giác máy tính tập trung vào việc phân tích và nhận dạng nội dung trong các video. Mục tiêu của video recognition là hiểu về các hình ảnh, đối tượng, hành động và sự tương tác trong video giống như cách con người làm. Công nghệ video recognition đã phát triển đáng kể nhờ sự tiến bộ trong AI, xử lý ảnh và tài nguyên tính toán.

Có nhiều ứng dụng quan trọng của video recognition trong thế giới thực, có thể kể đến một số ví dụ điển hình sau:

  • Giám sát an ninh: Video recognition được sử dụng trong hệ thống giám sát an ninh để phát hiện và nhận dạng các hoạt động đáng ngờ, như xâm nhập, vật thể nghi vấn và hành vi không phù hợp. Việc này giúp cải thiện đáng kể khả năng phát hiện và giám sát trong các khu vực như ngân hàng, sân bay, trung tâm mua sắm và các cơ sở quân sự.

  • Xử lý video tự động: Video recognition được sử dụng để tự động phân loại và gắn nhãn các video dựa trên nội dung chúng. Điều này giúp tạo ra các công cụ tìm kiếm video thông minh và hệ thống gợi ý video, đồng thời cải thiện trải nghiệm người dùng và khả năng quản lý nội dung trên các nền tảng video trực tuyến.

  • Xe tự hành: Video recognition cũng đóng vai trò quan trọng trong xe tự hành. Các hệ thống xe tự hành sử dụng video recognition để phát hiện và nhận dạng các vật thể xung quanh, như người đi bộ, xe đạp, ô tô và biển báo giao thông. Điều này giúp xe tự hành đưa ra quyết định an toàn và tương tác thông minh với môi trường xung quanh.

  • Quảng cáo và truyền thông: Video recognition cung cấp khả năng phân tích nội dung video. Các công ty quảng cáo có thể sử dụng thông tin này để tạo ra các chiến dịch quảng cáo được cá nhân hóa hơn và đưa ra đề xuất sản phẩm phù hợp dựa trên sở thích và hành vi xem video của khách hàng.

Đó chỉ là một số ứng dụng phổ biến của video recognition, lĩnh vực này đang phát triển và mở ra nhiều cơ hội mới trong nhiều ngành công nghiệp khác nhau.

Với nhiều ứng dụng thực tiễn, nhiều nghiên cứu đã được thực hiện để cung cấp giải pháp cho bài toán này.

Đóng góp của bài báo

Đóng góp của bài báo:

  • Thiết kế một kiến trúc mô hình mới cho việc mô hình video temporal.
  • Xây dựng kĩ thuật video-specific prompting để trả về biểu diễn văn bản ở mức instance-level một cách tự động. Kĩ thuật này sử dụng thông tin nội dung video để nâng cao chất lượng tạo prompt.
  • Đề xuất một cách mới để mở rộng các mô hình language-image pretrained cho bài toán video recognition và các task về video khác.

Phương pháp

Tổng quan

Các phương pháp trước đây giải quyết bài toán Video recognition theo hướng là học feature embedding riêng biệt được supervise theo các label one-hot. Nhược điểm của cách này là khả năng dự đoán bị đóng khung theo các label có sẵn, do đó sẽ rất khó để train những label khác mà không có trong tập label hiện tại. Chính vì vậy, giống như các mô hình contrastive language-image pretraining, nhóm tác giả sử dụng text là supervision vì text cung cấp nhiều ngữ nghĩa thông tin hơn.

image.png

Phương pháp đề xuất trong bài báo là học cách căn chỉnh biểu diễn video và biểu diễn text tương ứng bằng cách train cả video encoder và text encoder. Thay vì tốn tài nguyên và thời gian để train lại từ đầu cũng như tận dụng được sức mạnh của các model pretraining trước đó, phương pháp tận dùng các model pretraining này và mở rộng với dạng video và các textual prompt.

Cụ thể, cho một video clip VVV \in \mathcal{V} và text description tương ứng là CCC \in \mathcal{C} trong đó V\mathcal{V} là tập các video và C\mathcal{C} là tập tên các category. Đầu tiên ta sẽ truyền video VV vào video encoder fθvf_{\theta_v} và text CC vào text encoder fθcf_{\theta_c} để nhận biểu diễn video v\mathbf{v} và biểu diễn text c\mathbf{c} tương ứng, trong đó

image.png

Sau đó, ta sử dụng prompt generator fθpf_{\theta_p} để trả về instance-level biểu diễn text cho mỗi video, cụ thể như sau:

image.png

Cuối cùng, ta sử dụng consine similarity để tính độ tương đồng giữa biểu diện hình ảnh và text.

image.png

Mục tiêu của phương pháp này là tối đa hóa sim(v,c^)\operatorname{sim}(\mathbf{v}, \hat{\mathbf{c}}) nếu như VVCC khớp nhau, ngược lại tất nhiên là tối thiểu hóa 😄

Video Encoder

Video encoder bao gồm 2 thành phần:

  • Cross-frame communication transformer có nhiệm vụ nhận các frame làm input, thông qua pretrained language-image model, output là các biểu diễn frame-level có chứa thông tin trao đổi giữa các frame.
  • Multi-frame integration transformer có nhiệm vụ tích hợp các biểu diễn frame-level với các video feature.

Cụ thể, cho một video clip VRT×H×W×3V \in \mathbb{R}^{T \times H \times W \times 3} trong đó TT là số frame được lấy mẫu, HHWW là chiều cao và chiều rộng của frame, theo model ViT ta sẽ chia frame thành NN patch {xt,i}i=1NRP2×3\left\{\mathbf{x}_{t, i}\right\}_{i=1}^N \in \mathbb{R}^{P^2 \times 3} không chồng chéo nhau, mỗi patch sẽ có kích thước là P×PP \times P pixel và N=HW/P2N = HW/P^2. Sau đó, ý tưởng như ViT 😄 ta sẽ nhúng các patch vào patch embedding sử dụng linear projection ER3P2×D\mathbf{E} \in \mathbb{R}^{3 P^2 \times D}. Tiếp theo, ta sẽ thêm một learnable embedding là Xclass \mathbf{X}_{\text {class }} (hay class token) vào chuỗi các patch được embedding. Vậy ta có đầu vào của cross-frame communication transformer tại frame tt được biểu diễn như sau:

image.png

trong đó espa\mathbf{e}^{s p a} là spatial position encoding.

Tiếp theo ta sẽ truyền các patch embedding trên vào một Lc-layer Cross-frame Communication Transformer (CCT) để nhận biểu diễn frame-level ht\mathbf{h}_t:

image.png

trong đó ll là block index của CCT, zt,0(Lc)\mathbf{z}_{t, 0}^{\left(L_c\right)} biểu diễn final output của class token.

Cuối cùng, LmL_m - layer Multi-frame Integration Transformer (MIT) nhận tất cả các biểu diễn frame H=[h1,h2,,hT]\mathbf{H}=\left[\mathbf{h}_1, \mathbf{h}_2, \cdots, \mathbf{h}_T\right] làm input và output là video-level representation v\mathbf{v} được biểu diễn như sau:

image.png

trong đó AvgPool và etemp\mathbf{e}^{temp} lần lượt là average pooling và temporal position encoding. Multi-frame integration transformer được xây dựng bởi multi-head self-attention và feed-forward networks tiêu chuẩn 😄

image.png

Để có thể có được thông tin trao đổi giữa các frame với nhau, nhóm tác giả đề xuất một module attention mới. Thành phần của module này gồm 2 loại attention là cross-frame fusion attention (CFA) và intra-frame diffusion attention (IFA), với một feed-forward network (FFN). Nhóm cũng giới thiệu cơ chế message token cho mỗi frame có vai trò trừu tượng, gửi và nhận thông tin, do đó có thể trao đổi thông tin visual giữa các frame như hình trên.

Cụ thể, message token mt(l)\mathbf{m}_t^{(l)} cho frame thứ tt tại layer thứ ll được tạo bằng cách sử dụng một linear transformation trên class token zt,0(l1)\mathbf{z}_{t, 0}^{(l-1)}. Điều này cho phép các message token có thể trừu tượng thông tin visual của frame hiện tại.

Sau đó, ta sẽ tổng hợp các message token để học các phụ thuộc toàn cục spatio-temporal của video đầu vào. Cụ thể, quá trình tại block thứ ll như sau:

image.png

trong đó, M^(l)=[m^1(l),m^2(l),,m^T(l)]\hat{\mathbf{M}}^{(l)}=\left[\hat{\mathbf{m}}_1^{(l)}, \hat{\mathbf{m}}_2^{(l)}, \cdots, \hat{\mathbf{m}}_T^{(l)}\right] và LN là layer normalization.

Sau đó, IFA nhận các frame token với message token liên kết (xem hình trên) để học biểu diễn visual, trong đó message token liên quan cũng có thể "khuếch tán" phụ thuộc spatio-temporal toàn cục cho quá trình học. Quá trình tại block thứ ll được biểu diễn như sau:

image.png

trong đó [,][\cdot, \cdot] concat các feature của frame token và message token.

Cuối cùng, ta cho các frame token qua feed-forward network (FFN) như sau

image.png

Chú ý rằng, message token được bỏ qua trước FFN layer và không được truyền vào block sau, lý do là message token được tạo liên tục và được sử dụng cho frame communication trong mỗi block.

Bằng cách thực hiện đan xen việc kết hợp và phân tán các attention qua các LcL_c block, CCT có thể encode thông tin spatial và temporal toàn cục của các video frames. Mặt khác, điều này cũng giảm đáng kể chi phí tính toán (xem hình dưới).

image.png

Về việc khởi tạo, thay vì train từ đầu, mô hình tận dụng các pretrained image encoder vào video encoder và có 2 chỉnh sửa chính:

  • IFA kết thừa trọng số trực tiếp từ các pretrained model, trong khi CFA được khởi tạo ngẫu nhiên.
  • MIT được khởi tạo ngẫu nhiên.

Text Encoder

Nhóm tác giả sử dụng pretrained text encoder và mở rộng cho việc xây dựng nội dung mô tả cho video. Gọi CC là mô tả của một video và biểu diễn text c\mathbf{c} (c=fθc(C)\mathbf{c}=f_{\theta_c}(C)) tạo bởi text encoder. Nhóm tác giả chỉ sử dụng tên nhãn cơ bản, ngắn gọn làm text description CC😄 và đề xuất một text prompting scheme có thể học được.

Để hiểu ảnh hoặc video, ta thường cần một ngữ cảnh để hỗ trợ phân biệt. Ví dụ như ngữ cảnh "in the water" sẽ giúp ta dễ dàng phân biệt "swimming" và "running". Tuy nhiên, rất khó để có được ngữ nghĩa trực quan như vậy trong các tác vụ nhận dạng video, lý do là dataset chỉ cung cấp tên các category cố định và video có cùng class sẽ có cùng category nhưng visual context và content có thể khác nhau. Để giải quyết vấnd dề này, nhóm tác giả đề xuất một learnable prompting scheme để sinh biểu diễn text tự động. Cụ thể như sau:

image.png

trong đó c\mathbf{c} là text embedding, MHSA là multi-head self-attention, ZRN×d\overline{\mathbf{Z}} \in \mathbb{R}^{N \times d} là trung bình của {zt(Lc)}t=1T\left\{\mathbf{z}_t^{\left(L_c\right)}\right\}_{t=1}^Tc~\tilde{\mathbf{c}} là prompt của video. Nhóm tác giả sử dụng biểu diễn text c\mathbf{c} là query và biểu diễn nội dung video z~\tilde{\mathbf{z}} là key và value. Cách cài đặt này giúp cho biểu diễn text có thể trích xuất thông tin visual context từ video.

Sau đó, nhóm tác giả cài đặt c^=c+αc~\hat{\mathbf{c}}=\mathbf{c}+\boldsymbol{\alpha} \tilde{\mathbf{c}}, trong đó α\alpha là learnable parameter được khởi tạo giá trị là 0.1. Giá trị c^\hat{\mathbf{c}} cuối cùng được sử dụng cho việc phân loại.

Coding

Khối CCT được xây dựng như sau:

from collections import OrderedDict
from timm.models.layers import trunc_normal_
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint_sequential
import sys
sys.path.append("../")
from clip.model import LayerNorm, QuickGELU, DropPath


class CrossFramelAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, droppath = 0., T=0, ):
        super().__init__()
        self.T = T

        self.message_fc = nn.Linear(d_model, d_model)
        self.message_ln = LayerNorm(d_model)
        self.message_attn = nn.MultiheadAttention(d_model, n_head,)
           
        self.attn = nn.MultiheadAttention(d_model, n_head,)
        self.ln_1 = LayerNorm(d_model)
        
        self.drop_path = DropPath(droppath) if droppath > 0. else nn.Identity()
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]


    def forward(self, x):
        l, bt, d = x.size()
        b = bt // self.T
        x = x.view(l, b, self.T, d) 

        msg_token = self.message_fc(x[0,:,:,:]) 
        msg_token = msg_token.view(b, self.T, 1, d) 
        
        msg_token = msg_token.permute(1,2,0,3).view(self.T, b, d) 
        msg_token = msg_token + self.drop_path(self.message_attn(self.message_ln(msg_token),self.message_ln(msg_token),self.message_ln(msg_token),need_weights=False)[0])
        msg_token = msg_token.view(self.T, 1, b, d).permute(1,2,0,3)
        
        x = torch.cat([x, msg_token], dim=0)
        
        x = x.view(l+1, -1, d)
        x = x + self.drop_path(self.attention(self.ln_1(x)))
        x = x[:l,:,:]
        x = x + self.drop_path(self.mlp(self.ln_2(x)))
        return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, droppath=None, use_checkpoint=False, T=8):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        if droppath is None:
            droppath = [0.0 for i in range(layers)] 
        self.width = width
        self.layers = layers
        
        self.resblocks = nn.Sequential(*[CrossFramelAttentionBlock(width, heads, attn_mask, droppath[i], T) for i in range(layers)])
       
    def forward(self, x: torch.Tensor):
        if not self.use_checkpoint:
            return self.resblocks(x)
        else:
            return checkpoint_sequential(self.resblocks, 3, x)


class CrossFrameCommunicationTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
                 droppath = None, T = 8, use_checkpoint = False,):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        ## Attention Blocks
        self.transformer = Transformer(width, layers, heads, droppath=droppath, use_checkpoint=use_checkpoint, T=T,)
        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))


    def init_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)

        cls_x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            cls_x = cls_x @ self.proj
        
        return cls_x, x[:,1:,:]

Khối MIT được xây dựng như sau:

import torch
from torch import nn
from collections import OrderedDict
from timm.models.layers import trunc_normal_
import sys
sys.path.append("../")
from clip.model import QuickGELU


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = nn.LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class MultiframeIntegrationTransformer(nn.Module):
    def __init__(self, T, embed_dim=512, layers=1,):
        super().__init__()
        self.T = T
        transformer_heads = embed_dim // 64
        self.positional_embedding = nn.Parameter(torch.empty(1, T, embed_dim))
        trunc_normal_(self.positional_embedding, std=0.02)
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(d_model=embed_dim, n_head=transformer_heads) for _ in range(layers)])

        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Linear,)):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x):
        ori_x = x
        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)
        x = self.resblocks(x)
        x = x.permute(1, 0, 2)  
        x = x.type(ori_x.dtype) + ori_x
        
        return x.mean(dim=1, keepdim=False)

Tiếp theo, ta có module text encoder

from timm.models.layers import trunc_normal_
import torch
from torch import nn
import sys
sys.path.append("../")
from clip.model import QuickGELU


class MulitHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim ** -0.5

        self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)


        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, q, k, v):
        B, N, C = q.shape
        B, M, C = k.shape
        q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0,2,1,3)
        k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads).permute(0,2,1,3)
        v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads).permute(0,2,1,3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class PromptGeneratorLayer(nn.Module):
    def __init__(
        self,
        d_model,
        nhead,
        dropout=0.,
    ):
        super().__init__()
        self.cross_attn = MulitHeadAttention(d_model, nhead, proj_drop=dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            QuickGELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x, visual):
        q = k = v = self.norm1(x)
        x = x + self.cross_attn(q, visual, visual)
        x = x + self.dropout(self.mlp(self.norm3(x)))
        return x


class VideoSpecificPrompt(nn.Module):
    def __init__(self, layers=2, embed_dim=512, alpha=0.1,):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.decoder = nn.ModuleList([PromptGeneratorLayer(embed_dim, embed_dim//64) for _ in range(layers)])
        self.alpha = nn.Parameter(torch.ones(embed_dim) * alpha)
        self.apply(self._init_weights)


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    
    def forward(self, text, visual):
        B, N, C = visual.shape
        visual = self.norm(visual)
        for layer in self.decoder:
            text = layer(text, visual)
        
        return self.alpha * text

Tổng hợp lại, ta có model hoàn chỉnh:

from typing import Tuple, Union
import torch
from torch import nn
import numpy as np
from .mit import MultiframeIntegrationTransformer
from .prompt import VideoSpecificPrompt
from .cct import CrossFrameCommunicationTransformer
import sys
import warnings
sys.path.append("../")
from clip.model import CLIP,LayerNorm,Transformer
import clip

class XCLIP(CLIP):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int, 
                 # video
                 T=8, 
                 droppath=0.,
                 mit_layers=1,
                 # prompt 
                 prompts_alpha=1e-4,
                 prompts_layers=1,
                 # other
                 use_cache=True,
                 use_checkpoint=False,
                 ):
        super().__init__(
            embed_dim,
            image_resolution, vision_layers, vision_width, vision_patch_size,
            context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
        )
        
        self.prompts_generator = VideoSpecificPrompt(layers=prompts_layers, embed_dim=embed_dim, alpha=prompts_alpha,)
        self.use_cache=use_cache
        self.mit = MultiframeIntegrationTransformer(T=T, embed_dim=embed_dim, layers=mit_layers,)

        dpr = [x.item() for x in torch.linspace(0, droppath, vision_layers)] if droppath > 0. else None

        vision_heads = vision_width // 64
        self.visual = CrossFrameCommunicationTransformer(
            input_resolution=image_resolution,
            patch_size=vision_patch_size,
            width=vision_width,
            layers=vision_layers,
            heads=vision_heads,
            output_dim=embed_dim,
            droppath=dpr,
            T=T,
            use_checkpoint=use_checkpoint,
        )

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )
        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.cache_text_features = None
        self.prompts_visual_ln = LayerNorm(vision_width)
        self.prompts_visual_proj = nn.Parameter(torch.randn(vision_width, embed_dim))
        
        self.initialize_parameters()
    
    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'positional_embedding'}

    def encode_image(self, image):
        return self.visual(image)

    def encode_text(self, text):
        x = self.token_embedding(text)
        eos_indx = text.argmax(dim=-1)
        K, N1, C = x.shape

        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection
        x = x.reshape(K, -1)
        return x

    def encode_video(self, image):
        b,t,c,h,w = image.size()
        image = image.reshape(-1,c,h,w)

        cls_features, img_features = self.encode_image(image)
        img_features = self.prompts_visual_ln(img_features)
        img_features = img_features @ self.prompts_visual_proj
        
        cls_features = cls_features.view(b, t, -1)
        img_features = img_features.view(b,t,-1,cls_features.shape[-1])
        
        video_features = self.mit(cls_features)

        return video_features, img_features

    def cache_text(self, text):
        self.eval()
        with torch.no_grad():
            if self.cache_text_features is None:
                self.cache_text_features = self.encode_text(text)
        self.train()
        return self.cache_text_features

    def forward(self, image, text):
        b = image.shape[0]
        video_features, img_features = self.encode_video(image) 
        img_features = img_features.mean(dim=1, keepdim=False)

        if self.use_cache:
            text_features = self.cache_text(text)
        else:
            text_features = self.encode_text(text)
        
        text_features = text_features.unsqueeze(0).expand(b, -1, -1)
        text_features = text_features + self.prompts_generator(text_features, img_features)
           
        video_features = video_features / video_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        logit_scale = self.logit_scale.exp()
        logits = torch.einsum("bd,bkd->bk", video_features, logit_scale * text_features)
        
        return logits


def build_model(state_dict: dict, T=8, droppath=0., use_checkpoint=False, logger=None, prompts_alpha=1e-1, prompts_layers=2, use_cache=True, mit_layers=4,):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
    else:
        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
    
    model = XCLIP(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,  
        T=T, droppath=droppath, mit_layers=mit_layers,
        prompts_alpha=prompts_alpha, prompts_layers=prompts_layers,
        use_checkpoint=use_checkpoint, use_cache=use_cache,
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    msg = model.load_state_dict(state_dict,strict=False)
    logger.info(f"load pretrained CLIP: {msg}")
    
    return model.eval()


def load(model_path, name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 
         jit=True, T=8, droppath=0., use_checkpoint=False, logger=None, use_cache=True, prompts_alpha=1e-1, prompts_layers=2, mit_layers=1,
):
    if model_path is None:
        model_path = clip._download(clip._MODELS[name])
    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
        state_dict = None
    except RuntimeError:
        # loading saved state dict
        if jit:
            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
            jit = False
        state_dict = torch.load(model_path, map_location="cpu")

    model = build_model(state_dict or model.state_dict(), T=T, droppath=droppath, 
                        use_checkpoint=use_checkpoint, logger=logger,
                        prompts_alpha=prompts_alpha, 
                        prompts_layers=prompts_layers,
                        use_cache=use_cache,
                        mit_layers=mit_layers,
                        )
    if str(device) == "cpu":
        model.float()
    return model, model.state_dict()

Thực nghiệm

Bảng so sánh kết quả với các SOTA trên bộ data Kinetics-600.

image.png

Kết quả khi thực hiện zero shot trên tập HMDB51, UCF101 và Kinetic.

image.png

image.png

Tham khảo

[1] Expanding Language-Image Pretrained Models for General Video Recognition

[2] https://github.com/microsoft/VideoX/tree/master/X-CLIP


All rights reserved

Viblo
Hãy đăng ký một tài khoản Viblo để nhận được nhiều bài viết thú vị hơn.
Đăng kí