【Python】 VAE(Variational Auto Encoder)の写経
書籍「Deep Learning with Python」にMNISTを用いたVAEの実装があったので写経します(書籍では一つのファイルに全部書くスタイルだったので、VAEクラスを作ったりしました)。
VAEの解説は以下が詳しいです。 qiita.com
実装ですが、まずは以下をvae.pyに書きます。
import numpy as np from keras import Input from keras.layers import Conv2D, Flatten, Dense, Lambda, Reshape, Conv2DTranspose, Layer from keras.models import Model from keras.metrics import binary_crossentropy import keras.backend as K class CustomVariationalLayer(Layer): def set_z_mean(self, z_mean): self._z_mean = z_mean def set_z_log_var(self, z_log_var): self._z_log_var = z_log_var def _vae_loss(self, x, z_decoded): x = K.flatten(x) z_decoded = K.flatten(z_decoded) reconstruction_loss = binary_crossentropy(x, z_decoded) regularization_parameter = -5e-4 * self._compute_KL_divergence(self._z_mean, self._z_log_var) return K.mean(reconstruction_loss + regularization_parameter) def _compute_KL_divergence(self, z_mean, z_log_var): return K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) def call(self, inputs): x = inputs[0] z_decoded = inputs[1] loss = self._vae_loss(x, z_decoded) self.add_loss(loss, inputs=inputs) return x class VAE(object): def __init__(self, image_shape, latent_dim): self._latent_dim = latent_dim # Encoding input_img = Input(shape=image_shape) x = Conv2D(32, 3, padding='same', activation='relu')(input_img) x = Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x) x = Conv2D(64, 3, padding='same', activation='relu')(x) x = Conv2D(64, 3, padding='same', activation='relu')(x) shape_before_flattening = K.int_shape(x) x = Flatten()(x) x = Dense(32, activation='relu')(x) z_mean = Dense(latent_dim)(x) z_log_var = Dense(latent_dim)(x) # Sampling z = Lambda(self._sampling)([z_mean, z_log_var]) # Decoding decoder_input = Input(K.int_shape(z)[1:]) x = Dense(np.prod(shape_before_flattening[1:]), activation='relu')(decoder_input) x = Reshape(shape_before_flattening[1:])(x) x = Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x) x = Conv2D(1, 3, padding='same', activation='sigmoid')(x) self._decoder = Model(inputs=decoder_input, outputs=x) z_decoded = self._decoder(z) l = CustomVariationalLayer() l.set_z_mean(z_mean) l.set_z_log_var(z_log_var) y = l([input_img, z_decoded]) self._vae = Model(input_img, y) def _sampling(self, args): z_mean, z_log_var = args epsilon = K.random_normal(shape=(K.shape(z_mean)[0], self._latent_dim), mean=0.0, stddev=1.0) return z_mean + K.exp(z_log_var)*epsilon def get_model(self): return self._vae def get_decoder(self): return self._decoder
後は以下をmain.pyに書けばVAEに文字を生成させることができます。
import numpy as np import matplotlib.pyplot as plt from scipy.stats import norm from keras.optimizers import RMSprop from keras.datasets import mnist from vae import VAE img_shape = (28, 28, 1) batch_size = 32 latent_dim = 2 (x_train, _), (x_test, y_test) = mnist.load_data() x_train = x_train.astype('float32')/255.0 x_train = x_train.reshape(x_train.shape + (1,)) x_test = x_test.astype('float32')/255.0 x_test = x_test.reshape(x_test.shape + (1,)) vae = VAE(img_shape, latent_dim) decoder = vae.get_decoder() vae = vae.get_model() vae.compile(optimizer=RMSprop(), loss=None) history = vae.fit(x=x_train, y=None, shuffle=True, epochs=10, batch_size=batch_size) with open('loss.txt', 'a') as f: for loss in history.history['loss']: f.write(str(loss) + '\r') n = 15 digit_size = 28 figure = np.zeros((digit_size*n, digit_size*n)) grid_x = norm.ppf(np.linspace(0.05, 0.95, n)) grid_y = norm.ppf(np.linspace(0.05, 0.95, n)) for i, yi in enumerate(grid_y): for j, xi in enumerate(grid_x): z_sample = np.array([[xi, yi]]) z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2) x_decoded = decoder.predict(z_sample, batch_size) digit = x_decoded[0].reshape(digit_size, digit_size) figure[i*digit_size:(i+1)*digit_size, j*digit_size:(j+1)*digit_size] = digit plt.figure(figsize=(10, 10)) plt.imshow(figure, cmap='Greys_r') plt.show()
plt.show()を実行したところで以下のような図が描画されます。