旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】 KerasでConditional DCGANを実装する

前回DCGANを実装しましたが、今回はConditional DCGAN([1411.1784] Conditional Generative Adversarial Nets)を実装します。

DCGANの例は入力からどのような数字が生成されるかコントロールできませんでしたが、Conditional DCGANは付加情報を足すことで生成する数字をコントロールできるようになります(下図のyが付加情報)。 f:id:ni4muraano:20171225180733p:plain

実装にあたりzとyはベクトルなのでGenerator側の実装はイメージが付くのですが、xは画像なのでDiscriminator側の実装が分かりませんでした。そんな時に参考になったのが下記サイトでした。

qiita.com

--- 学習環境 ---
Windows10 Home
Python 3.5.2
Keras 2.1.2
tensorflow-gpu 1.2.0

実装に移りますが、以下をconditional_dcgan.pyに記述します。

import numpy as np
import keras
from keras.layers import Input, Dense, Activation, BatchNormalization, Reshape, UpSampling2D, Conv2D, MaxPool2D, Flatten, concatenate, multiply
from keras.models import Model

class Generator(object):
    def __init__(self, latent_dim, condition_dim):
        # latent vector input
        generator_input1 = Input(shape=(latent_dim,))
        # condition input
        generator_input2 = Input(shape=(condition_dim,))
        # concat 2 inputs
        generator_input = concatenate([generator_input1, generator_input2])
        x = Dense(1024)(generator_input)
        x = Activation('tanh')(x)
        x = Dense(128*7*7)(x)
        x = BatchNormalization()(x)
        x = Activation('tanh')(x)
        x = Reshape((7, 7, 128))(x)
        x = UpSampling2D(size=(2, 2))(x)
        x = Conv2D(64, 5, padding='same')(x)
        x = Activation('tanh')(x)
        x = UpSampling2D(size=(2, 2))(x)
        x = Conv2D(1, 5, padding='same')(x)
        x = Activation('tanh')(x)
        # pass condition input to output so we can give it to discriminator
        self.generator = Model(inputs=[generator_input1, generator_input2], outputs=[x, generator_input2])

    def get_model(self):
        return self.generator


class Discriminator(object):
    def __init__(self, height, width, channels, condition_dim):
        # real or fake image
        discriminator_input1 = Input(shape=(height, width, channels))
        # condition input from generator
        discriminator_input2 = Input(shape=(condition_dim,))
        # expand dimension from (batch, channel) to (batch, height, width, channel)
        di2 = Reshape((1, 1, condition_dim))(discriminator_input2)
        # expand height and width from (1, 1) to (height, width)
        di2 = UpSampling2D((height, width))(di2)
        # concat 2 inputs
        discriminator_input = concatenate([discriminator_input1, di2])
        x = Conv2D(64, 5, padding='same')(discriminator_input)
        x = Activation('tanh')(x)
        x = MaxPool2D()(x)
        x = Conv2D(128, 5)(x)
        x = Activation('tanh')(x)
        x = MaxPool2D()(x)
        x = Flatten()(x)
        x = Dense(1024)(x)
        x = Activation('tanh')(x)
        x = Dense(1, activation='sigmoid')(x)
        self.discriminator = Model(inputs=[discriminator_input1, discriminator_input2], outputs=x)

    def get_model(self):
        return self.discriminator


class ConditionalDCGAN(object):
    def __init__(self, latent_dim, height, width, channels, condition_dim):
        # set generator
        self._latent_dim = latent_dim
        g = Generator(latent_dim, condition_dim)
        self._generator = g.get_model()
        # set discriminator
        d = Discriminator(height, width, channels, condition_dim)
        self._discriminator = d.get_model()
        # compile discriminator
        discriminator_optimizer = keras.optimizers.SGD(lr=0.0005, momentum=0.9, nesterov=True)
        self._discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')
        # disable training when combined with generator
        self._discriminator.trainable = False
        # set DCGAN
        dcgan_input1 = Input(shape=(latent_dim,))
        dcgan_input2 = Input(shape=(condition_dim,))
        dcgan_output = self._discriminator(self._generator([dcgan_input1, dcgan_input2]))
        self._dcgan = Model([dcgan_input1, dcgan_input2], dcgan_output)
        # compile DCGAN
        dcgan_optimizer = keras.optimizers.SGD(lr=0.0005, momentum=0.9, nesterov=True)
        self._dcgan.compile(optimizer=dcgan_optimizer, loss='binary_crossentropy')

    def train(self, real_images, conditions, batch_size):
        # Train discriminator so it can detect fake
        random_latent_vectors = np.random.normal(size=(batch_size, self._latent_dim))
        generated_images = self._generator.predict([random_latent_vectors, conditions])
        labels = np.ones((batch_size, 1))
        labels += 0.05 * np.random.random(labels.shape)
        d_loss1 = self._discriminator.train_on_batch(generated_images, labels)
        # Train discriminator so it can detect real
        labels = np.zeros((batch_size, 1))
        labels += 0.05 * np.random.random(labels.shape)
        d_loss2 = self._discriminator.train_on_batch([real_images, conditions], labels)
        d_loss = (d_loss1 + d_loss2)/2.0
        # Train generator so it can fool discriminator
        random_latent_vectors = np.random.normal(size=(batch_size, self._latent_dim))
        misleading_targets = np.zeros((batch_size, 1))
        g_loss = self._dcgan.train_on_batch([random_latent_vectors, conditions], misleading_targets)
        return d_loss, g_loss

    def predict(self, latent_vector, condition):
        # return only image (remember generator returns condition too)
        return self._generator.predict([latent_vector, condition])[0]

    def load_weights(self, file_path, by_name=False):
        self._dcgan.load_weights(file_path, by_name)

    def save_weights(self, file_path, overwrite=True):
        self._dcgan.save_weights(file_path, overwrite)


次に以下をmain.pyに記述します(予めmain.pyと同じフォルダにgeneratedというフォルダを作成してください)。

import os
import numpy as np
import keras
from conditional_dcgan import ConditionalDCGAN
from keras.preprocessing import image
from keras.utils.np_utils import to_categorical

def normalize(X):
    return (X - 127.5)/127.5

def denormalize(X):
    return (X + 1.0)*127.5

def train(latent_dim, height, width, channels, num_class):
    (X_train, Y_train), (_, _) = keras.datasets.mnist.load_data()
    Y_train = to_categorical(Y_train, num_class)
    X_train = X_train.reshape((X_train.shape[0],) + (height, width, channels)).astype('float32')
    X_train = normalize(X_train)
    epochs = 50
    batch_size = 128
    iterations = X_train.shape[0]//batch_size
    dcgan = ConditionalDCGAN(latent_dim, height, width, channels, num_class)
    for epoch in range(epochs):
        for iteration in range(iterations):
            real_images = X_train[iteration*batch_size:(iteration+1)*batch_size]
            conditions = Y_train[iteration*batch_size:(iteration+1)*batch_size]
            d_loss, g_loss = dcgan.train(real_images, conditions, batch_size)
            if (iteration + 1)%10 == 0:
                print('discriminator loss:', d_loss)
                print('generator loss:', g_loss)
                print()
                with open('loss.txt', 'a') as f:
                    f.write(str(d_loss) + ',' + str(g_loss) + '\r')
        if (epoch + 1)%5 == 0:
            dcgan.save_weights('gan' + '_epoch' + str(epoch + 1) + '.h5')
            random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
            generated_images = dcgan.predict(random_latent_vectors, conditions)
            for i, generated_image in enumerate(generated_images):
                img = denormalize(generated_image)
                img = image.array_to_img(img, scale=False)
                condition = np.argmax(conditions[i])
                img.save(os.path.join('generated', str(epoch) + '_' + str(condition) + '.png'))
        print('epoch' + str(epoch) + ' end')
        print()

def predict(latent_dim, height, width, channels, num_class):
    dcgan = ConditionalDCGAN(latent_dim, height, width, channels, num_class)
    dcgan.load_weights('gan_epoch50.h5')
    for num in range(num_class):
        for id in range(10):
            random_latent_vectors = np.random.normal(size=(1, latent_dim))
            condition = np.zeros((1, num_class), dtype=np.float32)
            condition[0, num] = 1
            generated_images = dcgan.predict(random_latent_vectors, condition)
            img = image.array_to_img(denormalize(generated_images[0]), scale=False)
            img.save(os.path.join('generated', str(num) + '_' + str(id) + '.png'))

if __name__ == '__main__':
    latent_dim = 100
    height = 28
    width = 28
    channels = 1
    num_class = 10
    train(latent_dim, height, width, channels, num_class)
    predict(latent_dim, height, width, channels, num_class)


生成された画像は以下のようになります。conditionを付加したことで生成する文字をコントロールできています。
f:id:ni4muraano:20171225184732p:plain
最後に学習時のDiscriminatorとGeneratorのロス値の推移も貼っておきます。
f:id:ni4muraano:20171225184822p:plain