Amplify GraphQLで[Int]型のフィールドでContainsオペレーターを利用する

結論

AWS Appsyncのスキーマで該当する型(ModelIntInput)を手動でアップデートし、contains: Intを追加する。

はじめに

AWS AmplifyではデータベースとしてGraphQLを利用することができ、ユーザはGraphQLスキーマを定義するだけでAppSyncやDynamoDBの設定をすることなく利用することができます。 しかし、クエリのfilterで、[Int]型のフィールドに対して、本来利用ができるはずのcontainsオペレータが利用ができず苦しんだので備忘録を残しておきます。

Version

"aws-amplify": "^5.3.3"

Example schema

以下のようなUserテーブルがschema.graphqlとして定義されているとします。

type User @model @auth(rules: [{ allow: public }]) {
  id: ID!
  favoriteNumbers: [Int]!
}

ここで、

  • id: ユーザに一意に付与されるID
  • favoriteNumbers: ユーザが好きな数字のリスト

とします。 こちらのスキーマ定義をamplify push apiすると、User tableに対するqueryやmutationを実行することができます。

Contains operator

ここで、User tableからUserのリストを取得することを考えます。 GraphQLクエリは以下のようになるでしょう。

const listUsers = /* GraphQL */ `
  query ListUsers($filter: ModelUserFilterInput, $limit: Int, $nextToken: String) {
    listUsers(filter: $filter, limit: $limit, nextToken: $nextToken) {
      items {
        id
        favoriteNumbers
      }
    }
  }
`;

ここで、「ある特定の数字が好きなユーザのリストを取得する」ことを考えます。すると、各レコードのfavoriteNumbersに格納されている[Int]型の値の中に特定のInt値が含まれているかどうかを確認する必要があります。

AmplifyでGraphQLを利用する時には裏側でDynamoDBを利用することになりますので、DynamoDBでこのような演算子が利用できるかを確認します。

Comparison operator and function reference - Amazon DynamoDB

こちらに、

A List that contains a particular element within the list.

とあることから、containsというオペレータを利用することで、[Int]に特定の要素が含まれているレコードを取得することができそうです。 しかし、上記のlistUsersを用いてクエリを実行しても、うまくfliterが効かないです。 以下は、ある特定の数字numをお気に入りにしているユーザの一覧を取得するための関数fetchUsersです。

import { API } from "aws-amplify";

const listUsers = /* GraphQL */ `
  query ListUsers($filter: ModelUserFilterInput, $limit: Int, $nextToken: String) {
    listUsers(filter: $filter, limit: $limit, nextToken: $nextToken) {
      items {
        id
        favoriteNumbers
      }
    }
  }
`;

const fetchUsers = async (num)=> {
  const variables = {
    filter: { favoriteNumbers: { contains: num } },
  };
  const response = await API.graphql({ query: listUsers, variables: variables });
}

AppSync

AppSyncは、AWSフルマネージドサーバーレス GraphQL API サーバーです。 該当するAPIの「スキーマ」を確認し、フィールド favoriteNumbersに対応するスキーマを確認すると

input ModelIntInput {
    ne: Int
    eq: Int
    le: Int
    lt: Int
    ge: Int
    gt: Int
    between: [Int]
    attributeExists: Boolean
    attributeType: ModelAttributeTypes
}

となっていることが分かります。どうやら、ここで利用可能なオペレータを定義しているようです。 ここでcontainsオペレータを手で追加してください。そうすればcontains オペレータを用いたfilterが有効になります。

input ModelIntInput {
    ne: Int
    eq: Int
    le: Int
    lt: Int
    ge: Int
    gt: Int
    between: [Int]
    contains: Int
    attributeExists: Boolean
    attributeType: ModelAttributeTypes
}

注意

このAppSyncスキーマamplify push apiを実行するたびにリセットがされるため、containsオペレータを利用するためには毎回手動でAppSyncのスキーマを編集しなければいけません。 なお、Int型以外のString型に対しては自動でcontainsオペレータが追加されますので上記の操作は必要ありません。

「良いコード悪いコードで学ぶ設計入門」メモ

これ何

良いコード悪いコードで学ぶ設計入門を読んで大事だなと思ったことを忘れないようにメモしておく。

良いコード/悪いコードで学ぶ設計入門 ―保守しやすい 成長し続けるコードの書き方:書籍案内|技術評論社

個人的にはリーダブルコードよりも読みやすく、全てのコードを書く人におすすめできる本だと思いました。

各章ごと

3:クラス設計

  • クラスは単体で作動するように設計する
    • 自己防衛責務がある
    • ベストプラクティス
      • コンストラクタで正常に全ての初期値を設定すること。
        • 不正な値を許さないバリデーションも備えておくべき
      • 計算ロジックをクラスのメソッドとして提供しておく
        • 凝縮性が上がってメンテナンスしやすい
      • インスタンス変数を上書きできないようにする(できれば)
        • finalのような式があれば利用する
      • インスタンスを更新したい場合には、上書きするのではなく、新しいインスタンスを作成する
      • 静的型付け言語であれば、型によるバリデーションも提供できる
        • intやstringといったプリミティブ型ではなく、独自の型を定義しておくほうがバグを減らせる
    • よい設計パターン
      • 完全constructor
      • 値object
        • 値をクラス(type)として表現する方法

大事なこと:データとデータ操作ロジックを一箇所にまとめておくこと。必要な操作だけを後悔すること。

# example: 物体検出問題において正解データとなるbounding boxデータを表現するクラス
class BoundingBox:
    def __init__(self, xmin: int, xmax: int, ymin: int, ymax: int, label: str):

        if not self._are_valid_positions(xmin, xmax, ymin, ymax):
            raise Exception("Invalid positions.")
        if len(label) == 0:
            raise Exception("Empty string is passed as a label name.")

        self.xmin = xmin
        self.xmax = xmax
        self.ymin = ymin
        self.ymax = ymax
        self.label = label
        return

    def _are_valid_positions(self, xmin: int, xmax: int, ymin: int, ymax: int) -> bool:
        # 座標のバリデーション
        if xmin < 0 or ymin < 0:
            return False
        if xmin >= xmax or ymin >= ymax:
            return False
        return True

    def get_bbox_area(self) -> int:
        # bboxの面積を返すメソッド
        # クラスに関連する処理はクラスのメソッドとして提供することで凝縮性を高める
         return (self.xmax - self.xmin) * (self.ymax - self.ymin)

    def slide_bbox(self, stride_x: int, stride_y: int):
        # bboxを移動するメソッド
        # 副作用を防ぐために、クラス変数を上書きするのではなく新しいインスタンスを作成する
        return BoundingBox(
            self.xmin + stride_x,
            self.xmax + stride_x,
            self.ymin + stride_y,
            self.ymax + stride_y,
            self.label
        )

4:不変の活用

  • 再代入をできるだけ許さない
  • 可変であることで生じること
    • 可変インスタンスがあちこちで変更されて収集がつかなくなる
    • 副作用のデメリット
      • 主作用と副作用
    • 関数の影響範囲を限定することが重要。引数で状態を受け取り、状態変更なしに値を返す関数が理想。
      • データを引数で受け取る
      • 状態を変更しない
      • 値は関数の戻り値で返却する

7: コレクション

  • 言語がコレクションに対する処理を提供している場合は自前で実装しない。
  • ループ中に条件分岐が多くある場合には早期continueやbreakを活用する
  • コレクションに関する実装が散らばってくるときにはFirst class collectionを検討する(カプセル化)
    • コレクション型のインスタンス変数と、それらを不正状態から防御し正常に制御するためのメソッドを提供する。
@dataclass
class Member:
    name: str
    hp: int


# PartyはFirst Class Collectionとして機能する
class Party:
    def __init__(self, members: List[Member]):
        self.members = members

    def add_member(self, new_member: Member) -> Party:
        # self.membersを上書きするのではなく、新しいPartyインスタンスを返却する
        new_members = copy.deepcopy(self.members)
        new_members.append(new_member)
        return Party(members=new_members)

8:密結合

  • 単一責任の原則(クラスが担う責任はただ一つにするべき)を強く意識しておくことが疎結合性を担保するコツ
  • 重複コードを恐れない。ほとんど同じコードであっても責務や概念が異なる場合にはコードを分割することは悪いことではない。

12:メソッド

  • 他のクラスのインスタンス変数を変更しない。 変更するのは自身のインスタンス変数のみに絞る。
  • コマンド・クエリ分離(Command-Query Separation:CQS)の原則を守る
    • 状態の取得及び状態の変更のどちらかを責務とするメソッドにする。どちらも同時に行わない。
  • 引数が多くなりそうなら別クラスにまとめることを考える
    • 概念ごとにクラスに分割すれば引数が多くなりすぎることがなくなるはず
  • 戻り値
    • プリミティブ型で返すよりも独自の型を定義して返す方が安全
    • エラーは戻り値で返すのではなく例外をthrowする
      • 負数などでエラーを表現しない

14:リファクタリング

  • リファクタリングの流れ
    • ネストの解消
      • 条件の反転や早期returnでネストを解消できる
    • ロジックの入れ替え
    • 条件を読みやすくする
    • ベタガキの条件文を、目的を表すメソッドでまとめる
  • ユニットテストでバグを防ぐ
  • リファクタリングの注意点
    • 機能追加とリファクタリングを同時にやらない。レビューも実装も訳がわからなくなる
    • スモールステップで実施する

Unbiased learning to rank: introduction

The Web Conference 2020でのチュートリアルUnbiased Learning to Rank: Counterfactual and Online Approaches

をCounterfactual LTRまで読んで、メモがてら日本語でスライドを作ったので置いておきます。 speakerdeck.com

ヒルベルト空間における凸射影定理・直交射影定理

ヒルベルト空間とは

完備性

点列の収束先が自身に含まれる空間 のことです.

つまり,ある空間Xが完備である,とは自身の中にある任意のコーシー列 {\bf x}_n \in X \:\: (n=1,2,\cdots)の収束先 {\bf x}について, {\bf x} \in Xが成り立つことです.

内積空間

内積を備えた空間内積空間と呼びます. 内積 - Wikipedia

内積空間は内積から誘導されるノルムを持つノルム空間でもあり,そのノルムを距離関数として採用した距離空間でもあります.

ヒルベルト空間

ヒルベルト空間は完備な内積空間です.

完備なノルム空間はバナッハ空間と呼ばれ,完備な距離空間は単に完備距離空間と呼ばれます.

凸射影定理

ヒルベルト空間 \mathcal{H}の空でない閉凸集合 C \subset \mathcal{H}について考えます.この時,Cの中で,点{\bf x} \in \mathcal{H}にもっとも近い点は唯一存在するというのが凸射影定理です.

主張1

 \mathcal{H}の任意の点{\bf x} \in \mathcal{H}に対して,


d\left( {\bf x}, C \right) := \inf_{{\bf y} \in C}{ \| {\bf x} - {\bf y} \|}
            =  \| {\bf x} - P_{C}\left({\bf x} \right) \| = \min_{y \in C} \| {\bf x} - {\bf y} \|

を満たす唯一の点 P_C \left( {\bf x} \right) \in C が存在する. P_{C}:\mathcal{H} \rightarrow CCへの距離射影・凸射影・または単に射影と呼ばれます.

主張2

また, {\bf x}^{*} \in Cに対して,


{\bf x}^{*} = P_{C}\left( {\bf x} \right) \iff 
            \left< {\bf x} - {\bf x}^{*}, {\bf y} - {\bf x}^{*} \right> \leq 0 \:\:\: \left( \forall{{\bf y}} \in C \right)

が成立する.

P_{C}\left( {\bf x} \right)Cの中で{\bf x}をもっともよく近似する点として,最良近似点と呼ばれる.

直交射影定理

ヒルベルト空間 \mathcal{H}の空でない閉部分空間 M \subset \mathcal{H}について考えます.凸集合ならば部分空間でもあるので,凸射影定理は Mについても成り立つことに注意してください.

この時,Mの中で,点{\bf x} \in \mathcal{H}にもっとも近い点は唯一存在し,そのような点と {\bf x}を結んだベクトルは Mに直交する,というのが直交射影定理です.

主張1

\forall{{\bf x}} \in \mathcal{H}に対して,


            \inf_{{\bf y} \in M}{ \| {\bf x} - {\bf y} \|}
            =  \| {\bf x} - P_{M}\left({\bf x} \right) \| = \min_{y \in M} \| {\bf x} - {\bf y} \|

を満たす唯一の点P_{M} \left( {\bf x} \right) \in Mが存在する.

主張2

また,{\bf x}^{*} \in Mに対して,


            {\bf x}^{*} = P_{M}\left( {\bf x} \right) \iff 
            \left< {\bf x} - {\bf x}^{*}, {\bf y} \right> = 0 \:\:\: \left( \forall{{\bf y}} \in M \right)

が成立する.P_{M}:\mathcal{H} \rightarrow MMへの距離射影であるが,この性質から特別に直交射影と呼ばれる.

主張3

直交射影について,以下の一般化されたピタゴラスの定理が成立する.

            \|{\bf x} \|^{2} = \|{\bf x} - P_{M}\left( {\bf x} \right) \|^{2} + \|P_{M}\left( {\bf x} \right) \|^{2}

直交分解

ここで,ヒルベルト空間の任意の点が,ある閉部分空間の1点と,その直交補空間の1点の和の形に一意に分解することができる,ことを紹介します. 以下では柄部分空間をMとおきます.

直交補空間

Mの直交補空間としてM^{\perp}

  
M^{\perp} := \{ {\bf x} \in \mathcal{H} | \left<{\bf x},y \right> = 0 \;\; \left( \forall{y} \in M \right) \}

のように定義すると,M^{\perp}も閉部分空間となり,かつM \cap M^{\perp} = \{ {\bf 0} \}となります.

直交分解の一意性

 \forall{{\bf x}} \in\mathcal{H}は,

              {\bf x} = {\bf x}_1 + {\bf x}_2 \:\: \left( {\bf x}_1 \in M, \; {\bf x}_2 \in M^{\perp}  \right)

のように一意に分解可能です.

Adversarial validation as density ratio estimation

Adversarial validation

Adversarial validation is a technique mainly used in Kaggle.

Kaggle: Your Home for Data Science

When the distribution of train set and that of public test set differs, to make validation set randomly from train set leads to low correlation between public leaderboard and local validation. Adversarial validation chooses data from train set that has high density in public test set, and makes nice validation set that has high correlation to public test set.

Steps of adversarial validation are like following :

  1. Give pseudo negative labels -1 to train set labels and give pseudo positive labels +1 to test set.
  2. Train any probabilistic binary classifier which discriminates train set and test set using cross-validation.
  3. Infer train set using above classifier, employ top-N data which have high score for class +1 as validation set.

(Ref)

Density ratio

Here, "density ratio" : r(x) is introduced.

density ratio \displaystyle{
r(x) = \frac{q(x)}{p(x)}
}

Density ratio is rewritten using Bayes' rule as :

\displaystyle{
r(x) = \frac{q(x)}{p(x)} = \frac{p (x | y = +1)} {p (x | y = -1) } = \frac {p(y=-1)p (x , y = +1)} {p(y=+1)p (x , y = -1)}  = \frac {p(y=-1)p (y = +1|x)} {p(y=+1)p (y = -1|x)}
}

p (y = -1|x) and p (y = +1|x) can be approximated by a binary classifier that discriminates train set and test set. Also, p(y=-1) and p(y=+1) are constant value, and can be approximated by the ratio between the number of train data and that of test data.

Thus, density ratio is :

density ratio \displaystyle{
r(x) =\frac{N}{M}  \frac{f(x)}{1-f(x)}
}

where N and M is the number of train data and that of test data. f(x) is the probability for class +1 (target set) of data x that is outputted by the classifier.

Importantly, density ratio estimation can be boiled down to a binary classification problem. That is the same process as the adversarial validation. (Remind that we make a binary classifier between train set and test set in adversarial validation!)

(Introduction to density ratio estimation) Density Ratio Estimation for KL Divergence Minimization between Implicit Distributions | Louis Tiao

Density ratior(x) is the monotonically increasing function w.r.t. f(x). For simplicity, N=M is assumed in the following discussion.

Shape of density ratio is like :

f:id:daiki-yosky:20191028082734p:plain
Horizontal axis:f(x), Vertical axis:density ratio

Relationship between adversarial Validation and density ratio

Shown above, adversarial validation chooses data from train data that have high density ratio value.

Data which have f(x) > 0.5 are high density in test data and low density in train set. Such data is so useful for validation because they are rare in train set, while frequently appeared in test data.

Data which have around f(x) = 0.5 are difficult to classify into train set and test set. These are also useful for validation data.

Data which have 0.5 > f(x) are high density in train data and low density in test set. It is dangerous to use this kind of data for validation data because they are rare in test set.

Conclusion

In this article, I introduce the relationship between adversarial validation technique and density ratio estimation. Adversarial validation is theoretically equivalent to density ratio estimation between train set and test set.

密度比推定としてのAdversarial Validation

Link to English page


Adversarial Validationとは

Adversarial Validationとは、主にKaggleの文脈で使われる言葉です。 TrainデータとTestデータの分布が大きく異なる時に、出来るだけTestデータでの密度が高いデータをTrainデータから選んでValidationデータとして使うための技術です。(適当にValidationを作るとPublic LBと相関が低いValidationデータを作ることになってしまう)

手順としては、

  1. trainデータに仮のラベル-1、testデータに仮のラベル+1をつける。
  2. trainデータとtestデータを分類する2値分類器を訓練する。(cross-validation等で)
  3. 先ほどの分類器にtrainデータを推論させ、trainデータの中でtestデータである確率が高いtopN件をValidationデータとして採用する。

というようなものです。

(参考)

密度比(density ratio)

ここで、次のような式で定義される密度比(density ratio) : r(x)を導入することにします。

density ratio \displaystyle{
r(x) = \frac{q(x)}{p(x)}
}

また、密度比は、ベイズの定理を使って以下のように変形できます。

\displaystyle{
r(x) = \frac{q(x)}{p(x)} = \frac{p (x | y = +1)} {p (x | y = -1) } = \frac {p(y=-1)p (x , y = +1)} {p(y=+1)p (x , y = -1)}  = \frac {p(y=-1)p (y = +1|x)} {p(y=+1)p (y = -1|x)}
}

ここで、trainデータに仮のラベル-1、testデータに仮のラベル+1を振れば、 p (y = -1|x)p (y = +1|x)はちょうどtrainデータとtestデータを分類する2値分類器によって近似することができます。 また、p(y=-1)p(y=+1)は定数で、それぞれtrainデータの数とtestデータの数の割合で近似できます。

結局、密度比は

density ratio \displaystyle{
r(x) =\frac{N}{M}  \frac{f(x)}{1-f(x)}
}

となります。 NとMはそれぞれtrainデータの数、testデータの数です。また、f(x)はtrainデータとtestデータを分類する2値分類器が出力する、データxのクラス+1(targetデータ)の確率です。

ここで重要なのは、密度比推定は結局2値分類問題として帰着できることです。

Density Ratio Estimation for KL Divergence Minimization between Implicit Distributions | Louis Tiao (わかりやすい説明がのっています)

密度比r(x)f(x)に対して単調増加関数で、f(x)=0.5\frac{N}{M}の値をとります。以下では簡単のためN=Mを想定します。密度比は以下のような形をしています。

f:id:daiki-yosky:20191028082734p:plain
横軸:f(x)、縦軸:密度比

Adversarial Validationと密度比の関係

よって、Adversarial Validationが行なっているのは、密度比が高い点をtrainデータから抽出していることになります。

確かに、f(x) > 0.5のようなデータはtestデータの密度が高く、trainデータの密度が低いデータで、「trainデータにはあまり現れないが、testデータにはよく現れるデータ」であると言え、Validationデータとしての価値はかなり高いです。

そして、f(x) = 0.5付近の点はモデルがtrainデータ由来なのかtestデータ由来なのかを見分けることができなかったデータであり、こちらもValidationデータとしての価値があると思われます。

逆に、 0.5 > f(x)のようなデータは「trainデータにはよく現れるが、testデータにはあまり現れないデータ」であり、これをValidationデータとして利用するのは危険です。

Atcoder beginner contest125 [python]

C:GCD on Blackboard

atcoder.jp


数列があって、ある1つの要素を抜いた時のGCDの最大値を求める問題。
愚直に1個ずつ抜いて計算してTLEでした。小手先の高速化はしてみましたが、TLE。

解法

ある値 x_iを抜く時のGCDは、x_iよりも左のGCDとx_iより右のGCDのGCDを取れば良い。
全てのiについてそれらを事前に計算しておくことで、高速に最大GCDを求めることができる。

import numpy as np

def gcd(a,b):    
    if a < b:
        a, b = b, a
    if b == 0:
        return a
    c = a % b
    return gcd(b, c)
    
N = int(input())
A_list = list(map(int, input().split()))

# =====
# initialize L and R
# =====
L = np.zeros(N+1)
R = np.zeros(N+1)

for idx, element in enumerate(A_list):
    if L[idx] == 0:
        L[idx+1] = element
    else:
        L[idx+1] = gcd(L[idx], element)

for idx, element in enumerate(A_list[-1::-1]):
    idx = N-idx
    if R[idx] == 0:
        R[idx-1] = element
    else:
        R[idx-1] = gcd(R[idx], element)

ans = 0
for i in range(N):
    if L[i] == 0:
        M = R[i+1]
    elif R[i+1] == 0:
        M = L[i]
    else:
        M = gcd(L[i], R[i+1])
    if ans < M:
        ans = M

print(int(ans))

D : Flipping Signs

atcoder.jp

練習のため、dpで解く。
i = 0,1,...,Nに対して、dp[i][j]を、i番目の要素の正負を確定した上で、i番目の要素をj=1の時はそのまま(つまりi+1番目の要素もそのまま)、j=0の時は反転させた(つまりi+1番目の要素も反転)とした時の最大値とします。

こういうdpの定義をスパッとできるようになりたい。

dpテーブルの初期条件は、

  • dp[0][1] = -inf (0番目の要素と1番目の要素を反転させる操作に相当し、1番左の要素を反転させることはできないため、有効な操作ではありません。)
  • dp[0][0] = 0

更新は、
https://img.atcoder.jp/abc125/editorial.pdf
の通り。

最終的にほしいのは、dp[N][0]です。

import numpy as np
N = int(input())
A_list = list(map(int, input().split()))

dp = np.zeros((len(A_list)+1, 2))
dp[0][1] = - 10**5

for idx, a in enumerate(A_list):
    dp[idx+1][0] = max(dp[idx][0] + a, dp[idx][1] - a)
    dp[idx+1][1] = max(dp[idx][0] - a, dp[idx][1] + a)
print(int(dp[N][0]))

GANからWasserstein GANへ

generative adversarial network(GAN)からWasserstein generative adversarial network(WGAN)への道の整理をします。 こちらを参考にしました:

目次

Kullback–Leibler Divergence (KL divergence) と Jensen–Shannon Divergence (JS divergence)

まず、確率密度関数の類似度をはかる2つの指標を導入します。

Kullback–Leibler Divergence

2つの確率密度関数p(x)q(x)を考えます。KL divergenceはpがqからどれだけ異なるか、をはかる指標です。

KL Divergence \displaystyle{
D _ {KL}(p||q) = \int _ x p(x) \log \frac{p(x)}{q(x)} dx
}

KL divergenceの性質

  • KL divergenceはpとqに対して非対称(D _ {KL}(p||q) \neq D _ {KL}(q||p))です。すなわち、距離として使うことはできません。
  • p(x)がほぼ0で、q(x)が0でない場所ではqの影響が無視されます。

Jensen–Shannon Divergence

JS divergenceは2つの確率密度関数の類似度をはかるもう一つの指標です。また、範囲は[0,1]です。

JS Divergence \displaystyle{
D _ {JS}(p||q) = \frac{1}{2} D _ {KL} \left( p||\frac{p+q}{2} \right) +  \frac{1}{2} D _ {KL} \left( q||\frac{p+q}{2} \right)
}

JS divergenceはpとqに関して対称です。GANではこちらのJS divergenceによって p _ {generator}(x)p _ {data}(x)の類似度を測ります。

GAN

GANは、現実のデータ集合が与えられたとき、それらに似たデータを生成することを目指します。

GANは2つのモデルからできています。

  • Discriminator(D) : Discriminatorは入力されたサンプルが現実のデータかどうかを識別する、2値分類器です。
  • Generator (G) : Generatorはノイズz \sim p(z)を入力として受け取り、人工的なデータを出力します。その際、現実のデータの分布と似た分布を学習します。つまり、Discriminatorを騙すような(人工的なデータではあるが、現実のデータだと識別させるような)データを生成することを目指します。

これらの2つのモデルが互いを見抜く・騙すように訓練されて、十分学習が進めばGeneratorが現実のデータと見分けがつかないようなデータを生成できるようになる、というわけです。欲しいのは良いGeneratorです。

f:id:daiki-yosky:20190424180631p:plain
GANの概観

ここで、

  • p _ z : ノイズzの分布(一様分布を使うことが多いです)
  • p _ g : Generatorが生成するデータの分布
  • p _ r : 現実のデータの分布
  • D(x) : Discriminatorが、入力されたデータxを実際のデータだと判断する確率
  • G(x) : Generatorが、入力されたノイズzから生成するデータ

とします。

GANの目的関数

まず、Discriminatorは現実のデータを正しく本物だと識別してほしいです。つまり、 $$ \mathbb{E} _ {x \sim p _ r(x)} \left[ \log D(x) \right] $$ を最大化したいです。一方で、Generatorが生成したデータ G(z)を正しく偽物だと識別して欲しいので、 $$ \mathbb{E} _ {z \sim p _ z(z)} \left[ \log \left( 1 - D(G(z) \right) \right] $$ を最大化して欲しいです。

次に、Generatorに関しては生成したデータをDiscriminatorが本物だと誤分類させたいので、 $$ \mathbb{E} _ {z \sim p _ z(z)} \left[ \log \left( 1 - D(G(z) \right) \right] $$ を最小化したいです。

これらを組み合わせると、以下のようなmin-max lossになります。

Loss of GAN $$ \min_G \max_D L(D,G) = \mathbb{E} _ {x \sim p _ r(x)} \left[ \log D(x) \right] + \mathbb{E} _ {z \sim p _ z(z)} \left[ \log \left( 1 - D(G(z) \right) \right] \\ = \mathbb{E} _ {x \sim p _ r(x)} \left[ \log D(x) \right] + \mathbb{E} _ {x \sim p _ g(x)} \left[ \log \left( 1 - D(x) \right) \right] \tag{1}$$

密度比推定との関連

Discriminatorの学習は、密度比推定と深い関係があります。密度比とは、2つの確率密度関数( p _ r(x)p _ g(x))の比で、

$$ r(x) = \frac{p _ r(x)}{p _ g(x)} $$

です。

密度比を推定する方法

現実のデータ集合に仮にラベル+1を割り当て、Generatorが生成したデータに仮にラベル-1を割り当てることにします。 この時、ラベルがgivenという条件下のもとでデータの分布を表すことができて、

$$ p _ r(x) = p (x | y = +1) $$ $$ p _ g(x) = p (x | y = -1) $$

です。

密度比は、ベイズの定理から

\displaystyle{
r(x) = \frac{p _ r(x)}{p _ g(x)} \\
= \frac{p (x | y = +1)} {p _ g(x) = p (x | y = -1) } \\
= \frac {p(y=-1)p (x , y = +1)} {p(y=+1)p (x , y = -1)}  \\
= \frac {p(y=-1)p (y = +1|x)} {p(y=+1)p (y = -1|x)}
}

となります。p (y = +1|x)p (y = -1|x)は任意の2値分類器で求めることができて、それはまさにDiscriminatorです。\frac{p(y=-1)}{p(y=-1)}はデータ数の比で近似出来ます。 Discriminatorの損失にはBinary Cross Entropyを用いればよくて、それを変形すると(1)の目的関数になります。 つまり、結果としてDiscriminatorの学習はp _ r(x)p _ g(x)の密度比を推定するように行われることになります。


似たようなことは以下の論文にも記述されています。

[1610.02920] Generative Adversarial Nets from a Density Ratio Estimation Perspective

こちらはDiscriminatorが密度比推定を行なっていることに注目し、f-divergenceを最小化するGANを提案しています。

Discriminatorの最適解

先ほどの目的関数を最大化するDiscriminatorの最適解をまず求めてみます。 L(G,D)は期待値の部分を書き直せば

$$ L(G,D) = \int \left( p _ {r} (x) \log(D(x)) + p _ {g}(x) \log(1 - D(x)) \right) dx $$

とかけます。今我々の興味はL(G,Dを最大化するようなD(x)なので、

$$ \hat{x}=D(x), A = p _ r(x), B = p _ {g}(x) $$とおきます。

すると、

$$ f(\hat{x}) = A\log \hat{x} + B \log (1- \hat{x}) $$ とかけて、\hat{x}について微分すれば

$$ \frac{d f(\hat{x})}{d\hat{x}} = \frac{A-(A+B)\hat{x}} {\hat{x} (1- \hat{x})} $$ となります。これを0とおくと、最適なD(x)

$$ D^{\ast}(x) = \frac{A}{A+B} = \frac{ p _ r(x)} { p _ r(x)+p _ {g}(x)} $$

になります。

さらに、Generatorが最適に学習すれば、p _ gp _ rに近しいものになり、p _ g = p _ rのような状況では D^{\ast}(x) = \frac{1}{2} になります。これは、完璧なGeneratorができれば、Discriminatorはもはや機能しなくなる、ということです。

What is global optimal?

DiscriminatorとGeneratorが最適な学習をするとp _ g = p _ rD^{\ast}(x) = \frac{1}{2}になることは上で確認しました。 この時、GAN のlossは、

$$ L(G^{\ast}, D^{\ast}) = \int \left( p _ {r} (x) \log(D^{\ast}(x)) + p _ {g}(x) \log(1 - D^{\ast}(x)) \right) dx \tag{2}\\ = \log \frac{1}{2} \int p _ {r} (x) dx + \log \frac{1}{2} \int p _ {g} (x) dx = -2 \log 2 $$

なお(2)は  L(G, D^{\ast})に対応します。

GANの目的関数が意味すること

p _ rp _ gの間のJS divergenceは、

\displaystyle{
D _ {JS}(p _ r||p _ g) = \frac{1}{2} D _ {KL} \left( p _ r||\frac{p _ r+p _ g}{2} \right) +  \frac{1}{2} D _ {KL} \left( p _ g||\frac{p _ r+p _ g}{2} \right) \\
= \frac{1}{2} \left(  \int p _ r(x) \log \frac{2p _ r(x)}{p _ r(x)+p _ g(x) }dx    \right)   +   \frac{1}{2} \left(  \int p _ g(x) \log \frac{2p _ g(x)}{p _ r(x)+p _ g(x) }dx    \right) \\
= \frac{1}{2} \left( \int p _ r(x) \log 2 dx +  \int p _ r(x) \log \frac{p _ r(x)}{p _ r(x)+p _ g(x) }dx    \right)   +   \frac{1}{2} \left(   \int p _ g(x) \log 2 dx + \int p _ g(x) \log \frac{p _ g(x)}{p _ r(x)+p _ g(x) }dx    \right) \\
= \frac{1}{2} \left( \log 4 + L(G, D^{\ast}) \right)
}

と変形できて、

 L(G, D^{\ast}) = 2  D _ {JS}(p _ r||p _ g) - 2 \log 2

と表せます。 この式から、Discriminatorが最適である時、GANの目的関数 L(G, D^{\ast})p _ dp _ gの間のJS divergenceを定量化します。なお、Generatorが最適である時、JS divergenceは0になって、 L(G^{\ast}, D^{\ast})=-2\log 2と一致します。

GANの問題点

  • ナッシュ均衡を達成するのが困難

  • low dimensional supports

  • 勾配消失

  • mode collapse

  • 適切な評価指標が存在しない

Wasserstein GAN (WGAN)

Wasserstein distance

Wasserstein distanceとは、JS divergenceと同じように2つの確率密度関数の距離をはかる指標です。Wasserstein distanceはEarth Mover's distanceとも呼ばれ、短くEM distanceと呼ばれることもあります。

Wasserstein distanceは、ある確率密度関数を動かしてもう一つの確率密度関数に一致させるときの最小コストです。 以下では、確率密度を「土」として表現し、「土」の最適な輸送としてWasserstein distanceを考えます。

2つの確率密度関数 p _ r p _ gのWasserstein distanceは以下のように与えられます。

Wasserstein distance $$ W(p _ r, p _ g) = \inf _ {\gamma \sim \Pi(p _ r, p _ g)} \mathbb{E} _ {(x,y) \sim \gamma} \left[ ||x - y || \right] $$

infは下限で、wasserstein distanceを求めること自体が最適化問題になっています。

\gamma(x,y)p _ rある地点xからp _ gある地点yに動かす土の量です。正確には地点xから、全土の量\int p _ r(x) dxのうちどれだけを地点yへ輸送するか、という量です。 土を動かし、 p _ r p _ gに一致させることから、直ちに

\displaystyle{
\sum _ {x} \gamma(x,y) = p _ g(y)
}

が成り立ちます。(地点yへ動かされた土の量をxについて和をとると動かし終わった土の量p _ g(y)と一致するはず)

逆に、

\displaystyle{
\sum _ {y} \gamma(x,y) = p _ r(x)
}

も成り立ちます。(地点xから動かされた土の量をyについて和を取るともともとxにあった土の量p _ r(x)と一致するはず)

土の量に動かす距離||x-y||をかけることでコスト\gamma(x,y)||x - y||を算出します。 全てのx,yについてコストの平均をとると、

\displaystyle{
\sum _ {x,y} \gamma(x,y) ||x - y|| = \mathbb{E} _ {x,y \sim \gamma} ||x-y||
}

候補となる土の動かし方戦略\gammaのうち、総コストがもっとも小さいものをとればwasserstein distanceが求まります。

Wasserstein GAN がJS divergenceとKL divergenceよりも良い理由

確率密度関数が低次元かつ2つの確率密度関数に重なりが場合でもWasserstein distanceはより滑らかな表現を提供してくれます。 例えば、以下のような2つの2次元の確率密度PQを考えます。Pのx成分は0に固定し、y成分は[0,1]の一様分布に従います。一方でQのx成分は\thetaに固定しyは[0,1]の一様分布に従います。

f:id:daiki-yosky:20190424220931p:plain
PとQの概観

\theta \neq 0の時

  • \displaystyle{D _ {KL}(P||Q) = \int _ {x=0, y \sim U(0,1)} P \log \frac{P}{Q} dxdy =  \infty}
  • \displaystyle{D _ {KL}(Q||P) = \int _ {x=\theta, y \sim U(0,1)} Q \log \frac{Q}{P} dxdy =  \infty}
  • \displaystyle{D _ {JS}(P,Q) =  \frac{1}{2} D _ {KL} \left( P||\frac{P+Q}{2} \right) +  \frac{1}{2} D _ {KL} \left( Q||\frac{Q+P}{2} \right)  = \log 2}
  • \displaystyle{ W(Q||P) = |\theta|}

一方\theta = 0の時、PとQはx=0で完全に重なっていて、

  • \displaystyle{D _ {KL}(P||Q) ={D _ {KL}(Q||P)} = {D _ {JS}(P,Q)}} = 0
  • \displaystyle{ W(Q||P) = 0 = |\theta|}

このように、KL divergenceは2つの確率密度に重なりがない場合\inftyに発散してしまいます。 JS divergenceは\theta=0で突然ジャンプし、微分不可能になってしまいます。 Wasserstein distanceは\thetaの変化に対して滑らかで、勾配降下法で学習する場合に安定すると考えられます。

GANの損失としてのWasserstein distance

Wasserstein distanceはKantorovich-Rubinstein双対性を使って、

$$ W(p _ r, p _ g) = \frac{1}{K} \sup _ {||f|| _ {L} \leq K} \mathbb{E} _ {x \sim p _ {r}} [f(x)] - \mathbb{E} _ {x \sim p _ {g}} [f(x)] $$

と変換することができます。

Lipschitz 連続性

Wasserstein distanceのfには、||f|| _ {L} \leq Kという制約がついています。つまりfはK-リプシッツ連続である必要があります。 関数f:\mathbb{R} \to \mathbb{R}は以下の条件を満たす時にK-リプシッツ連続です。

ある定数K \geq 0が存在して、全てのx _ {1}, x _ {2} \in \mathbb{R}に対して、 $$ |f(x _ {1} - f(x _ {2})| \leq K |x _ {1} - x _ {2}| $$ これは直感的には、任意の区間で傾きがある値Kで抑えられるということを意味します。(Kはリプシッツ定数と呼ばれます)

任意の場所で微分可能な関数はリプシッツ連続です。なぜなら\frac{|f(x _ {1} - f(x _ {2})|}{|x _ {1} - x _ {2}|}にはboundが存在するからです。 しかし、リプシッツ連続だからと言って任意の場所で微分可能である訳ではありません。例えば、f(x)=|x|は原点で微分不可能です。

Wasserstein loss

f _ wがパラメータwをもつK-リプシッツ関数とします。Wasserstein GANでは、Discriminatorは良いf _ wを求めます。WGANの損失としては p_r (現実のデータの分布)とp _ g (Generatorが生むデータの分布)間のWasserstein distanceを採用します。つまり、学習が進むにつれてGeneratorは現実のデータの分布に近いデータの分布を出力できるようになります。

Loss of Wasserstein GAN $$ L(p _ {r}, p _ {g}) = W(p _ {r}, p _ {g}) = \max _ {w \in W} \mathbb{E} _ {x \sim p _ {r}} [f _ {w}(x)] - \mathbb{E} _ {z \sim p _ {z}} [f _ {w}(g _ {\theta}(z))] $$

\inf\maxで近似されています。

WGAN全体としては、こちらのLossを最小化することを目指します。

ここで重要なのが、f _ wのK-リプシッツ性を維持する方法です。簡単かつ強力な方法として、重みwを更新した後、wを[-0.01, 0.01]といった小さな範囲でクリップします。 それにより、パラメータ空間Wは小さくなり、f _ wの傾きはboundで抑えられます。WGANの著者らは、clipingよりも良いK-リプシッツ性を維持する方法があるはずだ、とも述べています。

Wasserstein GANの学習

Wasserstein lossのGeneratorのパラメータ  \theta に関する微分は、

$$ \frac{\partial}{\partial \theta} L(p _ {r}, p _ {g}) = \frac{\partial}{\partial \theta} - \mathbb{E} _ {z \sim p _ {z}} [f _ {w}(g _ {\theta}(z))] $$

であり、こちらはサンプル近似によって

$$ \frac{\partial}{\partial \theta} - \mathbb{E} _ {z \sim p _ {z}} [f _ {w}(g _ {\theta}(z))] = \frac{1}{M} \sum_{m=1}^{M} \frac{\partial}{\partial \theta} - f _ {w}(g _ {\theta}(z_m)) $$

と近似できます。 Mはバッチサイズです。

よって、WGAN全体の学習は

  1. Discriminatorのパラメータ{\bf w}に関して、WGANのLossを微分し、Wasserstein distanceの良い近似を求めるように{\bf w}を更新する

  2. Discriminatorのパラメータ{\bf w}のリプシッツ連続性を保つため、クリッピングを行う

  3. Generatorのパラメータ  \theta に関して、Lossを微分し、Wasserstein distanceを小さくするように \theta を更新する

以上を繰り返します。

エクサウィザーズ2019

ExaWizards 2019 - AtCoder

C - Snuke the Wizard

愚直にやってしまいTLEでした..

左に落ちるゴーレムのうち一番みぎにいるやつと、右に落ちるゴーレムのうち一番左にいるやつを求めれば良い。二分探索で求めると早くできる。 らしいけど、コードが全く通らないです、誰かのコードを見たすぎる

N, Q = map(int, input().split())
s = input()
t_list, d_list = [], []
for _ in range(Q):
    t, d = input().split()
    t_list.append(t)
    d_list.append(d)

    
def if_survive(idx):
    """
    Args :
        idx : position of gorem
        max_ : max value of position idx
        min_ : min value of position idx
    Return : 
        +1 : 右側に落ちた時
        -1 : 左側に落ちた時
        0 : 落ちなかった時
    """
    max_ = len(s) - 1
    min_ = 0
    target_moji = s[idx]
    current_position = idx
    
    for t, d in zip(t_list, d_list):
        if t == target_moji:
            if d == "L":
                current_position -= 1
            elif d == "R":
                current_position += 1
                
            if current_position > max_:
                return 1
            elif current_position < min_:
                return -1
            target_moji = s[current_position]

        else:
            pass
    return 0

# 二分探索
def bisection_search_right(x_max, x_min):
    if if_survive(x_max) < 1:
        return x_max + 1
        
    while(True):
        x_mid = int((x_max + x_min)/2)
        max_survive = if_survive(x_max)
        mid_survive = if_survive(x_mid)

        if mid_survive + max_survive == 2:
            x_max = x_mid
        elif mid_survive + max_survive < 2:
            x_min = x_mid
        
        if x_max - x_min == 1:
            return x_max
        
def bisection_search_left(x_max, x_min):
    if if_survive(x_min) > -1:
        return x_min - 1
    
    while(True):
        x_mid = int((x_max + x_min)/2)
                
        min_survive = if_survive(x_min)
        mid_survive = if_survive(x_mid)

        if mid_survive + min_survive == -2:
            x_min = x_mid
        elif mid_survive + min_survive > -2:
            x_max = x_mid
        
        if x_max - x_min == 1:
            return x_min
    
right_idx = bisection_search_right(N-1, 0) 
left_idx = bisection_search_left(N-1, 0) 

print(N - (left_idx+1) - (N-right_idx))

Atcoder Beginner Contest 120 (Union-Find Tree)

AtCoder Beginner Contest 120のふりかえり

atcoder.jp

結果はA、B、Cの三完。Cに時間をかけすぎてスコアがやばいことになった

C問題

01がそれぞれ1つ以上あったら、どこかで01が隣り合っているのだから結局S中の0の総数と1の総数の少ない方回取り除ける。

それに気づかずごちゃごちゃやってしまった。

S = input()

save = []
flag = False
modori_flag = False
count = 0
for idx, s in enumerate(S):
    if modori_flag:
        modori_flag = False
    elif flag:
        flag = False
        continue
            
    this_ = int(s)
    if len(save) > 0 and this_ + save[-1] == 1:
        modori_flag = True
        del save[-1]
        count += 2
        continue
    if idx == len(S) -1:
        break
    next_ = int(S[idx+1])
    if this_ + next_ == 1:
        flag = True
        count += 2
        continue
        
    save.append(this_)
    
print(count)

D問題

やったこと

まっさらな状態から橋を1本ずつかけていくという発想まではいけたが、それ以上は時間切れ

解法

キーポイントは

i + 1 番目の辺 ei+1 = (ai+1, bi+1) によって 1. 元々頂点 ai+1, bi+1 が連結でなかった (互いに行き来可能でなかった) 場合 辺 e を加える前のグラフにおいて、頂点 ai+1 を含む連結成分の大きさを N1、頂点 bi+1 を含む連結成分の大きさを N2 とすると、 ans(i) = ans(i + 1) − N1 × N2 です。 2. 元々連結だった場合 不便さに変化はなく、ans(i) = ans(i + 1) です

ここの部分だと思う。 N1とN2を高速に求めるために、Union-Findというデータ構造を利用する。

Union-Findはグループ分けをうまく扱えるデータ構造で、

unite(u, v): 頂点 u が属するグループと頂点 v が属するグループを併合し、同じグループにする
(O(α(n)))
find(v): 頂点 v が属するグループ番号を得る (O(α(n)))
size(v): 頂点 v が属するグループと同じグループに属する頂点数を得る (O(1))

という操作を定義する。

www.slideshare.net (解説スライドがあったので神) (こういう頻出のデータ構造を知っておきたい感がある...と思ったらあった↓) プログラミングコンテストでのデータ構造

実装上は木構造を用いることができる。

class UnionFind():
    def __init__(self, size):
        # self.group holds a list of idx for parant node of i-th node
        # if self.group[i] = i, i-th node is a root.
        # if self.group[i] = j, j-th node is a paranet of i-th node. 
        self.group = [i for i in range(size)]
        
        # self.height holds a list of trees' heights
        self.height = [0] * (size)
        
    def find(self, i):
        """
        Search for i-th node's group.
        Args : 
            i : index of target node.
        Return :
            index of root of target node
        """ 
        
        if self.group[i] == i:
            # if root
            return i
        else:
            # if not root, search by the parent's node
            self.group[i] = self.find(self.group[i])
            return self.group[i]     
        
    def unite(self, i, j):
        """
        Unite two groups
        Args : 
            i, j : indices of target nodes.
        """ 
        # search roots
        i = self.find(i)
        j = self.find(j)
        
        if i == j:
            return 
        
        if self.height[i] < self.height[j]:
            self.group[i] = j
        else:
            self.group[j] = i
            if self.height[i] == self.height[j]:
                self.height[i] += 1
                
    def same(self, i, j):
        return self.find(i) == self.find(j)
       
    def size(self, i):
        """
        Get num of nodes in same group
        Args : 
            i : index of target node.
        """
        ans = 1
        root = self.find(i)
        for node_idx, par in enumerate(self.group):
            if root == par:
                ans += 1
            elif root == self.find(par):
                ans += 1
            else:
                pass
        ans -= 1
        return ans
        
        

import numpy as np
N, M = map(int, input().split())
ab_list = [list(map(int, input().split())) for _ in range(M)]

a_list = np.array(ab_list)[:, 0][-1::-1]
b_list = np.array(ab_list)[:, 1][-1::-1]

huben = [int((N) * (N-1) * 0.5)]
        
uni = UnionFind(size=N+1)

for a, b in zip(a_list, b_list):
    a_group = uni.find(a)
    b_group = uni.find(b)
    
    if a_group == b_group:
        huben.append(huben[-1])
    else:
        huben.append(huben[-1] - uni.size(a)*uni.size(b))
        uni.unite(a, b)
        
for ans in huben[-1::-1][1:]:
    print(ans)

という感じなったが、結局RTEで通らない。。。 pythonで通した人のコードを見せていただきたいものです。

==追記== ルートノードのインデックスでそのグループに属する要素数を持てばsize()がO(1)で行える。そうすれば間に合う

class UnionFind():
    def __init__(self, size):
        # self.parent holds a list of idx for parant node of i-th node
        # if i-th node is a root, self.parent[i] holds the - (number of nodes) in the parent.
        # if self.parent[i] = j, j-th node is a paranet of i-th node. 
        self.parent = [-1 for _ in range(size)]
        
        # self.height holds a list of trees' heights
        self.height = [0] * (size)
        
    def find(self, i):
        """
        Search for i-th node's parent.
        Args : 
            i : index of target node.
        Return :
            index of root of target node
        """ 
        
        if self.parent[i] < 0:
            # root node
            return i
        else:
            # if not root, search by the parent's node
            self.parent[i] = self.find(self.parent[i])
            return self.parent[i]     
        
    def unite(self, i_, j_):
        """
        Unite two parents
        Args : 
            i, j : indices of target nodes.
        """ 
        # search roots
        i = self.find(i_)
        j = self.find(j_)
        
        if self.height[i] < self.height[j]:
            self.parent[j] += self.parent[i]
            self.parent[i] = j_
            
        else:
            self.parent[i] += self.parent[j]
            self.parent[j] = i_
            if self.height[i] == self.height[j]:
                self.height[i] += 1
       
    def size(self, i):
        """
        Get num of nodes in same parent
        Args : 
            i : index of target node.
        """
        return - self.parent[self.find(i)]
        
        

import numpy as np
N, M = map(int, input().split())
ab_list = [list(map(int, input().split())) for _ in range(M)]

a_list = np.array(ab_list)[:, 0][-1::-1]
b_list = np.array(ab_list)[:, 1][-1::-1]

huben = [int((N) * (N-1) * 0.5)]
        
uni = UnionFind(size=N+1)

for a, b in zip(a_list, b_list):
    a_group = uni.find(a)
    b_group = uni.find(b)
    
    if a_group == b_group:
        huben.append(huben[-1])
    else:
        huben.append(huben[-1] - uni.size(a)*uni.size(b))
        uni.unite(a, b)
        
for ans in huben[-1::-1][1:]:
    print(ans)

他にもわかりやすい記事がありました Python:Union-Find木について - あっとのTECH LOG

AtCoder Beginner Contest 119

AtCoder Beginner Contest 119のふりかえり

atcoder.jp

結果はA、Bの二完 (AとBしか解けなかったけど緑にいけた)

C問題

やったこと

どうやるのか全く思いつかなかった。どの材料より短いやつは先にのけとく??

解法

N < 8 なので全探索可能(!)。 たけの扱いは4通り。

  • 長さA の竹の “材料” とする

  • 長さ B の竹の材料とする

  • 長さ C の竹の材料とする

  • 使わない

この4N通りを全部試す。深さ優先探索で探索すれば良い。

N, A,B,C = map(int, input().split())
l = sorted([int(input()) for _ in range(N)])[-1::-1]

INF = 1e+10

def dfs(cur, a, b, c):
    if cur == N:
#         print("a:{}, b:{}, c:{}".format(a,b,c))
        # 材料が何も使われないことを防ぐ為にa=b=c=0の時INFを返す
        # 30は余分に数えているMP
        return abs(a - A) + abs(b - B) + abs(c - C) - 30 if min(a, b, c) > 0 else INF
    # cur本めを使わない場合
    ret0 = dfs(cur + 1, a, b, c)
    # cur本めをAに使う場合
    ret1 = dfs(cur + 1, a + l[cur], b, c) + 10
    # cur本めをBに使う場合
    ret2 = dfs(cur + 1, a, b + l[cur], c) + 10
    # cur本めをCに使う場合
    ret3 = dfs(cur + 1, a, b, c + l[cur]) + 10
    
    print("cur:{}, a:{},b:{},c:{},ret0:{}, ret1:{}, ret2:{}, ret3:{}".format(cur,a,b,c,ret0,ret1,ret2,ret3))
    return min(ret0, ret1, ret2, ret3)

print(dfs(0, 0, 0, 0))

D問題

やったこと

一番近い神社・寺をみて、それらからもっとも近い寺・神社までの距離を足す。その2つのうち短い方を先に訪れるという戦略をとったが、TLE。

import numpy as np
A,B,Q = map(int, input().split())

s_list = [int(input()) for _ in range(A)]
t_list = [int(input()) for _ in range(B)]
x_list = [int(input()) for _ in range(Q)]

# print("s : ",s_list)
# print("t : ",t_list)
# print("x : ",x_list)

s_nearest_t = []
t_nearest_s = []

for s in s_list:
    t_dis = [abs(s-t) for t in t_list]
    s_nearest_t.append(np.min(t_dis))
for t in t_list:
    s_dis = [abs(s-t) for s in s_list]
    t_nearest_s.append(np.min(s_dis))


def get_ans(s_list, t_list, x):
    s_res = [abs(s-x) for s in s_list]
    t_res = [abs(t-x) for t in t_list]
    s_nearest = np.min(s_res)
    t_nearest = np.min(t_res)
    
    s_nearest_idx = s_res.index(s_nearest)
    t_nearest_idx = t_res.index(t_nearest)
    
    if s_nearest_t[s_nearest_idx] + s_nearest < t_nearest_s[t_nearest_idx] + t_nearest:
        ans = s_nearest_t[s_nearest_idx] + s_nearest
    else:
        ans = t_nearest_s[t_nearest_idx] + t_nearest

    return ans
        
for x in x_list:
    print(get_ans(s_list, t_list, x))

解法

二分探索で挿入位置を探すと早い。pythonでは bisect という標準ライブラリがある。 bisect_right で挿入位置の一個右のインデックスを得られる。

import bisect
A,B,Q = map(int, input().split())

INF = 1e+18

s_list = [-INF] + [int(input()) for _ in range(A)] + [INF]
t_list = [-INF] + [int(input()) for _ in range(B)] + [INF]
x_list = [int(input()) for _ in range(Q)]

def get_ans(s_list, t_list, x):
    b, d = bisect.bisect_right(s_list, x), bisect.bisect_right(t_list, x)
    ans = INF
    for S in [s[b - 1], s[b]]:
        for T in [t[d - 1], t[d]]:
            d1, d2 = abs(S - x) + abs(T - S), abs(T - x) + abs(S - T)
            ans = min(ans, d1, d2)
    return ans
        
for x in x_list:
    print(get_ans(s_list, t_list, x))

トポロジカルソートについて

atcoder.jp

 

こちらに参加したのですが,D問題で「トポロジカルソート」というテクニックが必要とのことで,調べました.

 

トポロジカルソートtopological sort)とは、グラフ理論において、有向非巡回グラフdirected acyclic graph, DAG)の各ノードを順序付けして、どのノードもその出力辺の先のノードより前にくるように並べることである。有向非巡回グラフは必ずトポロジカルソートすることができる。 

 とのこと. 

非巡回グラフというのは、有向サイクル(矢印の向きにたどってノードを一周できるようなサイクル)が無いことを表します.

 

わかりやすく言えば, 任意の矢印に対して,

矢印が出ているノードの番号 < 矢印が入ってくるノードの番号

という関係が成り立つようにノードに番号を振り分けることができる, ということですかね。

 

証明は:cited from 

有向非巡回グラフ(DAG)の意味とトポロジカルソート - 具体例で学ぶ数学

f:id:daiki-yosky:20190128021829p:plain

proof

 

直感的でわかりやすい。

 

 

 

Atcoder ABC-116 (python)

AtCoder Beginner Contest 116のpythonコード.

  • A問題 省略.

  • B問題

s = int(input())

a = s
idx = 1
a_list = [a]
while(idx < 1000000):
    idx += 1
    if a % 2 == 0:
        a = int(a / 2)
    else:
        a = 3*a+1
    if a in a_list:
        print(idx)
        break
    else:
        a_list.append(a)

計算時間やばいかと思ったがぎり行けた.

  • C問題
import numpy as np
N = int(input())
h_list = list(map(int, input().split()))

def shorten(idx_list, current_h):
    for idx in idx_list:
        current_h[idx] -= 1
    return current_h

def get_highest_idx(current_h):
    max_h = np.max(current_h)
    max_flower_idx = []
    for idx, h in enumerate(current_h):
        if h == max_h:
            max_flower_idx.append(idx)
    return max_flower_idx

ans = 0
while(np.max(h_list) > 0):
    highest_idx_list = get_highest_idx(h_list)
    mizunuki_idx = []
    
    for i, idx in enumerate(highest_idx_list):
        hight = h_list[idx]
        if i==len(highest_idx_list)-1:
            break
        else:
            if idx+1 == highest_idx_list[i+1]:
                j = 0
                while(True):
                    try:
                        if idx+j == highest_idx_list[i+j]:
                            j += 1
                        else:
                            break
                    except:
                        break

                mizunuki = np.array(highest_idx_list)[np.arange(i, i+j,1)]
                mizunuki_idx.append(mizunuki)
    if len(mizunuki_idx) == 0:
        mizunuki_ = [highest_idx_list[0]]
    else:
        mizunuki_ = mizunuki_idx[0]
    
    new_h = shorten(mizunuki_, h_list)
    h_list = new_h
    ans += 1

コードがきたない. 隣り合った一番高い花たちを探して, そこの高さを1減らすのを繰り返す.

  • D問題 TLEとREが出ている. 優先度付きキュー(ヒープ)でやるらしいがうまく行かない.
import numpy as np
import heapq
import copy

N, K = list(map(int, input().split()))
t_list, d_list = [], []
for _ in range(N):
    t_, d_ = list(map(int, input().split()))
    t_list.append(t_)
    d_list.append(d_)
    
# Sort by delicoiousness.
c = zip(t_list, d_list)
c = sorted(c, key=lambda x: x[1])[-1::-1]
t_list, d_list = zip(*c)

first_sushi_set = []
nokori_set = []

iter_ = 0
for neta, delicious in zip(t_list, d_list):
    if iter_ < K:
        heapq.heappush(first_sushi_set, (delicious, neta))
    else:
        heapq.heappush(nokori_set, (-delicious, neta))
    iter_ += 1
    
first_neta_num = len(list(set(t_list[:K])))

def get_next_sushi_set(current_sushi, nokori_sushi):
    current_sushi_cp = copy.deepcopy(current_sushi)

    while(len(current_sushi) > 0):
        #種類数が減らないかつ一番美味しくない寿司を選ぶ
        remove_sushi = heapq.heappop(current_sushi)
        score, neta = remove_sushi
        if len(current_sushi) == 0:
            break
        if neta in np.array(current_sushi)[:, 1]:
            break
        else:
            pass
        
    current_sushi_cp.remove(remove_sushi)
    heapq.heappush(nokori_sushi, (-score, neta))
    nokori_sushi_cp = copy.deepcopy(nokori_sushi)

    while(len(nokori_sushi) > 0):
        #種類数が増えるかつ一番美味しい寿司を選ぶ
        add_sushi = heapq.heappop(nokori_sushi)
        negative_score, neta = add_sushi
        score = - negative_score
        if len(nokori_sushi) == 0:
            break
        if neta in np.array(current_sushi_cp)[:, 1]:
            pass
        else:
            break        
    heapq.heappush(current_sushi_cp, (score, neta))
    nokori_sushi_cp.remove(add_sushi)
            
    return current_sushi_cp, nokori_sushi_cp
        
current_sushi_set = first_sushi_set
nokori_sushi_set = nokori_set
score_list = []
for neta_num in range(first_neta_num, N+1, 1):
    score =  len(set(np.array(current_sushi_set)[:, 1]))**2 + sum(np.array(current_sushi_set)[:, 0])
    score_list.append(score)
    current_sushi_set, nokori_sushi_set = get_next_sushi_set(current_sushi_set, nokori_sushi_set)
print(max(score_list))