+1

[New Idea] - Data Free Model Pruning with Genetic Algorithm

Phương pháp pruning model hiện tại

Các phương pháp model pruning hiện tại được sinh ra với mục đích giảm kích thước của model bằng cách loại bỏ đi các trọng số không quan trọng trong mạng:

Chiến lược để pruning trong các phương pháp hiện tại như sau:

  • Sử dụng một binary bitmask cho mỗi layer (chứa giá trị 0 nếu các weight tương ứng không được sử dụng trong quá trình forward)

  • Huấn luyện một model lớn với độ chính xác đủ tốt trên orignal dataset
  • Chọn ra các weight không quan trọng bằng một thuật toán cụ thể (ví dụ đơn giản có thể chọn theo threshold hoặc theo sparsity của mạng)
  • Finetune lại mạng mới (với chỉ các weight được giữ lại) trên original dataset

Tham khảo: To prune, or not to prune: exploring the efficacy of pruning for model compression

Vấn đề phát sinh trong thực tế

  • Đôi khi chúng ta muốn pruning một pretrained model nhưng không thể truy cập vào dataset gốc
  • Dataset gốc quá lớn khiến cho việc finetuning khó khăn (ví dụ như Imagenet chẳng hạn)
  • Vấn đề đặt ra là Liệu có thể thực hiện pruning trực tiếp từ pre-trained model mà không cần đến dataset gốc hay không?

Ý tưởng ban đầu

  • Thay vì tự định nghĩa sub-mask các weight không quan trọng thì sẽ tự học các sub-mask đó
  • Thay vì fine-tuning lại các weight thì không fine-tuning nữa mà giữ nguyên các weight của pre-trained model thay vào đó sẽ đi tìm các kết nối giữa các weight (hay sub-network)
  • Không sử dụng tập dữ liệu ban đầu nên có thể không cần tiếp cận theo gradient-based mà sử dụng các thuật toán evolution để tìm kiếm sub-network

Thực hiện thuật toán

Thuật toán sẽ có các bước chính như sau:

  • Xuất phát từ một pre-trained
  • Xây dựng một model có kiến trúc giống pre-trained model nhưng có thêm một lớp scores mask cho mỗi layer
  • Lớp scores mask sẽ lưu điểm của mỗi weight cho từng layer. Việc lựa chọn các weight sẽ được sử dụng trong scores mask đơn giản là lựa chọn top k% các score trên từng layer
  • Sử dụng GA để khởi tạo một population ban đầu gồm nhiều score khác nhau.
  • Tiến hành chạy GA để learning ra các scores mask

Code thử

Xây dựng pre-trained model

Thử với tập MNIST có accuracy 98% làm pre-trained model. Xây dựng một mô hình CNN đơn giản

  • Import modules cần thiết
import os
import math

import torch
import numpy as np 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
  • Định nghĩa device
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
  • Định nghĩa kiến trúc mạng
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128, bias=False)
        self.fc2 = nn.Linear(128, 10, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        
        return output
  • Load pre-trained model
model = Net().to(device)
model.load_state_dict(torch.load('mnist_cnn.pt'))
  • Lưu các module_list và weight tương ứng để lát nữa copy sang mạng mới
module_list = [module for module in model.modules() if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)]

module_shape = [m.weight.shape for m in module_list]

original_weights = [m.weight for m in module_list]

Xây dựng mạng mới với sub-mask

Xây dựng phần GetSubmassk tính toán trên top k%.

class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the supermask by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

Xây dựng supermask cho layer convolution

class SupermaskConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # NOTE: initialize the weights like this.
        nn.init.kaiming_normal_(self.weight, mode="fan_in", nonlinearity="relu")

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False
        self.scores.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), sparsity)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        
        return x

Xây dựng supermask cho layer Linear

class SupermaskLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # NOTE: initialize the weights like this.
        nn.init.kaiming_normal_(self.weight, mode="fan_in", nonlinearity="relu")

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False
        self.scores.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), sparsity)
        w = self.weight * subnet
        
        return F.linear(x, w, self.bias)

Định nghĩa mạng mới tương đương với mạng pre-trained nhưng với các layer super mask

class GANet(nn.Module):
    def __init__(self):
        super(GANet, self).__init__()
        self.conv1 = SupermaskConv(1, 32, 3, 1, bias=False)
        self.conv2 = SupermaskConv(32, 64, 3, 1, bias=False)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = SupermaskLinear(9216, 128, bias=False)
        self.fc2 = SupermaskLinear(128, 10, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        
        return output

Tiến hành copy weight của pre-trained model vào model mới

ga_model = GANet().to(device)
ga_module_list = [module for module in ga_model.modules() if isinstance(module, SupermaskConv) or isinstance(module, SupermaskLinear)]

# Update weight of ga model with original trained model 

for i, weight in enumerate(original_weights):
    ga_module_list[i].weight = weight

Xây dựng thuật toán GA

Build Agent

class Agent:
    def __init__(self, params):
        self.params = params
        self.fitness = 0
        
    def set_fitness(self, fitness):
        self.fitness = fitness

Khởi tạo quần thể

# Init population

def init_pop(pop_size=100):
    population = []
    for _ in range(pop_size):
        params = []
        for shape in module_shape:
            scores = nn.Parameter(torch.Tensor(shape))
            nn.init.kaiming_uniform_(scores, a=math.sqrt(5))
            params.append(scores)
        agent = Agent(params=params)
        population.append(agent)
    return population

Update scores từng agent cho ga_model

def change_scores(module_list, agent):
    for i, m_scores in enumerate(agent.params):
        module_list[i].scores = m_scores

Đột biến

def mutation(agent, mut_rate=0.1):
    params = []
    for param in agent.params:
        out = param.clone()
        # flat_out and out share the same memory
        flat_out = out.flatten().to(device)
        # Get index mutation 
        indexes = np.where(np.random.uniform(low=0, high=1, size=(len(flat_out))) < mut_rate)
        replace_values = np.random.uniform(low=-1, high=1, size=(len(flat_out)))[indexes]
        # Mutation
        flat_out.index_copy_(0, torch.LongTensor(indexes[0]).to(device), torch.FloatTensor(replace_values).to(device))
        params.append(nn.Parameter(out))
    return Agent(params=params)

Lai ghép - tái tổ hợp

def recombine_agent(agent_1, agent_2):
    params_1 = []
    params_2 = []
    for i, param in enumerate(agent_1.params):
        param_1 = param.clone()
        param_2 = agent_2.params[i].clone()
        # Flatten 
        flat_1 = param_1.flatten().to(device)
        flat_2 = param_2.flatten().to(device)
        # Define children
        child_1 = torch.zeros(len(flat_1))
        child_2 = torch.zeros(len(flat_1))
        # Select cross point
        cross_pt = random.randint(0, len(flat_1))
        # Swap
        child_1[cross_pt:len(flat_1)] = flat_1[cross_pt:len(flat_1)]
        child_1[0:cross_pt] = flat_2[0:cross_pt]
        child_2[cross_pt:len(flat_1)] = flat_2[cross_pt:len(flat_1)]
        child_2[0:cross_pt] = flat_1[0:cross_pt]
        # Append to params 
        params_1.append(nn.Parameter(child_1.reshape(module_shape[i])))
        params_2.append(nn.Parameter(child_2.reshape(module_shape[i])))

    return Agent(params_1), Agent(params_2)

Đánh giá quần thể

from tqdm import tqdm 

def evaluate_population(pop):
    avg_fit = 0
    
    for agent in tqdm(pop):
        change_scores(ga_module_list, agent)
        fit = test(ga_model.to(device), device)
        agent.fitness = fit
        avg_fit += fit
    avg_fit /= len(pop)
    
    return pop, avg_fit

Next generation

def next_generation(pop, size=100, mut_rate=0.01):
    new_pop = []
    while len(new_pop) < size:
        parents = random.choices(pop, k=2, weights=[x.fitness for x in pop])
        offspring_ = recombine_agent(parents[0],parents[1])
        offspring = [mutation(offspring_[0], mut_rate=mut_rate), mutation(offspring_[1], mut_rate=mut_rate)]
        new_pop.extend(offspring) #add offspring to next generation
    return new_pop

Định nghĩa data test

Do không sử dụng tập train nên chỉ cần load data test để đánh giá

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        os.path.join("./data", "mnist"),
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=10000,
    shuffle=True
)

Xây dựng hàm test

def test(model, device):
    model.eval()
    
    with torch.no_grad():
        output = model(test_data)
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct = pred.eq(test_target.view_as(pred)).sum().item()
        out = correct / len(test_data)
    return out

Hyperparams

batch_size = 1024
momentum = 0.9
wd = 0.0005
lr = 0.01
epochs = 20
sparsity = 0.3
log_interval = 1000
seed = 1507

Huấn luyện

num_generations = 10000
population_size = 100

pop = init_pop(population_size)

mutation_rate = 0.1 # 0.1% mutation rate

pop_fit = []

pop = init_pop(population_size) #initial population

for gen in range(num_generations):
    # trainning
    pop, avg_fit = evaluate_population(pop)
    print('Generation {} with pop_fit {}'.format(gen, avg_fit))
    pop_fit.append(avg_fit) #record population average fitness
    new_pop = next_generation(pop, size=population_size, mut_rate=mutation_rate)
    pop = new_pop

Một vài kết quả

100%|██████████| 100/100 [00:05<00:00, 18.65it/s]
Generation 0 with pop_fit 0.49528699999999987
100%|██████████| 100/100 [00:05<00:00, 19.14it/s]
Generation 1 with pop_fit 0.5210469999999999
100%|██████████| 100/100 [00:05<00:00, 19.04it/s]
Generation 2 with pop_fit 0.554441
100%|██████████| 100/100 [00:05<00:00, 19.19it/s]
Generation 3 with pop_fit 0.550083
100%|██████████| 100/100 [00:05<00:00, 19.06it/s]
Generation 4 with pop_fit 0.5807529999999997
100%|██████████| 100/100 [00:05<00:00, 18.60it/s]
Generation 5 with pop_fit 0.5960270000000002
100%|██████████| 100/100 [00:05<00:00, 19.20it/s]
Generation 6 with pop_fit 0.6354899999999999
100%|██████████| 100/100 [00:05<00:00, 18.87it/s]
Generation 7 with pop_fit 0.6630680000000001
100%|██████████| 100/100 [00:05<00:00, 18.96it/s]
Generation 8 with pop_fit 0.6598939999999999
100%|██████████| 100/100 [00:05<00:00, 19.11it/s]
Generation 9 with pop_fit 0.6729839999999999
100%|██████████| 100/100 [00:05<00:00, 18.28it/s]
Generation 10 with pop_fit 0.6746130000000001
100%|██████████| 100/100 [00:05<00:00, 19.19it/s]
Generation 11 with pop_fit 0.6668120000000001
100%|██████████| 100/100 [00:05<00:00, 19.05it/s]
Generation 12 with pop_fit 0.6691490000000002
100%|██████████| 100/100 [00:05<00:00, 18.85it/s]
Generation 13 with pop_fit 0.6474719999999999
100%|██████████| 100/100 [00:05<00:00, 19.12it/s]
Generation 14 with pop_fit 0.6627349999999999
100%|██████████| 100/100 [00:05<00:00, 18.79it/s]
Generation 15 with pop_fit 0.6822900000000003
100%|██████████| 100/100 [00:05<00:00, 18.74it/s]
Generation 16 with pop_fit 0.6807899999999999
100%|██████████| 100/100 [00:05<00:00, 19.00it/s]
Generation 17 with pop_fit 0.7026549999999999
100%|██████████| 100/100 [00:05<00:00, 18.45it/s]
Generation 18 with pop_fit 0.6868509999999998
100%|██████████| 100/100 [00:05<00:00, 19.11it/s]
Generation 19 with pop_fit 0.7088399999999999
100%|██████████| 100/100 [00:05<00:00, 18.90it/s]
Generation 20 with pop_fit 0.7220369999999999
100%|██████████| 100/100 [00:05<00:00, 19.04it/s]
Generation 21 with pop_fit 0.713581
100%|██████████| 100/100 [00:05<00:00, 19.00it/s]
Generation 22 with pop_fit 0.7366479999999997
100%|██████████| 100/100 [00:05<00:00, 19.00it/s]
Generation 23 with pop_fit 0.7478739999999998
100%|██████████| 100/100 [00:05<00:00, 18.96it/s]
Generation 24 with pop_fit 0.7208869999999997
100%|██████████| 100/100 [00:05<00:00, 19.06it/s]
Generation 25 with pop_fit 0.7302569999999998
100%|██████████| 100/100 [00:05<00:00, 18.97it/s]
Generation 26 with pop_fit 0.7252170000000003
100%|██████████| 100/100 [00:05<00:00, 18.76it/s]
Generation 27 with pop_fit 0.7199049999999997
100%|██████████| 100/100 [00:05<00:00, 19.07it/s]
Generation 28 with pop_fit 0.730054
100%|██████████| 100/100 [00:05<00:00, 18.33it/s]
Generation 29 with pop_fit 0.738942
100%|██████████| 100/100 [00:05<00:00, 18.94it/s]
Generation 30 with pop_fit 0.7318150000000001
100%|██████████| 100/100 [00:05<00:00, 19.08it/s]
Generation 31 with pop_fit 0.7219880000000001
100%|██████████| 100/100 [00:05<00:00, 18.45it/s]
Generation 32 with pop_fit 0.7238969999999999
100%|██████████| 100/100 [00:05<00:00, 19.07it/s]
Generation 33 with pop_fit 0.7283080000000001
100%|██████████| 100/100 [00:05<00:00, 18.71it/s]
Generation 34 with pop_fit 0.71085
100%|██████████| 100/100 [00:05<00:00, 19.10it/s]
Generation 35 with pop_fit 0.7257509999999999
100%|██████████| 100/100 [00:05<00:00, 19.06it/s]
Generation 36 with pop_fit 0.726268
100%|██████████| 100/100 [00:05<00:00, 17.18it/s]
Generation 37 with pop_fit 0.7375379999999998
100%|██████████| 100/100 [00:05<00:00, 19.07it/s]
Generation 38 with pop_fit 0.7513240000000004
100%|██████████| 100/100 [00:05<00:00, 18.53it/s]
Generation 39 with pop_fit 0.7483460000000001
100%|██████████| 100/100 [00:05<00:00, 19.11it/s]
Generation 40 with pop_fit 0.7527380000000001
100%|██████████| 100/100 [00:05<00:00, 18.99it/s]
Generation 41 with pop_fit 0.7640440000000001
100%|██████████| 100/100 [00:05<00:00, 18.77it/s]
Generation 42 with pop_fit 0.7598250000000002
100%|██████████| 100/100 [00:05<00:00, 19.01it/s]
Generation 43 with pop_fit 0.7414619999999998
100%|██████████| 100/100 [00:05<00:00, 18.78it/s]
Generation 44 with pop_fit 0.7588380000000001
100%|██████████| 100/100 [00:05<00:00, 19.01it/s]
Generation 45 with pop_fit 0.7678069999999999
100%|██████████| 100/100 [00:05<00:00, 17.95it/s]
Generation 46 with pop_fit 0.7691239999999996
100%|██████████| 100/100 [00:05<00:00, 17.38it/s]
Generation 47 with pop_fit 0.793402
100%|██████████| 100/100 [00:05<00:00, 18.59it/s]
Generation 48 with pop_fit 0.7898880000000004
100%|██████████| 100/100 [00:05<00:00, 17.82it/s]
Generation 49 with pop_fit 0.7741570000000004
100%|██████████| 100/100 [00:05<00:00, 18.31it/s]
Generation 50 with pop_fit 0.780985
100%|██████████| 100/100 [00:05<00:00, 18.55it/s]
Generation 51 with pop_fit 0.781711
100%|██████████| 100/100 [00:05<00:00, 18.78it/s]
Generation 52 with pop_fit 0.7949269999999999
100%|██████████| 100/100 [00:05<00:00, 19.02it/s]
Generation 53 with pop_fit 0.790513
100%|██████████| 100/100 [00:05<00:00, 18.21it/s]
Generation 54 with pop_fit 0.779757
100%|██████████| 100/100 [00:05<00:00, 18.11it/s]
Generation 55 with pop_fit 0.7727679999999999
100%|██████████| 100/100 [00:05<00:00, 18.93it/s]
Generation 56 with pop_fit 0.7883189999999998
100%|██████████| 100/100 [00:05<00:00, 19.09it/s]
Generation 57 with pop_fit 0.791412
100%|██████████| 100/100 [00:05<00:00, 18.74it/s]
Generation 58 with pop_fit 0.794981
100%|██████████| 100/100 [00:05<00:00, 18.28it/s]
Generation 59 with pop_fit 0.7836850000000003
100%|██████████| 100/100 [00:05<00:00, 18.66it/s]
Generation 60 with pop_fit 0.7869880000000002
100%|██████████| 100/100 [00:05<00:00, 19.13it/s]
Generation 61 with pop_fit 0.793687
100%|██████████| 100/100 [00:05<00:00, 18.90it/s]
Generation 62 with pop_fit 0.7843660000000001
100%|██████████| 100/100 [00:05<00:00, 18.26it/s]
Generation 63 with pop_fit 0.7861880000000006
100%|██████████| 100/100 [00:05<00:00, 18.99it/s]
Generation 64 with pop_fit 0.7954170000000002
100%|██████████| 100/100 [00:05<00:00, 18.70it/s]
Generation 65 with pop_fit 0.807838
100%|██████████| 100/100 [00:05<00:00, 19.09it/s]
Generation 66 with pop_fit 0.8058349999999997
100%|██████████| 100/100 [00:05<00:00, 17.87it/s]
Generation 67 with pop_fit 0.812714
100%|██████████| 100/100 [00:05<00:00, 18.62it/s]
Generation 68 with pop_fit 0.8091359999999999
100%|██████████| 100/100 [00:05<00:00, 18.19it/s]
Generation 69 with pop_fit 0.8095650000000002
100%|██████████| 100/100 [00:05<00:00, 19.25it/s]
Generation 70 with pop_fit 0.8086880000000002
100%|██████████| 100/100 [00:05<00:00, 17.88it/s]
Generation 71 with pop_fit 0.8138010000000002
100%|██████████| 100/100 [00:05<00:00, 18.97it/s]
Generation 72 with pop_fit 0.8169339999999998
100%|██████████| 100/100 [00:05<00:00, 18.20it/s]
Generation 73 with pop_fit 0.8242490000000001
100%|██████████| 100/100 [00:05<00:00, 19.17it/s]
Generation 74 with pop_fit 0.8266410000000001
100%|██████████| 100/100 [00:05<00:00, 19.12it/s]
Generation 75 with pop_fit 0.8304610000000002
100%|██████████| 100/100 [00:05<00:00, 19.22it/s]
Generation 76 with pop_fit 0.8254399999999997
100%|██████████| 100/100 [00:05<00:00, 19.07it/s]
Generation 77 with pop_fit 0.8176640000000001
100%|██████████| 100/100 [00:05<00:00, 19.21it/s]
Generation 78 with pop_fit 0.8264109999999996
100%|██████████| 100/100 [00:05<00:00, 19.32it/s]
Generation 79 with pop_fit 0.8273800000000001
100%|██████████| 100/100 [00:05<00:00, 19.27it/s]
Generation 80 with pop_fit 0.83148
100%|██████████| 100/100 [00:05<00:00, 17.61it/s]
Generation 81 with pop_fit 0.8278819999999999
100%|██████████| 100/100 [00:05<00:00, 18.97it/s]
Generation 82 with pop_fit 0.826982
100%|██████████| 100/100 [00:05<00:00, 18.91it/s]
Generation 83 with pop_fit 0.8198989999999998
100%|██████████| 100/100 [00:05<00:00, 19.11it/s]
Generation 84 with pop_fit 0.8339390000000002
100%|██████████| 100/100 [00:05<00:00, 18.41it/s]
Generation 85 with pop_fit 0.8264149999999997
100%|██████████| 100/100 [00:05<00:00, 17.95it/s]
Generation 86 with pop_fit 0.8333849999999999
100%|██████████| 100/100 [00:05<00:00, 17.06it/s]

........
100%|██████████| 100/100 [00:05<00:00, 18.24it/s]
Generation 337 with pop_fit 0.8917819999999995
100%|██████████| 100/100 [00:05<00:00, 18.97it/s]
Generation 338 with pop_fit 0.8906439999999998
100%|██████████| 100/100 [00:05<00:00, 19.03it/s]
Generation 339 with pop_fit 0.8893889999999998
100%|██████████| 100/100 [00:05<00:00, 19.16it/s]
Generation 340 with pop_fit 0.8890430000000005
100%|██████████| 100/100 [00:05<00:00, 19.13it/s]
Generation 341 with pop_fit 0.8893149999999999
100%|██████████| 100/100 [00:05<00:00, 18.87it/s]
Generation 342 with pop_fit 0.8903490000000001

All rights reserved

Viblo
Hãy đăng ký một tài khoản Viblo để nhận được nhiều bài viết thú vị hơn.
Đăng kí