Disjoint set cấu trúc dữ liệu đơn giản mà hiệu quả

Bất kì bài toán tin học nào cũng được giải quyết dựa trên thuật toán (Algorithm) và cấu trúc dữ liệu biểu diễn nó (Data Structure).
Algorithms + Data Structures = Programs
Tuy nhiên, trong quá trình giải quết bài toán ta thường quá chú tâm tới giải thuật mà quên mất rằng việc lựa chọn cấu trúc dữ liệu hợp lý ảnh hưởng rất nhiều tới thuật toán. Hay hiệu quả của thuật toán phụ thuộc vào cấu trúc dữ liệu được sử dụng.

Cấu trúc dữ liệu cũng rất đa dạng từ những cấu trúc dữ liệu đơn giản như mảng một chiều, nhiều chiều tới ngăn xếp, hàng đợi, bảng băm, cấu trúc dữ liệu dạng cây như cây nhị phân, heap, ... và những cấu trúc dữ liệu nâng cao khác.

Ngoài ra còn có một dạng cấu trúc dữ liệu cài đặt khá đơn giản nhưng rất hiệu quả trong nhiều bài toán đó là Disjoint set.

Disjoint set là gì ?

Theo wikipedia: https://en.wikipedia.org/wiki/Disjoint-set_data_structure

Disjoint set là một cấu trúc dữ liệu theo dõi (tracking) một tập các phần tử được phân chia thành các tập con khác nhau không chồng chéo nhau (non-overlapping).
Hay đơn giản hơn ta có ví dụ sau:
Bài toán 1
Ta có 6 thành phố là A, B, C, D, E, F. Thành phố A có thể đi tới thành phố B và C. Thành phố B có thể đi tới thành phố D. Thành phố E có thể đi tới thành phố F. Các thành phố xảy ra chiến tranh những thành phố có đường đi trưc tiếp ( chẳng hạn từ A tới B và C) hoặc gián tiếp ( từ A tới D qua B) thì tướng lĩnh của thành phố đó có thể đem quân đánh chiếm thành phố khác sáp nhập làm của mình tạo thành một vùng mới và có 1 vị vua cai quản. Vậy khi tàn cuộc có bao nhiêu vùng đất khác nhau được cai quản bởi các vị vua.
Để tìm số vùng đất được cai quản bởi các vị vua khác nhau ta sẽ chia tập hợp các thành phố {A, B, C, D, E, F} thành các tập con sao cho hai tập con bất kì không thể có đường đi từ thành phố của tập con này tới thành phố của tập con kia.
Dễ thấy các vùng đất {A, B, C, D} có đường đi với nhau cuối cùng sẽ gộp chung lại thành 1 vùng đất được cai quản bởi 1 vị vua. Tuy nhiên vị vua đó không thể tiến đánh các thành phố E, F do không có đường đi. Cuối cùng sẽ chỉ có 2 vùng đất được tạo thành bởi các thành phố {A, B, C, D} và {E, F}.

Bài toán 2
Với một bài toán tương tự nhưng ta xét một ví dụ dễ biểu diễn trong máy tính hơn đó là: Cho tập hợp gồm N phần tử A = {1, 2, 3, 4, 5, 6, 7, 8 } các phần tử (1, 2) (1, 3) (2, 3) (4, 7) (4, 8) (5, 6) có liên kết với nhau. Tìm các tập hợp con mà giữa 2 tập hợp không có liên kết trực tiếp hoặc gián tiếp.

Nếu không sử dụng ý tưởng cấu trúc dữ liệu disjoint set ta có cách tiếp cận sau
Sử dụng mảng link[] để theo dõi liên kết giữa 2 phần tử tức phần tử 1 và 8 có liên kết với nhau nếu link[1] = link[8].
Giả sử ban đầu chưa có liên kết nào được thêm vào thì mỗi phần tử sẽ là 1 tập con link[] = {1, 2, 3, 4, 5, 6, 7, 8}.
Bây giờ ta sẽ duyệt qua tất cả phần tử và kiểm tra liên kết trực tiếp của nó. Khi thêm liên kết ta sẽ thực hiện 2 thao tác đó là

  • find(x, y) kiểm tra x và y có liên kết chưa (link[x] == link[y])
  • union(x, y) hợp x vào y và tất cả các phần tử liên kết với x cũng đều liên kết với y.

Chẳng hạn:
Bắt đầu với 1 có

  • liên kết (1, 2)
    find(1, 2) trả về false do link[1] != link[2]
    -> union(1, 2) link[1] = link[2] = 2.
  • liên kết (1,3)
    find(1, 3) trả về false -> link[1] = link[3] = 3 và các phần tử liên kết với 1 (phần tử 2) cũng liên kết với 3 link[2] = link[3] = 3

Tiếp theo với phần tử 2

  • liên kết (2, 3)
    find(2, 3) trả về true -> không cần union nữa

Tương tự với phần tử 3, 4, 5, 6, 7, 8. Cuối cùng ta có mảng link[] = {3, 3, 3, 8, 6, 6, 8, 8} như vậy trong mảng link[] có 3 phần tử khác nhau -> có 3 tập hợp con

    void init( int link[], int N)
    {
        for(int i = 1; i <= N; i++)
        link[i] = i ;
    }

    bool find( int link[], int x, int y)                           
    {
        // return true nếu x có liên kết với y
        if(link[x] == link[y])
            return true;
        return false;   
    }
    
    
    void union(int link[], int N, int x, int y)
    {
        // cập nhật lại x và các phần tử liên kết với x liên kết với y
        int temp = link[x];
        for(int i = 1; i <= N; i++)
        {
            if(link[i] == temp)
            link[i] = link[y]; 
        }
    }
    
    void solve()
    {
         for (int i = 1; i <= N; i++)
         {
             // với j là các cạnh có liên kết trực tiếp với i
             if (find(i, j) == false)
                 union(i, j);
         }
    }
   

Tuy nhiên có thể thấy độ phức tạp thuật toán của lời giải trên là O(n2). Đây là độ phức tạp khá lớn => Sử dụng disjoint set để đạt hiệu năng tốt hơn

Ý tưởng của disjoint set

Thay vì sử dụng mảng link[] như cách tiếp cận trên ta sử dụng mảng parent[] ta có parent[x] là cha của x. Nếu x có liên kết với y thì parent của y bằng x => Tất cả các phần tử có liên kết với nhau đều có chung tổ tiên (root).
Giả sử ta có tập A = {1, 2, 3, 4 ,5, 6, 7, 8, 9} và có các liên kết (6, 1) (6, 8) (4, 2) (4, 3) (4, 9) (9, 7)
(Nguồn: Google)

Lúc này ta có 3 cây với 3 root là 5, 6, 4 là 3 tập hợp con. Do ta xây dựng 3 tập hợp con là cấu trúc dạng cây nên để kiểm tra 2 phần tử có liên kết hay không ta chỉ cần kiểm tra xem chúng có chung root hay không
find(x, y) sẽ thực hiện findSet(x) == findSet(y) với findSet(x) trả về phần tử root của tập hợp chứa x. Ví dụ findSet(2) = 4, findSet(7) = 4. Độ phức tạp thuật toán là O(logN) hàm findSet sẽ đệ quy theo chiều cao của cây.

union(x,y) khi find(x, y) trả về false -> gắn root của cây chứa y vào root của cây chứa x hoặc ngược lại gắn root của cây chứa x vào root của cây chứa y. Tuy nhiên nếu luôn luôn chọn gắn root của cây chứa y vào root của cây chứa x hay cách ngược lại có thể dẫn đến chiều cao của cây đạt 0(N). Hay nói cách khác là cây bị suy biến dẫn đến hiệu quả thao tác trên cây bị sụt giảm (Nguồn: Google)

=> findSet(x) rơi vào worst case cũng là O(N). Để tránh điều này ta sử dụng kỹ thuật union by rank (hoặc by size). Union by rank luôn gắn root của cây có độ cao thấp hơn vào root của cây cao hơn. Như thế độ cao của cây thấp hơn sẽ tăng lên 1( chính là đỉnh root của cây cao hơn). Đây cũng là kĩ thuật mình sẽ implement ở dưới đây.

Cài đặt disjoint set

Ở bước khởi tạo ta coi mỗi phần tử là 1 cây với 1 nút root là chính nó thông qua makeSet(x)

    void makeSet(int x){
        p[x] = x;
        rank[x] = 0;
    }

Tiếp theo ta viết hàm xác định tập hợp (cây) chứa x

    int findSet(int x){
        if(x != parent[x])
            parent[x] = findSet(parent[x]);
        return parent[x];
    }

Cuối cùng là union thay vì luôn gắn root của cây chứa y vào root của cây chứa x ta thực hiện union by rank

    void union(int x, int y){
        if(rank[x] > rank[y]) parent[y] = x; // gắn root của cây thấp (rank[y] < rank[x]) vào cây cao
        else {
            parent[x] = y;
            if(rank[x] == rank[y]) rank[y] = rank[y] + 1; // nếu 2 cây cao bằng nhau tăng rank của cây được gán lên 1
        }
    }

Với các ví dụ trên việc áp dụng cấu trúc dữ liệu disjoint set thay vì độ phức tạp là O(N2) sẽ có độ phức tạp là 0(NlogN)



Ứng dụng

Ứng dụng được biết tới nhiều nhất của disjoint set đó là nó theo dõi các thành phần liên thông (connected components) của đồ thị vô hướng đóng vài trò quan trọng xây dựng thuật toán Kruskal để tìm cây khung nhỏ nhất ( minimum spanning tree - MST). MST giúp giải quyết 1 số bài toán thực tế như các ứng dụng trực tiếp trong thiết kế mạng, bao gồm mạng máy tính, mạng viễn thông, mạng giao thông, mạng cấp nước và lưới điện. Chúng được gọi như một chương trình con trong các thuật toán cho các vấn đề khác bao gồm thuật toán Christofides để tính gần đúng cho vấn đề nhân viên bán hàng du lịch, ...
Tham khảo thêm https://en.wikipedia.org/wiki/Minimum_spanning_tree
Chẳng hạn với bài toán giảm chi phí đi cáp ngầm TV hay điện thoại

Ta cần xây dựng 1 hệ thống cáp nối tất cả các điểm trong thành phố A, B, C, D, ... Tuy nhiên chi phí cáp giữa các điểm là khác nhau. MST giúp ta xác định đươc 1 hệ thống (dạng cây) nối tất cả các điểm với chi phí nhỏ nhất như hình vẽ.

Ý tưởng của thuật toán Kruskal dựa trên disjoint set và tham lam đã được trình bày ở đây: https://viblo.asia/p/cac-dang-bai-su-dung-thuat-toan-tham-lam-greedy-algorithm-problems-924lJARYZPM#_cac-bai-toan-lien-quan-toi-do-thi-6

Xây dựng hệ thống đếm số lượng bạn chung (mutual friends)

Hiện tại một nhóm sinh viên trong trường đang có trang social networking giúp kết nối bạn bè trong trường học tương tác với nhau, những người e ngại việc sử dụng facebook do bị leak thông tin. Một trong số những tính năng mà adminstrator đang cần bạn phát triển đó là hiển thị số lượng bạn chung (mutual friend) của giữa X và Y ở đây tất những người bạn trực tiếp với X hay gián tiếp qua một người bạn Z của X lúc này có thể coi là bạn chung với Y. Các requirement bạn nhận được như sau:

  1. Nếu user A gửi request cho user B thì A sẽ following B. A và B chỉ là bạn bè khi B cũng following A
  2. Khi A và B đã là bạn bè, có thể xem được số lượng bạn chung giữa A và B.

Phân tích
Đối với yêu cầu 1 ta xây dựng 1 mảng vector following[] cho từng user quản lí id của các user khác đang following A. Chẳng hạn following[3] sẽ là 1 vector chứa các id của user đang following user có id là 3. A và B là bạn bè khi following của B chứa id của A và ngược lại.

Khi A và B đã là bạn bè ta sẽ hợp (union) X và Y.

Đối với yêu cầu 2 khi nhận được yêu cầu đếm số lượng bạn chung giữa A và B ta chỉ việc kiểm tra xem A và B có chung root hay không nó có trả về số lượng thành viên của root đó.

Note: Ở trên ta cài đặt union by rank theo độ cao của cây. Tuy nhiên đối với bài toán này ta cần lưu trữ số lượng thành viên của root (số lượng phần tử trong cây của root) nên ta thực hiện union by size thay vì union by rank

Để dễ biểu diễn nhất ta quy định yêu cầu 1 sẽ truyền và params là 'following' A_id B_id với A_id là user_id người gửi request following cho B_id. yêu cầu 2 sẽ truyền params là 'mutual friends' A_id B_id.

Ta có implement đơn giản trong C như sau

    void makeSet(int x) 
    {
        parent[x] = x;
        size[x] = 1;
    }
    
    int findSet(int x){
        if(x != parent[x])
            parent[x] = findSet(parent[x]);
        return parent[x];
    }
    
    
    void union(int x, int y) 
    {
        if(size[x]<size[y])
        {
            parent[x] = y;
            siz[y] += siz[x];
        } else
        {
            parent[y] = x;
            size[x] += size[y];
        }
}
    int mutual_friend(int x,int y)
    {
       if(x != y)
            return 0;
        return size[x]-2;
}

int main(void)
{
    for(int i = 1; i <= n; i++)
    {
        makeSet(i);
    }
    
    while(true) // đọc vào các yêu cầu
    {
        cin >> option; // đọc vào option 
        cin >> x;
        cin >> y;
        if(option == "following")
        {
            following[y].push_back(x);
            if(following của x cũng chứa y) 
            {
                int root_x = findSet(x);
                int root_y = findSet(y);
                 union(root_x, root_y);
            }
               
        }
        else {
            int root_x = findSet(x);
            int root_y = findSet(y);
            cout << mutual_friend(root_x, root_y) << "\n";
        }
    }
    
    return 0;
}