Callback trong fastai (P3)
Intro
Tiếp tục chuỗi bài viết về thư viện fastai, trong bài viết hôm nay, chúng ta sẽ cùng nhau tìm hiểu về hệ thống callback - nguyên liệu chính của training loop trong class Learner.
Một chút về Callback
Callback là gì? Callback về cơ bản chỉ là một function được gọi khi một sự kiện nào đó xảy ra.
Ví dụ khi các bạn code 1 trang web bằng HTML với một nút trên đó. Nếu bạn muốn có 1 tác vụ nào đó được thực hiện khi người dùng bấm nút, khi đó bạn sẽ viết một function làm những việc cần thiết và truyền vào thuộc tính onClick của thẻ HTML. Hàm này được gọi là callback.
Callback trong fastai
Basic
Callback trong fastai được sử dụng để customize training loop của mô hình. Thông thường, training loop viết bằng pytorch sẽ giống như đoạn code ở dưới
def train(train_dl, model, epochs, optimizer, loss_func):
    for _ in range(epochs):
        model.train()
        for xb, yb in train_dl:
            out = model(xb)
            loss = loss_func(out, yb)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        for xb, yb in val_dl: 
            validate(xb, yb)
            model.eval()
Trong nhiều trường hợp ta sẽ cần phải thêm tính năng cho training loop, ví dụ như:
- Thêm regularization
- Hyperparameter scheduling (learning rate, momentum, ...)
- Log metrics
Cho mối trường hợp ta sẽ phải viết lại training loop để thực hiện những chức năng trên. Fastai giải quyết vấn đề này bằng 1 hệ thống callback. Sau mỗi bước của training loop trong fastai (hàm fit của Learner) sẽ có 1 đoạn code gọi tới hàm callback.
 
Dưới đây là một ví dụ cực đơn giản về cách tạo và sử dụng callback:
from fastai.test_utils import synth_learner
from fastai.callback.core import Callback
# 
class CountParamCallback(Callback):
    def before_fit(self):
        print("Num param:", self.count_parameters(self.learn.model))
    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
# khởi tạo learner
learn = synth_learner(cbs=[CountParamCallback()])
with learn.no_mbar():
    learn.fit(2)
Trong ví dụ trên mình đã tạo một callback đơn giản để đếm số parameter của mô hình trước khi train.
Để định nghĩa một callback, ta kế thừa class Callback và định nghĩa một số method đặc biệt :
class CountParamCallback(Callback):
    def before_fit(self):
        print("Num param:", self.count_parameters(self.learn.model))
Tên method cũng khá là dễ hiểu: before_fit nghĩa là trước khi sự kiện fit (train mô hình) thì hãy làm các việc trong method này. Để sử dụng callback thì khi khởi tạo Learner, chúng ta chỉ cần set tham số cbs là một list các callback cần thiết:
learn = synth_learner(cbs=[CountParamCallback(), ...])
#hoặc 
learn = Learner(dls, model, cbs=[...])
Về cơ bản thì hệ thống callback trong fastai cho phép ta truy cập và sửa đổi tất cả mọi thứ trong quá trình huấn luyện mô hình (dữ liệu , optimizer, learning rate, ...), một trong những tác giả của thư viện đã gọi đây là "infinitely customizable training loop".
Các sự kiện trong fastai training loop
Mọi điều chỉnh đối với training loop đều được thực hiện thông qua Callback với các method có tên tương ứng với các sự kiện trong training loop. Ta cũng có thể dễ dàng kết hợp các kỹ thuật khác nhau được định nghĩa trong các callback khác nhau. Một callback có thể implement các sự kiện sau:
- after_create: gọi sau khi khởi tạo- Learner
- before_fit: gọi trước khi bắt đầu training hoặc inference
- before_epoch: gọi ở đầu mỗi epoch, hữu ích khi cần reset trạng thái nào đó sau mỗi epoch
- before_train: gọi trước khi bắt đầu quá trình train của mỗi epoch
- before_batch: gọi ở đầu mỗi batch, sau khi lấy batch ra từ data loader. Có thể dùng để thay đổi input trước khi đi qua mô hình (data augmentation chẳng hạn).
- after_pred: gọi sau khi gọi phương thức- forwardcủa mô hình. Có thể dùng để thay đổi output trước khi cho qua hàm loss (reshape, ...)
- after_loss: gọi sau khi tính loss nhưng trước khi gọi- backward. Có thể dùng để thêm regularization cho loss (L2, L1, ...)
- before_backward: gọi sau khi tính loss
- after_backward: gọi sau khi gọi- backwardcủa hàm loss, nhưng trước khi update tham số mô hình.
- before_step: tương tự- after_backwardnhưng trong docs khuyến khích dùng cái này thay vì- after_backward. Có thể dùng để cập nhật lại gradient (gradient clipping, ...)
- after_step: gọi sau khi cập nhật tham số mô hình (- opimizer.step()) và trước khi gọi- optimizer.zero_grad()
- after_batch: gọi ở cuối mỗi batch
- after_train: gọi ở cuối mỗi epoch
- before_validate: gọi ở đầu quá trình validation của mỗi epoch
- after_validate: gọi ở cuối quá trình validation của mỗi epoch
- after_epoch: gọi ở cuối mỗi epoch
- after_fit: gọi ở cuối quá trình training
Các attribute có thể truy cập trong callback
Khi viết callback, ta có thể truy cập một số attribute của class Learner. Sử dụng bằng cách viết: self.learn.attr
( thay attr bằng attribute tương ứng.
- model: mô hình hiện dùng để train hoặc validate
- dls: object DataLoaders
- loss_func: hàm loss truyền vào khi khởi tạo Learner
- opt: object optimizer
- cbs: danh sách tất cả callback
- dl: dataloader hiện đang sử dụng (train hoặc val dataloader)
- x/xb: input của mô hình lấy từ- dl. Chỉ có thể assign giá trị cho attribute- xb
- y/yb: output của mô hình lấy từ- dl. Chỉ có thể assign giá trị cho attribute- yb
- pred: prediction của- model
- loss_grad: giá trị hàm loss
- loss: bản copy của- loss_grad. Dùng cho logging
- n_epoch: số epoch
- n_iter: độ dài của- dl
- epoch: epoch hiện tại (từ 0 - n_epoch-1)
- iter: index hiện tại của- dl(từ 0 - n_iter - 1)
Một số callback có sẵn trong fastai
Gradient clipping
from fastai.test_utils import synth_learner
from fastai.callback.training import GradientClip
learn = synth_learner()
learn.fit(3, cbs=[GradientClip])
Mix Precision training
Chắc sẽ hữu ích cho bạn nào máy ít VRAM
from fastai.test_utils import synth_learner
from fastai.callback.fp16 import MixedPrecision
fp16 = MixedPrecision()
learn = synth_learner(lr=1.1)
learn.fit(3, cbs=[fp16])
Một cách khác để dùng mix precision trong fastai
learn = synth_learner()
learn.to_fp16()
Weights & Biases
W&B là một công cụ dùng để visualize và theo dõi các thí nghiệm học máy. Chỉ cần thêm lệnh khởi tạo callback bạn có thể log tất tần tật về mô hình của bạn lên W&B
from fastai.callback.wandb import *
wb = WandbCallback(
   log_preds=True,
   log_model=True,
   log_dataset=True
)
# To log only during one training phase
learn.fit(..., cbs=[wb])
# To log continuously for all training phases
learn = learner(..., cbs=[wb])
Ngoài ra, còn có rất nhiều callback hữu ích khác. Các bạn có thể tham khảo trong documentation: https://docs.fast.ai/
Bonus: GAN training loop
Phần này không có sẵn trong fastai nhưng mình thấy khá hay nên giới thiệu ở đây. Hệ thống callback trong fastai khá là linh hoạt nên ta có thể tận dụng nó để implement các training loop phức tạp hơn một chút, điển hình là GAN. Trong repo https://github.com/tmabraham/UPIT, tác giả đã viết một callback chỉ để train phần discriminator của mạng GAN (cụ thể là CycleGAN) và để training loop bình thường lo việc train phần generator.
class CycleGANTrainer(Callback):
   """`Learner` Callback for training a CycleGAN model."""
   run_before = Recorder
   def before_train(self, **kwargs):
       self.crit = self.learn.loss_func.crit
       if not getattr(self,'opt_G',None):
           self.opt_G = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))), self.learn.lr)
       else:
           self.opt_G.hypers = self.learn.opt.hypers
       if not getattr(self, 'opt_D',None):
           self.opt_D = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.D_A), *flatten_model(self.D_B))), self.learn.lr)
       else:
           self.opt_D.hypers = self.learn.opt.hypers
       self.learn.opt = self.opt_G
   def before_batch(self, **kwargs):
       self._set_trainable()
       self._training = self.learn.model.training
       self.learn.xb = (self.learn.xb[0],self.learn.yb[0]),
       self.learn.loss_func.set_input(*self.learn.xb)
   def after_step(self):
       self.opt_D.hypers = self.learn.opt.hypers
   def after_batch(self, **kwargs):
       "Discriminator training loop"
       if self._training:
           # Obtain images
           fake_A, fake_B = TensorBase(self.learn.pred[0].detach()), TensorBase(self.learn.pred[1].detach())
           (real_A, real_B), =self.learn.xb
           real_A, real_B = TensorBase(real_A), TensorBase(real_B)
           self._set_trainable(disc=True)
           # D_A loss calc. and backpropagation
           loss_D_A = 0.5 * (self.crit(self.D_A(real_A), 1) + self.crit(self.D_A(fake_A), 0))
           loss_D_A.backward()
           self.learn.loss_func.D_A_loss = loss_D_A.detach().cpu()
           # D_B loss calc. and backpropagation
           loss_D_B = 0.5 * (self.crit(self.D_B(real_B), 1) + self.crit(self.D_B(fake_B), 0))
           loss_D_B.backward()
           self.learn.loss_func.D_B_loss = loss_D_B.detach().cpu()
           # Optimizer stepping (update D_A and D_B)
           self.opt_D.step()
           self.opt_D.zero_grad()
           self._set_trainable()
Reference
All rights reserved
 
  
  
 
