Atcoder Beginner Contest 120 (Union-Find Tree)

AtCoder Beginner Contest 120のふりかえり

atcoder.jp

結果はA、B、Cの三完。Cに時間をかけすぎてスコアがやばいことになった

C問題

01がそれぞれ1つ以上あったら、どこかで01が隣り合っているのだから結局S中の0の総数と1の総数の少ない方回取り除ける。

それに気づかずごちゃごちゃやってしまった。

S = input()

save = []
flag = False
modori_flag = False
count = 0
for idx, s in enumerate(S):
    if modori_flag:
        modori_flag = False
    elif flag:
        flag = False
        continue
            
    this_ = int(s)
    if len(save) > 0 and this_ + save[-1] == 1:
        modori_flag = True
        del save[-1]
        count += 2
        continue
    if idx == len(S) -1:
        break
    next_ = int(S[idx+1])
    if this_ + next_ == 1:
        flag = True
        count += 2
        continue
        
    save.append(this_)
    
print(count)

D問題

やったこと

まっさらな状態から橋を1本ずつかけていくという発想まではいけたが、それ以上は時間切れ

解法

キーポイントは

i + 1 番目の辺 ei+1 = (ai+1, bi+1) によって 1. 元々頂点 ai+1, bi+1 が連結でなかった (互いに行き来可能でなかった) 場合 辺 e を加える前のグラフにおいて、頂点 ai+1 を含む連結成分の大きさを N1、頂点 bi+1 を含む連結成分の大きさを N2 とすると、 ans(i) = ans(i + 1) − N1 × N2 です。 2. 元々連結だった場合 不便さに変化はなく、ans(i) = ans(i + 1) です

ここの部分だと思う。 N1とN2を高速に求めるために、Union-Findというデータ構造を利用する。

Union-Findはグループ分けをうまく扱えるデータ構造で、

unite(u, v): 頂点 u が属するグループと頂点 v が属するグループを併合し、同じグループにする
(O(α(n)))
find(v): 頂点 v が属するグループ番号を得る (O(α(n)))
size(v): 頂点 v が属するグループと同じグループに属する頂点数を得る (O(1))

という操作を定義する。

www.slideshare.net (解説スライドがあったので神) (こういう頻出のデータ構造を知っておきたい感がある...と思ったらあった↓) プログラミングコンテストでのデータ構造

実装上は木構造を用いることができる。

class UnionFind():
    def __init__(self, size):
        # self.group holds a list of idx for parant node of i-th node
        # if self.group[i] = i, i-th node is a root.
        # if self.group[i] = j, j-th node is a paranet of i-th node. 
        self.group = [i for i in range(size)]
        
        # self.height holds a list of trees' heights
        self.height = [0] * (size)
        
    def find(self, i):
        """
        Search for i-th node's group.
        Args : 
            i : index of target node.
        Return :
            index of root of target node
        """ 
        
        if self.group[i] == i:
            # if root
            return i
        else:
            # if not root, search by the parent's node
            self.group[i] = self.find(self.group[i])
            return self.group[i]     
        
    def unite(self, i, j):
        """
        Unite two groups
        Args : 
            i, j : indices of target nodes.
        """ 
        # search roots
        i = self.find(i)
        j = self.find(j)
        
        if i == j:
            return 
        
        if self.height[i] < self.height[j]:
            self.group[i] = j
        else:
            self.group[j] = i
            if self.height[i] == self.height[j]:
                self.height[i] += 1
                
    def same(self, i, j):
        return self.find(i) == self.find(j)
       
    def size(self, i):
        """
        Get num of nodes in same group
        Args : 
            i : index of target node.
        """
        ans = 1
        root = self.find(i)
        for node_idx, par in enumerate(self.group):
            if root == par:
                ans += 1
            elif root == self.find(par):
                ans += 1
            else:
                pass
        ans -= 1
        return ans
        
        

import numpy as np
N, M = map(int, input().split())
ab_list = [list(map(int, input().split())) for _ in range(M)]

a_list = np.array(ab_list)[:, 0][-1::-1]
b_list = np.array(ab_list)[:, 1][-1::-1]

huben = [int((N) * (N-1) * 0.5)]
        
uni = UnionFind(size=N+1)

for a, b in zip(a_list, b_list):
    a_group = uni.find(a)
    b_group = uni.find(b)
    
    if a_group == b_group:
        huben.append(huben[-1])
    else:
        huben.append(huben[-1] - uni.size(a)*uni.size(b))
        uni.unite(a, b)
        
for ans in huben[-1::-1][1:]:
    print(ans)

という感じなったが、結局RTEで通らない。。。 pythonで通した人のコードを見せていただきたいものです。

==追記== ルートノードのインデックスでそのグループに属する要素数を持てばsize()がO(1)で行える。そうすれば間に合う

class UnionFind():
    def __init__(self, size):
        # self.parent holds a list of idx for parant node of i-th node
        # if i-th node is a root, self.parent[i] holds the - (number of nodes) in the parent.
        # if self.parent[i] = j, j-th node is a paranet of i-th node. 
        self.parent = [-1 for _ in range(size)]
        
        # self.height holds a list of trees' heights
        self.height = [0] * (size)
        
    def find(self, i):
        """
        Search for i-th node's parent.
        Args : 
            i : index of target node.
        Return :
            index of root of target node
        """ 
        
        if self.parent[i] < 0:
            # root node
            return i
        else:
            # if not root, search by the parent's node
            self.parent[i] = self.find(self.parent[i])
            return self.parent[i]     
        
    def unite(self, i_, j_):
        """
        Unite two parents
        Args : 
            i, j : indices of target nodes.
        """ 
        # search roots
        i = self.find(i_)
        j = self.find(j_)
        
        if self.height[i] < self.height[j]:
            self.parent[j] += self.parent[i]
            self.parent[i] = j_
            
        else:
            self.parent[i] += self.parent[j]
            self.parent[j] = i_
            if self.height[i] == self.height[j]:
                self.height[i] += 1
       
    def size(self, i):
        """
        Get num of nodes in same parent
        Args : 
            i : index of target node.
        """
        return - self.parent[self.find(i)]
        
        

import numpy as np
N, M = map(int, input().split())
ab_list = [list(map(int, input().split())) for _ in range(M)]

a_list = np.array(ab_list)[:, 0][-1::-1]
b_list = np.array(ab_list)[:, 1][-1::-1]

huben = [int((N) * (N-1) * 0.5)]
        
uni = UnionFind(size=N+1)

for a, b in zip(a_list, b_list):
    a_group = uni.find(a)
    b_group = uni.find(b)
    
    if a_group == b_group:
        huben.append(huben[-1])
    else:
        huben.append(huben[-1] - uni.size(a)*uni.size(b))
        uni.unite(a, b)
        
for ans in huben[-1::-1][1:]:
    print(ans)

他にもわかりやすい記事がありました Python:Union-Find木について - あっとのTECH LOG