キンサクプログラマー

お金儲けと技術のメモ

GANのソースを読む

巷で話題の生成モデル

自動で線画に色をつけてくれるPaintChainerでも使用されてる技術"GAN"
PaintsChainer -線画自動着色サービス-

ネットでもそこそこ情報がのっているようで、chainer,tensorflow,kerasを使って実装した例が転がっていた。
以前、チュートリアルを試した限りではkerasがわかりやすかったので、GANの実装を読んでみた。

ソースは以下を拝借
github.com
理論は以下を参考
elix-tech.github.io

全体
train側
def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((128, 7, 7), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Convolution2D(64, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Convolution2D(1, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    return model

Seqeuntialはネットワークのクラス。
addすることで、モデルにスタックされていく。Denseは全結合レイヤ。入力次元100出力次元1024てことになる。
model.add(Activation('tanh'))活性化関数、model.add(Dense(128*7*7)) 6272次元の中間層、BatchNormalization正規化レイヤ(なんか値が徐々にシフトしていって、制度が下がるという課題があるらしく、その補正のために正規化が必要らしい。詳しくはよくわからない)、
同じ関数は飛ばすとして、UpSampling2D(size=(2,2)))2次元方向にサンプル数を増やす。Convolution2d畳み込みカネールの数、たて、よこらしい。Reshpaeで3次元になった後に、UpSampling2Dやった場合ってどうなるのかよくわからない。

【追記】

そもそもモデル選定部分については、 GANの本質ではなさそうなので真剣に読む必要はなさそう。

def discriminator_model():
    model = Sequential()
    model.add(Convolution2D(
                        64, 5, 5,
                        border_mode='same',
                        input_shape=(1, 28, 28)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(128, 5, 5))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model
def train(BATCH_SIZE):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()  #X_train  
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    discriminator = discriminator_model()  #判別モデル
    generator = generator_model()          #生成モデル
    discriminator_on_generator = \    #判別生成モデル
        generator_containing_discriminator(generator, discriminator)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True) #確率的勾配降下法
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True) #確率的勾配降下法
    generator.compile(loss='binary_crossentropy', optimizer="SGD")#generaterモデルの交差エントロピーを確率的勾配降下法で最小化
    discriminator_on_generator.compile(
        loss='binary_crossentropy', optimizer=g_optim) #判別生成モデルの交差エントロピーを確率的勾配降下法で最小化
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)#判別モデルの交差エントロピーを確率的勾配降下法で最小化
    noise = np.zeros((BATCH_SIZE, 100))
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images = generator.predict(noise, verbose=0) #■予測
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(
                    str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))#前半がtrainデータ、後半が生成モデルから出て来たでーた
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE#前半はtrainデータ(正解)後半は偽物
            d_loss = discriminator.train_on_batch(X, y)     #★判別モデルの学習。(学習データ=実画像、生成画像 正解データ=(正解、誤解)
            print("batch %d d_loss : %f" % (index, d_loss))
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch(
                noise, [1] * BATCH_SIZE)                                         #★学習  ノイズを与えて生成された画像が、本物であると判断するようにgeneratorを学習させる
            discriminator.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                generator.save_weights('generator', True)             #モデルセーブ
                discriminator.save_weights('discriminator', True)  #モデルセーブ
肝となる理論の部分を図にしてみた

f:id:pikurusux:20170517232015p:plain

  1. データ生成
  2. 判別器の学習
  3. 生成器の学習

というイテレーションを回すことで、判別器と生成器が切磋琢磨しあいながら精度を向上していくのが最大の特徴と言えそう。
今日はここまで。
*ご指摘等あれば、教えていただけるとありがたいです。