[GAN series-3] Tự sinh nhân vật Anime với Deep learning - GAN

1. Dataset

Trong hai bài trước, mình đã nói qua về khái niệm GAN và thực hành GAN với bộ dataset đơn giản: Mnist. Trong bài này, mình sẽ tiến hành code một GAN phức tạp hơn, trên dataset phức tạp hơn: bộ chân dung các nhân vật Anime. Mọi người có thể tải theo link sau: Anime Dataset .

Dataset được download và đặt trong thư mục "dataset" trong project folder.

2. Tiến hành code

Đoạn code dưới đây được mình code bằng Keras. Do mình từng code trên cả pytorch => có một vài đoạn code mình sử dụng lại code pytorch trong việc load data. Mình sẽ show từng đoạn code kèm giải thích cho từng đoạn.

Import các thư viện cần thiết:

import torch
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import pandas
import os
from matplotlib import pyplot as plt
from keras.utils import plot_model
from keras.optimizers import Adam
from random import random
import tensorflow as tf
import keras
from keras.layers import *
from keras.models import Model, Sequential
import matplotlib.gridspec as gridspec    

Viết hàm định nghĩa các model trong GAN:

2.1 ĐỊnh nghĩa Discriminator

Có một điểm đặc biệt trong mạng này, mình tạo một discriminator với 2 outputs. Trong đó 1 output có shape = (:, 2, 2, 2) => Đại diện cho (2 x 2) vùng "local" và 1 oput có shape = (:, 2) đại diện cho toàn bộ ảnh -> "global". Mình sẽ giải thích lí do của việc multi-outputs trong bài viết tiếp theo (bạn hoàn toàn có thể bỏ đi output đầu tiên).

def create_discriminator(input_shape=(96,96,3), base_c=32):
    input = Input(shape=input_shape)
    # shape: 96*96*32
    x1 = Conv2D(filters=base_c*4, strides=1, kernel_size=4, padding='same')(input)
    x1 = BatchNormalization()(x1)
    x1 = LeakyReLU(alpha=0.2)(x1)
    
    x1 = Conv2D(filters=base_c*8, strides=2, kernel_size=4, padding='same')(x1)
    x1 = BatchNormalization()(x1)
    x1 = LeakyReLU(alpha=0.2)(x1)
    
    x2 = Conv2D(filters=base_c*4, strides=2, kernel_size=4, padding="same")(x1)
    x2 = BatchNormalization()(x2)
    x2 = LeakyReLU(alpha=0.2)(x2)
    
    x3 = Conv2D(filters=base_c*4, strides=2, kernel_size=5, padding="same")(x2)
    x3 = BatchNormalization()(x3)
    x3 = LeakyReLU(alpha=0.2)(x3)
    
    x4 = Conv2D(filters=base_c*2, strides=2, kernel_size=5, padding="same")(x3)
    x4 = BatchNormalization()(x4)
    x4 = LeakyReLU(alpha=0.2)(x4)

    x5 = Conv2D(filters=base_c, strides=2, kernel_size=3, padding="same")(x4)
    x6 = Conv2D(2, activation='softmax', kernel_size=2, name='output_1')(x5)
    
    x7 = Flatten()(x5)
    x8 = Dense(2, activation="softmax", name="output_2")(x7)
    model = Model(input, outputs=[x6, x8])
    return model

2.2 ĐỊnh nghĩa Generator

def create_generator(z_dim=100, base_wh=6, base_c=32):
    input = Input(shape=(z_dim,))
    x1 = Dense(base_wh*base_wh*base_c, activation='relu')(input)
    x1 = Reshape(target_shape=(base_wh, base_wh, base_c))(x1)
    
    # shape: 12*12*256
    x2 = Conv2DTranspose(filters=base_c*8, kernel_size=4, strides=2, padding="same")(x1)
    x2 = BatchNormalization(momentum = 0.8)(x2)
    x2 = Activation('relu')(x2)
    
    # shape: 24*24*128
    x3 = Conv2DTranspose(filters=base_c*5, kernel_size=4, strides=2, padding="same")(x2)
    x3 = BatchNormalization(momentum = 0.8)(x3)
    x3 = Activation('relu')(x3)
    
    # shape: 48*48*64
    x4 = Conv2DTranspose(filters=base_c*3, kernel_size=3, strides=2, padding="same")(x3)
    x4 = BatchNormalization(momentum = 0.8)(x4)
    x4 = Activation('relu')(x4)
    
    # shape: 96*96*64
    x5 = Conv2DTranspose(filters=base_c*2, kernel_size=3, strides=2, padding="same")(x4)
    x5 = BatchNormalization(momentum = 0.8)(x5)
    x5 = Activation('relu')(x5)

    # shape: 96*96*32
    x6 = Conv2D(base_c, padding='same', kernel_size=3)(x5)
    x6 = BatchNormalization(momentum=0.8)(x6)
    x6 = Activation('relu')(x6)
    
    # shape: 96*96*3
    x7 = Conv2D(3, padding='same', kernel_size=3)(x6)
    x7 = Activation('tanh')(x7)
    
    model = Model(input, x7)
    return model

2.3 ĐỊnh nghĩa Dataset và một vài hàm phụ:

class AnimeDataset(Dataset):
    def __init__(self, img_paths, transforms=None):
        self.img_paths = img_paths
        self.transforms = transforms
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path)
        if self.transforms:
            img = self.transforms(img)
        img = img.permute(1,2,0)
        return img
    
def create_hparams(hparams_string=None, verbose=False):
    hparams = tf.contrib.training.HParams(
        img_size=96,
        batch_size=64,
        num_worker=2
    )
    return hparams
    
def plot_img(img, size=2):
    img = img * 0.5 + 0.5
    plt.figure(figsize=(size, size))
    plt.imshow(img)
    plt.show()


# validate discriminator
# việc validate được tiến hành trên cả bộ dữ liệu thật và dữ liệu generate ra bởi gen_model.
# y1 -> output đại diện (2x2) vùng local
# y2 -> output đại diện toàn bộ ảnh
def val_dis_model(dis_model, gen_model, val_loader, out1_shape=2):
    y1_real = []
    y1_fake = []
    y2_real = []
    y2_fake = []
    
    for real_batch in val_loader:
        real_batch = real_batch.numpy()
        noise = np.random.normal(0, 1, (hparams.batch_size, z_dim ))
        fake_batch = gen_model.predict(noise)
        
        y1, y2 = dis_model.predict(real_batch)
        y1_real.append(y1)
        y2_real.append(y2)
        
        y1, y2 = dis_model.predict(fake_batch)
        y1_fake.append(y1)
        y2_fake.append(y2)
        
    y1_real = np.array(y1_real).reshape((-1, out1_shape, out1_shape, 2))
    y1_fake = np.array(y1_fake).reshape((-1, out1_shape, out1_shape, 2))
    y2_real = np.array(y2_real).reshape((-1, 2))
    y2_fake = np.array(y2_fake).reshape((-1, 2))
    
    # tính giá trị trung bình của y1_real, y1_fake
    # => accuracy tương ứng của discriminator trong tập real_data, fake_data
    real = y1_real[:,:,:,0].mean() + y2_real[:,0].mean()
    fake = y1_fake[:,:,:,0].mean() + y2_fake[:,0].mean()
    return (real/2, fake/2)
    
    
# show ảnh kết quả
def plot_img_batch(img_batch, iteration):
    plt.figure(figsize=(4,4))
    gs1 = gridspec.GridSpec(4, 4)
    gs1.update(wspace=0, hspace=0)
    for i in range(16):
        ax1 = plt.subplot(gs1[i])
        ax1.set_aspect('equal')
        image = img_batch[i]
        fig = plt.imshow(image)
        plt.axis('off')
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
    plt.tight_layout()
    plt.savefig("./result/{}.png".format(iteration))    
    plt.show()  

Định nghĩa hàm sinh label cho quá trình train, label thay vì là nhãn chính xác,. Việc thêm nhiễu này giúp cho discriminator tránh được hiện tượng học quá nhanh, vượt mặt generator. (Mình sẽ giải thích kĩ hơn trong bài 4). Mình sẽ tiến hành thêm nhiễu vào dữ liệu bằng hai cách:

  • Với một tỷ lệ nhất đinh (5%), mình hoán đổi giá trị 1 -> 0, 0 -> 1.
  • Thay vì để giá trị label chính xác 0 và 1, mình đổi thành giá trị xấp xỉ nằm trong một khoảng gần đúng với giá trị 0 -> [0, 0.1] ; giá trị 1 - > [0.9, 1].
def create_label(batch_size=hparams.batch_size, is_real_label=True, noise_ratio=0.1, swap_ratio=0.05, out1_shape=2):
    label_1 = np.zeros((batch_size, out1_shape, out1_shape, 2))
    label_2 = np.zeros((batch_size,2))
    
    noise_1 = np.random.uniform(0,1, size=(batch_size, out1_shape, out1_shape))*noise_ratio
    noise_2 = np.random.uniform(0,1, size=(batch_size))*noise_ratio
    
    if is_real_label:
        label_1[:,:,:,0] = 1 - noise_1
        label_1[:,:,:,1] = noise_1
        label_2[:,0] = 1 - noise_2
        label_2[:,1] = noise_2
    else:
        label_1[:,:,:,0] = noise_1
        label_1[:,:,:,1] = 1- noise_1
        label_2[:,0] = noise_2
        label_2[:,1] = 1- noise_2
    swap_index_1 = np.random.choice(batch_size, size=int(batch_size*swap_ratio))
    swap_index_2 = np.random.choice(batch_size, size=int(batch_size*swap_ratio))
    
    label_1[swap_index_1,:,:,0], label_1[swap_index_1,:,:,1] = label_1[swap_index_1,:,:,1], label_1[swap_index_1,:,:,0] 
    label_2[swap_index_2, 0], label_2[swap_index_2, 1] = label_2[swap_index_2, 1], label_2[swap_index_2, 0] 
    return [label_1, label_2]

2.4 Khởi tạo các giá trị và model

from random import shuffle
import tqdm 

hparams = create_hparams()
root_folder = "./dataset"
csv_link = "annotations.csv"
img_info = pandas.read_csv(csv_link)
img_paths = list(img_info["image"])
img_paths = [os.path.join(root_folder, img_path) for img_path in img_paths]
img_paths = [path for path in tqdm.tqdm(img_paths) if os.path.exists(path)]
shuffle(img_paths)

# định nghĩa các phép biến đổi ảnh => data augmentation - làm giàu dữ liệu
my_trainsform = transforms.Compose([
    transforms.Scale(hparams.img_size),
    transforms.RandomRotation(degrees=(-10,10)),
    transforms.CenterCrop(85),
    transforms.Scale(hparams.img_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]),
])

Khởi tạo dataset

train_data, val_data = train_test_split(img_paths, test_size=0.03)
val_set = AnimeDataset(val_data, my_trainsform)
train_set = AnimeDataset(train_data, my_trainsform)

train_loader = DataLoader(
    train_set,
    shuffle=True,
    drop_last=True,
    batch_size=hparams.batch_size,
    num_workers=hparams.num_worker,
)

val_loader = DataLoader(
    val_set,
    shuffle=True,
    drop_last=True,
    batch_size=hparams.batch_size,
    num_workers=hparams.num_worker,
)

Khởi tạo model. Do discriminator là multi-outputs => quá trình compile model phức tạp hơn bình thường. (có định nghĩa loss_function cùng trọng số riêng cho từng output).

Khởi tạo generator và discriminator

alpha = 0.4
dis_model = create_discriminator()
dis_optim = Adam(lr=0.000022, beta_1=0.3, beta_2=0.99)
dis_model.compile(optimizer = dis_optim,
                  loss = {"output_1": "categorical_crossentropy", "output_2": "categorical_crossentropy"},
                  loss_weights={"output_1": alpha, "output_2": 1 - alpha})

gen_model = create_generator()
gen_optim = Adam(lr=0.000022, beta_1=0.3, beta_2=0.99)
gen_model.compile(loss="categorical_crossentropy", optimizer = gen_optim)


input_gen = Input(gen_model.input.shape[1:])
fake_img = gen_model(input_gen)
dis_model.trainable=False

Khởi tạo gan_model

output_1, output_2 =  dis_model(fake_img)
outname_1 = "output_1_1"
outname_2 = "output_2_2"

output_1 = Layer(name=outname_1)(output_1)
output_2 = Layer(name=outname_2)(output_2)
gan_model = Model(input_gen, outputs=[output_1, output_2])
loss_weights={outname_1: alpha, outname_2: 1 - alpha}
gan_model.compile(optimizer = dis_optim,
                  loss = {outname_1: "categorical_crossentropy",
                          outname_2: "categorical_crossentropy"},
                  loss_weights=loss_weights)

2.5 Tiến hành train.

Trong quá trình train, mình có viết 1 đoạn code nhằm thay đổi batch_size đối với từng loại dữ liệu, từng model sau mỗi 50 step. Đoạn code này do mình tự code, không bắt buộc phải có. Mình nhận thấy khi có đoạn code đó, xác suất thất bại khi train đã được giảm xuống đáng kế.

gan_batch_size = hparams.batch_size

for epoch in range(epoch_nb):
    print("epoch: ", epoch)
    for real_batch in train_loader:
        iteration += 1 
        ##### Train Discriminator
        dis_model.trainable=True
        
        ## train dis_model với dữ liệu thật
        real_batch = real_batch.numpy()
        real_label = create_label(real_batch_size, is_real_label=True)
        dis_model.train_on_batch(
            real_batch,
            {"output_1": real_label[0], "output_2": real_label[1]})
        
        ## train dis_model với dữ liệu fake
        noise = np.random.normal(0, 1, (fake_batch_size, z_dim))
        fake_batch = gen_model.predict(noise)
        fake_label = create_label(fake_batch_size, is_real_label=False)
        dis_model.train_on_batch(
            fake_batch,
            {"output_1": fake_label[0], "output_2": fake_label[1]})
        
        ##### train Generator với dữ liệu fake
        dis_model.trainable=False
        noise = np.random.normal(0, 1, (gan_batch_size, z_dim))
        gan_label = create_label(gan_batch_size, is_real_label=True)
        gan_model.train_on_batch(
            noise, {outname_1: gan_label[0], outname_2: gan_label[1]
        })
        
        if iteration % 50 == 0:
            real, fake = val_dis_model(dis_model, gen_model, val_loader)
            acc = (real + fake)/2
            gan_acc = 1 - fake
            print("Step {}: real_acc {:4f}, fake_acc {:4f}, gan_acc {:4f}".format(iteration, real, fake, gan_acc))
            
            ### Điều chỉnh batch size từng loại dữ liệu khi discriminator bị mất cân bằng (unbalance)
            if real - gan_acc > b_size_thresh:
                gan_batch_size = int(real_batch_size*up_scale)
            elif gan_acc - real > b_size_thresh:
                gan_batch_size = int(real_batch_size/up_scale)
            else:
                gan_batch_size = real_batch_size

            if real - fake > b_size_thresh:
                fake_batch_size = int(real_batch_size*up_scale)
            elif fake - real > b_size_thresh:
                fake_batch_size = int(real_batch_size/up_scale)
            else:
                fake_batch_size = real_batch_size
                    
        if iteration % 100 == 0:
            test_batch = gen_model.predict(test_noise)
            test_batch = test_batch*0.5 + 0.5
            plot_img_batch(test_batch[:16], iteration)
        if iteration % 1000 == 0:
            dis_model.save("dis_model_{}.h5".format(iteration))
            gen_model.save("gen_model_{}.h5".format(iteration))            

Và đây là kết quả sau 3 giờ, chưa đẹp và hoàn hảo lắm nhưng cũng bắt đầu ra hình hài, khuôn mặt. Do vấn đề chỉ còn phụ thuộc vào thời gian train nên mình dừng lại, không train tiếp. Mình nhớ tay xóa hết ảnh kết quả nên chỉ còn vài cái minh họa (phát cuối nhảy hơi nhanh 🤣 )

Note: Các bài viết được viết trên trang https://viblo.asia , nếu trích dẫn hãy ghi nguồn trang cùng tên tác giả đầy đủ.

facebook: trungthanhnguyen0502