+52

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này

  • Tổng quan về GNN, GCN
  • Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh
  • Mô hình GNN
  • Tập dữ liệu hóa đơn - SROIE / ICDAR 2019
  • Invoice-GCN
  • Huấn luyện mô hình với GCN
  • 1 số cách tiếp cận và hướng phát triển khác cho bài toán KIE
  • Kết luận
  • Tài liệu tham khảo

  • UPDATED (31-10-2021): link project về trích rút thông tin từ hóa đơn với GCN, tập dữ liệu sử dụng là tập hóa đơn của Việt Nam MC-OCR: https://github.com/huyhoang17/KIE_invoice_minimal

  • Kết quả cuối cùng thu được cho bài toán trích rút thông tin từ hóa đơn, với các text box màu đỏ thể hiện các thực thể (entity) mà mô hình phân loại được với nhãn tương ứng (company, address, date, total)

Imgur

Imgur

Tổng quan về GNN, GCN

  • Trong năm 2020, cùng với Transformer, GNN hay Graph Neural Network nhận được nhiều sự chú ý và quan tâm hơn từ cộng đồng. Gần đây, mình có viết 1 bài review về các phương pháp và hướng ứng dụng của Graph Neural Network, các bạn có thể tham khảo thêm tại đường link sau: [Deep Learning] Graph Neural Network - A literature review and applications

  • 1 số lớp bài toán điển hình của GNN bao gồm:

    • Node classification
    • Graph classification
    • Link prediction
    • Graph clustering
    • ...
  • Trong thực tế, GNN có thể được áp dụng vào nhiều kiểu dữ liệu và bài toán khác nhau. 1 số ứng dụng như:

    • Các hệ thống gợi ý sản phẩm, 2 ví dụ tiêu biểu nhất là mô hình PinSage (Pinterest) và UberEat (Uber)
    • Decagon, mô hình giúp chẩn đoán các tác dụng phụ của thuốc, hay side-effect khi sử dụng chung nhiều loại thuốc với nhau
    • Feature Matching, liên quan tới lớp bài toán Deep Graph Matching, 1 mô hình SOTA trong thời gian gần đây có thể kể tới như SuperGlue
    • Text Classification, bài toán phân loại văn bản khá quen thuộc khi có thể kết hợp sử dụng thêm module GNN để extract được các mối liên hệ giữa word và document, 1 số mô hình tiêu biểu như: TextGCN, HeteGCN, Every Document Owns Its Structure, ...
    • NLP, các bài toán có mối liên hệ giữa các từ như Dependency Parsing, Relation Extraction, 1 số paper có thể kể tới như: SpanBERT, GraphREL, ...
    • Computer Vision: GNN cũng được sử dụng trong các bài toán về xử lý ảnh như: Image Segmentation, Scene Text Detection, Scene Graph Generation, Pose Estimation, ...
    • Reinforcement Learning: 1 bài toán điển hình là Goal-Directed Generation, với việc hình thành các cấu trúc phân tử dựa trên 1 số mục tiêu và điều kiện / quy tắc cho trước. Hay 1 hướng nghiên cứu được quan tâm gần đây như việc kết hợp reinforcement learning và GNN cho bài toán Recommender System.
    • Key Information Extraction: bài toán trích rút thông tin từ văn bản, cũng sẽ là chủ đề được mình đề cập tới trong bài blog lần này.
    • .. và còn rất nhiều các hướng phát triển và ứng dụng khác nữa..

Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh

  • Bài toán Information Extraction là 1 bài toán không mới, nhưng trong bài hướng dẫn này, dạng dữ liệu mà mình muốn hướng tới là hóa đơn (invoice). Nhiệm vụ đặt ra là làm sao phân loại được các text box vào các trường thông tin tương ứng, bao gồm: company (tên cty, nhà phân phối sản phẩm), address (địa chỉ), date (ngày giao dịch), total (tổng giá tiền) và other (không thuộc 4 trường trên).

  • Có 1 chú ý rằng bài toán này được thực hiện với 1 yêu cầu rằng cần thực hiện 2 bài toán con trước đó là Scene Text Detection và Scene Text Recognition. Đầu ra của 2 bài toán này sẽ được sử dụng để xây dựng các feature và đồ thị cho bài toán thứ 3 là Key Information Extraction. Đầu vào của mô hình là ảnh, đầu ra ứng với mỗi text box sẽ được phân loại thuộc 4 trường thông tin tương ứng.

  • Thực ra, với bài toán trích rút thông tin từ ảnh này, ta hoàn toàn có thể sử dụng các phương pháp dễ tiếp cận và quen thuộc hơn như: template-based hoặc NLP-based. Tuy nhiên, mỗi phương pháp đều có những hạn chế tương ứng:

    • Template-based: đơn giản là việc áp dụng các rule (luật), được định nghĩa từ trước lên các form, văn bản có layout / structure cố định, không thay đổi nhiều. Tiếp đó, sử dụng các phương pháp về text / keyword matching để xác định các trường thông tin tương ứng. Tuy nhiên, nhược điểm lớn nhất của phương pháp này là chúng ta phải định nghĩa từng luật riêng ứng với từng form, không có khả năng adapt sang dạng form mới và bị phụ thuộc hoàn toàn vào domain knowledge của từng người.
    • NLP-based: với phương pháp này, các nội dung thu được từ text-box có thể đưa vào 1 mô hình text classification hoặc NER để tiến hành phân loại hoặc xác định các thực thể thuộc từng trường thông tin tương ứng. Ưu điểm của phương pháp này so với Template-based là có khả năng adapt được với dữ liệu mới. Tuy nhiên, 1 số nhược điểm có thể kể tới như: bị phụ thuộc rất nhiều vào layout của form, hạn chế với dữ liệu được biểu diễn dưới dạng bảng / table, hoàn toàn không sử dụng các thông tin / feature về vị trí của text-box, cho dù các thông tin về layout như vậy cũng sẽ giúp ích rất nhiều trong việc xác định các trường tương ứng.
  • Việc thay thế và áp dụng graph-based method cho bài toán này đến từ 1 số lý do sau:

    • Local pattern: tương tự như mô hình CNN, nhưng thay vì là các điểm pixel, các node có kết nối với nhau cũng sẽ có mối liên hệ cao hơn với các node xa hơn trong đồ thị.
    • Positional feature: các thông tin về vị trí / tọa độ của nút trên ảnh cũng sẽ giúp mô hình dễ dàng phân biệt các trường thông tin hơn. Ví dụ như thông tin về tên của siêu thị / cửa hàng thực phẩm thường được ghi ngay trên đầu của hóa đơn
    • Textual feature: tương tự như positional feature, các thông tin về text cũng rất quan trọng. Ví dụ như việc phân biệt trường thông tin address với các trường dữ liệu khác
    • Việc stack nhiều các module GCN lên nhau giúp model học được các high level feature tốt hơn

Mô hình hóa bài toán với GCN

  • Trong phần này, để dễ hình dung, mình sẽ đề cập tới phương pháp trong paper Invoice-GCN, bao gồm các bước về Feature Engineering, Graph Modeling và Model Training. 2 phần code mẫu thực hiện bởi Pytorch và Pytorch-Geometric sẽ được đề cập tại các phần bên dưới

Invoice GCN / An Invoice Reading System Using a Graph Convolutional Network

  • Paper đầu tiên sử dụng GCN để trích rút các trường thông tin từ tập dữ liệu hóa đơn. Các điểm chính trong paper bao gồm:
    • Graph Modeling: cách thức xây dựng graph dựa trên các bounding box text đã được OCR
    • Feature Engineering: cách xây dựng các đặc trưng ban đầu ứng với từng nút
    • Mô hình hóa với Chebyshev GCN và GCN model
    • Dataset, Experiment setup

Feature Engineering

  • Khi áp dụng các mô hình về GNN cho từng bài toán riêng biệt, điều đầu tiên ta cần quan tâm là làm thế nào để biểu diễn dữ liệu hay các feature dữ liệu dưới dạng đồ thị để đưa vào mô hình GCN sau này. 1 ví dụ đơn giản với tập dữ liệu CORA dataset, tập dữ liệu về academic paper thuộc 7 class. Các node của đồ thị là các paper, các cạnh thể hiện việc giữa 2 paper có cite lẫn nhau, ta chỉ xét đơn giản với đồ thị vô hướng. Các node feature của đồ thị được xây dựng khá đơn giản khi ứng với 1 node (paper), node feature sẽ được thể hiện bởi 1 vector 1433 chiều, ứng với index của 1433 từ hay xuất hiện trong vocab. Vector thu được là 1 binary vector (0 và 1), với 1 thể hiện rằng 1 từ có xuất hiện trong paper và ngược lại.

  • Còn đối với dữ liệu hiện tại là ảnh hóa đơn thì chúng ta sẽ encode graph dựa trên các thông tin sau:

    • Các bounding box ứng với từng dòng text của ảnh. Phần text detection này có thể sử dụng các mô hình Object Detection phổ biến hoặc dùng các mô hình chuyên biệt cho bài toán Scene Text Detection như: CTPN, EAST, Differentiable Binarization, CRAFT,...
    • Nội dung của các text box đó. Phần text recognition này có thể sử dụng các tool mì ăn liền như Tesseract hoặc các mô hình về Scene Text Recognition như: CRNN-CTC loss, Attention-OCR,...
  • Ví dụ 1 ảnh sau khi thực hiện qua 2 bước detect và OCR (ảnh demo mình sử dụng từ trang: https://clova.ai/ocr của clova-AI)

Imgur

  • Phần pipeline của mô hình được mô tả như sau

Imgur

  • Phần Word Extractor bao gồm 2 phần là Text Detection và OCR như bên trên mình đã có đề cập.
  • Phần Feature Calculator hay feature engineering, là phần tạo feature cho các node (nút, đỉnh) trên đồ thị. Các nút ở đây chính là các bounding box thu được sau bước text detect. Việc định nghĩa các cạnh của graph thuộc phần Graph Modeler, sẽ được mình đề cập kĩ hơn ở bên dưới. Giờ việc cần làm tiếp theo là xây dựng feature ban đầu cho các nút của đồ thị. Có nhiều cách nhưng mình sẽ ví dụ cách thực hiện trong paper Invoice-GCN, với việc xây dựng và tổng hợp feature từ nhiều kiểu / thuộc tính khác nhau:
    • Boolean feature: dựa trên đầu ra từ mô hình text recognition, ta xây dựng các thuộc tính như:
      • isDate: có phải là ngày tháng ko (1 / 0)
      • isZipCode: 6 kí tự có thuộc 1 zipcode mã vùng có sẵn ko (1 / 0)
      • isKnownCity, isKnownDept, isKnownCountry: lần lượt kiểm tra xem phần nội dung text có phải là tên của cục, sở, thành phố hay đất nước nào không (1 / 0)
      • nature: gồm 8 phần tử, lần lượt kiểm tra các thuộc tính bao gồm: isAlphabetic, isNumeric, isAlphaNumeric, isNumberwithDecimal, isRealNumber, isCurrency, hasRealandCurrency, mix (except these categories), mixc (mix and currency word). Mình thực ra cũng không hiểu hết các thuộc tính mang ý nghĩa gì, trong paper cũng không đề cập rõ nhưng với nature feature, ta sẽ thu được 1 binary vector 8 chiều ứng với 8 thuộc tính con.
    • Numeric feature: khoảng cách tương đối từ text box hiện tại tới 4 box tương ứng (Top, bottom, left, right). Việc xác định các box tương ứng sẽ được đề cập tại phần Graph Modeling.
    • Text feature: dựa trên đầu ra của mô hình text recognize, ta có thể sử dụng các mô hình thông dụng như Word2Vec, Glove,... để lấy vector embedding của từ (vì phần text detect trong paper là dựa theo word-based). Tuy nhiên, các mô hình này có hạn chế là OOV - Out of vocabulary hay các từ không xuất hiện lúc training sẽ không có embedding. Điều này có thể được cải thiện bằng cách sử dụng các phương pháp khác như: FastText hay BPE (InvoiceGCN). Hoặc nếu với 1 text line thì có thể sử dụng các mô hình về Bert-based để lấy embedding cho câu hiện tại!

==> Sau cùng, ta "nối" tất cả các thuộc tính đó lại và thu được 1 feature vector 317 chiều (1 + 1 + 3 + 8 + 4 + 300) làm node feature ban đầu ứng với từng nút (từng text box) trong graph!

Graph Modeling

  • Như bên trên mình có đề cập tới numeric feature, 4 features về tọa độ này được dựa trên vị trí tương đối với 4 text box trên, dưới, trái, phải như hình bên dưới

Imgur

  • Với 4 thông số RDLRD_L, RDRRD_R, RDTRD_T, RDBRD_B được tính toán như sau

Imgur

ví dụ với RDLRD_L sẽ được tính toán dựa trên tọa độ của các bounding box (output của model text detection), hay chính bằng khoảng cách từ bounding box Source tới bounding box Left, rồi chia cho độ rộng của ảnh, tương tự với các thông số các cũng như vậy.

Imgur

  • Ngoài ra, còn 1 điều chú ý khi xây dựng graph cho từng văn bản (invoice). Ví dụ như hình bên trên, các đường nối giữa các text box đã được thể hiện khá rõ ràng. Nhưng nếu để ý, các bạn có thể thấy rằng không có đường nối giữa 2 text box là anticipéle. Đơn giản vì việc xác định các text box nào được nối với nhau sẽ theo 1 số luật như sau:

    • Xét theo 4 phía (trên, dưới, trái, phải) và xác định các RDLRD_L, RDRRD_R, RDTRD_T, RDBRD_B tương ứng bằng việc chọn các text box có khoảng cách gần nhất
    • Thứ tự ưu tiên thực hiện sẽ từ trên xuống dưới, từ trái sang phải. Và 1 hướng chỉ có 1 đường nối với 1 text box khác! Như ví dụ bên trên, do text box anticipéfois đã kết nối với nhau từ trước nên sẽ không có đường nối giữa anticipéle nữa. Mặc dù 2 text box foisle đều nằm ngay dưới và có khoảng cách tới text box anticipé là ngang nhau.
  • Dưới đây là 1 ảnh minh họa thể hiện các node và edge trên 1 hóa đơn trong tập dữ liệu SROIE

Imgur

  • Tất nhiên, về luật để xây dựng graph cho từng văn bản không hề có hạn chế, hoàn toàn có thể thử nghiệm các cách xây dựng khác nhau, ví dụ việc text box anticipé sẽ có 2 đường nối tới 2 text box foisle, vì đều nằm ngay bên dưới và có khoảng cách tới text box anticipé là ngang nhau.

  • Phần tạo graph cho từng hóa đơn các bạn có thể tham khảo 1 bài hướng dẫn sau:

Modeling

  • Trong paper Invoice-GCN có đề cập tới việc sử dụng Chebyshev-GCN model, là 1 spectral graph neural network. Về phần mô hình, có thể tóm gọn như ảnh dưới:

Imgur

  • Model Chebyshev-GCN được xây dựng với 5 layers như hình trên, bao gồm 4 hidden layers và 1 output layer. Tham số KK trong Chebyshev model được chọn mặc định = 3 tại tất cả các hidden layer với số node input được quy định lần lượt là: 16, 32, 64, 128 (với nf = 16) và output layer gồm 28 output ứng với 28 nhãn / thực thể cần phân biệt. Mũi tên xanh là Relu activation function, mũi tên tím là Softmax, loss function là cross-entropy.

Dataset & Experiment

  • Vì tập dữ liệu khá đặc thù nên không public, bao gồm khoảng 3100 hóa đơn với 27 trường thông tin cần trích rút (product description, unit price, quantity, total, ...)

Imgur

  • Trên đây là bảng kết quả ứng với 27 trường thông tin với f1 trung bình = 0.93. Có thể thấy trong paper lựa chọn cách thức extract text theo word-based, không phải theo sentence-based, sử dụng BPE để encode embedding và xử lý với những case OOV. Ví dụ ảnh visualize kết quả đầu ra của mô hình

Imgur

Huấn luyện mô hình với GCN

Tập dữ liệu hóa đơn - SROIE

  • Trong phần này, mình sẽ thực hiện định nghĩa và huấn luyện mô hình trên tập dữ liệu SROIE-2019. SROIE hay Scanned Receipts OCR and Information Extraction là tập dữ liệu được sử dụng trong RRC Competition - ICDAR 2019. Gồm 3 task con: text detection, text recognition và key information extraction. Các bạn có thể download tập dữ liệu tại trang chủ hoặc dữ liệu đã được xử lý qua link driver sau: preprocessed_SROIE_2019

Định nghĩa mô hình với thư viện Torch-Geometric.

GCN vs Chebyshev GCN

  • Để nhanh gọn, tại phần đầu này mình sẽ sử dụng các module có sẵn trong thư viện Torch-Geometric để xây dựng 1 mô hình đơn giản cho bài toán Invoice Information Extraction. Torch-Geometric, cùng với DGL, là 2 trong số rất nhiều thư viện về Graph Network được xây dựng và sử dụng hiện nay, tính đến thời điểm mình viết bài blog này đã có hơn 10k star trên repo. Torch-Geometric được contribute và cập nhật thường xuyên các mô hình về GNN mới, kèm theo các file example, danh sách các mô hình được support có thể xem thêm tại trang chủ: https://github.com/rusty1s/pytorch_geometric . Việc định nghĩa các thành phần mới cũng khá dễ dàng.

  • Đoạn code bên dưới mô phỏng mô hình GCN được sử dụng trong bài toán này, bao gồm module GCN (trong paper Semi-Supervised Node classification) và Chebyshev-GCN (trong paper Invoice-GCN):

class InvoiceGCN(nn.Module):

    def __init__(self, input_dim, chebnet=False, n_classes=5, dropout_rate=0.2, K=3):
        super().__init__()

        self.input_dim = input_dim
        self.n_classes = n_classes
        self.dropout_rate = dropout_rate

        if chebnet:
            self.conv1 = ChebConv(self.input_dim, 64, K=K)
            self.conv2 = ChebConv(64, 32, K=K)
            self.conv3 = ChebConv(32, 16, K=K)
            self.conv4 = ChebConv(16, self.n_classes, K=K)
        else:
            self.conv1 = GCNConv(self.first_dim, 64, improved=True, cached=True)
            self.conv2 = GCNConv(64, 32, improved=True, cached=True)
            self.conv3 = GCNConv(32, 16, improved=True, cached=True)
            self.conv4 = GCNConv(16, self.n_classes, improved=True, cached=True)

    def forward(self, data):
        # for transductive setting with full-batch update
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr

        x = F.dropout(F.relu(self.conv1(x, edge_index, edge_weight)), p=self.dropout_rate, training=self.training)
        x = F.dropout(F.relu(self.conv2(x, edge_index, edge_weight)), p=self.dropout_rate, training=self.training)
        x = F.dropout(F.relu(self.conv3(x, edge_index, edge_weight)), p=self.dropout_rate, training=self.training)
        x = self.conv4(x, edge_index, edge_weight)

        return F.log_softmax(x, dim=1)

Định nghĩa dataset

  • Để thuận tiện trong việc lưu giữ và lấy các thông tin, mình sử dụng luôn module dataset có sẵn trong Torch-Geometric. Về phần xây dựng node feature cho từng node, mình sử dụng thêm vector embedding từ mô hình Sentence Transformer cho text feature. Đoạn code mô tả như bên dưới:
from torch_geometric.utils.convert import from_networkx
from bpemb import BPEmb
from sentence_transformers import SentenceTransformer

bpemb_en = BPEmb(lang="en", dim=100)
sent_model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')

def make_sent_bert_features(text):
    emb = sent_model.encode([text])[0]
    return emb

def get_data(save_fd):
    """
    returns one big graph with unconnected graphs with the following:
    - x (Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. (default: None)
    - edge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]. (default: None)
    - edge_attr (Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)
    - y (Tensor, optional) – Graph or node targets with arbitrary shape. (default: None)
    - validation mask, training mask and testing mask 
    """
    path = "/gdrive/MyDrive/workspace/data/sroie2019/data/raw/box"
    files = [i.split('.')[0] for i in os.listdir(path)]
    files.sort()
    all_files = files[1:]

    list_of_graphs = []
    train_list_of_graphs, test_list_of_graphs = [], []

    files = all_files.copy()
    random.shuffle(files)

    """Resulting in 550 receipts for training"""
    training, testing = files[:550], files[550:]

    for file in tqdm_notebook(all_files):
        connect = Grapher(file, data_fd)
        G,_,_ = connect.graph_formation()
        df = connect.relative_distance() 
        individual_data = from_networkx(G)

        feature_cols = ['rd_b', 'rd_r', 'rd_t', 'rd_l','line_number', \
                'n_upper', 'n_alpha', 'n_spaces', 'n_numeric','n_special']

        text_features = np.array(df["Object"].map(make_sent_bert_features).tolist()).astype(np.float32)
        numeric_features = df[feature_cols].values.astype(np.float32)

        features = np.concatenate((numeric_features, text_features), axis=1)mak
        features = torch.tensor(features)

        for col in df.columns:
            try:
                df[col] = df[col].str.strip()
            except AttributeError as e:
                pass

        df['labels'] = df['labels'].fillna('undefined')
        df.loc[df['labels'] == 'company', 'num_labels'] = 1
        df.loc[df['labels'] == 'address', 'num_labels'] = 2
        df.loc[df['labels'] == 'date', 'num_labels'] = 3
        df.loc[df['labels'] == 'total', 'num_labels'] = 4
        df.loc[df['labels'] == 'undefined', 'num_labels'] = 5
 
        assert df['num_labels'].isnull().values.any() == False, f'labeling error! Invalid label(s) present in {file}.csv'
        labels = torch.tensor(df['num_labels'].values.astype(np.int))
        text = df['Object'].values

        individual_data.x = features
        individual_data.y = labels
        individual_data.text = text
        individual_data.img_id = file

        if file in training:
            train_list_of_graphs.append(individual_data)
        elif file in testing:
            test_list_of_graphs.append(individual_data)

    train_data = torch_geometric.data.Batch.from_data_list(train_list_of_graphs)
    train_data.edge_attr = None
    test_data = torch_geometric.data.Batch.from_data_list(test_list_of_graphs)
    test_data.edge_attr = None

    torch.save(train_data, os.path.join(save_fd, 'train_data.dataset'))
    torch.save(test_data, os.path.join(save_fd, 'test_data.dataset'))

get_data(save_fd="/gdrive/MyDrive/workspace/data/sroie2019/data/processed" )
  • Về phần tạo graph cho dữ liệu, các bạn có thể tham khảo file sau: graph.py
def load_train_test_split(save_fd):
    train_data = torch.load(os.path.join(save_fd, 'train_data.dataset'))
    test_data = torch.load(os.path.join(save_fd, 'test_data.dataset'))
    return train_data, test_data

train_data, test_data = load_train_test_split(save_fd="/gdrive/MyDrive/workspace/data/sroie2019/data/processed")
print(train_data, test_data)
# Batch(batch=[29704], edge_index=[2, 40638], img_id=[550], text=[550], x=[29707, 778], y=[29707])
# Batch(batch=[3919], edge_index=[2, 5347], img_id=[76], text=[76], x=[3919, 778], y=[3919])

Huấn luyện mô hình

from sklearn.utils.class_weight import compute_class_weight

model = InvoiceGCN(input_dim=train_data.x.shape[1], chebnet=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=0.001, weight_decay=0.9
)
train_data = train_data.to(device)
test_data = test_data.to(device)

# class weights for imbalanced data
_class_weights = compute_class_weight(
    "balanced", train_data.y.unique().cpu().numpy(), train_data.y.cpu().numpy()
)
print(_class_weights)

no_epochs = 2000
for epoch in range(1, no_epochs + 1):
    model.train()
    optimizer.zero_grad()

    # NOTE: just use boolean indexing to filter out test data, and backward after that!
    # the same holds true with test data :D
    # https://github.com/rusty1s/pytorch_geometric/issues/1928
    loss = F.nll_loss(
        model(train_data), train_data.y - 1, weight=torch.FloatTensor(_class_weights).to(device)
    )
    loss.backward()
    optimizer.step()

    # calculate acc on 5 classes
    with torch.no_grad():
        if epoch % 200 == 0:
            model.eval()

            # forward model
            for index, name in enumerate(['train', 'test']):
                _data = eval("{}_data".format(name))
                y_pred = model(_data).max(dim=1)[1]
                y_true = (_data.y - 1)
                acc = y_pred.eq(y_true).sum().item() / y_pred.shape[0]

                y_pred = y_pred.cpu().numpy()
                y_true = y_true.cpu().numpy()
                print("\t{} acc: {}".format(name, acc))
                # confusion matrix
                if name == 'test':
                    cm = confusion_matrix(y_true, y_pred)
                    class_accs = cm.diagonal() / cm.sum(axis=1)
                    print(classification_report(y_true, y_pred))

            loss_val = F.nll_loss(model(test_data), test_data.y - 1
            )
            fmt_log = "Epoch: {:03d}, train_loss:{:.4f}, val_loss:{:.4f}"
            print(fmt_log.format(epoch, loss, loss_val))
            print(">" * 50)

Inference

  • Việc sử dụng API của Torch-Geometric 1 phần khiến cho việc inference trên 1 ảnh (1 graph) gặp khó khăn hơn 1 chút. Ví dụ đoạn code inference trên tập test và visualize được mình định nghĩa như bên dưới:
test_output_fd = "/gdrive/MyDrive/workspace/data/sroie2019/test_output"
shutil.rmtree(test_output_fd)
if not os.path.exists(test_output_fd):
    os.mkdir(test_output_fd)

def make_info(img_id='584'):
    connect = Grapher(img_id, data_fd)
    G, _, _ = connect.graph_formation()
    df = connect.relative_distance()
    individual_data = from_networkx(G)
    img = cv2.imread(os.path.join(img_fd, "{}.jpg".format(img_id)))[:, :, ::-1]

    return G, df, individual_data, img

y_preds = model(test_data).max(dim=1)[1].cpu().numpy()
LABELS = ["company", "address", "date", "total", "other"]
test_batch = test_data.batch.cpu().numpy()
indexes = range(len(test_data.img_id))
for index in tqdm_nb(indexes):
    start = time.time()
    img_id = test_data.img_id[index]  # not ordering by img_id
    sample_indexes = np.where(test_batch == index)[0]
    y_pred = y_preds[sample_indexes]

    print("Img index: {}".format(index))
    print("Img id: {}".format(img_id))
    print("y_pred: {}".format(Counter(y_pred)))
    G, df, individual_data, img = make_info(img_id)

    assert len(y_pred) == df.shape[0]

    img2 = np.copy(img)
    for row_index, row in df.iterrows():
        x1, y1, x2, y2 = row[['xmin', 'ymin', 'xmax', 'ymax']]
        true_label = row["labels"]

        if isinstance(true_label, str) and true_label != "invoice":
            cv2.rectangle(img2, (x1, y1), (x2, y2), (0, 255, 0), 2)

        _y_pred = y_pred[row_index]
        if _y_pred != 4:
            cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 3)
            _label = LABELS[_y_pred]
            cv2.putText(
                img2, "{}".format(_label), (x1, y1),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2
            )

    end = time.time()
    print("\tImage {}: {}".format(img_id, end - start))
    plt.imshow(img2)
    plt.savefig(os.path.join(test_output_fd, '{}_result.png'.format(img_id)), bbox_inches='tight')
    plt.savefig('{}_result.png'.format(img_id), bbox_inches='tight')

Kết quả trên tập test

  • 1 vài kết quả trên tập test, còn 1 số ảnh bị dự đoán nhầm trên từng text box nhưng nhìn chung kết quả thu được khá ổn:

Imgur

Định nghĩa mô hình đơn giản với Pytorch

  • Mô hình dưới đây được mô phỏng theo paper: Semi-supervised classification with GCN. Tuy nhiên, mình có tùy chỉnh bằng việc thêm 1 số các module linear trong mô hình, code minh họa như bên dưới:
class GraphConvolution(nn.Module):

    def __init__(
        self,
        input_dim,
        output_dim,
        dropout=0.2,
        bias=True,
        activation=F.relu,
    ):
        super().__init__()

        if dropout:
            self.dropout = dropout
        else:
            self.dropout = 0.0

        self.bias = bias
        self.activation = activation

        def glorot(shape, name=None):
            """Glorot & Bengio (AISTATS 2010) init."""
            init_range = np.sqrt(6.0 / (shape[0] + shape[1]))
            init = torch.FloatTensor(shape[0], shape[1]).uniform_(
                -init_range, init_range
            )
            return init

        self.weight = nn.Parameter(glorot((input_dim, output_dim)))
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.zeros(output_dim))

    def forward(self, inputs):
        # node feature, adj matrix
        # D^(-1/2).A.D^(-1/2).H_i.W_i
        # with H_0 = X (init node features)
        # V, A
        x, support = inputs

        x = F.dropout(x, self.dropout)
        xw = torch.mm(x, self.weight)
        out = torch.sparse.mm(support, xw)

        if self.bias is not None:
            out += self.bias

        if self.activation is None:
            return out, support
        else:
            return self.activation(out), support

class LinearEmbedding(torch.nn.Module):

    def __init__(self, input_size, output_size, use_act="relu"):
        super().__init__()
        self.C = output_size
        self.F = input_size

        self.W = nn.Parameter(torch.FloatTensor(self.F, self.C))
        self.B = nn.Parameter(torch.FloatTensor(self.C))

        if use_act == "relu":
            self.act = torch.nn.ReLU()
        elif use_act == "softmax":
            self.act = torch.nn.Softmax(dim=-1)
        else:
            self.act = None

        nn.init.xavier_normal_(self.W)
        nn.init.normal_(self.B, mean=1e-4, std=1e-5)

    def forward(self, V):
        # V shape B,N,F
        # V: node features
        V_out = torch.matmul(V, self.W) + self.B
        if self.act:
            V_out = self.act(V_out)

        return V_out
        
class GCN(nn.Module):

    def __init__(self, input_dim, output_dim, hidden_dims=[256, 128, 64],
                 bias=True, dropout_rate=0.1):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.bias = bias
        self.dropout_rate = dropout_rate

        gcn_layers = []
        for index, (h1, h2) in enumerate(
            zip(self.hidden_dims[:-1], self.hidden_dims[1:])):
            gcn_layers.append(
                GraphConvolution(
                    h1,
                    h2,
                    activation=None if index == len(self.hidden_dims) else F.relu,
                    bias=self.bias,
                    dropout=self.dropout_rate,
                    is_sparse_inputs=False
                )
            )

        self.layers = nn.Sequential(*gcn_layers)
        self.linear1 = LinearEmbedding(input_dim, self.hidden_dims[0], use_act='relu')
        self.linear2 = LinearEmbedding(self.hidden_dims[-1], self.output_dim, use_act='relu')

    def forward(self, inputs):
        # features, adj
        x, support = inputs
        x = self.linear1(x)
        x = F.dropout(x, p=self.dropout_rate)
        x, _ = self.layers((x, support))
        x = self.linear2(x)
        return x, support

Mini-batch training on multiple unique graphs

  • 1 vấn đề cần lưu ý nữa là việc training theo batch cho mô hình GCN với multiple unique graphs, tức nhiều graph độc lập với số lượng node trên đồ thị là khác nhau

Imgur

  • Ta lần lượt kí hiệu:

    • AA: adjacency matrix, ma trận kề của đồ thị
    • XX: ma trận lưu giữ node feature ứng với từng nút trong đồ thị
    • YY: ma trận ứng với nhãn của từng nút (company, address, date, total, other)
  • Đặc điểm của dạng dữ liệu đồ thị là số lượng nút (node), cạnh (edge) không cố định và đồ thị thường "thưa thớt" (sparse). Với dạng dữ liệu này, ta có thể thực hiện training theo batch theo phương pháp sau:

Contruct big diagonal matrix

  • Cách thứ hai không cần padding mà sẽ thực hiện tạo 1 ma trận adjacency A mới là sự kết hợp của các ma trận AiA_i con, độc lập theo đường chéo chính của ma trận A. Hơi khó giải thích 1 chút nhưng các bạn xem hình mô tả bên dưới sẽ dễ hiểu hơn

Imgur

  • hay theo kí hiệu toán học

Imgur

Huấn luyện mô hình

  • Phần huấn luyện mô hình cũng không quá khó khăn. Loss function sử dụng cho bài toán node classification là cross-entropy. Vì số lượng class other là lớn hơn rất nhiều so với các class khác, nên trong quá trình training mình còn thực hiện đánh thêm trọng số cho các class ít sample, giúp mô hình học tốt hơn. Ngoài ra, trong phần mô hình bên trên, mình cũng sử dụng thêm các layer Dropout để hạn chế overfit trong quá trình training mô hình.

  • Có nhiều hướng xử lý với bài toán mất cân bằng dữ liệu, 1 cách đơn giản nhất là ta tác động vào loss function. Ví dụ cách đơn giản nhất là đánh thêm trọng số hoặc sử dụng các hàm loss thông dụng hơn cho bài toán imbalanced data như focal loss.

  • Xây dựng các hàm utility

def weight_mask(labels):
    label_classes = copy.deepcopy(LABELS)
    weight_dict = {}
    for k in label_classes:
        if k == "other" or k == 'invoice':
            weight_dict[k] = 0.8
        else:
            weight_dict[k] = 1.0
    tmp_list = []
    for arr in labels:
        index = np.argmax(arr)
        tmp_list.append(weight_dict[label_classes[index]])
    return np.array(tmp_list)

def weighted_loss(preds, labels, weight=None, class_weight=False, device='cuda'):
    """Softmax cross-entropy loss with weights."""
    if class_weight:
        if weight is not None:
            weight = torch.tensor(weight).float().to(device)
        loss = F.cross_entropy(preds, labels, reduction='none', weight=weight)
    else:
        # sample weight
        # https://discuss.pytorch.org/t/how-to-weight-the-loss/66372/3
        # https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264
        loss = F.cross_entropy(preds, labels, reduction='none')
        if weight is not None:
            weight = weight.float()
            loss *= weight
    loss = loss.mean()
    return loss

def load_single_graph(img_id):
    adj = scipy.sparse.load_npz(os.path.join(matrix_dir, img_id + "_adj.npz"))
    features = np.load(os.path.join(matrix_dir, img_id + "_feature.npy"), allow_pickle=True)
    labels = np.load(os.path.join(matrix_dir, img_id + "_label.npy"), allow_pickle=True)
    weights_mask = weight_mask(labels)
    return adj, features, labels, weights_mask

def cal_accuracy(out, label):
    "Accuracy in single graph."
    pred = out.argmax(dim=1)
    correct = torch.eq(pred, label).float()
    acc = correct.mean()
    return acc

def convert_sparse_input(adj, features):
    supports = preprocess_adj(adj)
    # coords, values in coord
    m = torch.from_numpy(supports[0]).long()
    n = torch.from_numpy(supports[1])
    support = torch.sparse.FloatTensor(m.t(), n, supports[2]).float()

    features = [
        torch.tensor(idxs, dtype=torch.float32).to(device)
        if torch.cuda.is_available()
        else torch.tensor(idxs, dtype=torch.float32)
        for idxs in features
    ]
    features = torch.stack(features).to(device)

    if torch.cuda.is_available():
        m = m.to(device)
        n = n.to(device)
        support = support.to(device)
    return features, support

def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))  # D
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()  # D^-0.5
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)  # D^-0.5
    return (
        adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
    )  # D^-0.5AD^0.5

def sparse_to_tuple(sparse_mx):
    """Convert sparse matrix to tuple representation."""

    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        coords = np.vstack((mx.row, mx.col)).transpose()
        values = mx.data
        shape = mx.shape
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx

def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
    adj_normalized = normalize_adj(adj + scipy.sparse.eye(adj.shape[0]))
    return sparse_to_tuple(adj_normalized)

def convert_loss_input(y_train, weight_mask):
    train_label = torch.from_numpy(y_train).long()
    weight_mask = torch.from_numpy(weight_mask)

    if torch.cuda.is_available():
        train_label = train_label.to(device)
        weight_mask = weight_mask.to(device)
    train_label = train_label.argmax(dim=1)

    return train_label, weight_mask

def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))  # D
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()  # D^-0.5
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)  # D^-0.5
    return (
        adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
    )  # D^-0.5AD^-0.5
  • Chú ý, ta normalize Adjacency matrix đúng như trong công thức của GCN model, với:

Hi+1=f(Hi,A)=σ(AHiWi)=σ(D1AHiWi)H^{i + 1} = f(H^{i}, A) = \sigma(A H^{i} W^{i}) = \sigma(D^{-1}A H^{i} W^{i})

với Symmetric Normalization, đổi công thức thành

σ(D1AHiWi)=σ(D1/2AD1/2HiWi)\sigma(D^{-1}A H^{i} W^{i}) = \sigma(D^{-1/2}A D^{-1/2} H^{i} W^{i})

với H0=XH^{0} = X

Reference http://web.stanford.edu/class/cs224w/slides/08-GNN.pdf

hoặc các bạn có thể đọc thêm tại bài hướng dẫn sau: Graph Neural Network - A literature review and applications

  • Phần code thực hiện training mô hình
# training model
for epoch in range(20):
    net.train()
    random.shuffle(train_ids)
    t1 = time.time()
    train_losses = 0
    train_accs = []

    batch_losses = []
    # simple training with batch_size = 1
    for img_index, img_id in tqdm.notebook.tqdm(enumerate(train_ids)):
        adj, features, train_labels, weight_mask_ = load_single_graph(img_id)
        features, support = convert_sparse_input(adj, features)
        train_labels, weight_mask_ = convert_loss_input(train_labels, weight_mask_)
        support = support.to(device)
        out = net((features, support))[0]
        loss = weighted_loss(out, train_labels, _class_weights, class_weight=True)

        train_losses += loss.item()
        batch_losses.append(loss.item())
        if img_index % 100 == 0:
            print("\ttrain loss: {:.5f} ".format(np.mean(batch_losses)))
            batch_losses = []

        acc = cal_accuracy(out, train_labels)
        train_accs.append(acc.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_losses /= (img_index + 1)
    acc = np.mean(train_accs)
    t2 = time.time()
    print(
        "Epoch:",
        "%04d" % (epoch + 1),
        "time: {:.5f}, loss: {:.5f}, acc: {:.5f}".format(
            (t2 - t1), train_losses, acc.item()
        ),
    )

Hậu xử lý

  • Tập dữ liệu SROIE 2019 vẫn tồn tại 1 số trường hợp mà dữ liệu annotate chưa đúng, ví dụ như label của các text box. Điều này là khó tránh khỏi và dễ gây nhầm lẫn cho mô hình. Để hạn chế việc đó, mình có tiến hành hậu xử lý để hạn chế các case false-positive và lọc false-negative:
    • Với tất cả các text box trong ảnh, set 1 ngưỡng (threshold) để lọc bớt các case có confidence score không cao, ví dụ 0.7. Nếu thấp hơn threshold thì gán lại nhãn là other
    • Với tất cả các hóa đơn, chỉ tồn tại 1 text box với nhãn là total hoặc date. Vậy nên việc lấy output có xác suất lớn nhất ứng với từng text box có thể dẫn đến xuất hiện nhiều các text box của total / date. Đơn giản là ta sẽ lấy text box có confidence score cao nhất.
    • 1 số trường hợp, các text box bị dự đoán nhầm là total (ví dụ như các text-box với nội dung là Total, Total Cost,...) mà đáng lẽ phải là text box số nằm cùng hàng bên phải. Việc dự đoán nhầm vậy cũng 1 phần do tập dữ liệu annotate không chính xác. Cách đơn giản nhất để xử lý case này là ta lấy text box có confidence score cao nhất, ứng với nhãn total và dóng sang phải, nếu tồn tại 1 text box khác thì gán lại label là total cho text box đó.

Kết quả

  • 1 số kết quả thu được sau khi training mô hình với Pytorch

Imgur

1 số cách tiếp cận và hướng phát triển khác cho bài toán KIE

  • Như hướng tiếp cận mình vừa có trình bày bên trên, việc sử dụng Adjacency matrix kèm theo node feature giúp chúng ta sử dụng được các thông tin về vị trí và texture của từng text box. Tuy nhiên, các bạn có thể tham khảo thêm 1 số các cải tiến từ các paper dưới đây:
    • Kết hợp graph embedding, text embedding, edge feature và Bi-LSTM - CRF để tiến hành phân loại các trường thông tin như trong paper GCN for multi-modal visual rich document
    • Kết hợp các thông tin của cả global context và local context như: ảnh, text embedding, vị trí,... để xây dựng và kết hợp nhiều loại feature với nhau. 1 số paper có thể kể tới như: LayoutLM, PICK, Attention-based GNN with Global Context
    • Sử dụng thêm các thông tin về text như: font size, font type để xây dựng font embedding cho từng text box như trong paper Robust-layout IE for VRD

GCN libraries

  • Dưới đây là 1 vài các thư viện về GNN các bạn có thể tham khảo thử, cá nhân mình thường dùng Pytorch-Geometric và DGL, Stellagraph cũng đã từng sử dụng nhưng API không được thân thiện lắm 😦

Kết luận

  • Trên đây là bài chia sẻ về việc áp dụng mô hình Graph Neural Network vào 1 bài toán khá đặc thù là Key Informatrion Extraction. Hi vọng qua bài viết này sẽ giúp các bạn hiểu và có thể áp dụng vào các vấn đề tương tự. Đừng quên upvote và chia sẻ rộng rãi bài viết này tới mọi người nhé ^^

  • Mọi ý kiến phản hồi và góp ý vui lòng comment bên dưới bài viết hoặc gửi mail về địa chỉ: hoangphan0710@gmail.com 😄 Hẹn gặp lại các bạn trong những bài viết sắp tới!

Reference


All rights reserved

Viblo
Hãy đăng ký một tài khoản Viblo để nhận được nhiều bài viết thú vị hơn.
Đăng kí