旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】 VAE(Variational Auto Encoder)の写経

書籍「Deep Learning with Python」にMNISTを用いたVAEの実装があったので写経します(書籍では一つのファイルに全部書くスタイルだったので、VAEクラスを作ったりしました)。

VAEの解説は以下が詳しいです。
qiita.com

実装ですが、まずは以下をvae.pyに書きます。

import numpy as np
from keras import Input
from keras.layers import Conv2D, Flatten, Dense, Lambda, Reshape, Conv2DTranspose, Layer
from keras.models import Model
from keras.metrics import binary_crossentropy
import keras.backend as K

class CustomVariationalLayer(Layer):
    def set_z_mean(self, z_mean):
        self._z_mean = z_mean

    def set_z_log_var(self, z_log_var):
        self._z_log_var = z_log_var

    def _vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        reconstruction_loss = binary_crossentropy(x, z_decoded)
        regularization_parameter = -5e-4 * self._compute_KL_divergence(self._z_mean, self._z_log_var)
        return K.mean(reconstruction_loss + regularization_parameter)

    def _compute_KL_divergence(self, z_mean, z_log_var):
        return K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)

    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self._vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

class VAE(object):
    def __init__(self, image_shape, latent_dim):
        self._latent_dim = latent_dim

        # Encoding
        input_img = Input(shape=image_shape)
        x = Conv2D(32, 3, padding='same', activation='relu')(input_img)
        x = Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x)
        x = Conv2D(64, 3, padding='same', activation='relu')(x)
        x = Conv2D(64, 3, padding='same', activation='relu')(x)
        shape_before_flattening = K.int_shape(x)
        x = Flatten()(x)
        x = Dense(32, activation='relu')(x)
        z_mean = Dense(latent_dim)(x)
        z_log_var = Dense(latent_dim)(x)

        # Sampling
        z = Lambda(self._sampling)([z_mean, z_log_var])

        # Decoding
        decoder_input = Input(K.int_shape(z)[1:])
        x = Dense(np.prod(shape_before_flattening[1:]), activation='relu')(decoder_input)
        x = Reshape(shape_before_flattening[1:])(x)
        x = Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)
        x = Conv2D(1, 3, padding='same', activation='sigmoid')(x)
        self._decoder = Model(inputs=decoder_input, outputs=x)
        z_decoded = self._decoder(z)
        l = CustomVariationalLayer()
        l.set_z_mean(z_mean)
        l.set_z_log_var(z_log_var)
        y = l([input_img, z_decoded])

        self._vae = Model(input_img, y)

    def _sampling(self, args):
        z_mean, z_log_var = args
        epsilon = K.random_normal(shape=(K.shape(z_mean)[0], self._latent_dim), mean=0.0, stddev=1.0)
        return z_mean + K.exp(z_log_var)*epsilon

    def get_model(self):
        return self._vae

    def get_decoder(self):
        return self._decoder


後は以下をmain.pyに書けばVAEに文字を生成させることができます。

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from keras.optimizers import RMSprop
from keras.datasets import mnist
from vae import VAE

img_shape = (28, 28, 1)
batch_size = 32
latent_dim = 2

(x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')/255.0
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32')/255.0
x_test = x_test.reshape(x_test.shape + (1,))

vae = VAE(img_shape, latent_dim)
decoder = vae.get_decoder()
vae = vae.get_model()
vae.compile(optimizer=RMSprop(), loss=None)
history = vae.fit(x=x_train, y=None, shuffle=True, epochs=10, batch_size=batch_size)
with open('loss.txt', 'a') as f:
    for loss in history.history['loss']:
        f.write(str(loss) + '\r')

n = 15
digit_size = 28
figure = np.zeros((digit_size*n, digit_size*n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict(z_sample, batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i*digit_size:(i+1)*digit_size,
               j*digit_size:(j+1)*digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

plt.show()を実行したところで以下のような図が描画されます。
f:id:ni4muraano:20180101223102p:plain