Sinh dữ liệu với mô hình diffusion và mô hình dạng SDE tổng quát
Báo cáo
Thêm vào series của tôi
Trong bài viết này, mình sẽ giới thiệu về mô hình diffusion, một mô hình sinh với sự đột phá gần đây, cùng với mô hình score matching đã vượt qua GAN trong việc sinh dữ liệu. Hai mô hình này có thể xem như trường hợp đặc biệt của phương trình vi phân ngẫu nhiên, và được tổng quát thành mô hình dạng phương trình vi phân ngẫu nhiên (Stochastic differential equation - SDE), đưa ra một góc nhìn mới cũng như việc kết hợp hai loại mô hình này. Mô hình diffusion cũng như mô hình dạng SDE khi sinh dữ liệu không điều kiện thậm chí còn cho kết quả tốt hơn GAN khi sinh dữ liệu với nhãn cho trước.
Do nội dung khá dài nên phần cài đặt mình sẽ để sang bài khác nếu có thời gian, các bạn có thể xem notebook tutorial của tác giả tại đây. Một số chứng minh chi tiết mình sẽ để ở cuối để tránh đi xa khỏi nội dung chính, bạn đọc quan tâm có thể đọc thêm.
Mô hình diffusion
Ý tưởng của phương pháp này là biến đổi phân bố dữ liệu thành một phân bố có thể lấy mẫu được. Việc sinh dữ liệu sẽ bắt đầu từ phân bố này, sau đó biến đổi ngược về phân bố ban đầu. Mô hình cần học ở đây sẽ là phép biến đổi ngược đó. Quá trình biến đổi này được mô tả bằng một chuỗi các phân bố, cụ thể hơn chúng ta sẽ sử dụng quá trình ngẫu nhiên để mô tả chuỗi này.
Quá trình ngẫu nhiên là một họ các biến ngẫu nhiên {Xt}t∈T từ cùng một không gian xác suất sang cùng một không gian trạng thái. Ở đây ta chỉ quan tâm đến tập chỉ số T có thứ tự, được hiểu như trục thời gian, ví dụ T=R+ hoặc T=Z+.
Quá trình ngẫu nhiên được gọi là quá trình Markov nếu nó thỏa mãn tính chất Markov. Một cách trực quan, xác suất của trạng thái tại tương lai khi biết trạng thái hiện tại không phụ thuộc vào quá khứ. Đối với chuỗi Markov, tính chất này có thể được viết thành
P(Xn+m=i∣X1,…,Xn)=P(Xn+m=i∣Xn)
Quá trình thuận
Tại từng mốc thời gian, dữ liệu sẽ được biến đổi bằng cách thêm lần lượt nhiễu, quá trình này được gọi là quá trình thuận. Để cho đơn giản, xác suất chuyển từ thời điểm t sang thời điểm s sẽ được kí hiệu là q(xs∣xt).Từ tính chất Markov, xác suất liên hợp được phân tích thành
q(x0…xt)=q(x0)i=1∏Tq(xi∣xi−1)
Trong quá trình thuận, xác suất chuyển sẽ được xác định trước là một phân bố nào đó. Ở đây tham số của phân bố này sẽ được cố định trước, do đó không tham gia vào quá trình huấn luyện, tuy nhiên trong trường hợp khác chúng cũng có thể xem như tham số học được.
Đối với dữ liệu ảnh, chúng ta có thể xem chúng như dữ liệu liên tục, và xác suất chuyển q(xt∣xt−1) được mô hình bởi N(xt;1−βtxt−1,βtI).
Xác suất khi biết trạng thái x0 cũng là phân bố Gaussian, đạt được nhờ tính chất Markov. Đặt αt=1−βt,αtˉ=∏i=1tαi, ta có q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I), chứng minh công thức có thể xem ở cuối bài.
Phân bố tại trạng thái xT được xem như prior, sao cho có thể lấy mẫu được. Nhờ vào tính chất của xác suất điều kiện trên, với βt phù hợp, q(xT)≈N(0,I).
Quá trình nghịch
Khi đã có phân bố prior p(xT) dễ lấy mẫu rồi, việc sinh dữ liệu sẽ bắt đầu lấy mẫu từ phân bố này, sau đó biến đổi ngược trở về phân bố p(x0) ban đầu. Việc làm này cũng có thể mô tả bởi một quá trình ngẫu nhiên, gọi là quá trình nghịch, chúng ta sẽ cần một mô hình học quá trình ngẫu nhiên này. Quá trình này có thể xem như một chuỗi Markov với chiều ngược lại, do đó xác suất liên hợp được phân tích thành
p(x0…xT)=p(xT)i=1∏Tp(xi−1∣xi)
Mục tiêu lúc này là tìm xác suất chuyển p(xt−1∣xt) của chuỗi Markov này. Ta sẽ mô hình xác suất này bởi phân bố Gaussian, có dạng N(xt−1;μθ(xt,t),Σθ(xt,t)). Khi mô hình được quá trình nghịch rồi, ở bước sinh mẫu, dữ liệu từ phân bố của xT sẽ được biến đổi thêm lần lượt nhiễu dựa trên xác suất chuyển này.
Huấn luyện
Mục tiêu của quá trình huấn luyện là cực đại likelihood của phân bố dữ liệu của mô hình sinh
Thành phần logq(x1∣x0) là xác suất chuyển của quá trình thuận, do đó không có tham số và có thể loại bỏ trong quá trình huấn luyện. Mục tiêu của chúng ta là cực đại chặn dưới của log likelihood, tương đương với việc cực tiểu hàm mục tiêu sau
Các xác suất ở trên đều là phân bố Gaussian, do đó khoảng cách KL có thể tính từ kì vọng và phương sai. Đối với posterior, q(xt−1∣xt,x0) sẽ là phân bố Gaussian N(xt−1;1−αˉtαˉt−1βtx0+1−αˉtαt(1−αˉt−1)xt,β~tI), với β~t=1−αˉt1−αˉt−1βt, chứng minh công thức có thể xem ở cuối bài.
Mô hình denoise diffusion
Để cho đơn giản, Σθ(xt,t) sẽ được đặt là σt2I, với σt được chọn trước, do đó không tham gia vào quá trình huấn luyện. Tác giả đưa ra hai lựa chọn σt2=βt và σt2=β~t, tương đương với việc entropy H(q(xt−1∣xt)) lớn nhất và nhỏ nhất (xem thêm ở phần cuối), qua thực nghiệm hai cách chọn này cho kết quả tương đương.
Kí hiệu μ~t(xt,x0) là kì vọng của q(xt−1∣xt,x0), với khoảng cách KL giữa hai phân bố Gaussian ta có
Hàm μθ(xt,t) dự đoán kì vọng μ~(xt,x0)=1−αˉtαˉt−1βtx0+1−αˉtαt(1−αˉt−1)xt của q(xt−1∣xt,x0) khi biết xt và t. Điều này tương đương với việc dự đoán x0 khi biết xt. Tuy nhiên, từ thực nghiệm, tác giả thấy việc tham số như vậy không đưa ra kết quả tốt. Từ xác suất chuyển của quá trình thuận, chúng ta có xt(x0,ϵ)=αtˉx0+1−αtˉϵ, trong đó ϵ∼N(0,I). Nói cách khác, trong quá trình thuận, x0 có thể được tham số bởi xt(x0,ϵ) và một biến ngẫu nhiên độc lập ϵ thông qua x0=αˉt1(xt−1−αˉtϵ). Như vậy, thay vì đoán x0 khi biết xt, chúng ta có thể xây dựng mô hình ϵθ(xt,t) đoán nhiễu ϵ khi biết xt (đây là lí do cho từ denoise trong tên gọi).
Từ cách tham số này, chúng ta có thể thay vào μ~(xt,x0) để được
Tương tự như vậy, μθ(xt,t) lúc này sẽ được tham số như sau
μθ(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
Nhắc lại, trong quá trình huấn luyện (quá trình thuận), xt có thể tính từ x0 thông qua xt(x0,ϵ)=αtˉx0+1−αtˉϵ. Lúc này, hàm mục tiêu sẽ trở thành
và hàm mục tiêu cho tại toàn bộ vị trí sẽ là L=Et[Lt−1] với t tuân theo phân bố đều U{1,T}.
Để cho đơn giản, chúng ta có thể tối ưu với phiên bản không trọng số của hàm mục tiêu bên trên
L=Ex0,ϵ,t[∣∣ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∣∣2]
Lấy mẫu
Thay vì mô hình trực tiếp kì vọng của p(xt−1∣xt), chúng ta đã mô hình nhiễu ϵθ(xt,t). Do đó, ở bước lấy mẫu, giả sử đã biết xt, chúng ta sẽ tính lại kì vọng này qua công thức
μ~(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
Lúc này xt−1 sẽ được tính bởi
xt−1=μ~(xt,t)+σtz,z∼N(0,I)
Bắt đầu từ xT∼N(0,I), chúng ta thực hiện tuần tự T bước đến khi tìm được x0.
Mô hình SDE tổng quát
Liên hệ giữa mô hình diffusion và score matching
Hàm mục tiêu của mô hình denoise diffusion có thể xem như denoise score matching. Với phân bố q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I), score ∇logq(xt∣x0) của phân bố này sẽ là 1−αˉtαˉtx0−xt. Chú ý −1−αˉtαˉtx0−xt=ϵ, nếu ta thay biến ngẫu nhiên này cho ϵ trong thành phần Lt−1 của hàm mục tiêu không trọng số trong mô hình denoise diffusion, ta có
Nhắc lại, với mô hình Noise conditional score network(NCSN), mô hình sẽ tối ưu khoảng cách Fisher giữa phân bố của dữ liệu khi thêm nhiễu và phân bố của mô hình trên nhiều mức độ nhiễu. Ta có thể thấy hàm mục tiêu của mô hình denoise diffusion chính là hàm mục tiêu của NCSN với trọng số (1−αˉt) khi sử dụng denoise score matching. Tương tự như NCSN, trọng số (1−αˉt) có tính chất (1−αˉt)∝1/E[∣∣∇logq(xt∣x0)∣∣2]. Cách nhìn này cho thấy sự liên hệ giữa phương pháp score matching và mô hình diffusion, đó là thay đổi phân bố dữ liệu bằng một họ các nhiễu, và học mô hình khử nhiễu lần lượt. Từ đây, ta có thể tổng quát cả hai phương pháp này, bằng cách mô hình họ các nhiễu bởi quá trình ngẫu nhiên liên tục, biểu diễn bởi một phương trình vi phân ngẫu nhiên (SDE).
Mô hình với SDE
Cụ thể hơn, với phân bố dữ liệu p0 ban đầu, ta mong muốn biến đổi nó thành một phân bố đơn giản pT, theo nghĩa có thể lấy mẫu một cách dễ dàng, ví dụ như N(0,I) trong mô hình diffusion. Nói cách khác, ta cần một quá trình ngẫu nhiên Xt với t∈[0,T] sao cho p(x0)=p0,p(xT)=pT. Quá trình ngẫu nhiên này có thể mô tả bởi phương trình vi phân ngẫu nhiên Itô (từ bây giờ khi nhắc đến SDE, chúng ta sẽ hiểu đó là Itô SDE)
dxt=f(x,t)dt+g(t)dw
trong đó f(x,t):Rd×R+↦Rd, g(t):R+↦R, dw kí hiệu một cách hình thức vi phân của chuyển động Brown. Một cách trực quan, dw=N(0,ΔtI) với Δt→0. Để cho đơn giản, chúng ta chỉ xét g(t) có dạng trên, tuy nhiên tất cả kết quả bên dưới đều có thể mở rộng cho hàm g(t) trả về ma trận.
Quá trình thuận và nghịch lúc này sẽ là quá trình liên tục. Để phân bố không bị thay đổi quá nhiều theo thời gian, chúng ta sẽ mô hình bởi quá trình diffusion mô tả bởi phương trình vi phân ngẫu nhiên như trên. Phương trình này có thể hiểu như sau: Theo thời gian, trạng thái của biến ngẫu nhiên x bị thay đổi dần dần theo hàm f, tuy nhiên nếu chỉ như vậy, quá trình biến đổi này sẽ là tất định, do đó chúng ta sẽ thêm thành phần ngẫu nhiên dw vào phép biến đổi này. Chuyển động Brown có thể hiểu như một quá trình ngẫu nhiên mà tại mỗi thời điểm, xác suất để chuyển sang các trạng thái có vị trí gần với trạng thái hiện tại cao hơn. Một ví dụ minh họa là một người bước đi một cách ngẫu nhiên trên mặt phẳng, mỗi bước do đó không thể có độ dài quá lớn. Một ví dụ khác từ góc nhìn phân bố: Xác suất trạng thái biến ngẫu nhiên x tại mỗi thời điểm có thể xem như nhiệt độ tại vị trí (trạng thái) đó, theo thời gian nhiệt độ sẽ tản ra từ từ sang các vị trí lân cận, cuối cùng hội tụ đến một phân bố nào đó.
SDE của mô hình diffusion
Nhắc lại quá trình thuận của mô hình diffusion có thể được mô tả bởi quá trình ngẫu nhiên {xt}t=0T. Giả sử chúng ta dùng σt2=βt, chuỗi Markov có dạng
xt=1−βtxt−1+βtzt−1,z∼N(0,1)
Quá trình ngẫu nhiên này có thể xem như rời rạc của một quá trình ngẫu nhiên liên tục, chúng ta sẽ tìm quá trình này bằng cách cho T→∞. Đặt βˉt=Tβt, chuỗi này sẽ tiến về một hàm β(t):[0,1]↦R, β(Tt)=βˉt. Tương tự quá trình ngẫu nhiên của xi và zi cũng tiến tới quá trình ngẫu nhiên liên tục x(Tt)=xt,z(Tt)=zt. Đặt Δt=t/T, dùng khai triển Taylor bậc 1, phương trình trên có thể viết lại thành
Nhắc lại, mô hình NCSN thêm lần lượt nhiễu với phương sai {σt}t=1N vào phân bố dữ liệu. Quá trình này có thể viết lại thành
xt=xt−1−σt2−σt−12z,z∼N(0,I)
với σ0=0.
Lập luận tương tự như trên, ta có thể tính giới hạn khi N→∞
x(t+Δt)=x(t)+σ2(t+Δt)−σ2(t)z(t)≈dtdσ2(t)Δtz(t)
sử dụng khai triển Taylor bậc 1 của σ2(t). Khi Δt→0, chuỗi xt hội tụ tới quá trình ngẫu nhiên mô tả bởi
dxt=dtdσ2(t)dw
Lấy mẫu
Việc lấy mẫu tương đương với đảo chiều thời gian của quá trình ngẫu nhiên. Quá trình nghịch này được mô tả bởi SDE sau
dxt=(f(x,t)−g(t)2∇xtlogpt(xt))dt+g(t)dwˉ
ở đây wˉ là chuyển động Brown theo chiều ngược lại, từ 1 về 0.
Nếu biết được score của pt(x), chúng ta có thể mô phỏng lại quá trình ngược này. Bắt đầu từ xT∼pT, từ phương trình trên, chúng ta sẽ biến đổi xT thành x0 tuân theo phân bố p0 của dữ liệu. Như vậy, mục tiêu của chúng ta là xây dựng mô hình sθ(x(t),t) xấp xỉ ∇xtlogpt(xt)).
Giải SDE
Quá trình lấy mẫu được thực hiện bằng cách giải phương trình SDE nghịch. Tương tự như khi rời rạc hóa quá trình thuận, chúng ta có thể giải bằng cách rời rạc hóa quá trình nghịch
Quay lại với cách cập nhật của mô hình denoise diffusion, giả sử ta dùng σt2=βt
xt−1=αt1(xt−1−αˉtβtϵθ(xt,t))+βtz
Đặt sθ(xt,t)=−1−αˉtϵθ(xt,t), ta có thể biến đổi như sau
xt−1=1−βt1(xt+βts(xt,t))+βtz≈(1+21βt)(xt+βts(xt,t))+βtzkhai triển Taylor=(1+21βt)xt+βts(xt,t)+21βt2s(xt,t)+βtz≈xt+21βtxt+βts(xt,t)+βtz
Quá trình nghịch của SDE ứng với mô hình diffusion là
dxt=(−21β(t)xt−βt∇xlogpt(xt))dt+β(t)dw
Ta có thể thấy thuật toán lấy mẫu của mô hình denoise diffusion gần giống với việc giải quá trình nghịch thông qua rời rạc hóa.
Lấy mẫu với Predictor-Corrector
Ở phần trước, ta đã biết quá trình lấy mẫu có thể thực hiện bằng việc giải phương trình SDE nghịch, và thuật toán lấy mẫu của mô hình diffusion thuộc loại này. Mặt khác, ta đang mô hình score của pt(xt), do đó ta cũng có thể lấy mẫu với (annealed) Langevin dynamics.
Để có thể sinh dữ liệu tốt hơn, chúng ta có thể kết hợp hai phương pháp này. Lấy mẫu thông qua giải SDE sẽ được xem như thuật toán chính, gọi là Predictor. Ở bước thứ i trong Predictor, sau khi cập nhật xT−i qua xT−i+1, chúng ta sẽ thực hiện Langevin dynamics M lần với s(xT−i,T−i)
xT−i=xT−i+ϵis(xT−i,T−i)+2ϵiz,z∼N(0,I)
Từ góc nhìn này, cách sinh dữ liệu của NCSN có thể xem như Predictor là hàm đồng nhất, Corrector là Langevin dynamics, cách sinh dữ liệu của mô hình denoise diffusion có thể xem như Predictor là giải quá trình nghịch, Corrector là hàm đồng nhất.
Huấn luyện
Tương tự như hàm mục tiêu của NCSN cũng như mô hình denoise diffusion, hàm mục tiêu của mô hình SDE sẽ có dạng score matching trên tất cả mức độ nhiễu. Điểm khác biệt là biến thời gian t lúc này là biến ngẫu nhiên liên tục tuân theo phân bố đều U[0,1]
Ở đây λ(t) là hàm trọng số, có thể chọn giống như NCSN và mô hình denoise diffusion là λ(t)∝1/E[∣∣∇logq(xt∣x0)∣∣2].
Việc tính hàm mất mát yêu cầu score của phân bố chuyển trong quá trình thuận. Đối với trường hợp SDE tổng quát, ta cần giải phương trình Kolmogorov tiến để tìm phân bố này. Khi f(x,t)=a(t)x+b(t), phân bố chuyển là phân bố Gaussian, do đó chỉ cần biết kì vọng và phương sai để tính score. Kì vọng mt và ma trận hiệp phương sai Pt sẽ thỏa mãn phương trình vi phân sau
dtdmt=a(t)mt+b(t)
dtdPt=2a(t)Pt+g(t)2
Để tránh việc phải tính phân bố chuyển, chúng ta có thể dùng phương pháp score matching khác, ví dụ như sliced score matching, với hàm mục tiêu
với Js(.,t)(xt) là ma trận Jacobian của s(xt,t), v⊺Js(.,t)(xt)v tính bởi v⊺∇(v⊺s(xt,t)).
Kết luận
Trong bài này, mình đã giới thiệu mô hình diffusion và mô hình dạng SDE tổng quát mà trong đó score matching và mô hình diffusion là trường hợp đặc biệt. Cách tiếp cận này hiện đã cho kết quả tốt nhất hiện tại cho mô hình sinh.
Tuy nhiên cách tiếp cận này có các nhược điểm sau: Các trạng thái có cùng số chiều, do đó việc mô hình quá trình nghịch cần đảm bảo điều đó chứ không thể thay đổi số chiều dữ liệu. Việc lấy mẫu tốn khá nhiều thời gian, do cần phải đi từng bước để giải phương trình SDE nghịch, chưa tính đến việc kết hợp với Corrector trong quá trình lấy mẫu.
Tính chất trên có thể chứng minh dễ dàng bằng tính chất Markov.
Chúng ta chỉ cần chứng minh cho t=2, các trường hợp còn lại có thể suy ra theo quy nạp. Hơn nữa, ma trận hiệp phương sai có dạng βtI, do đó ta chỉ cần chứng minh cho trường hợp x∈R.
Phần này sẽ chỉ ra chặn trên và chặn dưới của H(q(xt−1∣xt)) bởi entropy của xác suất chuyển trong quá trình thuận.
Trước hết ta có
H(xt−1∣xt)=H(xt∣xt−1)+H(xt−1)−H(xt)
Do xt tính từ xt−1 và nhiễu z, H(xt)≥H(xt∣z)≥H(xt−1∣z)≥H(xt−1). Suy ra H(xt−1∣xt)≤H(xt∣xt−1). Nếu mô hình xác suất trong quá trình nghịch của xt−1 khi biết xt bởi N(xt−1;μ(xt),σt2I), dấu bằng xảy ra khi σt2=βt.
Đối với chặn dưới, ta có H(x0∣xt)≥H(x0∣xt−1), suy ra
Một số tính chất của phương trình vi phân ngẫu nhiên
Phương trình vi phân ngẫu nhiên Itô với điều kiện đầu x0=x
dxt=f(xt,t)dt+g(t)dw
là biểu diễn hình thức của phương trình tích phân sau
xt=x+∫0tf(xt,t)dt+∫0tg(t)dw
Tích phân đầu tiên là tích phân Riemann-Stieltjes thông thường. Tuy nhiên ta không thể tính tích phân thứ hai như vậy, do chuyển động Brown không thỏa mãn tính chất bounded variation. Thay vào đó, ta sẽ sử dụng tích phân Itô để tính đại lượng này
Với quá trình ngẫu nhiên xt và một hàm tất định u(x,t):Rd×R+↦R, chúng ta cũng không thể tính đạo hàm toàn phần dtdu(xt,t) bằng chain rule như thông thường, thay vào đó chúng ta sẽ dùng công thức Itô
Thay u=xi, ta tính được kì vọng của xi, từ đó suy ra kì vọng của x. Với ma trận hiệp phương sai, ta thay u=xixj−mi(t)mj(t), với m(t) là kì vọng của xt.