+6

Giới thiệu về Connectionist Temporal Classification (CTC) (Phần 1)

Giới thiệu

Trong nhận diện giọng nói hay nhận diện chữ viết tay, đầu ra sẽ là một câu, nhưng chưa hoàn chỉnh vì có các ký tự lặp lại như "heelllo", "toooo", ... hay các chữ có những khoảng trống (blank) như "he l l oo", "t o o", ... . Nguyên nhân dẫn tới những hiện tượng này là giọng nói dài (các đoạn ngân nga trong các bài hát, ...), giọng bị ngắt quãng, kích thước chữ viết tay lớn, nhỏ, ...

Do đó, để cho ra được một câu hoàn chỉnh thì ta cần phải căn chỉnh lại đầu ra ấy, loại bỏ các ký tự lặp lại và khoảng trống. Vấn đề này được gọi là alignment problem và nó được giải quyết bằng CTC.

Đầu tiên ta sẽ nói về temporal classification.

Temporal Classification là gì?

Gọi SS là tập huấn luyện gồm các mẫu từ một phân phối cố định DX×ZD_{X \times Z}, trong đó:

  • X=(Rm)X = (\mathbb{R}^m)^* là một tập gồm tất cả các chuỗi (sequence) của các vectors số thực có độ dài là mm.

  • Z=LZ = L^* là một tập gồm tất cả các chuỗi chỉ có các ký tự alphabet có hạn LL của các labels. Các phần tử của LL^* còn được gọi là label sequences hay labellings.

Mỗi mẫu trong SS bao gồm một cặp chuỗi (x,z)(x,z). Chuỗi mục tiêu (target sequence) z=(z1,z2,...,zU)z = (z_1, z_2, ..., z_U) có độ dài gần như chuỗi đầu vào (input sequence) x=(x1,x2,...,xT)x = (x_1, x_2, ..., x_T), tức là UTU \le T.

Mục tiêu của ta là dùng SS để huấn luyện bộ temporal classifier h:XZh: X \mapsto Z để phân loại các chuỗi đầu vào chưa thấy trước đó (previously unseen input sequences) sao cho tối thiểu hóa một lượng mất mát (error measure) nào đó.

Một trong các error measure là label error rate (LER) được tính bằng trung bình của edit distance của đầu ra dự đoán h(x)h(x) và nhãn zz trong tập SS' lấy từ SS:

LER(h,S)=1S(x,z)SED(h(x),z)z \text{LER}(h, S') = \frac{1}{\lvert S' \lvert} \sum_{(x,z) \in S'} \frac{\text{ED}(h(x), z)}{\lvert z \lvert}

Edit distance ED(p,q)\text{ED}(p,q) là số nhỏ nhất của số thêm vào (insertions), số thay thế (substitutions) và số xóa đi (deletions) để chuyển từ pp sang qq.

Vậy, temporal classifier hh là một mô hình phân loại nào đó mang tính thời gian (temporal). Và cũng vì thế, ở đây ta sẽ lấy RNN làm một temporal classifier cho gần gũi (thật ra ta có thể lấy bất kỳ model nào cho ra output theo thời gian).

Thế thì, connectionist ở đâu ra? Để trả lời cho điều đó, ta hãy qua cách thức hoạt động của CTC.

CTC hoạt động như thế nào?

Trước tiên, ta cần phải xác định đầu ra của RNN sao cho RNN là một temporal classifier.

Đầu ra của CTC Network: RNN

Như ta đã nói ở trên, LL bao gồm các ký tự trong bảng chữ cái của ngôn ngữ ta đang sử dụng, ZZ là tất cả các labellings có thể có.

Đầu ra của một CTC Network là kết quả của lớp softmax có số unit bằng số ký tự trong LL cộng thêm một ký tự trống (blank), nghĩa là xác suất phân loại của các ký tự trong L+1|L| + 1 tại một thời điểm nhất định.

Như vậy, tổ hợp của tất cả các đầu ra theo thời gian là ZZ. Dưới đây là hình minh họa.

Xác suất của một labelling

Xác suất của một labelling là tổng xác suất của tất cả các alignment cho ra labelling đó.

Gọi yy là một chuỗi của đầu ra của network, trong đó ykty_k^t là kết quả của unit (ký tự - label) kk tại thời điểm tt. Khi đó ykty_k^t thuộc phân phối trên tập LTL'^T của các chuỗi có độ dài TT trên tập chữ cái L=L{blank}L' = L \cup \{\text{blank}\}. Và xác suất của một alignment như sau:

p(πx)=t=1Tyπtt,xLTp(\pi \lvert x) = \prod_{t=1}^T y_{\pi_t}^t, \forall x \in L'^T

Trong đó π\pipath hay còn gọi là alignment. Và công thức ở trên giả định rằng đầu ra của network tại các thời điểm là độc lập với nhau.

Tiếp theo, ta định nghĩa một many-to-one map B:LTLTB: L'^T \mapsto L^{\le T} bằng cách loại bỏ các ký tự blank và ký tự lặp lại (B(abb)=B(aabbb)=abb)(B(a-b-b) = B(aa-bb-b) = abb) và dùng BB để tính xác suất của một labelling ll trong ZZ:

p(lx)=πB1(l)p(πx)p(l \lvert x) = \sum_{\pi \in B^{-1}(l)} p(\pi \lvert x)

Xây dựng bộ phân loại

Từ công thức ở trên, đầu ra của bộ phân loại sẽ là labelling có vẻ đúng nhất.

h(x)=argmaxlLTp(lx)h(x) = arg \max_{l \in L^{\le T}} p(l \lvert x)

Vậy làm sao để tìm h(x)h(x)?

Trong bài báo này, họ đã đưa ra hai phương pháp:

  • Best path decoding: h(x)=B(π)h(x) = B(\pi^\star) trong đó π=argmaxπNtp(πx)\pi^\star = arg \max_{\pi \in N^t} p(\pi \lvert x) là sự kết hợp của các unit có xác suất cao nhất của mỗi time-step, do đó không đảm bảo sẽ tìm thấy labelling đúng nhất.
  • Prefix search decoding (PSD): phương pháp này dựa trên forward-backward algorithm, nếu có đủ thời gian, PSD sẽ luôn tìm thấy labelling phù hợp nhất, nhưng số lượng prefix tối đa sẽ tăng theo hàm mũ, phức tạp nên phải áp dụng heuristic.

Xây dựng hàm mất mát (CTC Loss function)

Ta sẽ xây dựng hàm mất mát để có thể train bằng gradient descent. Hàm mất mát được lấy theo maximum likelihood. Nghĩa là khi tối thiểu hóa nó thì sẽ cực đại hóa log likelihood.

Như vậy hàm mất mát (hàm mục tiêu) sẽ là negative log likelihood:

OML(S,Nw)=(x,z)Sln(p(lx))O^{ML}(S, N_w) = - \sum_{(x,z) \in S} ln(p(l \lvert x))

Tổng kết

Như vậy, một CTC Network chẳng qua là một network phân loại thông thường có output theo thời gian (temporal classifier), ta tính toán xác suất của các alignments bằng cách connect các xác suất của output của các thời điểm lại với nhau và chọn alignment phù hợp nhất, tính sai số của nó và cho network học lại. Vì thế họ gọi là Connectionist Temporal Classification.

Cách áp dụng các giải thuật decoding để tìm h(x)h(x) và tính hàm mất mát thì mình sẽ nói trong phần 2.

Tham khảo

  1. https://www.cs.toronto.edu/~graves/icml_2006.pdf
  2. https://towardsdatascience.com/intuitively-understanding-connectionist-temporal-classification-3797e43a86c

All rights reserved

Viblo
Hãy đăng ký một tài khoản Viblo để nhận được nhiều bài viết thú vị hơn.
Đăng kí