+21

Sinh dữ liệu với mô hình diffusion và mô hình dạng SDE tổng quát.

Trong bài viết này, mình sẽ giới thiệu về mô hình diffusion, một mô hình sinh với sự đột phá gần đây, cùng với mô hình score matching đã vượt qua GAN trong việc sinh dữ liệu. Hai mô hình này có thể xem như trường hợp đặc biệt của phương trình vi phân ngẫu nhiên, và được tổng quát thành mô hình dạng phương trình vi phân ngẫu nhiên (Stochastic differential equation - SDE), đưa ra một góc nhìn mới cũng như việc kết hợp hai loại mô hình này. Mô hình diffusion cũng như mô hình dạng SDE khi sinh dữ liệu không điều kiện thậm chí còn cho kết quả tốt hơn GAN khi sinh dữ liệu với nhãn cho trước.

Do nội dung khá dài nên phần cài đặt mình sẽ để sang bài khác nếu có thời gian, các bạn có thể xem notebook tutorial của tác giả tại đây. Một số chứng minh chi tiết mình sẽ để ở cuối để tránh đi xa khỏi nội dung chính, bạn đọc quan tâm có thể đọc thêm.

Mô hình diffusion

Ý tưởng của phương pháp này là biến đổi phân bố dữ liệu thành một phân bố có thể lấy mẫu được. Việc sinh dữ liệu sẽ bắt đầu từ phân bố này, sau đó biến đổi ngược về phân bố ban đầu. Mô hình cần học ở đây sẽ là phép biến đổi ngược đó. Quá trình biến đổi này được mô tả bằng một chuỗi các phân bố, cụ thể hơn chúng ta sẽ sử dụng quá trình ngẫu nhiên để mô tả chuỗi này.

image.png

Quá trình ngẫu nhiên là một họ các biến ngẫu nhiên {Xt}tT\{X_t\}_{t\in T} từ cùng một không gian xác suất sang cùng một không gian trạng thái. Ở đây ta chỉ quan tâm đến tập chỉ số TT có thứ tự, được hiểu như trục thời gian, ví dụ T=R+T=\mathbb{R}^+ hoặc T=Z+T=\mathbb{Z}^+.

Quá trình ngẫu nhiên được gọi là quá trình Markov nếu nó thỏa mãn tính chất Markov. Một cách trực quan, xác suất của trạng thái tại tương lai khi biết trạng thái hiện tại không phụ thuộc vào quá khứ. Đối với chuỗi Markov, tính chất này có thể được viết thành

P(Xn+m=iX1,,Xn)=P(Xn+m=iXn)\mathbb{P}(X_{n+m}=i|X_1,\dots,X_n)=\mathbb{P}(X_{n+m}=i|X_n)

Quá trình thuận

Tại từng mốc thời gian, dữ liệu sẽ được biến đổi bằng cách thêm lần lượt nhiễu, quá trình này được gọi là quá trình thuận. Để cho đơn giản, xác suất chuyển từ thời điểm tt sang thời điểm ss sẽ được kí hiệu là q(xsxt)q(x_s|x_t).Từ tính chất Markov, xác suất liên hợp được phân tích thành

q(x0xt)=q(x0)i=1Tq(xixi1)q(x_0\dots x_t)=q(x_0)\prod_{i=1}^T q(x_i|x_{i-1})

Trong quá trình thuận, xác suất chuyển sẽ được xác định trước là một phân bố nào đó. Ở đây tham số của phân bố này sẽ được cố định trước, do đó không tham gia vào quá trình huấn luyện, tuy nhiên trong trường hợp khác chúng cũng có thể xem như tham số học được.

Đối với dữ liệu ảnh, chúng ta có thể xem chúng như dữ liệu liên tục, và xác suất chuyển q(xtxt1)q(x_t|x_{t-1}) được mô hình bởi N(xt;1βtxt1,βtI)\mathcal{N}(x_t;\sqrt{1-\beta_t} x_{t-1}, \beta_tI). Xác suất khi biết trạng thái x0x_0 cũng là phân bố Gaussian, đạt được nhờ tính chất Markov. Đặt αt=1βt,αtˉ=i=1tαi\alpha_t = 1-\beta_t,\,\bar{\alpha_t}=\prod_{i=1}^t \alpha_i, ta có q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_0, (1-\bar{\alpha_t})I), chứng minh công thức có thể xem ở cuối bài.

Phân bố tại trạng thái xTx_T được xem như prior, sao cho có thể lấy mẫu được. Nhờ vào tính chất của xác suất điều kiện trên, với βt\beta_t phù hợp, q(xT)N(0,I)q(x_T)\approx \mathcal{N}(0, I).

Quá trình nghịch

Khi đã có phân bố prior p(xT)p(x_T) dễ lấy mẫu rồi, việc sinh dữ liệu sẽ bắt đầu lấy mẫu từ phân bố này, sau đó biến đổi ngược trở về phân bố p(x0)p(x_0) ban đầu. Việc làm này cũng có thể mô tả bởi một quá trình ngẫu nhiên, gọi là quá trình nghịch, chúng ta sẽ cần một mô hình học quá trình ngẫu nhiên này. Quá trình này có thể xem như một chuỗi Markov với chiều ngược lại, do đó xác suất liên hợp được phân tích thành

p(x0xT)=p(xT)i=1Tp(xi1xi)p(x_0\dots x_T) = p(x_T)\prod_{i=1}^{T}p(x_{i-1}|x_i)

Mục tiêu lúc này là tìm xác suất chuyển p(xt1xt)p(x_{t-1}|x_t) của chuỗi Markov này. Ta sẽ mô hình xác suất này bởi phân bố Gaussian, có dạng N(xt1;μθ(xt,t),Σθ(xt,t))\mathcal{N}(x_{t-1};\mu_{\theta}(x_t,t), \Sigma_{\theta}(x_t, t)). Khi mô hình được quá trình nghịch rồi, ở bước sinh mẫu, dữ liệu từ phân bố của xTx_T sẽ được biến đổi thêm lần lượt nhiễu dựa trên xác suất chuyển này.

Huấn luyện

Mục tiêu của quá trình huấn luyện là cực đại likelihood của phân bố dữ liệu của mô hình sinh

p(x0)=p(x0xT)dx1xT=p(x0xT)q(x1xTx0)q(x1xTx0)dx1xT=p(xT)i=1Tp(xi1xi)q(xixi1)dQ(x1xTx0)\begin{aligned} p(x_0)&=\int p(x_0\dots x_T)dx_1\dots x_T\\ &=\int \frac{p(x_0\dots x_T)}{q(x_1\dots x_T|x_0)}q(x_1\dots x_T|x_0)dx_1\dots x_T\\ &= \int p(x_T)\prod_{i=1}^T \frac{p(x_{i-1}|x_i)}{q(x_i|x_{i-1})} dQ(x_1\dots x_T|x_0) \end{aligned}

Chúng ta sẽ muốn biến đổi likelihood sao cho có thể tối ưu trên từng thành phần (tương ứng với thời điểm) riêng. Áp dụng bất đẳng thức Jensen ta có

logp(x0)log(p(xT)i=1Tp(xi1xi)q(xixi1))dQ(x1xTx0)\begin{aligned} \log p(x_0) &\geq\int \log(p(x_T)\prod_{i=1}^T \frac{p(x_{i-1}|x_i)}{q(x_i|x_{i-1})}) dQ(x_1\dots x_T|x_0) \end{aligned}

với t>1t>1, ta có thể tính posterior như sau

q(xtxt1)=q(xtxt1,x0)tıˊnh chaˆˊt Markov=q(xt1xt,x0)q(xtx0)q(xt1x0)\begin{aligned} q(x_t|x_{t-1}) &= q(x_t|x_{t-1}, x_0)\quad \text{tính chất Markov}\\ &=\frac{q(x_{t-1}|x_t, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)} \end{aligned}

Chặn dưới của log likelihood trở thành

L(x0)=Eq(x1xTx0)[logp(XT)+t=2Tlogp(Xt1Xt)q(XtXt1)+logp(x0X1)+logq(X1x0)]=Eq(x1xTx0)[logp(XT)q(XTx0)+t=2Tlogp(Xt1XT)q(Xt1Xt,x0)+logp(x0X1)+logq(X1x0)]\begin{aligned} L(x_0)&=\mathbb{E}_{q(x_1\dots x_{T}|x_0)}[\log p(X_T)+\sum_{t=2}^T\log\frac{p(X_{t-1}|X_t)}{q(X_t|X_{t-1})} + \log p(x_0|X_1) + \log q(X_1|x_0)]\\ &=\mathbb{E}_{q(x_1\dots x_{T}|x_0)}[\log\frac{p(X_T)}{q(X_T|x_0)}+\sum_{t=2}^T\log \frac{p(X_{t-1}|X_T)}{q(X_{t-1}|X_t, x_0)} + \log p(x_0|X_1) + \log q(X_1|x_0)]\\ \end{aligned}

Thành phần logq(x1x0)\log q(x_1|x_0) là xác suất chuyển của quá trình thuận, do đó không có tham số và có thể loại bỏ trong quá trình huấn luyện. Mục tiêu của chúng ta là cực đại chặn dưới của log likelihood, tương đương với việc cực tiểu hàm mục tiêu sau

L=Eq[KL(q(xTX0)p(xT))+t=2TKL(q(xt1Xt,X0)p(xt1Xt))logp(X0X1)]L=\mathbb{E}_q[KL(q(x_T|X_0)||p(x_T)) +\sum_{t=2}^TKL(q(x_{t-1}|X_t, X_0)||p(x_{t-1}|X_t))-\log p(X_0|X_1)]

với kì vọng được lấy theo q(x0xT)q(x_0\dots x_T).

Các xác suất ở trên đều là phân bố Gaussian, do đó khoảng cách KL có thể tính từ kì vọng và phương sai. Đối với posterior, q(xt1xt,x0)q(x_{t-1}|x_t, x_0) sẽ là phân bố Gaussian N(xt1;αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt,β~tI)\mathcal{N}(x_{t-1}; \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 +\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t, \tilde\beta_tI), với β~t=1αˉt11αˉtβt\tilde \beta_t=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t, chứng minh công thức có thể xem ở cuối bài.

Mô hình denoise diffusion

Để cho đơn giản, Σθ(xt,t)\Sigma_{\theta}(x_t, t) sẽ được đặt là σt2I\sigma_t^2I, với σt\sigma_t được chọn trước, do đó không tham gia vào quá trình huấn luyện. Tác giả đưa ra hai lựa chọn σt2=βt\sigma_t^2=\beta_tσt2=β~t\sigma_t^2=\tilde \beta_t, tương đương với việc entropy H(q(xt1xt))H(q(x_{t-1}|x_t)) lớn nhất và nhỏ nhất (xem thêm ở phần cuối), qua thực nghiệm hai cách chọn này cho kết quả tương đương.

Kí hiệu μ~t(xt,x0)\tilde\mu_t(x_t, x_0) là kì vọng của q(xt1xt,x0)q(x_{t-1}|x_t, x_0), với khoảng cách KL giữa hai phân bố Gaussian ta có

Lt1=Eq[KL(q(xt1Xt,X0)p(xt1Xt))]=Ex0,xt[12σt2μ~t(Xt,X0)μθ(Xt,t)2]+C\begin{aligned} L_{t-1} &= \mathbb{E}_q[KL(q(x_{t-1}|X_t, X_0)||p(x_{t-1}|X_t))]\\ &=\mathbb{E}_{x_0, x_t}[\frac{1}{2\sigma_t^2}||\tilde\mu_t(X_t,X_0)-\mu_{\theta}(X_t,t)||^2] + C \end{aligned}

Hàm μθ(xt,t)\mu_{\theta}(x_t,t) dự đoán kì vọng μ~(xt,x0)=αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt\tilde\mu(x_t,x_0)=\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 +\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t của q(xt1xt,x0)q(x_{t-1}|x_t, x_0) khi biết xtx_ttt. Điều này tương đương với việc dự đoán x0x_0 khi biết xtx_t. Tuy nhiên, từ thực nghiệm, tác giả thấy việc tham số như vậy không đưa ra kết quả tốt. Từ xác suất chuyển của quá trình thuận, chúng ta có xt(x0,ϵ)=αtˉx0+1αtˉϵx_t(x_0,\epsilon) = \sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon, trong đó ϵN(0,I)\epsilon\sim\mathcal{N}(0,I). Nói cách khác, trong quá trình thuận, x0x_0 có thể được tham số bởi xt(x0,ϵ)x_t(x_0,\epsilon) và một biến ngẫu nhiên độc lập ϵ\epsilon thông qua x0=1αˉt(xt1αˉtϵ)x_0=\frac{1}{\sqrt{\bar \alpha_t}}(x_t-\sqrt{1-\bar\alpha_t}\epsilon). Như vậy, thay vì đoán x0x_0 khi biết xtx_t, chúng ta có thể xây dựng mô hình ϵθ(xt,t)\epsilon_{\theta}(x_t,t) đoán nhiễu ϵ\epsilon khi biết xtx_t (đây là lí do cho từ denoise trong tên gọi).

Từ cách tham số này, chúng ta có thể thay vào μ~(xt,x0)\tilde\mu(x_t,x_0) để được

μ~(xt,x0)=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉt(1αˉt(xt1αˉtϵ))=1αt(xtβt1αˉtϵ)\begin{aligned} \tilde\mu(x_t,x_0)&=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}(\frac{1}{\sqrt{\bar \alpha_t}}(x_t-\sqrt{1-\bar\alpha_t}\epsilon))\\ &=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon) \end{aligned}

Tương tự như vậy, μθ(xt,t)\mu_{\theta}(x_t,t) lúc này sẽ được tham số như sau

μθ(xt,t)=1αt(xtβt1αˉtϵθ(xt,t))\mu_{\theta}(x_t,t)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}(x_t,t))

Nhắc lại, trong quá trình huấn luyện (quá trình thuận), xtx_t có thể tính từ x0x_0 thông qua xt(x0,ϵ)=αtˉx0+1αtˉϵx_t(x_0,\epsilon) = \sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon. Lúc này, hàm mục tiêu sẽ trở thành

Lt1=Ex0,ϵ[βt22σt2αt(1αtˉ)ϵϵθ(αtˉx0+1αtˉϵ,t)2]L_{t-1}=\mathbb{E}_{x_0,\epsilon}[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha_t})}||\epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon,t)||^2]

và hàm mục tiêu cho tại toàn bộ vị trí sẽ là L=Et[Lt1]L=\mathbb{E}_t[L_{t-1}] với tt tuân theo phân bố đều U{1,T}\mathcal{U}\{1,T\}.

Để cho đơn giản, chúng ta có thể tối ưu với phiên bản không trọng số của hàm mục tiêu bên trên

L=Ex0,ϵ,t[ϵϵθ(αtˉx0+1αtˉϵ,t)2]L=\mathbb{E}_{x_0,\epsilon,t}[||\epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon,t)||^2]

Lấy mẫu

Thay vì mô hình trực tiếp kì vọng của p(xt1xt)p(x_{t-1}|x_t), chúng ta đã mô hình nhiễu ϵθ(xt,t)\epsilon_{\theta}(x_t,t). Do đó, ở bước lấy mẫu, giả sử đã biết xtx_t, chúng ta sẽ tính lại kì vọng này qua công thức

μ~(xt,t)=1αt(xtβt1αˉtϵθ(xt,t))\tilde \mu(x_t,t)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}(x_t,t))

Lúc này xt1x_{t-1} sẽ được tính bởi

xt1=μ~(xt,t)+σtz,zN(0,I)x_{t-1}=\tilde\mu(x_t,t)+\sigma_t z,\,z\sim\mathcal{N}(0,I)

Bắt đầu từ xTN(0,I)x_T\sim \mathcal{N}(0,I), chúng ta thực hiện tuần tự TT bước đến khi tìm được x0x_0.

Mô hình SDE tổng quát

Liên hệ giữa mô hình diffusion và score matching

Hàm mục tiêu của mô hình denoise diffusion có thể xem như denoise score matching. Với phân bố q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_0, (1-\bar{\alpha_t})I), score logq(xtx0)\nabla\log q(x_t|x_0) của phân bố này sẽ là αˉtx0xt1αˉt\frac{\sqrt{\bar\alpha_t}x_0-x_t}{1-\bar\alpha_t}. Chú ý αˉtx0xt1αˉt=ϵ-\frac{\sqrt{\bar\alpha_t}x_0-x_t}{\sqrt{1-\bar\alpha_t}}=\epsilon, nếu ta thay biến ngẫu nhiên này cho ϵ\epsilon trong thành phần Lt1L_{t-1} của hàm mục tiêu không trọng số trong mô hình denoise diffusion, ta có

Lt1=Ex0,xt[1αˉtlogq(xtx0)ϵθ(xt,t)2]=(1αˉt)Ex0,xt[logq(xtx0)sθ(xt,t)2]\begin{aligned} L_{t-1}&=\mathbb{E}_{x_0,x_t}[||-\sqrt{1-\bar\alpha_t}\nabla\log q(x_t|x_0)-\epsilon_{\theta}(x_t,t)||^2]\\ &=(1-\bar\alpha_t)\mathbb{E}_{x_0,x_t}[||\nabla\log q(x_t|x_0)-s_{\theta}(x_t,t)||^2] \end{aligned}

với sθ(xt,t)=ϵθ(xt,t)1αˉts_{\theta}(x_t,t)=-\frac{\epsilon_{\theta}(x_t,t)}{\sqrt{1-\bar\alpha_t}}. Lúc này, hàm mục tiêu sẽ là

L=t=1T(1αˉt)Ex0,xt[logq(xtx0)sθ(xt,t)2]L=\sum_{t=1}^T(1-\bar\alpha_t)\mathbb{E}_{x_0,x_t}[||\nabla\log q(x_t|x_0)-s_{\theta}(x_t,t)||^2]

Nhắc lại, với mô hình Noise conditional score network(NCSN), mô hình sẽ tối ưu khoảng cách Fisher giữa phân bố của dữ liệu khi thêm nhiễu và phân bố của mô hình trên nhiều mức độ nhiễu. Ta có thể thấy hàm mục tiêu của mô hình denoise diffusion chính là hàm mục tiêu của NCSN với trọng số (1αˉt)(1-\bar\alpha_t) khi sử dụng denoise score matching. Tương tự như NCSN, trọng số (1αˉt)(1-\bar\alpha_t) có tính chất (1αˉt)1/E[logq(xtx0)2](1-\bar\alpha_t)\propto1/\mathbb{E}[||\nabla\log q(x_t|x_0)||^2]. Cách nhìn này cho thấy sự liên hệ giữa phương pháp score matching và mô hình diffusion, đó là thay đổi phân bố dữ liệu bằng một họ các nhiễu, và học mô hình khử nhiễu lần lượt. Từ đây, ta có thể tổng quát cả hai phương pháp này, bằng cách mô hình họ các nhiễu bởi quá trình ngẫu nhiên liên tục, biểu diễn bởi một phương trình vi phân ngẫu nhiên (SDE).

Mô hình với SDE

image.png

Cụ thể hơn, với phân bố dữ liệu p0p_0 ban đầu, ta mong muốn biến đổi nó thành một phân bố đơn giản pTp_T, theo nghĩa có thể lấy mẫu một cách dễ dàng, ví dụ như N(0,I)\mathcal{N}(0,I) trong mô hình diffusion. Nói cách khác, ta cần một quá trình ngẫu nhiên Xt{X_t} với t[0,T]t\in[0,T] sao cho p(x0)=p0,p(xT)=pTp(x_0)=p_0, p(x_T)=p_T. Quá trình ngẫu nhiên này có thể mô tả bởi phương trình vi phân ngẫu nhiên Itô (từ bây giờ khi nhắc đến SDE, chúng ta sẽ hiểu đó là Itô SDE)

dxt=f(x,t)dt+g(t)dwdx_t=f(x,t)dt + g(t)dw

trong đó f(x,t):Rd×R+Rdf(x,t):\mathbb{R}^d\times\mathbb{R}^+\mapsto \mathbb{R}^d, g(t):R+Rg(t):\mathbb{R}^+\mapsto\mathbb{R}, dwdw kí hiệu một cách hình thức vi phân của chuyển động Brown. Một cách trực quan, dw=N(0,ΔtI)dw=\mathcal{N}(0,\Delta tI) với Δt0\Delta t\to 0. Để cho đơn giản, chúng ta chỉ xét g(t)g(t) có dạng trên, tuy nhiên tất cả kết quả bên dưới đều có thể mở rộng cho hàm g(t)g(t) trả về ma trận.

Quá trình thuận và nghịch lúc này sẽ là quá trình liên tục. Để phân bố không bị thay đổi quá nhiều theo thời gian, chúng ta sẽ mô hình bởi quá trình diffusion mô tả bởi phương trình vi phân ngẫu nhiên như trên. Phương trình này có thể hiểu như sau: Theo thời gian, trạng thái của biến ngẫu nhiên xx bị thay đổi dần dần theo hàm ff, tuy nhiên nếu chỉ như vậy, quá trình biến đổi này sẽ là tất định, do đó chúng ta sẽ thêm thành phần ngẫu nhiên dwdw vào phép biến đổi này. Chuyển động Brown có thể hiểu như một quá trình ngẫu nhiên mà tại mỗi thời điểm, xác suất để chuyển sang các trạng thái có vị trí gần với trạng thái hiện tại cao hơn. Một ví dụ minh họa là một người bước đi một cách ngẫu nhiên trên mặt phẳng, mỗi bước do đó không thể có độ dài quá lớn. Một ví dụ khác từ góc nhìn phân bố: Xác suất trạng thái biến ngẫu nhiên xx tại mỗi thời điểm có thể xem như nhiệt độ tại vị trí (trạng thái) đó, theo thời gian nhiệt độ sẽ tản ra từ từ sang các vị trí lân cận, cuối cùng hội tụ đến một phân bố nào đó.

SDE của mô hình diffusion

Nhắc lại quá trình thuận của mô hình diffusion có thể được mô tả bởi quá trình ngẫu nhiên {xt}t=0T\{x_t\}_{t=0}^T. Giả sử chúng ta dùng σt2=βt\sigma_t^2=\beta_t, chuỗi Markov có dạng

xt=1βtxt1+βtzt1,zN(0,1)x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}z_{t-1},\, z\sim\mathcal{N}(0,1)

Quá trình ngẫu nhiên này có thể xem như rời rạc của một quá trình ngẫu nhiên liên tục, chúng ta sẽ tìm quá trình này bằng cách cho TT\to\infty. Đặt βˉt=Tβt\bar\beta_t=T\beta_t, chuỗi này sẽ tiến về một hàm β(t):[0,1]R\beta(t):[0,1]\mapsto\mathbb{R}, β(tT)=βˉt\beta(\frac{t}{T})=\bar\beta_t. Tương tự quá trình ngẫu nhiên của xix_iziz_i cũng tiến tới quá trình ngẫu nhiên liên tục x(tT)=xt,z(tT)=ztx(\frac{t}{T})=x_t, z(\frac{t}{T})=z_t. Đặt Δt=t/T\Delta t=t/T, dùng khai triển Taylor bậc 1, phương trình trên có thể viết lại thành

x(t+Δt)=1β(t+Δt)Δtx(t)+β(t+Δt)Δtz(t)x(t)12β(t)Δtx(t)+β(t)Δtz(t)\begin{aligned} x(t+\Delta t)&=\sqrt{1-\beta(t+\Delta t)\Delta t}x(t)+\sqrt{\beta(t+\Delta t)\Delta t}z(t)\\ &\approx x(t)-\frac{1}{2}\beta(t)\Delta tx(t)+\sqrt{\beta(t)\Delta t}z(t) \end{aligned}

Khi Δt0\Delta t\to 0, phương trình này hội tụ tới SDE

dxt=12β(t)xtdt+β(t)dwdx_t=-\frac{1}{2}\beta(t)x_tdt+\sqrt{\beta(t)}dw

SDE của mô hình NCSN

Nhắc lại, mô hình NCSN thêm lần lượt nhiễu với phương sai {σt}t=1N\{\sigma_t\}_{t=1}^N vào phân bố dữ liệu. Quá trình này có thể viết lại thành

xt=xt1σt2σt12z,zN(0,I)x_{t}=x_{t-1}-\sqrt{\sigma_{t}^2-\sigma_{t-1}^2}z,\qquad z\sim\mathcal{N}(0,I)

với σ0=0\sigma_0=0. Lập luận tương tự như trên, ta có thể tính giới hạn khi NN\to\infty

x(t+Δt)=x(t)+σ2(t+Δt)σ2(t)z(t)dσ2(t)dtΔtz(t)x(t+\Delta t)=x(t)+\sqrt{\sigma^2(t+\Delta t)-\sigma^2(t)}z(t)\approx\sqrt{\frac{d \sigma^2(t)}{dt}\Delta t}z(t)

sử dụng khai triển Taylor bậc 1 của σ2(t)\sigma^2(t). Khi Δt0\Delta t\to 0, chuỗi xtx_t hội tụ tới quá trình ngẫu nhiên mô tả bởi

dxt=dσ2(t)dtdwdx_t=\sqrt{\frac{d \sigma^2(t)}{dt}}dw

Lấy mẫu

Việc lấy mẫu tương đương với đảo chiều thời gian của quá trình ngẫu nhiên. Quá trình nghịch này được mô tả bởi SDE sau

dxt=(f(x,t)g(t)2xtlogpt(xt))dt+g(t)dwˉdx_t=(f(x,t)-g(t)^2\nabla_{x_t}\log p_t(x_t))dt + g(t)d\bar w

ở đây wˉ\bar w là chuyển động Brown theo chiều ngược lại, từ 11 về 00.

Nếu biết được score của pt(x)p_t(x), chúng ta có thể mô phỏng lại quá trình ngược này. Bắt đầu từ xTpTx_T\sim p_T, từ phương trình trên, chúng ta sẽ biến đổi xTx_T thành x0x_0 tuân theo phân bố p0p_0 của dữ liệu. Như vậy, mục tiêu của chúng ta là xây dựng mô hình sθ(x(t),t)s_{\theta}(x(t),t) xấp xỉ xtlogpt(xt))\nabla_{x_t}\log p_t(x_t)).

Giải SDE

Quá trình lấy mẫu được thực hiện bằng cách giải phương trình SDE nghịch. Tương tự như khi rời rạc hóa quá trình thuận, chúng ta có thể giải bằng cách rời rạc hóa quá trình nghịch

xti=xti+1f(ti+1)(xti+1)+g(ti+1)2sθ(xti+1,ti+1)+g(ti+1)z,zN(0,I)x_{t_i}=x_{t_{i+1}}-f(t_{i+1})(x_{t_{i+1}})+g(t_{i+1})^2s_{\theta}(x_{t_{i+1}},t_{i+1})+g(t_{i+1})z,\,z\sim\mathcal{N}(0,I)

với 0=t0<t1<<tT1<tT=10=t_0<t_1<\cdots<t_{T-1}<t_T=1.

Quay lại với cách cập nhật của mô hình denoise diffusion, giả sử ta dùng σt2=βt\sigma_t^2=\beta_t

xt1=1αt(xtβt1αˉtϵθ(xt,t))+βtzx_{t-1} = \frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}(x_t,t)) + \sqrt{\beta_t}z

Đặt sθ(xt,t)=ϵθ(xt,t)1αˉts_{\theta}(x_t,t)=-\frac{\epsilon_{\theta}(x_t,t)}{\sqrt{1-\bar\alpha_t}}, ta có thể biến đổi như sau

xt1=11βt(xt+βts(xt,t))+βtz(1+12βt)(xt+βts(xt,t))+βtzkhai triển Taylor=(1+12βt)xt+βts(xt,t)+12βt2s(xt,t)+βtzxt+12βtxt+βts(xt,t)+βtz\begin{aligned} x_{t-1}&=\frac{1}{\sqrt{1-\beta_t}}(x_t+\beta_ts(x_t,t))+\sqrt{\beta_t}z\\ &\approx (1+\frac{1}{2}\beta_t)(x_t+\beta_ts(x_t,t)) +\sqrt{\beta_t}z\qquad \text{khai triển Taylor}\\ &=(1+\frac{1}{2}\beta_t)x_t+ \beta_ts(x_t,t)+\frac{1}{2}\beta_t^2s(x_t,t)+\sqrt{\beta_t}z\\ &\approx x_t+\frac{1}{2}\beta_tx_t+ \beta_ts(x_t,t)+\sqrt{\beta_t}z \end{aligned}

Quá trình nghịch của SDE ứng với mô hình diffusion là

dxt=(12β(t)xtβtxlogpt(xt))dt+β(t)dwdx_t=(-\frac{1}{2}\beta(t)x_t-\beta_t\nabla_x\log p_t(x_t))dt+\sqrt{\beta(t)}dw

Ta có thể thấy thuật toán lấy mẫu của mô hình denoise diffusion gần giống với việc giải quá trình nghịch thông qua rời rạc hóa.

Lấy mẫu với Predictor-Corrector

Ở phần trước, ta đã biết quá trình lấy mẫu có thể thực hiện bằng việc giải phương trình SDE nghịch, và thuật toán lấy mẫu của mô hình diffusion thuộc loại này. Mặt khác, ta đang mô hình score của pt(xt)p_t(x_t), do đó ta cũng có thể lấy mẫu với (annealed) Langevin dynamics.

Để có thể sinh dữ liệu tốt hơn, chúng ta có thể kết hợp hai phương pháp này. Lấy mẫu thông qua giải SDE sẽ được xem như thuật toán chính, gọi là Predictor. Ở bước thứ ii trong Predictor, sau khi cập nhật xTix_{T-i} qua xTi+1x_{T-i+1}, chúng ta sẽ thực hiện Langevin dynamics MM lần với s(xTi,Ti)s(x_{T-i},T-i)

xTi=xTi+ϵis(xTi,Ti)+2ϵiz,zN(0,I)x_{T-i}=x_{T-i}+\epsilon_i s(x_{T-i},T-i) + \sqrt{2\epsilon_i}z,\,z\sim\mathcal{N}(0,I)

Từ góc nhìn này, cách sinh dữ liệu của NCSN có thể xem như Predictor là hàm đồng nhất, Corrector là Langevin dynamics, cách sinh dữ liệu của mô hình denoise diffusion có thể xem như Predictor là giải quá trình nghịch, Corrector là hàm đồng nhất.

Huấn luyện

Tương tự như hàm mục tiêu của NCSN cũng như mô hình denoise diffusion, hàm mục tiêu của mô hình SDE sẽ có dạng score matching trên tất cả mức độ nhiễu. Điểm khác biệt là biến thời gian tt lúc này là biến ngẫu nhiên liên tục tuân theo phân bố đều U[0,1]\mathcal{U}[0,1]

L=Et[λ(t)Ex0,xt[logp(xtx0)sθ(xt,t)2]]L=\mathbb{E}_t[\lambda(t)\mathbb{E}_{x_0,x_t}[||\nabla\log p(x_t|x_0)-s_{\theta}(x_t,t)||^2]]

Ở đây λ(t)\lambda(t) là hàm trọng số, có thể chọn giống như NCSN và mô hình denoise diffusion là λ(t)1/E[logq(xtx0)2]\lambda(t)\propto1/\mathbb{E}[||\nabla\log q(x_t|x_0)||^2].

Việc tính hàm mất mát yêu cầu score của phân bố chuyển trong quá trình thuận. Đối với trường hợp SDE tổng quát, ta cần giải phương trình Kolmogorov tiến để tìm phân bố này. Khi f(x,t)=a(t)x+b(t)f(x,t)=a(t)x+b(t), phân bố chuyển là phân bố Gaussian, do đó chỉ cần biết kì vọng và phương sai để tính score. Kì vọng mtm_t và ma trận hiệp phương sai PtP_t sẽ thỏa mãn phương trình vi phân sau

dmtdt=a(t)mt+b(t)\frac{dm_t}{dt}=a(t)m_t+b(t)

dPtdt=2a(t)Pt+g(t)2\frac{dP_t}{dt}=2a(t)P_t+g(t)^2

Để tránh việc phải tính phân bố chuyển, chúng ta có thể dùng phương pháp score matching khác, ví dụ như sliced score matching, với hàm mục tiêu

L=Et[λ(t)Ex0ExtEv[12sθ(xt,t)2+vJs(.,t)(xt)v]]L=\mathbb{E}_t[\lambda(t)\mathbb{E}_{x_0}\mathbb{E}_{x_t}\mathbb{E}_v[\frac{1}{2}||s_{\theta}(x_t,t)||^2+v^\intercal J_{s(.,t)}(x_t)v]]

với Js(.,t)(xt)J_{s(.,t)}(x_t) là ma trận Jacobian của s(xt,t)s(x_t,t), vJs(.,t)(xt)vv^\intercal J_{s(.,t)}(x_t)v tính bởi v(vs(xt,t))v^\intercal\nabla(v^\intercal s(x_t,t)).

Kết luận

Trong bài này, mình đã giới thiệu mô hình diffusion và mô hình dạng SDE tổng quát mà trong đó score matching và mô hình diffusion là trường hợp đặc biệt. Cách tiếp cận này hiện đã cho kết quả tốt nhất hiện tại cho mô hình sinh.

Tuy nhiên cách tiếp cận này có các nhược điểm sau: Các trạng thái có cùng số chiều, do đó việc mô hình quá trình nghịch cần đảm bảo điều đó chứ không thể thay đổi số chiều dữ liệu. Việc lấy mẫu tốn khá nhiều thời gian, do cần phải đi từng bước để giải phương trình SDE nghịch, chưa tính đến việc kết hợp với Corrector trong quá trình lấy mẫu.

Tham khảo

Một số định nghĩa và chứng minh chi tiết

Công thức các phân bố trong quá trình thuận của mô hình diffusion

Tính chất: Với q(xtxt1)=N(xt;αtxt1,(1αt)I)q(x_t|x_{t-1})=\mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1}, (1-\alpha_t)I), ta có q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_0, (1-\bar{\alpha_t})I),

trong đó αˉt=i=1tαi\bar\alpha_t=\prod_{i=1}^t\alpha_i.

Chứng minh:

Quá trình Markov thỏa mãn tính chất sau

Mệnh đề: Với t1>t2>t3t_1>t_2>t_3, xác suất chuyển thỏa mãn phương trình Chapman-Kolmogorov

pt3t1(xt1xt3)=pt3t2(xt2xt3)pt2t1(xt1xt2)dxt2p_{t_3t_1}(x_{t_1}|x_{t_3})=\int p_{t_3t_2}(x_{t_2}|x_{t_3})p_{t_2t_1}(x_{t_1}|x_{t_2})dx_{t_2}

Tính chất trên có thể chứng minh dễ dàng bằng tính chất Markov.

Chúng ta chỉ cần chứng minh cho t=2t=2, các trường hợp còn lại có thể suy ra theo quy nạp. Hơn nữa, ma trận hiệp phương sai có dạng βtI\beta_tI, do đó ta chỉ cần chứng minh cho trường hợp xRx\in\mathbb{R}.

Từ phương trình Chapman-Kolmogorov, ta có

q(x2x0)=q(x2x1)q(x1x0)dx1=12(1α1)(1α2)πexp((x2α2x1)22(1α2))exp((x1α1x0)22(1α1))dx1=12(1α1)(1α2)πexp(12((x2α1α2x0)21α1α2+(1α1α2)(x1α2(1α1)x2+α1(1α2)x01α1α2)2(1α1)(1α2)))dx1=12π(1α1α2)exp(12(x2α1α2x0)21α1α2).\begin{aligned} q(x_2|x_0) &= \int q(x_2|x_1)q(x_1|x_0)dx_1\\ &=\frac{1}{2\sqrt{(1-\alpha_1)(1-\alpha_2)}\pi}\int \exp(-\frac{(x_2-\sqrt\alpha_2x_1)^2}{2(1-\alpha_2)})\exp(-\frac{(x_1-\sqrt\alpha_1x_0)^2}{2(1-\alpha_1)})dx_1\\ &=\frac{1}{2\sqrt{(1-\alpha_1)(1-\alpha_2)}\pi}\int\exp(-\frac{1}{2}(\frac{(x_2-\sqrt{\alpha_1\alpha_2}x_0)^2}{1-\alpha_1\alpha_2} +\frac{(1-\alpha_1\alpha_2)(x_1-\frac{\sqrt\alpha_2(1-\alpha_1)x_2+\sqrt\alpha_1(1-\alpha_2)x_0}{1-\alpha_1\alpha_2})^2}{(1-\alpha_1)(1-\alpha_2)}))dx_1\\ &=\frac{1}{\sqrt{2\pi(1-\alpha_1\alpha_2)}}\exp(-\frac{1}{2}\frac{(x_2-\sqrt{\alpha_1\alpha_2}x_0)^2}{1-\alpha_1\alpha_2}). \end{aligned}

\square

Tính chất: q(xt1xt,x0)=N(xt1;αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt,β~tI)q(x_{t-1}|x_t, x_0) =\mathcal{N}(x_{t-1}; \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 +\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t, \tilde\beta_tI), với β~t=1αˉt11αˉtβt\tilde \beta_t=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t

Chứng minh: Tương tự như trên, chúng ta cũng chỉ cần chứng minh cho trường hợp xRx\in\mathbb{R}.

q(xt1xt,x0)=q(xtxt1)q(xt1x0)q(xtx0)=(2πβt)1/2(2π(1αˉt1))1/2(2π(1αˉt))1/2exp(xtαtxt122βtxt1αˉt1x022(1αˉt1)+xtαˉtx022(1αˉt))=(2πβ~t)1/2exp(1β~txt1αˉt1βt1αˉtx0αt(1αˉt1)1αˉtxt2).\begin{aligned} q(x_{t-1}|x_t, x_0)&=\frac{q(x_t|x_{t-1})q(x_{t-1}|x_0)}{q(x_{t}|x_0)}\\ &=(2\pi\beta_t)^{-1/2}(2\pi(1-\bar\alpha_{t-1}))^{-1/2}(2\pi(1-\bar\alpha_t))^{1/2}\\ &\quad\exp\left(-\frac{||x_t-\sqrt{\alpha_t}x_{t-1}||^2}{2\beta_t}-\frac{||x_{t-1}-\sqrt{\bar\alpha_{t-1}}x_0||^2}{2(1-\bar\alpha_{t-1})}+\frac{||x_t-\sqrt{\bar\alpha_t}x_0||^2}{2(1-\bar\alpha_t)}\right)\\ &=(2\pi\tilde\beta_t)^{-1/2}\exp\left(-\frac{1}{\tilde\beta_t}||x_{t-1}-\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 -\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t||^2\right). \end{aligned}

\square

Ràng buộc của entropy H(q(xt1xt))H(q(x_{t-1}|x_t))

Phần này sẽ chỉ ra chặn trên và chặn dưới của H(q(xt1xt))H(q(x_{t-1}|x_t)) bởi entropy của xác suất chuyển trong quá trình thuận.

Trước hết ta có

H(xt1xt)=H(xtxt1)+H(xt1)H(xt)H(x_{t-1}|x_t)=H(x_t|x_{t-1})+H(x_{t-1})-H(x_t)

Do xtx_t tính từ xt1x_{t-1} và nhiễu zz, H(xt)H(xtz)H(xt1z)H(xt1)H(x_t)\geq H(x_{t}|z)\geq H(x_{t-1}|z)\geq H(x_{t-1}). Suy ra H(xt1xt)H(xtxt1)H(x_{t-1}|x_t)\leq H(x_t|x_{t-1}). Nếu mô hình xác suất trong quá trình nghịch của xt1x_{t-1} khi biết xtx_t bởi N(xt1;μ(xt),σt2I)\mathcal{N}(x_{t-1};\mu(x_t),\sigma_t^2I), dấu bằng xảy ra khi σt2=βt\sigma_t^2=\beta_t.

Đối với chặn dưới, ta có H(x0xt)H(x0xt1)H(x_0|x_t)\geq H(x_0|x_{t-1}), suy ra

H(xt1)H(xt)H(x0,xt1)H(x0,xt)=H(xt1x0)H(xtx0)H(x_{t-1})-H(x_t)\geq H(x_0,x_{t-1})-H(x_0,x_t)=H(x_{t-1}|x_{0})-H(x_t|x_0)

Do đó

H(xt1xt)H(xtxt1)+H(xt1x0)H(xtx0)=H(xtx0,xt1)+H(xt1x0)H(xtx0)=H(xt1x0,xt)\begin{aligned} H(x_{t-1}|x_t)&\geq H(x_t|x_{t-1})+H(x_{t-1}|x_{0})-H(x_t|x_0)\\ &=H(x_t|x_0,x_{t-1})+H(x_{t-1}|x_{0})-H(x_t|x_0)\\ &=H(x_{t-1}|x_0,x_t) \end{aligned}

Dấu bằng xảy ra khi σt2=β~t\sigma_t^2=\tilde \beta_t.

Một số tính chất của phương trình vi phân ngẫu nhiên

Phương trình vi phân ngẫu nhiên Itô với điều kiện đầu x0=xx_0=x

dxt=f(xt,t)dt+g(t)dwdx_t=f(x_t,t)dt+g(t)dw

là biểu diễn hình thức của phương trình tích phân sau

xt=x+0tf(xt,t)dt+0tg(t)dwx_t=x+\int_0^tf(x_t,t)dt+\int_0^tg(t)dw

Tích phân đầu tiên là tích phân Riemann-Stieltjes thông thường. Tuy nhiên ta không thể tính tích phân thứ hai như vậy, do chuyển động Brown không thỏa mãn tính chất bounded variation. Thay vào đó, ta sẽ sử dụng tích phân Itô để tính đại lượng này

0tg(t)dw=limKk=0K1g(tk)(wtk+1wtk)\int_0^tg(t)dw=\lim_{K\to \infty}\sum_{k=0}^{K-1}g(t_k)(w_{t_{k+1}}-w_{t_k})

với tk=kΔt,t=KΔtt_k=k\Delta t, t=K\Delta t.

Từ đây chúng ta có tính chất sau

k=0K1E[g(tk)(wtk+1wtk)]=k=0K1E[g(tk)]E[wtk+1wtk]=0\sum_{k=0}^{K-1}\mathbb{E}[g(t_k)(w_{t_{k+1}}-w_{t_k})]=\sum_{k=0}^{K-1}\mathbb{E}[g(t_k)]\mathbb{E}[w_{t_{k+1}}-w_{t_k}]=0

theo định nghĩa của chuyển động Brown, do đó

E[0tg(t)dw]=0\mathbb{E}[\int_0^tg(t)dw]=0

Với quá trình ngẫu nhiên xtx_t và một hàm tất định u(x,t):Rd×R+Ru(x,t):\mathbb{R}^d\times \mathbb{R}^+\mapsto\mathbb{R}, chúng ta cũng không thể tính đạo hàm toàn phần du(xt,t)dt\frac{du(x_t,t)}{dt} bằng chain rule như thông thường, thay vào đó chúng ta sẽ dùng công thức Itô

du(xt,t)=u(xt,t)tdt+u(xt,t)dxt+12(dxt)Hxu(xt,t)dxtdu(x_t,t)=\frac{\partial u(x_t,t)}{\partial t}dt+\nabla u(x_t,t)^\intercal dx_t+\frac{1}{2}(dx_t)^\intercal H_xu(x_t,t)dx_t