+7

Callback trong fastai (P3)

Intro

Tiếp tục chuỗi bài viết về thư viện fastai, trong bài viết hôm nay, chúng ta sẽ cùng nhau tìm hiểu về hệ thống callback - nguyên liệu chính của training loop trong class Learner.

Một chút về Callback

Callback là gì? Callback về cơ bản chỉ là một function được gọi khi một sự kiện nào đó xảy ra. Ví dụ khi các bạn code 1 trang web bằng HTML với một nút trên đó. Nếu bạn muốn có 1 tác vụ nào đó được thực hiện khi người dùng bấm nút, khi đó bạn sẽ viết một function làm những việc cần thiết và truyền vào thuộc tính onClick của thẻ HTML. Hàm này được gọi là callback.

Callback trong fastai

Basic

Callback trong fastai được sử dụng để customize training loop của mô hình. Thông thường, training loop viết bằng pytorch sẽ giống như đoạn code ở dưới

def train(train_dl, model, epochs, optimizer, loss_func):
    for _ in range(epochs):
        model.train()
        for xb, yb in train_dl:
            out = model(xb)
            loss = loss_func(out, yb)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        for xb, yb in val_dl: 
            validate(xb, yb)
            model.eval()

Trong nhiều trường hợp ta sẽ cần phải thêm tính năng cho training loop, ví dụ như:

  • Thêm regularization
  • Hyperparameter scheduling (learning rate, momentum, ...)
  • Log metrics

Cho mối trường hợp ta sẽ phải viết lại training loop để thực hiện những chức năng trên. Fastai giải quyết vấn đề này bằng 1 hệ thống callback. Sau mỗi bước của training loop trong fastai (hàm fit của Learner) sẽ có 1 đoạn code gọi tới hàm callback.

fastai_training_loop_callbacks.png

Fastai training loop với callback

Dưới đây là một ví dụ cực đơn giản về cách tạo và sử dụng callback:

from fastai.test_utils import synth_learner
from fastai.callback.core import Callback

# 
class CountParamCallback(Callback):
    def before_fit(self):
        print("Num param:", self.count_parameters(self.learn.model))

    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

# khởi tạo learner
learn = synth_learner(cbs=[CountParamCallback()])
with learn.no_mbar():
    learn.fit(2)

Trong ví dụ trên mình đã tạo một callback đơn giản để đếm số parameter của mô hình trước khi train. Để định nghĩa một callback, ta kế thừa class Callback và định nghĩa một số method đặc biệt :

class CountParamCallback(Callback):
    def before_fit(self):
        print("Num param:", self.count_parameters(self.learn.model))

Tên method cũng khá là dễ hiểu: before_fit nghĩa là trước khi sự kiện fit (train mô hình) thì hãy làm các việc trong method này. Để sử dụng callback thì khi khởi tạo Learner, chúng ta chỉ cần set tham số cbs là một list các callback cần thiết:

learn = synth_learner(cbs=[CountParamCallback(), ...])
#hoặc 
learn = Learner(dls, model, cbs=[...])

Về cơ bản thì hệ thống callback trong fastai cho phép ta truy cập và sửa đổi tất cả mọi thứ trong quá trình huấn luyện mô hình (dữ liệu , optimizer, learning rate, ...), một trong những tác giả của thư viện đã gọi đây là "infinitely customizable training loop".

Các sự kiện trong fastai training loop

Mọi điều chỉnh đối với training loop đều được thực hiện thông qua Callback với các method có tên tương ứng với các sự kiện trong training loop. Ta cũng có thể dễ dàng kết hợp các kỹ thuật khác nhau được định nghĩa trong các callback khác nhau. Một callback có thể implement các sự kiện sau:

  • after_create: gọi sau khi khởi tạo Learner
  • before_fit: gọi trước khi bắt đầu training hoặc inference
  • before_epoch: gọi ở đầu mỗi epoch, hữu ích khi cần reset trạng thái nào đó sau mỗi epoch
  • before_train: gọi trước khi bắt đầu quá trình train của mỗi epoch
  • before_batch: gọi ở đầu mỗi batch, sau khi lấy batch ra từ data loader. Có thể dùng để thay đổi input trước khi đi qua mô hình (data augmentation chẳng hạn).
  • after_pred: gọi sau khi gọi phương thức forward của mô hình. Có thể dùng để thay đổi output trước khi cho qua hàm loss (reshape, ...)
  • after_loss: gọi sau khi tính loss nhưng trước khi gọi backward. Có thể dùng để thêm regularization cho loss (L2, L1, ...)
  • before_backward: gọi sau khi tính loss
  • after_backward: gọi sau khi gọi backward của hàm loss, nhưng trước khi update tham số mô hình.
  • before_step: tương tự after_backward nhưng trong docs khuyến khích dùng cái này thay vì after_backward. Có thể dùng để cập nhật lại gradient (gradient clipping, ...)
  • after_step: gọi sau khi cập nhật tham số mô hình (opimizer.step()) và trước khi gọi optimizer.zero_grad()
  • after_batch: gọi ở cuối mỗi batch
  • after_train: gọi ở cuối mỗi epoch
  • before_validate: gọi ở đầu quá trình validation của mỗi epoch
  • after_validate: gọi ở cuối quá trình validation của mỗi epoch
  • after_epoch: gọi ở cuối mỗi epoch
  • after_fit: gọi ở cuối quá trình training

Các attribute có thể truy cập trong callback

Khi viết callback, ta có thể truy cập một số attribute của class Learner. Sử dụng bằng cách viết: self.learn.attr ( thay attr bằng attribute tương ứng.

  • model: mô hình hiện dùng để train hoặc validate
  • dls: object DataLoaders
  • loss_func: hàm loss truyền vào khi khởi tạo Learner
  • opt: object optimizer
  • cbs: danh sách tất cả callback
  • dl: dataloader hiện đang sử dụng (train hoặc val dataloader)
  • x/xb: input của mô hình lấy từ dl. Chỉ có thể assign giá trị cho attribute xb
  • y/yb: output của mô hình lấy từ dl. Chỉ có thể assign giá trị cho attribute yb
  • pred: prediction của model
  • loss_grad: giá trị hàm loss
  • loss: bản copy của loss_grad. Dùng cho logging
  • n_epoch: số epoch
  • n_iter: độ dài của dl
  • epoch: epoch hiện tại (từ 0 - n_epoch-1)
  • iter: index hiện tại của dl (từ 0 - n_iter - 1)

Một số callback có sẵn trong fastai

Gradient clipping

from fastai.test_utils import synth_learner
from fastai.callback.training import GradientClip

learn = synth_learner()
learn.fit(3, cbs=[GradientClip])

Mix Precision training

Chắc sẽ hữu ích cho bạn nào máy ít VRAM

from fastai.test_utils import synth_learner
from fastai.callback.fp16 import MixedPrecision

fp16 = MixedPrecision()
learn = synth_learner(lr=1.1)
learn.fit(3, cbs=[fp16])

Một cách khác để dùng mix precision trong fastai

learn = synth_learner()
learn.to_fp16()

Weights & Biases

W&B là một công cụ dùng để visualize và theo dõi các thí nghiệm học máy. Chỉ cần thêm lệnh khởi tạo callback bạn có thể log tất tần tật về mô hình của bạn lên W&B

from fastai.callback.wandb import *

wb = WandbCallback(
   log_preds=True,
   log_model=True,
   log_dataset=True
)
# To log only during one training phase
learn.fit(..., cbs=[wb])

# To log continuously for all training phases
learn = learner(..., cbs=[wb])

Ngoài ra, còn có rất nhiều callback hữu ích khác. Các bạn có thể tham khảo trong documentation: https://docs.fast.ai/

Bonus: GAN training loop

Phần này không có sẵn trong fastai nhưng mình thấy khá hay nên giới thiệu ở đây. Hệ thống callback trong fastai khá là linh hoạt nên ta có thể tận dụng nó để implement các training loop phức tạp hơn một chút, điển hình là GAN. Trong repo https://github.com/tmabraham/UPIT, tác giả đã viết một callback chỉ để train phần discriminator của mạng GAN (cụ thể là CycleGAN) và để training loop bình thường lo việc train phần generator.

class CycleGANTrainer(Callback):
   """`Learner` Callback for training a CycleGAN model."""
   run_before = Recorder

   def before_train(self, **kwargs):
       self.crit = self.learn.loss_func.crit
       if not getattr(self,'opt_G',None):
           self.opt_G = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))), self.learn.lr)
       else:
           self.opt_G.hypers = self.learn.opt.hypers
       if not getattr(self, 'opt_D',None):
           self.opt_D = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.D_A), *flatten_model(self.D_B))), self.learn.lr)
       else:
           self.opt_D.hypers = self.learn.opt.hypers

       self.learn.opt = self.opt_G

   def before_batch(self, **kwargs):
       self._set_trainable()
       self._training = self.learn.model.training
       self.learn.xb = (self.learn.xb[0],self.learn.yb[0]),
       self.learn.loss_func.set_input(*self.learn.xb)

   def after_step(self):
       self.opt_D.hypers = self.learn.opt.hypers

   def after_batch(self, **kwargs):
       "Discriminator training loop"
       if self._training:
           # Obtain images
           fake_A, fake_B = TensorBase(self.learn.pred[0].detach()), TensorBase(self.learn.pred[1].detach())
           (real_A, real_B), =self.learn.xb
           real_A, real_B = TensorBase(real_A), TensorBase(real_B)
           self._set_trainable(disc=True)
           # D_A loss calc. and backpropagation
           loss_D_A = 0.5 * (self.crit(self.D_A(real_A), 1) + self.crit(self.D_A(fake_A), 0))
           loss_D_A.backward()
           self.learn.loss_func.D_A_loss = loss_D_A.detach().cpu()
           # D_B loss calc. and backpropagation
           loss_D_B = 0.5 * (self.crit(self.D_B(real_B), 1) + self.crit(self.D_B(fake_B), 0))
           loss_D_B.backward()
           self.learn.loss_func.D_B_loss = loss_D_B.detach().cpu()
           # Optimizer stepping (update D_A and D_B)
           self.opt_D.step()
           self.opt_D.zero_grad()
           self._set_trainable()

Reference


All rights reserved

Viblo
Hãy đăng ký một tài khoản Viblo để nhận được nhiều bài viết thú vị hơn.
Đăng kí