Như chúng ta đã biết thì các phương pháp semi-supervised learning đã góp công không nhỏ trong việc cải thiện hơn nữa các model state-of-the-art trong rất nhiều computer vision tasks như image classification, object detection, và semantic segmentation. Các phương pháp như Pseudo Labels hay self-training chắc cũng khá quen thuộc với những người từng làm về semi-supervised learning. Hôm nay mình sẽ giới thiệu đến các bạn một phiên bản nâng cấp của Pseudo Labels, giúp đưa semi-supervised learning lên đỉnh của Imagenet. Nói sơ qua thì cách hoạt động của Pseudo Label khá đơn giản: chúng ta cần 2 model, một gọi là teacher và một là student. Đầu tiên, ta cần huấn luyện teacher model với dữ liệu có nhãn, sau đó sử dụng model teacher để predict ra nhãn giả - pseudo label của dữ liệu chưa có nhãn, từ bây giờ ta sẽ gọi dữ liệu có nhãn chuẩn là labeled data và dữ liệu được sinh nhãn giả là pseudo data cho ngắn gọn nhé. Pseudo data sẽ được kết hợp với labeled data để huấn luyện cho model student, nhờ có sự bổ sung này mà student có thể mang lại kết quả tốt hơn so với teacher.
Mặc dù phương pháp kể trên khá hiệu quả nhưng vẫn tồn tại một nhược điểm lớn là sẽ có những pseudo label mà teacher sinh ra không chính xác, kéo theo student cũng sẽ học từ dữ liệu sai lệch đấy và kết quả là performance của student bị giảm sút. Điểm yếu này được gọi là confirmation bias trong pseudo-labeling.
Paper mà hôm nay mình muốn thảo luận với mọi người là một phiên bản nâng cấp xịn xò của Pseudo Label - Meta Pseudo Labels. Những gì mà Meta Pseudo Label muốn làm là cải thiện nhược điểm kể trên của teacher thông qua việc quan sát pseudo label mà nó sinh ra sẽ có ảnh hưởng gì đến student, nghĩa là nó sẽ nhận lại feedback của student sau khi học từ pseudo label và tự chỉnh sửa lại bản thân để cho ra những phiên bản pseudo label tốt hơn. Vì phần chứng minh của paper khá nhiều toán nên mình sẽ cố gắng đi chậm, nếu các bạn phát hiện ra mình sai thì đừng ngần ngại góp ý nhé 😁
2. Meta Pseudo Labels
Trái: Pseudo Labels, teacher được cố định sau khi train với labeled data sau đó sinh pseudo label cho student học. Phải: Meta Pseudo Labels, teacher được train song song với student.
Ký hiệu:
T,S : mô hình teacher và student
θt,θs : tham số của teacher và student
θsPL: tham số của mô hình student được train với pseudo label tạo bởi teacher
∣T∣,∣S∣ : dimension của student và teacher
(xL,yL) : labeled data gồm image và label
xU : unlabeled data chỉ gồm image
T(xU,θT): soft prediction của teacher với unlabeled data
S(xU,θS),S(xL,θS): soft prediction của student với xU và xL
CE(q,p): cross-entropy loss giữa 2 phân phối q, p với q là label
Ex[f] : giá trị kỳ vọng của phương trình f với biến ngẫu nhiên x.
∇ : gradient
2.1 Revisit Pseudo Label
Trước khi nói về Meta Pseudo Label, ta sẽ mở đầu với việc ôn lại 1 chút về Pseudo Label nhé. Như đã giới thiệu ở phần mở đầu, Pseudo Label huấn luyện student model với unlabeled data để tối thiểu hóa hàm cross-entropy:
Lu(θT,θS) : loss của student khi train với pseudo label tạo bởi teacher trên unlabeled data.
Giả sử ta đã có model teacher được train tốt với tập labeled, mục tiêu của pseudo label là tạo ra θsPL tối ưu trên tập labeled data:
Exl,yl[CE(yl,S(xl;θSPL))]:=Ll(θSPL)(2)
2.2 Solution for confirmation-bias
Với Pseudo Labels, muốn cho student θsPL tối ưu thì bắt buộc phải phụ thuộc vào teacher θT thông qua pseudo label T(xU,θT). Để miêu tả sự phụ thuộc này ta sẽ dùng ký hiệu θSPL(θT). Như vậy hàm loss của student trên labeled data có thể được viết gọn lại như sau: Ll(θSPL(θT)) và tất nhiên nhiệm vụ của hàm này sẽ là tối ưu 2 tham số θSPL và θT. Từ đó ta có thể tối ưu hóa Ll theo θT như sau:
θTmin trong đoˊLl(θSPL(θT)),θSPL(θT)=θSargminLu(θT,θS)(3)
Theo như công thức trên thì ta có thể tối ưu hóa teacher thông qua biểu hiện của student, từ đó pseudo label dùng để train student cũng sẽ dần được cải thiện. Tuy nhiên do mối phụ thuộc θSPL(θT) và θT là vô cùng phức tạp nên việc tính gradient ∇θT(θSPL(θT)) nếu muốn diễn ra thì bắt buộc phải thay đổi toàn bộ quá trình training của student.
Để đơn giản hóa việc này, ta sẽ áp dụng ý tưởng của meta-learning : xấp xỉ θSargmin bằng cách update từng bước gradient của θT:
θSPL(θT)≈θS−ηS⋅∇θSLu(θT,θS)(4)
với ηS laˋ learning rate của student
Thay biểu thức trên vào phương trình (3) ta sẽ có hàm tối ưu của teacher trong Meta Pseudo Labels:
θTminLl(θS−ηS⋅∇θSLu(θT,θS))(5)
Về cơ bản thì quá trình training của student vẫn phụ thuộc vào phương trình (1) của Pseudo Labels, ngoại trừ việc tham số của teacher sẽ không còn cố định mà thay đổi dần dựa vào student. Từ đó chúng ta sẽ rút ra được quá trình tối ưu hóa song song teacher - student:
Student: sử dụng pseudo label từ teacher - T(xU,θT) để tối ưu hóa hàm mục tiêu với SGD:
θS′=θS−ηS⋅∇θSLu(θT,θS)(6)
Teacher: sử dụng labeled data kết hợp với feedback của student để cải thiện pseudo label và tối ưu hóa hàm mục tiêu với SGD:
θT′=θT−ηT⋅∇θTLl(θS−∇θSLu(θT,θS))(7)
2.3 Teacher's auxiliary losses
Các tác giả thấy rằng Meta Pseudo Labels tự thân nó đã khá tốt rồi, tuy nhiên nếu thêm một task phụ vào quá trình training của teacher thì performance sẽ còn tốt hơn. Do đó khi train teacher với labeled data, ta có thể thêm một auxiliary task dạng self-supervised để tận dụng unlabeled data giúp tăng độ generalization của model teacher. Auxiliary task này được thực hiện theo paper UDA (Unsupervised Data Augmentation for Consistency Training) với tổng quan như sau:
Ta có thể mô tả 1 cách đơn giản về UDA như sau:
B1 : Với labeled data (x,y), ta để model predict label y^=Pθ(y∣x) và tính supervised loss Lsup=CE(y,y^)
B2 : Với unlabeled data (x), ta tiến hành augment x để có x^, sau đó để model predict label cho x và x^ : Pθ(y∣x) và Pθ(y∣x^) rồi tính unsupervised loss với 2 label trên : Lunsup(Pθ(y∣x),Pθ(y∣x^))
B3: tính loss tổng : Lfinal=Lsup+α⋅Lunsup và optimize model dựa trên loss tổng
2.4 Derivation of the Teacher’s Update Rule
Nhắc lại một số ký hiệu toán học:
cho hàm khả vi f:Rm→Rn,x↦f(x),x∈Rm, ta sẽ tìm được ma trận jacobi của f dựa trên đạo hàm từng phần hàm f với x:
Giờ ta sẽ vào món chính: tính gradient cho quá trình cập nhật teacher. Giả sử với một batch unlabeled data xu, teacher sẽ sinh pseudo label y^u∼T(xu;θT), sau đó student sử dụng (xu,y^u) để cập nhật tham số θS của nó. Chúng ta kỳ vọng tham số mới của student sẽ có dạng Eyu∼T(xu;θT)[θS−ηS∇θSCE(yu,S(xu;θS))]. Ta sẽ cập nhật tham số của teacher trên tập labeled data thông qua cross-entropy của sự thay đổi giữa tham số của student cũ và student mới:
Xét phương trình (8), phần A chính là quá trình train student θS′ với labeled data sau khi đã train θS với pseudo data để có θS′, phần này hoàn toàn có thể tính thông qua backprop thông thường.
Chú ý : với pt (11), jacobian của CE(yu,S(xu;θS)) có dim=1×∣S∣ cần được chuyển vị để khớp với dimθS=∣S∣×1.
Vậy thì tại sao θS có dim=∣S∣×1 và ∇θSCE có dim=1×∣S∣ ? Ở đây ∣S∣ chính là số lượng tham số có trong student. Với θS là tham số của student nên dĩ nhiên dim của nó sẽ là ∣S∣ và mỗi tham số trong student là duy nhất nên dimθS=∣S∣×1. Còn ∇θSCE là gradient của hàm loss với biến là θS và chỉ có 1 θS được xét đến, trong θS có ∣S∣ lượng tham số nên dim∇θSCE=1×∣S∣.
Xét phương trình (11), để đơn giản thì ta đặt gS là ký hiệu gradient của student:
Bây giờ chúng ta sẽ đi giải quyết "củ khoai" này nhé, theo như paper thì đúng ra sẽ dùng REINFORCE algorithm, nhưng mà mình có đọc qua paper gốc được viết năm 1992 thì thấy khó nuốt quá nên có thử tự giải theo cách "dễ nhai" hơn. Mọi người xem qua và cho ý kiến về cách giải của mình nhé.
Với phương trình (13) thì ta sẽ đi giải quyết đạo hàm của hàm kỳ vọng Eyu[gS(yu)]. Một cách tổng quát thì kỳ vọng của hàm f(x) với biến ngẫu nhiên rời rạc x sẽ có dạng:
E[f(x)]=x∑P(x)f(x)
Áp dụng công thức trên vào phương trình (13) với Eyu[gS(yu)]:
Cuối cùng, ta sẽ sử dụng phép xấp xỉ Monte-Carlo cho mọi biểu thức trong pt(18) với y^u đã tính từ trước. Cụ thể hơn thì ta sẽ tính xấp xỉ θˉS′ với θS bằng cách cập nhật tham số student với (xu,yu): θS′=θS−ηS⋅∇θSCE(y^u,S(xu;θs)). Đồng thời ước lượng E cũng với y^u. Với kết quả ước lượng vừa rồi, ta sẽ tính được gradient của ∇θTLu(θT,θS).
Pt(18) là dạng tổng quát cho 1 batch dữ liệu. Để tường minh hơn ta sẽ lấy 1 mẫu ngẫu nhiên trong batch để tính gradient:
Đến đây là hết phần diễn giải cách cập nhật của teacher dựa trên gradient của student rồi nhỉ, các bạn thấy scalar h bên trên chứ ? Đấy chính là thứ mà chúng ta mong muốn từ đầu đến giờ : feedback của student để teacher cải thiện performance. Khi các bạn xem phần pseudo code với UDA bên dưới thì sẽ thấy một h tương tự:
Tuy nhiên, khi xem code của Meta Pseudo Label thì các bạn sẽ thấy h được tính như thế này:
. Biến dot_product chính là công thức tính h lằng nhằng phía trên đấy :v
Nếu viết lại theo công thức toán học dựa trên đoạn code thì h sẽ được tính như sau: h=L(θS)−L(θS′). Vậy tại sao từ h dài dòng lại có thể biến đổi thành phép trừ 2 hàm loss đơn giản như vậy? Thử chứng minh 1 chút nhé:
θS′=θS−ηS∇θSCE(y^u,S(xu;θS))
Đặt η=ηS∇θSCE(y^u,S(xu;θS)) ta coˊ:
θS′=θS−η
Áp dụng công thức xấp xỉ taylor: f(x+h)=f(x)+hf′(x)
Dưới đây là toàn bộ quá trình train teacher với UDA và feedback từ student:
Và kết quả SoTA của MPL với EfficientNet-L2:
Lời kết
Bài viết của mình đến đây là đã hoàn thành mục đích ban đầu: cố gắng thử thách bản thân với một paper kinh điển do các idol người Việt viết và mang paper này đến với mọi người một cách dễ hiểu nhất. Nếu có thắc mắc thì các bạn có thể comment bên dưới, mình sẽ cố gắng trả lời trong tầm kiến thức của bản thân. Hoặc nếu các bạn phát hiện lỗi sai thì cứ thẳng thắn góp ý nhé. Cảm ơn các bạn đã đọc bài.