Style Transfer - explain and code

Bạn nghĩ sao về một bức ảnh chụp (máy ảnh) Hà Nội nhưng lại mang phong cách tranh thiên tài Picasso. Với sự ra đời của thuật toán Style Transfer, chuyện đó là hoàn toàn có thể."

1. Thuật toán.

Dưới đây là hình minh họa cho thuật toán. Chúng ta có 3 ảnh gồm:

  • input_image: Được khởi tạo random, lúc đầu là ảnh nhiễu bất kì, sau quá trình update, tối ưu thành kết quả ta muốn (output image)
  • content_image: chứa nội dung mà ảnh output_image sẽ chứa.
  • style_image: chứa style (phong cách) mà output_image sẽ chứa sau quá trình update sẽ thành ảnh mong muốn.

Ý tưởng thuật toán rất đơn giản: cả 3 ảnh cùng đưa vào 1 pretrained-CNN để trích xuất ra các feature_map. Những Feature map này là những thông tin đặc trưng, chứa đựng thông tin về nội dung, đường nét, màu sắc của ảnh, hay còn gọi là content_feature. Style của một họa sĩ, 1 bức tranh thực tế chính là mối quan hệ giữa các đường nét, màu sắc trong tranh. Như vậy, bằng một phép biến đổi (gram_matrix), ta có thể tính ra được style_feature dựa trên content_feature. Thuật toán cụ thể như sau:

  • Gọi 3 ảnh input_image, content_image, style_image lần lượt là A, B, C
  • Đưa cả 3 ảnh vào cùng 1 pretrained_CNN để trích xuất ra các feature cho 3 ảnh, lần lượt là: A_content, B_content, C_content. Các feature này chính là feature đặc trưng về nội dung (content)
  • Để biến đổi content_feature sang style_feature, ta dùng 1 thuật toán gọi là hàm GramMatrix()
  • A_style = GramMatrix(A_content), C_style = GramMatrix(C_content)
  • Ta có ContentLoss = MSE(A_content, B_content), StyleLoss = MSE(A_style, C_style). Với MSE là hàm MeanSquareError được dùng phổ biến trong Machine learning.
  • Tính CombineLoss = w1ContentLoss + w2 StyleLoss
  • Tính đạo hàm input_image theo CombineLoss. Update input_image với thuật toán gradient descent. Chú ý, thứ chúng ta tối ưu ở đây là input_image chứ không phải các params của pretrained_model.
  • Quy trình sẽ lặp đi lặp lại nhằm tối ưu giá trị CombineLoss. Ảnh input_image dần được update sao cho ContentLoss và StyleLoss giảm dần. Tức có nội dung giống với content_image, phong cách giống với style_image. Dừng thuật toán khi input_image đủ tốt

1.2 Style_feature

Mình sẽ nói kĩ hơn về Style_feature. Như bạn đã biết, content_feature đã bao hàm thông tin về đường nét, nội dung, hình ảnh của 1 ảnh. style_feature chính là mối quan hệ tương quan giữa các thông tin này với nhau. Trong Đại số, ta có khái niệm GramMatrix. GramMatrix chính là thứ chúng ta cần.

GramMatrix(A)=AATGramMatrix(A) = A * A^T

Phép nhân trong công thức này là phép nhân ma trận (dot product). Thuật toán về cơ bản là thế, bắt tay vào code bạn sẽ dễ hiểu hơn.

2. Thực hành code

2.1 Chuẩn bị

Để tiện cho việc code và train, mình sẽ code trên Google Colab. Hãy download và đọc code ở link sau:

Torch khá tiện khi tính gradient của 1 biến và update biến đó theo các thuật toán optimizer. Vì vậy, mình quyết định dùng Torch thay vì tensorflow. Cũng bài này, 2 năm trước mình code bài này bằng Tensorflow thuần khá vất vả nên mình quyết định đổi sang torch.

2.2 Gõ code 😄

Chắc hẳn đây là phần được mong đợi nhất. Mình sẽ tiến hành code.

Trước hết, do cả project bao gồm code và dataset được lưu trên drive, ta cần mount drive vào môi trường ảo của colab. Bước này có thể bỏ qua nếu bạn run trên máy local

## mount drive to notebook (dataset is saved in mydrive)
from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/MyDrive/Project by me/Colab/Code in AI Blog/Style_transfer_1'

Import lib and init some variable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from PIL import Image
from torchvision import transforms as T
from torchvision import models
from matplotlib import pyplot as plt

img_size = 512
device = torch.device('cuda')
img_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.ToTensor()
])

Một vài hàm phụ trợ:

def load_img(img_fn):
    global device
    global img_transforms
    image = Image.open(img_fn)
    image = img_transforms(image)
    # to batch
    image = image.unsqueeze(0).to(device, torch.float)
    return image

def to_image(img_tensor):
    image = img_tensor.cpu().clone()
    image = image.squeeze(0)
    image = T.ToPILImage()(image)
    return image

def plot_imgs(imgs, cols=3, size=7, title=""):
    rows = len(imgs)//cols + 1
    fig = plt.figure(figsize=(cols*size, rows*size))
    for i, img in enumerate(imgs):
        image = to_image(img) if isinstance(img, torch.Tensor) else img
        fig.add_subplot(rows, cols, i+1)
        plt.imshow(image)
    plt.suptitle(title)
    plt.show()

Các function và module quan trọng, bao gồm: ContentLoss, StyleLoss, GramMatrix, Normalization

def compute_gram_matrix(feature_map):
    batchsize, c, w, h = feature_map.size()
    feature = feature_map.view(batchsize*c, w*h)
    gram_matrix = torch.mm(feature, feature.T)
    gram_matrix = gram_matrix/(batchsize*c*w*h) #normalize
    return gram_matrix

# ContentLoss (perceptual loss) is MSE between input_feature and target feature
class ContentLoss(nn.Module):
    def __init__(self, target_feature):
        super(ContentLoss, self).__init__()
        # target_feature is an constant, not variable     
        self.target_feature = target_feature.detach()

    def forward(self, input_feature):
        self.loss = F.mse_loss(input_feature, self.target_feature)
        return input_feature

# styleLoss is MSE between gram_matrix of input_feature and target feature
class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        gram_matrix = compute_gram_matrix(target_feature)
        self.target_gram_matrix = gram_matrix.detach()

    def forward(self, input_feature):
        input_gram_matrix = compute_gram_matrix(input_feature)
        self.loss = F.mse_loss(input_gram_matrix, self.target_gram_matrix)
        return input_feature

# Additionally, VGG networks are trained on images with
# each channel normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
# We will use them to normalize the image before sending it into the network.
class Normalization(nn.Module):
    def __init__(self):
        super(Normalization, self).__init__()
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).to(device)

    def forward(self, image_tensor):
        return (image_tensor - self.mean) / self.std

Load pretrained-CNN model, build Style Transfer model


def build_ST_model(
    cnn_model, normalization, 
    content_layer_names, style_layer_names,
    content_img, style_img):
    """Build a style_transfer model with specify style and content"""

    cnn_model.eval()
    ST_model = nn.Sequential(normalization)
    all_content_loss = []
    all_style_loss = []
    i = 0

    for layer in cnn_model.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)

        ST_model.add_module(name=name, module=layer)

        if name in content_layer_names:
            target_feature = ST_model(content_img)
            c_loss = ContentLoss(target_feature=target_feature)
            ST_model.add_module(f'content_loss_{i}', module=c_loss)
            all_content_loss.append(c_loss)

        if name in style_layer_names:
            target_feature = ST_model(style_img)
            s_loss = StyleLoss(target_feature=target_feature)
            ST_model.add_module(f'style_loss_{i}', module=s_loss)
            all_style_loss.append(s_loss)

    return ST_model, all_content_loss, all_style_loss

    
# load pretrain-model and build StyleTransfer model
cnn_model = models.vgg19(pretrained=True).features.to(device).eval()
normalization = Normalization()
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
ST_model, all_content_loss, all_style_loss = build_ST_model(
    cnn_model, normalization, content_layers, style_layers, content_img, style_img)

Function chính, chứa trương trình chính để transfer style từ một ảnh sang ảnh khác.

def run_style_transfer(
    cnn_model, normalization, content_layers,
    style_layers, content_img, style_img, num_steps=100,
    style_weight=10000, content_weight=1):
  
    """Run main program to transfer style from style_img to content_img"""
    ST_model, content_losses, style_losses = build_ST_model(
        cnn_model, normalization, 
        content_layers, style_layers,
        content_img, style_img)

    input_image = content_img.clone()
    input_image.requires_grad_()
    optimizer = optim.RMSprop([input_image])

    for i in range(num_steps):
        input_image.data.clamp_(0,1)
        optimizer.zero_grad()
        content_score, style_score = 0, 0
        ST_model(input_image)

        for c_loss in content_losses:
            content_score += c_loss.loss

        for s_loss in style_losses:
            style_score += s_loss.loss

        total_loss = content_score*content_weight + style_score*style_weight
        total_loss.backward()
        optimizer.step()

        if (i+1)%50 == 0:
            print("content_loss: ", content_score.item(), "style_loss: ", style_score.item())

    input_image.data.clamp_(0, 1)
    return input_image

Thay vì khởi tạo random ảnh input_image, mình quyết định gán ảnh content_image cho input_image. Như thế input_image sẽ dễ dàng và học nhanh hơn, không tốn thời gian trong việc khôi phục content. Cũng bởi vậy, trong số của StyleLoss được chọn lớn hơn rất nhiều trọng số của ContentLoss.

style_img = load_img('dataset/style-3.jpg')
content_img = load_img('dataset/content_1.jpg') 
output_img = run_style_transfer(
    cnn_model, normalization, content_layers,
    style_layers, content_img, style_img, num_steps=700,
    style_weight=1000000, content_weight=1)

plot_imgs([content_img, style_img, output_img], size=10)

3 Kết luận

Như bạn thấy, thuật toán StyleTransfer này thực sự không khó, ý tưởng và cách làm rất đơn giản. Trong đề tài StyleTransfer, người ta còn áp dụng cả GAN. Trong thời gian tới, mình sẽ cố gắng viết bài về StyleTransfer với GAN. Hiện tại mình mới lập blog cá nhân tại https://trungthanhnguyen0502.github.io . Bạn có thể follow mình trên Viblo hoặc đón đọc bài trực tiếp từ Blog cá nhân. Cảm ơn bạn đã đọc

Tham khảo

Bài viết cách đây 2 năm của chính mình: https://forum.machinelearningcoban.com/t/style-transfer-tutorial/4026

https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

https://www.tensorflow.org/tutorials/generative/style_transfer


All Rights Reserved