Giới thiệu về Graph Neural Networks (GNNs)
Dữ liệu đồ thị chắc hẳn các bạn đã và đang tìm hiểu về học sâu và học máy cũng đã từng nghe qua khái niệm và các bài toán về đồ thị. Nhưng không có quá nhiều bạn thực sự hiểu và triển khai các bài toán trên dữ liệu đồ thị một các hiệu quả. Vậy nên việc hiểu về cách mà các mô hình học sâu được xây dựng trên dữ liệu đồ thị hoạt động như thế nào và triển khai một bài toán đơn giản trên dữ liệu đồ thị sao cho hiệu quả lần rất quan trọng và cần thiết. Dữ liệu đồ thị đã trở thành một phần quan trọng trong lĩnh vực máy học và khai phá dữ liệu, với khả năng biểu diễn và phân tích các mối quan hệ phức tạp giữa các đối tượng. Đồ thị có thể mô tả mạng xã hội, mạng lưới giao thông, sự tương tác giữa các phân tử trong hóa học, và nhiều hình thái khác. Bằng cách sử dụng các đỉnh (nút) và các cạnh (liên kết), đồ thị cung cấp khung cho việc phân tích mô hình, dự đoán và hiểu các quy luật ẩn trong dữ liệu. Trong quá trình phát triển AI hiện nay, việc hiểu và làm việc với dữ liệu đồ thị trở thành một kỹ năng quan trọng cho các nhà nghiên cứu và nhà phát triển. Trong bài viết này, mình sẽ giới thiệu về cách làm việc với dữ liệu đồ thị và các nhiệm vụ phổ biến trong lĩnh vực này nhé.
What is a Graph?
Graph hay đồ thị là những cái tên thị là loại cấu trúc dữ liệu chứa các nút và cạnh. Một nút có thể là một người, địa điểm hoặc vật thể, và các cạnh xác định mối quan hệ giữa các nút. Các cạnh có thể có hướng và không hướng dựa trên các phụ thuộc tương ứng . Trong ví dụ dưới đây, các hình tròn màu đỏ là các nút và các mũi tên là các cạnh. Hướng của các cạnh xác định các phụ thuộc giữa hai nút.
Một ví dụ thực tế hơn: Mạng lưới Nhạc Jazz - Jazz Musicians Network. Nó chứa 198 nút và 2742 cạnh. Trong biểu đồ đồ thị cộng đồng dưới đây, các màu sắc khác nhau của các nút đại diện cho các cộng đồng khác nhau của các nhạc sĩ Jazz và các cạnh kết nối giữa chúng. Có một mạng lưới hợp tác, trong đó một nhạc sĩ đơn lẻ có thể có mối quan hệ bên trong và bên ngoài cộng đồng.
Tạo một đồ thị với NetworkX
Trong phần này, chúng ta sẽ tìm hiểu cách tạo một đồ thị sử dụng NetworkX.
Tạo đối tượng DiGraph
của NetworkX với tên là "H".
Thêm các nút chứa các nhãn, màu sắc và kích thước khác nhau.
Thêm các cạnh để tạo mối quan hệ giữa hai nút. Ví dụ: "(0,1)" có nghĩa là nút 0 có phụ thuộc hướng đến nút 1. Chúng ta sẽ tạo các mối quan hệ hai chiều bằng cách thêm "(1,0)".
Trích xuất màu sắc và kích thước dưới dạng danh sách.
Vẽ đồ thị bằng cách sử dụng hàm draw
của NetworkX.
import networkx as nx
H = nx.DiGraph()
#adding nodes
H.add_nodes_from([
(0, {"color": "blue", "size": 250}),
(1, {"color": "yellow", "size": 400}),
(2, {"color": "orange", "size": 150}),
(3, {"color": "red", "size": 600})
])
#adding edges
H.add_edges_from([
(0, 1),
(1, 2),
(1, 0),
(1, 3),
(2, 3),
(3,0)
])
node_colors = nx.get_node_attributes(H, "color").values()
colors = list(node_colors)
node_sizes = nx.get_node_attributes(H, "size").values()
sizes = list(node_sizes)
#Plotting Graph
nx.draw(H, with_labels=True, node_color=colors, node_size=sizes)
Tổng quan về Graph Neural Network (GNN)?
Mạng Nơ-ron Đồ Thị (Graph Neural Network - GNN) là một loại mô hình học máy được thiết kế đặc biệt để làm việc với dữ liệu đồ thị. GNN có khả năng mở rộng và áp dụng trên các đồ thị có cấu trúc phức tạp, như mạng xã hội, mạng lưới giao thông, hay bất kỳ hệ thống nào có mối quan hệ giữa các đối tượng. GNN hoạt động bằng cách truyền thông tin qua các đỉnh và cạnh trong đồ thị. Mô hình học thông qua việc cập nhật và kết hợp thông tin từ các hàng xóm của mỗi đỉnh, cho phép nắm bắt thông tin cấu trúc và tương tác giữa các đối tượng trong đồ thị. Một trong những đặc điểm đáng chú ý của GNN là khả năng tích hợp thông tin từ cả đặc trưng của các đỉnh và cấu trúc đồ thị. Điều này cho phép GNN học mô hình phức tạp và biểu diễn các mối quan hệ phức tạp giữa các đối tượng trong đồ thị. GNN đã chứng tỏ được hiệu quả trong nhiều nhiệm vụ, bao gồm phân loại đồ thị, phân loại nút, dự đoán liên kết và nhúng đồ thị. Các ứng dụng của GNN rất đa dạng, từ phân tích mạng xã hội, gợi ý người dùng, cho đến phát hiện và kiểm soát các hiện tượng trong các hệ thống phức tạp.
Đồ thị đầu vào được đi qua một loạt mạng neural. Cấu trúc đồ thị đầu vào được chuyển đổi thành nhúng đồ thị, cho phép chúng ta duy trì thông tin về các nút, cạnh và ngữ cảnh toàn cục. Sau đó, vectơ đặc trưng của các nút A và C được thông qua lớp mạng neural. Nó tổng hợp những đặc trưng này và truyền chúng vào lớp tiếp theo.
Tuy GNN đã mang lại nhiều tiến bộ, nhưng vẫn còn nhiều thách thức trong việc khai thác toàn bộ tiềm năng của dữ liệu đồ thị và tăng tính khả chuyển của mô hình. Các nghiên cứu về GNN đang tiếp tục phát triển để nâng cao hiệu suất và ứng dụng của mô hình trong các lĩnh vực khác nhau.
Một số loại Graph Neural Networks
Có nhiều loại Mạng Nơ-ron Đồ Thị (Graph Neural Networks - GNN) được phát triển để làm việc với dữ liệu đồ thị. Dưới đây là một số loại GNN phổ biến:
-
Graph Convolutional Networks (GCN): GCN là một dạng GNN đơn giản và phổ biến. Nó sử dụng cơ chế tích chập đồ thị để truyền thông tin qua các đỉnh và cạnh trong đồ thị. GCN kết hợp thông tin đặc trưng của các đỉnh và cấu trúc đồ thị để thực hiện phân loại hoặc dự đoán trên đồ thị.
-
GraphSAGE: GraphSAGE (Graph Sample and Aggregated) sử dụng một quá trình lấy mẫu và tổng hợp thông tin từ hàng xóm của các đỉnh để cập nhật đặc trưng của mỗi đỉnh. Điều này giúp GraphSAGE học được biểu diễn đồ thị tổng thể và khả năng xử lý các đồ thị lớn.
-
Graph Attention Networks (GAT): GAT sử dụng cơ chế chú ý (attention mechanism) để xác định mức độ quan trọng của các hàng xóm đối với mỗi đỉnh trong đồ thị. Bằng cách trọng số hóa thông tin từ các hàng xóm, GAT tập trung vào các đỉnh quan trọng và xử lý đồ thị một cách linh hoạt.
-
Graph Autoencoders (GAE): GAE là một dạng GNN được sử dụng để học biểu diễn nhúng (embedding) của đồ thị. GAE cố gắng tái tạo đồ thị gốc từ biểu diễn nhúng, giúp học các đặc trưng ngầm của đồ thị và khám phá cấu trúc ẩn.
-
Graph Recurrent Neural Networks (GRNN): GRNN sử dụng kiến trúc mạng nơ-ron hồi quy để xử lý dữ liệu đồ thị có thứ tự thời gian. GRNN có khả năng mô hình các quá trình động trên đồ thị như dự đoán chuỗi thời gian hoặc phân tích dữ liệu đồ thị theo thời gian.
Một số nhiệm vụ của Graph Neural Networks
- Graph Classification: chúng ta sử dụng phương pháp này để phân loại đồ thị vào các danh mục khác nhau. Ứng dụng của nó bao gồm phân tích mạng xã hội và phân loại văn bản.
- Node Classification: nhiệm vụ này sử dụng nhãn của các nút láng giềng để dự đoán nhãn của các nút bị thiếu trong đồ thị.
- Link Classification: dự đoán liên kết giữa một cặp nút trong đồ thị với ma trận kề không đầy đủ. Phương pháp này thường được sử dụng trong mạng xã hội.
- Community detection: chia các nút thành các cụm khác nhau dựa trên cấu trúc cạnh. Nó học từ trọng số cạnh, khoảng cách và đối tượng đồ thị tương tự.
- Graph Embedding: ánh xạ đồ thị thành các vectơ, bảo tồn thông tin quan trọng về các nút, cạnh và cấu trúc.
- Graph Generation: học từ phân phối đồ thị mẫu để tạo ra một cấu trúc đồ thị mới nhưng tương tự.
Triển khai một bài toán Graph Neural Networks
Cài đặt môi trường
!pip install -q torch
%%capture
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['PYTHONWARNINGS'] = "ignore"
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git
Dữ liệu
Bộ dữ liệu Planetoid Cora Planetoid là một bộ dữ liệu mạng trích dẫn từ Cora, CiteSeer và PubMed. Các nút là các tài liệu với vectơ đặc trưng bag-of-words 1433 chiều, và các cạnh là các liên kết trích dẫn giữa các bài báo nghiên cứu. Có 7 lớp, và mình sẽ huấn luyện mô hình để dự đoán nhãn bị thiếu.
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print(data)
Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
Xây dựng mô hình phân loại GNN
Chúng ta sẽ tạo ra một kiến trúc GCN với 2 khối GCNConv, hàm kích hoạt relu và một hệ số dropout là 0.5.
Công thức của GCN layer:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
torch.manual_seed(1234567)
self.conv1 = GCNConv(dataset.num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GCN(hidden_channels=16)
print(model)
>>> GCN(
(conv1): GCNConv(1433, 16)
(conv2): GCNConv(16, 7)
)
Trực quan hóa mô hình trước khi huấn luyện
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
def visualize(h, color):
z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
plt.figure(figsize=(10,10))
plt.xticks([])
plt.yticks([])
plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
plt.show()
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
Huấn luyện GNN
Chúng ta sẽ huấn luyện mô hình trong 100 Epochs bằng cách sử dụng optimizer là Adam và hàm loss là Cross-Entropy.
Code của phần huấn luyện mô hình:
Tạo tỷ lệ độ chính xác bằng cách tính tổng số dự đoán chính xác chia cho tổng số nút.
model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss
def test(mask):
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
correct = pred[mask] == data.y[mask]
acc = int(correct.sum()) / int(mask.sum())
return acc
val_acc_all = []
test_acc_all = []
for epoch in range(1, 101):
loss = train()
val_acc = test(data.val_mask)
test_acc = test(data.test_mask)
val_acc_all.append(val_acc)
test_acc_all.append(test_acc)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
Kết quả huấn luyện
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
>>> Test Accuracy: 0.8150
Trực quan hóa mô hình sau khi huấn luyện
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
Kết luận
Tóm tắt lại thì bài viết mình đã mô tả khái niệm và cách triển khai một bài toán sử dụng kiến trúc mạng neural networks cho một bộ dữ liệu được mã hóa theo dạng đồ thị một cách đơn giản. Tuy nhiên các mô hình mạng Graph Neural Networks có rất nhiều biến thể và các kiến trúc độc đáo khác nhau tùy thuộc vào mục tiêu bài toán, hy vọng bài viết có thể giúp các bạn tiếp cận nhanh nhất với một bài toán liên quan tới kiểu dữ liệu đồ thị.
Tài liệu tham khảo
[1] A Gentle Introduction to Graph Neural Networks
[2] Graph neural networks: A review of methods and applications
[3] A Comprehensive Introduction to Graph Neural Networks (GNNs)
All rights reserved