Ứng dụng Deep Learning cho bài toán nhận diện chữ viết (OCR)

Introduction

Optical character recognition là một bài toán rất thú vị trong lĩnh vực Computer Vision, trên thực tế có rất nhiều hướng tiếp cận để giải quyết bài toán này. Một trong những phương pháp gây được sự chú ý trong thời gian gần đây đó là sử dụng Deep Learning kết hợp với hàm CTC loss. Ưu điểm của phương pháp này đó là chúng ta không cần bước segment ảnh ra như khi sử dụng các phương pháp Machine Learning truyền thống. Trong bài viết này mình sẽ hướng dẫn các bạn cách xây dựng 1 mô hình Deep Learning với CTC loss để giải quyết bài tóan OCR. Cụ thể trong bài viết này mình sẽ làm về bài toán nhận diện chữ in cho tiếng Anh. Bài viết này mình hướng đến đối tượng là những bạn đã có kiến thức cơ bản về Deep Learning, đặc biệt là 2 kiến trúc mạng Convolutional Neural Network (CNN), Recurrent Neural Network (RNN) và đã sử dụng tốt được thư viện Keras. Nếu như bạn là người mới bắt đầu với Deep Learning thì bạn nên học thêm về lý thuyết về CNN, RNN. Còn nếu bạn vẫn còn bỡ ngỡ với Keras thì các bạn có thể tham khảo series giới thiệu về Keras của mình tại link sau.

Before we start

Trước khi bắt đầu, các bạn clone repo này của mình về:

[email protected]:~/project$ git clone [email protected]:hoangdinhthoi95/OCR_demo.git

Tạo môi trưởng ảo với virtualenv:

[email protected]:~/project$ virtualenv -p python3 ~/ocr_demo_env
[email protected]:~/project$ source ~/ocr_demo_env/bin/activate

Cài đặt các thư viện cần thiết theo requirements.txt:

(ocr_demo_env) [email protected]:~/project$ cd OCR_demo/
(ocr_demo_env) [email protected]:~/project/OCR_demo$ pip install -r requirements.txt 

Create data for training & testing

Vấn đề muôn thuở với bất kì bài toán Machine Learning nào đó là dữ liệu để phục vụ quá trình training. Với ngôn ngữ là tiếng Anh, các bạn hoàn toàn có thể tìm kiếm được các dataset có sẵn, tuy nhiên trong bài viết này mình sẽ hướng dẫn các bạn sử dụng một công cụ giúp sinh dữ liệu một cách tự động và hoàn toàn theo ý muốn. Tool này mình được một người bạn giới thiệu cho, các bạn có thể xem thêm trên trang Github của tác giả. Một cách ngắn gọn thì tool cho phép bạn sinh ảnh từ một đoạn text tương ứng. Ngoài tiếng Anh, tool còn hỗ trợ sinh dữ liệu với một vài các ngôn ngữ khác như: tiếng Trung, tiếng Pháp, tiếng Nhật, etc. Tuy nhiên trong khuôn khổ bài viết này mình sẽ chỉ làm với tiếng Anh thôi, các bạn hoàn toàn có thể thử với các ngôn ngữ khác nếu muốn. Ví dụ, bạn có 1 đoạn text: Work and social obligations demand a portion of it., thì ảnh được generate tương ứng sẽ là:

Một ví dụ khác: To learn something well you need to study it for a while and then take a break.

Vấn đề tiếp theo là làm sao có được một từ điển gồm nhiều câu tiếng Anh có nghĩa để phục vụ quá trình sinh dữ liệu. Đợt vừa rồi trên Kaggle có một cuộc thì về phân loại câu hỏi trên trang Quora, data do Kaggle cung cấp có hơn 1 triệu câu hỏi nên mình quyết định sử dụng luôn tập dữ liệu này để sinh ảnh. Trong bài viết này, vì mục đích demo là chủ yếu, mình sẽ chỉ sử dụng hơn 21.000 câu để sinh ra 50.000 ảnh để vừa training và validate luôn. Các bạn có thể sinh ảnh với số lượng nhiều hơn tùy vào cấu hình phần cứng mà bạn có, tuy nhiên với số lượng ảnh lớn, quá trình training sẽ mất rất nhiều thời gian. Với data để test, mình lấy luôn 1 bài trên medium (khoảng 50 câu) để sinh dữ liệu. Ok bắt tay vào việc thôi nào.

Sau khi cài đặt các thư viện cần thiết xong, các bạn cd vào thư mục TextRecognitionDataGenerator

(ocr_demo_env) [email protected]:~/project/OCR_demo$ cd TextRecognitionDataGenerator/

Để sinh ảnh các bạn chỉ cần chạy file run.py và truyền các arguments tùy chỉnh. Có 1 vài arguments các bạn cần lưu ý:

  • --output_dir: thư mục chứa ảnh trong quá trình sinh dữ liệu.
  • --input_file: file text chứa các câu mà bạn muốn sinh ảnh.
  • --language: ngôn ngữ được sử dụng, cụ thể như sau: fr (French), en (English), es (Spanish), de (German), or cn (Chinese).
  • --count: lượng ảnh mà các bạn muốn tạo (có thể nhiều hơn số câu trong input_file do tool cho phép sử dụng nhiều fonts chữ và background khác nhau).
  • --thread_count: số luồng sử dụng (có hỗ trợ sinh ảnh đa luồng), các bạn setup giá trị tùy cấu hình máy.
  • --format: chiều cao của ảnh đầu ra, tham số này sẽ được fix cố định, còn chiều dài cuẩ ảnh sẽ phụ thuộc vào chiều dài của câu tương ứng.
  • --space_width: kích thước khoảng trắng giữa các kí tự (1.0 là khoảng cách tiêu chuẩn)
  • --length: số từ trong câu, tuy nhiên trong file dữ liệu của mình, mỗi dòng đã là 1 câu hoàn chỉnh rồi nên khi sinh ảnh chỉ cần chọn length = 1.

Tạo data train từ file quora.txt: 50K ảnh ứng với 21K câu hỏi (trung bình mỗi 1 câu sẽ có 2 ảnh với 2 fonts chữ khác nhau), chiều cao mỗi ảnh là 64 pixels.

(ocr_demo_env) [email protected]:~/project/OCR_demo/TextRecognitionDataGenerator$ python run.py \
> --output_dir ../data_quora \
> --input_file ../quora.txt \
> --language en \
> --count 50000 \
> --thread_count 6 \
> --format 64 \
> --space_width 0.7 \
> --length 1

Tạo data để test từ file medium.txt: 100 ảnh ứng với 48 câu văn (trung bình 1 câu sẽ có 2 ảnh với 2 fonts chữ khác nhau), chiều cao của ảnh vẫn là 64 pixels.

(ocr_demo_env) [email protected]:~/project/OCR_demo/TextRecognitionDataGenerator$ python run.py \
> --output_dir ../data_test \
> --input_file ../medium.txt \
> --language en \
> --count 100 \
> --thread_count 6 \
> --format 64 \
> --space_width0.95 \
> --length 1

Ok, như vậy là mình đã tạo xong data: 50K ảnh cho train và validata, 100 ảnh để anh em test thử, bắt tay vào xây dựng mô hình thôi nào.

Build & train the model

Với bài toán OCR, có rất nhiều hướng tiếp cận khác nhau. Trong bài viết này mình sẽ hướng dẫn các bạn xây dựng một mô hình Deep Learning với sliding window (cửa sổ trượt) và sử dụng CTC làm hàm loss. Ưu điểm của hướng tiếp cận này là, nhờ vào CTC, các bạn không cần phải segment ảnh thành các ảnh con của từng kí tự như các phương pháp truyền thống. Trong khuôn khổ của bài viết này mình sẽ không nói kĩ hơn về CTC, các bạn có thể tìm hiểu kĩ hơn về CTC trong paper này. Một cách tổng quát thì mô hình của mình sẽ sử dụng một mạng CNN như 1 bước trích chọn đặc trưng cơ bản (cụ thể thì mình dùng 1 biến thể của mạng Xception), sau đó sử dụng sliding window trích chọn đặc trưng cho từng vùng ảnh nhỏ liền kề để sinh feature, sau đó các feature với các time step liền kề nhau sẽ được đưa vào mạng RNN. Cuối cùng, ta sử dụng CTC làm hàm loss để tối ưu trong quá trình học.

Với CTC loss, các bạn cần có 4 input cho mạng: ảnh đầu vào, true label cho ảnh đó, số lượng time step của ảnh đầu vào khi đưa qua sliding window, chiều dài của true label, do đó khi build model ta cần định nghĩa 4 input layer. Phần code để tạo mô hình sẽ như sau, các bạn xem chi tiết trong file utils.py

def create_model(input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, NO_CHANNEL)):
    input_image = Input(shape=input_shape, name='input_image')
    input_true_label = Input(shape=(None,), name='input_true_label')
    input_time_step = Input(shape=(1,), name='input_time_step')
    input_label_length = Input(shape=(1,), name='input_label_length')
    # xception model
    # sliding window
    # RNN network
    # CTC loss
    ...
    model = Model([input_image, input_true_label,
                   input_time_step, input_label_length], loss_out)
    print(model.summary())
    return model

Hàm CTC loss (output của 2 time step đầu thường là rác nên sẽ bị bỏ đi):

def ctc_loss(args):
    y_pred, y_true, input_length, label_length = args
    y_pred = y_pred[:, 2:, :]
    return K.ctc_batch_cost(y_true, y_pred, input_length, label_length)

Sau khi định nghĩa mô hình, bước tiếp theo là viết 1 hàm để generate dữ liệu. Do ở đây mình có 50.000 ảnh với kích thước chiều cao 64 pixels, chiều rộng không cố định nên ta cần phải generate data theo từng batch và padding sao cho các ảnh trong cùng 1 batch có cùng kích thước chiều rộng. Mình sẽ viết 1 class có tên DataGenerator để làm việc này, chi tiết các bạn xem trong file utils.py

class DataGenerator():
    def __init__(self, train_image_list, val_image_list, batch_size=BATCH_SIZE):
    def load_image(self, image_path):
    def load_label_encoder(self):
    def get_batch(self, partition='train'):
    def next_train(self):
    def next_val(self):

Và bước cuối cùng trước khi bắt tay vào train model là encode các kí tự với class tương ứng, ví dụ: a -> 0, b -> 1, c -> 2, etc. Trong sk-learn hỗ trợ sẵn 1 class Label encoder để làm việc này, các bạn chỉ cần gọi từ thư viện ra và dùng thôi.

Để train mô hình các bạn chỉ cần chạy file train.py, mình lười viết arguments parser nên các bạn chịu khó sửa các thông số như: learning rate, số epochs, etc trong code trước khi chạy nha.😆😆😆

(ocr_demo_env) [email protected]:~/project/OCR_demo/TextRecognitionDataGenerator$ cd ..
(ocr_demo_env) [email protected]:~/project/OCR_demo$ python train.py

Máy mình chỉ có 1 GTX 1080Ti nên thời gian train hơi lâu, rơi vào khoảng 700s 1 epoch (cả train và validate). Trong bài toán OCR cho tiếng Anh, số lượng label khá ít (ít hơn rất nhiều so với tiếng Nhật, tiếng Trung) nên mô hình của mình hội tú khá nhanh (sau 3 epochs CTC loss đã giảm rất sâu) nên mình quyết đinh dừng train ngay sau 3 epochs.

Train on   40000 images
Validate on 10000 images
Epoch 1/200
2500/2500 [==============================] - 776s 310ms/step - loss: 3.5979 - val_loss: 0.1970

Epoch 00001: val_loss improved from inf to 0.19696, saving model to model/xception_model_0.19696.h5
Epoch 2/200
2500/2500 [==============================] - 713s 285ms/step - loss: 0.0980 - val_loss: 0.1063

Epoch 00002: val_loss improved from 0.19696 to 0.10629, saving model to model/xception_model_0.10629.h5
Epoch 3/200
2500/2500 [==============================] - 731s 292ms/step - loss: 0.0569 - val_loss: 0.0570

Epoch 00003: val_loss improved from 0.10629 to 0.05701, saving model to model/xception_model_0.05701.h5

Predict on new data

Sau bước training, ta đã có được 1 mô hình được huấn luyện rồi, bây giờ là lúc đem em nó đi predict trên dữ liệu mới. Trong phần trước, ta có tạo thêm 100 ảnh từ 1 bài viết trên medium. Các bạn nhớ sửa code chỗ MODEL_PATH = ... sao cho phù hợp với tên model được lưu trong máy của mình.

MODEL_PATH = 'model/xception_model_0.05701.h5'

Chạy thử file predict.py, ta có thể thấy mô hình dự đoán khá đúng với true label, có một vài trường hợp model dự đoán sai như dấu () bị nhầm thành các kí tự f, J do trong dữ liệu train không có label này 😂😂😂.

True label: As long as youre occupying your mind with (mostly) high quality content whats the harm?
Predicted : As long as youre occupying your mind with fmostlyJ high quality content whats the harm?
_________________________________________________________________________________________________________
True label: How much free time? That depends.
Predicted : How much free time? That depends.
_________________________________________________________________________________________________________
True label: Work and social obligations demand a portion of it.
Predicted : Work and social obligations demand a portion of it.
_________________________________________________________________________________________________________
True label: Confronted later with the same labyrinth the rats find their way through it more quickly.
Predicted : Confronted later with the same labyrinth the rats find their way through it more quickly.

Conclusion

Như vậy trong bài viết này mình đã hướng dẫn các bạn xây dựng cũng như huấn luyện 1 mô hình Deep Learnanh, kết hợp với CTC loss trong bài toán nhận diện kí tự chữ in cho tiếng anh. Các bạn có thể mở rộng bài toán trên cho các ngôn ngữ khác hoặc làm với dữ liệu dạng viết tay. Rất hy vọng bài viết này có thể giúp được cho những bạn đang muốn làm về OCR, hẹn gặp lại các bạn trong các bài viết tiếp theo.