Cài đặt bộ lọc Random Forest để giải bài toán OCR trong môi trường Ruby

Bài viết này sẽ hướng dẫn cách cài đặt giải bài toán OCR (optical character recognition - nhận diện chữ viết) sử dụng thuật toán phân loại Random Forest bằng Ruby. Bộ dữ liệu sử dụng trong bài là bộ dữ liệu chuẩn MNIST cho chữ số viết tay và thuật toán Random Forest được cài đặt sẵn trong thư viện học máy sci-kit viết bằng Python. Bạn sẽ thấy việc sử dụng các công cụ tính toán khoa học vào Ruby có thể được thực hiện rất dễ dàng.

Code đầy đủ của chương trình có thể được tìm thấy ở đây

Bộ dữ liệu

Bộ dữ liệu được sử dụng để huấn luyện là bộ MNIST về chữ số viết tay. Bộ dữ liệu này là một bộ dữ liệu con trong một bộ dữ liệu lớn về chữ số viết tay của Viện Tiêu chuẩn và Kỹ thuật quốc gia Hoa Kỳ(NIST). Ảnh trong bộ dữ liệu MNIST đã được chuẩn hóa về kích thước và căn chỉnh giữa để có thể sử dụng trực tiếp mà không cần thực hiện tiền xử lý.

Tập dữ liệu bao gồm 60.000 ảnh riêng biệt của các chữ số trong tập training, và 10.000 ảnh trong tập test.

Mỗi ảnh đều có kích thước 28x28, đã được grayscale, giống như một vài ví dụ dưới đây.

Thuật toán Random Forest

Thuật toán Random Forest là một thuật toán học máy có thể sử dụng để giải cả bài toán phân loại(classification) và hồi quy(regression). Nó làm việc bằng cách xây dựng một tập hợp các cây quyết định trong quá trình training, sau đó kết hợp kết quả trả về của mỗi cây để đưa ra quyết định dự đoán. cuối cùng. Thuật toán Random Forest có một số lợi thế như khả năng miễn nhiễm với những dữ liệu vô nghĩa, các đặc trưng không quan trọng và nhầm lẫn trong dữ liệu đầu vào.

Thuật toán được mô tả khác trực quan và sinh động ở video dưới đây:

Cài đặt thuật toán Random Forest trong Ruby

Không may là không có nhiều động lực và lý do để thực hiện cài đặt Random Forest bằng ngôn ngữ Ruby. Ngược lại có rất nhiều người đã thực hiện cài đặt bằng Python, và với việc sử dụng gem PyCall ta có thể sử dụng các thư viện Python trong Ruby mà không gặp trở ngại gì.

Bước đầu tiên ta cần phải cài đặt môi trường Python và thư viện sciki learn. Sau đó ta cài gem PyCall vào Ruby:

$ gem install pycall

Sau đó trong file code Ruby ta gọi đến thư viện PyCall và một class DatasetReader có tác dụng đọc dữ liệu ảnh:

require 'pycall/import'
require './dataset_reader.rb'

Ta import thuật toán RandomForestClassifier của scikit learn sử dụng PyCall:

include PyCall::Import
pyfrom :'sklearn.ensemble', import: :RandomForestClassifier

Ta cần phải có cách để đọc dữ liệu vào Ruby. Ở đây ta dùng class DatasetReader. Class này bao gồm hai method chính: read_labels và read_images, các method này sẽ đọc file dữ liệu mẫu và trả về Ruby array.

Bộ dữ liệu được đọc dưới đây:

test_labels = DatasetReader.read_labels( "data/t10k-labels.idx1-ubyte" )
test_images = DatasetReader.read_images( "data/t10k-images.idx3-ubyte" )
rows, columns = DatasetReader.read_rows_columns( "data/t10k-images.idx3-ubyte" )
puts "Labels: #{test_labels.size}, Images: #{test_images.size}, Rows: #{rows}, Columns: #{columns}"
train_labels = DatasetReader.read_labels( "data/train-labels.idx1-ubyte" )
train_images = DatasetReader.read_images( "data/train-images.idx3-ubyte" )
puts "Labels: #{train_labels.size}, Images: #{train_images.size}"

Với việc bộ dữ liệu đã được đọc vào ta có thể gọi thuật toán học máy và thực hiện "fit" với tập training:

# Initialize a RandomForestClassifier
clf = RandomForestClassifier.new()
# Fit with training data
clf.fit(train_images, train_labels)

Sau khi huấn luyện xong ta có thể thử với tập test và tính toán tỉ lệ chính xác:

# Score our fit using the test data
classification_score = clf.score(test_images,test_labels)
puts "Prediction score for Random Forest classifier #{(classification_score*100).round(2)}%"

Cuối cùng ta có thể thử chạy một dự đoán cụ thể trên tập test. Trong ví dụ này, ta sử dụng mẫu test thứ 8, và in ra tỉ lệ đoán có thể là 1 trong 10 chữ số (0, 1, 2, 3, 4, 5, 6, 7, 8, 9).

# Do a prediction for one sample
sample = 8
puts clf.predict([test_images[sample]])
puts clf.predict_proba([test_images[sample]])
puts "Correct label: #{test_labels[sample]}"

Sau khi chạy ta sẽ in ra được kết quả:

Labels: 10000, Images: 10000, Rows: 28, Columns: 28
Labels: 60000, Images: 60000
Prediction score for Random Forest classifier 95.06%
[5]
[[ 0. 0. 0.2 0. 0.2 0.6 0. 0. 0. 0. ]]
Correct label: 5

Ta có thể thấy rằng với việc sử dụng các tham số mặc định của skikit learn, thuật toán Random Forest đã cho độ chính xác lên đến 95.06% trên tập test.

Kết quả của mẫu số 8 được dự đoán là 5, với tỉ lệ 60% số cây quyết định là số 5, 20% là số 4 và 20% là số 3. Nhãn chính xác là 5.

Mẫu test số 8

Ta có thể thấy từ ảnh của mẫu test số 8 phía trên là nó khá giống số 5, mặc dù khá là khó cho máy tính có thể đoán được.

Bài hướng dẫn này cho thấy việc sử dụng thư viện scikit learn trong Ruby khả dễ dàng với gem PyCall. Và việc kết nối giữa Python và Ruby dễ dàng hơn ta tưởng.