Callback trong fastai (P3)
Intro
Tiếp tục chuỗi bài viết về thư viện fastai, trong bài viết hôm nay, chúng ta sẽ cùng nhau tìm hiểu về hệ thống callback - nguyên liệu chính của training loop trong class Learner
.
Một chút về Callback
Callback là gì? Callback về cơ bản chỉ là một function được gọi khi một sự kiện nào đó xảy ra.
Ví dụ khi các bạn code 1 trang web bằng HTML với một nút trên đó. Nếu bạn muốn có 1 tác vụ nào đó được thực hiện khi người dùng bấm nút, khi đó bạn sẽ viết một function làm những việc cần thiết và truyền vào thuộc tính onClick
của thẻ HTML. Hàm này được gọi là callback.
Callback trong fastai
Basic
Callback trong fastai được sử dụng để customize training loop của mô hình. Thông thường, training loop viết bằng pytorch sẽ giống như đoạn code ở dưới
def train(train_dl, model, epochs, optimizer, loss_func):
for _ in range(epochs):
model.train()
for xb, yb in train_dl:
out = model(xb)
loss = loss_func(out, yb)
loss.backward()
optimizer.step()
optimizer.zero_grad()
for xb, yb in val_dl:
validate(xb, yb)
model.eval()
Trong nhiều trường hợp ta sẽ cần phải thêm tính năng cho training loop, ví dụ như:
- Thêm regularization
- Hyperparameter scheduling (learning rate, momentum, ...)
- Log metrics
Cho mối trường hợp ta sẽ phải viết lại training loop để thực hiện những chức năng trên. Fastai giải quyết vấn đề này bằng 1 hệ thống callback. Sau mỗi bước của training loop trong fastai (hàm fit của Learner
) sẽ có 1 đoạn code gọi tới hàm callback.
Dưới đây là một ví dụ cực đơn giản về cách tạo và sử dụng callback:
from fastai.test_utils import synth_learner
from fastai.callback.core import Callback
#
class CountParamCallback(Callback):
def before_fit(self):
print("Num param:", self.count_parameters(self.learn.model))
def count_parameters(self, model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# khởi tạo learner
learn = synth_learner(cbs=[CountParamCallback()])
with learn.no_mbar():
learn.fit(2)
Trong ví dụ trên mình đã tạo một callback đơn giản để đếm số parameter của mô hình trước khi train.
Để định nghĩa một callback, ta kế thừa class Callback
và định nghĩa một số method đặc biệt :
class CountParamCallback(Callback):
def before_fit(self):
print("Num param:", self.count_parameters(self.learn.model))
Tên method cũng khá là dễ hiểu: before_fit
nghĩa là trước khi sự kiện fit (train mô hình) thì hãy làm các việc trong method này. Để sử dụng callback thì khi khởi tạo Learner
, chúng ta chỉ cần set tham số cbs
là một list các callback cần thiết:
learn = synth_learner(cbs=[CountParamCallback(), ...])
#hoặc
learn = Learner(dls, model, cbs=[...])
Về cơ bản thì hệ thống callback trong fastai cho phép ta truy cập và sửa đổi tất cả mọi thứ trong quá trình huấn luyện mô hình (dữ liệu , optimizer, learning rate, ...), một trong những tác giả của thư viện đã gọi đây là "infinitely customizable training loop".
Các sự kiện trong fastai training loop
Mọi điều chỉnh đối với training loop đều được thực hiện thông qua Callback với các method có tên tương ứng với các sự kiện trong training loop. Ta cũng có thể dễ dàng kết hợp các kỹ thuật khác nhau được định nghĩa trong các callback khác nhau. Một callback có thể implement các sự kiện sau:
after_create
: gọi sau khi khởi tạoLearner
before_fit
: gọi trước khi bắt đầu training hoặc inferencebefore_epoch
: gọi ở đầu mỗi epoch, hữu ích khi cần reset trạng thái nào đó sau mỗi epochbefore_train
: gọi trước khi bắt đầu quá trình train của mỗi epochbefore_batch
: gọi ở đầu mỗi batch, sau khi lấy batch ra từ data loader. Có thể dùng để thay đổi input trước khi đi qua mô hình (data augmentation chẳng hạn).after_pred
: gọi sau khi gọi phương thứcforward
của mô hình. Có thể dùng để thay đổi output trước khi cho qua hàm loss (reshape, ...)after_loss
: gọi sau khi tính loss nhưng trước khi gọibackward
. Có thể dùng để thêm regularization cho loss (L2, L1, ...)before_backward
: gọi sau khi tính lossafter_backward
: gọi sau khi gọibackward
của hàm loss, nhưng trước khi update tham số mô hình.before_step
: tương tựafter_backward
nhưng trong docs khuyến khích dùng cái này thay vìafter_backward
. Có thể dùng để cập nhật lại gradient (gradient clipping, ...)after_step
: gọi sau khi cập nhật tham số mô hình (opimizer.step()
) và trước khi gọioptimizer.zero_grad()
after_batch
: gọi ở cuối mỗi batchafter_train
: gọi ở cuối mỗi epochbefore_validate
: gọi ở đầu quá trình validation của mỗi epochafter_validate
: gọi ở cuối quá trình validation của mỗi epochafter_epoch
: gọi ở cuối mỗi epochafter_fit
: gọi ở cuối quá trình training
Các attribute có thể truy cập trong callback
Khi viết callback, ta có thể truy cập một số attribute của class Learner
. Sử dụng bằng cách viết: self.learn.attr
( thay attr
bằng attribute tương ứng.
model
: mô hình hiện dùng để train hoặc validatedls
: object DataLoadersloss_func
: hàm loss truyền vào khi khởi tạo Learneropt
: object optimizercbs
: danh sách tất cả callbackdl
: dataloader hiện đang sử dụng (train hoặc val dataloader)x/xb
: input của mô hình lấy từdl
. Chỉ có thể assign giá trị cho attributexb
y
/yb: output của mô hình lấy từdl
. Chỉ có thể assign giá trị cho attributeyb
pred
: prediction củamodel
loss_grad
: giá trị hàm lossloss
: bản copy củaloss_grad
. Dùng cho loggingn_epoch
: số epochn_iter
: độ dài củadl
epoch
: epoch hiện tại (từ 0 - n_epoch-1)iter
: index hiện tại củadl
(từ 0 - n_iter - 1)
Một số callback có sẵn trong fastai
Gradient clipping
from fastai.test_utils import synth_learner
from fastai.callback.training import GradientClip
learn = synth_learner()
learn.fit(3, cbs=[GradientClip])
Mix Precision training
Chắc sẽ hữu ích cho bạn nào máy ít VRAM
from fastai.test_utils import synth_learner
from fastai.callback.fp16 import MixedPrecision
fp16 = MixedPrecision()
learn = synth_learner(lr=1.1)
learn.fit(3, cbs=[fp16])
Một cách khác để dùng mix precision trong fastai
learn = synth_learner()
learn.to_fp16()
Weights & Biases
W&B là một công cụ dùng để visualize và theo dõi các thí nghiệm học máy. Chỉ cần thêm lệnh khởi tạo callback bạn có thể log tất tần tật về mô hình của bạn lên W&B
from fastai.callback.wandb import *
wb = WandbCallback(
log_preds=True,
log_model=True,
log_dataset=True
)
# To log only during one training phase
learn.fit(..., cbs=[wb])
# To log continuously for all training phases
learn = learner(..., cbs=[wb])
Ngoài ra, còn có rất nhiều callback hữu ích khác. Các bạn có thể tham khảo trong documentation: https://docs.fast.ai/
Bonus: GAN training loop
Phần này không có sẵn trong fastai nhưng mình thấy khá hay nên giới thiệu ở đây. Hệ thống callback trong fastai khá là linh hoạt nên ta có thể tận dụng nó để implement các training loop phức tạp hơn một chút, điển hình là GAN. Trong repo https://github.com/tmabraham/UPIT, tác giả đã viết một callback chỉ để train phần discriminator của mạng GAN (cụ thể là CycleGAN) và để training loop bình thường lo việc train phần generator.
class CycleGANTrainer(Callback):
"""`Learner` Callback for training a CycleGAN model."""
run_before = Recorder
def before_train(self, **kwargs):
self.crit = self.learn.loss_func.crit
if not getattr(self,'opt_G',None):
self.opt_G = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))), self.learn.lr)
else:
self.opt_G.hypers = self.learn.opt.hypers
if not getattr(self, 'opt_D',None):
self.opt_D = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.D_A), *flatten_model(self.D_B))), self.learn.lr)
else:
self.opt_D.hypers = self.learn.opt.hypers
self.learn.opt = self.opt_G
def before_batch(self, **kwargs):
self._set_trainable()
self._training = self.learn.model.training
self.learn.xb = (self.learn.xb[0],self.learn.yb[0]),
self.learn.loss_func.set_input(*self.learn.xb)
def after_step(self):
self.opt_D.hypers = self.learn.opt.hypers
def after_batch(self, **kwargs):
"Discriminator training loop"
if self._training:
# Obtain images
fake_A, fake_B = TensorBase(self.learn.pred[0].detach()), TensorBase(self.learn.pred[1].detach())
(real_A, real_B), =self.learn.xb
real_A, real_B = TensorBase(real_A), TensorBase(real_B)
self._set_trainable(disc=True)
# D_A loss calc. and backpropagation
loss_D_A = 0.5 * (self.crit(self.D_A(real_A), 1) + self.crit(self.D_A(fake_A), 0))
loss_D_A.backward()
self.learn.loss_func.D_A_loss = loss_D_A.detach().cpu()
# D_B loss calc. and backpropagation
loss_D_B = 0.5 * (self.crit(self.D_B(real_B), 1) + self.crit(self.D_B(fake_B), 0))
loss_D_B.backward()
self.learn.loss_func.D_B_loss = loss_D_B.detach().cpu()
# Optimizer stepping (update D_A and D_B)
self.opt_D.step()
self.opt_D.zero_grad()
self._set_trainable()
Reference
All rights reserved