Atcoder Beginner Contest 120 (Union-Find Tree)
AtCoder Beginner Contest 120のふりかえり
結果はA、B、Cの三完。Cに時間をかけすぎてスコアがやばいことになった
C問題
0
と1
がそれぞれ1つ以上あったら、どこかで0
と1
が隣り合っているのだから結局S中の0
の総数と1
の総数の少ない方回取り除ける。
それに気づかずごちゃごちゃやってしまった。
S = input() save = [] flag = False modori_flag = False count = 0 for idx, s in enumerate(S): if modori_flag: modori_flag = False elif flag: flag = False continue this_ = int(s) if len(save) > 0 and this_ + save[-1] == 1: modori_flag = True del save[-1] count += 2 continue if idx == len(S) -1: break next_ = int(S[idx+1]) if this_ + next_ == 1: flag = True count += 2 continue save.append(this_) print(count)
D問題
やったこと
まっさらな状態から橋を1本ずつかけていくという発想まではいけたが、それ以上は時間切れ
解法
キーポイントは
i + 1 番目の辺 ei+1 = (ai+1, bi+1) によって 1. 元々頂点 ai+1, bi+1 が連結でなかった (互いに行き来可能でなかった) 場合 辺 e を加える前のグラフにおいて、頂点 ai+1 を含む連結成分の大きさを N1、頂点 bi+1 を含む連結成分の大きさを N2 とすると、 ans(i) = ans(i + 1) − N1 × N2 です。 2. 元々連結だった場合 不便さに変化はなく、ans(i) = ans(i + 1) です
ここの部分だと思う。 N1とN2を高速に求めるために、Union-Findというデータ構造を利用する。
Union-Findはグループ分けをうまく扱えるデータ構造で、
unite(u, v): 頂点 u が属するグループと頂点 v が属するグループを併合し、同じグループにする
(O(α(n)))
find(v): 頂点 v が属するグループ番号を得る (O(α(n)))
size(v): 頂点 v が属するグループと同じグループに属する頂点数を得る (O(1))
という操作を定義する。
www.slideshare.net (解説スライドがあったので神) (こういう頻出のデータ構造を知っておきたい感がある...と思ったらあった↓) プログラミングコンテストでのデータ構造
実装上は木構造を用いることができる。
class UnionFind(): def __init__(self, size): # self.group holds a list of idx for parant node of i-th node # if self.group[i] = i, i-th node is a root. # if self.group[i] = j, j-th node is a paranet of i-th node. self.group = [i for i in range(size)] # self.height holds a list of trees' heights self.height = [0] * (size) def find(self, i): """ Search for i-th node's group. Args : i : index of target node. Return : index of root of target node """ if self.group[i] == i: # if root return i else: # if not root, search by the parent's node self.group[i] = self.find(self.group[i]) return self.group[i] def unite(self, i, j): """ Unite two groups Args : i, j : indices of target nodes. """ # search roots i = self.find(i) j = self.find(j) if i == j: return if self.height[i] < self.height[j]: self.group[i] = j else: self.group[j] = i if self.height[i] == self.height[j]: self.height[i] += 1 def same(self, i, j): return self.find(i) == self.find(j) def size(self, i): """ Get num of nodes in same group Args : i : index of target node. """ ans = 1 root = self.find(i) for node_idx, par in enumerate(self.group): if root == par: ans += 1 elif root == self.find(par): ans += 1 else: pass ans -= 1 return ans import numpy as np N, M = map(int, input().split()) ab_list = [list(map(int, input().split())) for _ in range(M)] a_list = np.array(ab_list)[:, 0][-1::-1] b_list = np.array(ab_list)[:, 1][-1::-1] huben = [int((N) * (N-1) * 0.5)] uni = UnionFind(size=N+1) for a, b in zip(a_list, b_list): a_group = uni.find(a) b_group = uni.find(b) if a_group == b_group: huben.append(huben[-1]) else: huben.append(huben[-1] - uni.size(a)*uni.size(b)) uni.unite(a, b) for ans in huben[-1::-1][1:]: print(ans)
という感じなったが、結局RTEで通らない。。。 pythonで通した人のコードを見せていただきたいものです。
==追記== ルートノードのインデックスでそのグループに属する要素数を持てばsize()がO(1)で行える。そうすれば間に合う
class UnionFind(): def __init__(self, size): # self.parent holds a list of idx for parant node of i-th node # if i-th node is a root, self.parent[i] holds the - (number of nodes) in the parent. # if self.parent[i] = j, j-th node is a paranet of i-th node. self.parent = [-1 for _ in range(size)] # self.height holds a list of trees' heights self.height = [0] * (size) def find(self, i): """ Search for i-th node's parent. Args : i : index of target node. Return : index of root of target node """ if self.parent[i] < 0: # root node return i else: # if not root, search by the parent's node self.parent[i] = self.find(self.parent[i]) return self.parent[i] def unite(self, i_, j_): """ Unite two parents Args : i, j : indices of target nodes. """ # search roots i = self.find(i_) j = self.find(j_) if self.height[i] < self.height[j]: self.parent[j] += self.parent[i] self.parent[i] = j_ else: self.parent[i] += self.parent[j] self.parent[j] = i_ if self.height[i] == self.height[j]: self.height[i] += 1 def size(self, i): """ Get num of nodes in same parent Args : i : index of target node. """ return - self.parent[self.find(i)] import numpy as np N, M = map(int, input().split()) ab_list = [list(map(int, input().split())) for _ in range(M)] a_list = np.array(ab_list)[:, 0][-1::-1] b_list = np.array(ab_list)[:, 1][-1::-1] huben = [int((N) * (N-1) * 0.5)] uni = UnionFind(size=N+1) for a, b in zip(a_list, b_list): a_group = uni.find(a) b_group = uni.find(b) if a_group == b_group: huben.append(huben[-1]) else: huben.append(huben[-1] - uni.size(a)*uni.size(b)) uni.unite(a, b) for ans in huben[-1::-1][1:]: print(ans)
他にもわかりやすい記事がありました Python:Union-Find木について - あっとのTECH LOG