【Keras】ArcFaceとUmapを使って特徴量を可視化する
ディープラーニングを用いたMetric Learningの一手法であるArcFaceで特徴抽出を行い、その特徴量をUmapを使って2次元に落とし込み可視化しました。KerasでArcFaceを用いる例としてメモしておきます。
実装は以下を引っ張ってきました。元とほぼ一緒なのですが一部以下の変更を入れています。 github.com
- archs.pyの簡素化
from keras.models import Model from keras.layers import Dense, Conv2D, BatchNormalization, Activation from keras.layers import Input, MaxPooling2D, Dropout, Flatten from keras import regularizers from metrics import * weight_decay = 1e-4 def vgg_block(x, filters, layers): for _ in range(layers): x = Conv2D(filters, (3, 3), padding='same', kernel_initializer='he_normal', kernel_regularizer=regularizers.l2(weight_decay))(x) x = BatchNormalization()(x) x = Activation('relu')(x) return x def vgg8_arcface(num_features, s, m): input = Input(shape=(28, 28, 1)) y = Input(shape=(10,)) x = vgg_block(input, 16, 2) x = MaxPooling2D(pool_size=(2, 2))(x) x = vgg_block(x, 32, 2) x = MaxPooling2D(pool_size=(2, 2))(x) x = vgg_block(x, 64, 2) x = MaxPooling2D(pool_size=(2, 2))(x) x = BatchNormalization()(x) x = Dropout(0.5)(x) x = Flatten()(x) x = Dense(num_features, kernel_initializer='he_normal', kernel_regularizer=regularizers.l2(weight_decay))(x) x = BatchNormalization()(x) output = ArcFace(10, s=s, m=m, regularizer=regularizers.l2(weight_decay))([x, y]) return Model([input, y], output)
2. train.pyをtrain_arcface.pyと名前を変更し、以下のように変更する(ここでArcFaceの学習を実行します)
import os import numpy as np import keras from keras.datasets import mnist from keras.optimizers import RMSprop from keras.callbacks import ModelCheckpoint, CSVLogger, TerminateOnNaN from archs import vgg8_arcface def main(epochs, batch_size, num_features, s, m): name = 'mnist_arcface_%dd' %(num_features) os.makedirs('models/%s' %name, exist_ok=True) (X, y), (X_test, y_test) = mnist.load_data() X = X[:, :, :, np.newaxis].astype('float32') / 255 X_test = X_test[:, :, :, np.newaxis].astype('float32') / 255 y = keras.utils.to_categorical(y, 10) y_test = keras.utils.to_categorical(y_test, 10) lr = 1e-3 optimizer = RMSprop(lr=lr) model = vgg8_arcface(num_features, s, m) model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) model.summary() callbacks = [ ModelCheckpoint(os.path.join('models', name, 'model.hdf5'), verbose=1, save_best_only=True), CSVLogger(os.path.join('models', name, 'log.csv')), TerminateOnNaN()] model.fit([X, y], y, validation_data=([X_test, y_test], y_test), batch_size=batch_size, epochs=epochs, callbacks=callbacks, verbose=1) model.load_weights(os.path.join('models/%s/model.hdf5' %name)) score = model.evaluate([X_test, y_test], y_test, verbose=1) print("Test loss:", score[0]) print("Test accuracy:", score[1]) if __name__ == '__main__': epochs = 15 batch_size = 128 num_features = 30 s = 30 m = 0.5 main(epochs, batch_size, num_features, s, m)
3. train_umap.pyを追加(ここでUmapの学習を実行します)
import numpy as np import pickle from keras.datasets import mnist from keras.models import load_model, Model from metrics import * from scipy.sparse.csgraph import connected_components from umap import UMAP def main(): # dataset (X, y), (X_test, y_test) = mnist.load_data() X_test = X_test[:, :, :, np.newaxis].astype('float32') / 255 # feature extraction arcface_model = load_model('models/mnist_vgg8_arcface_30d/model.hdf5', custom_objects={'ArcFace': ArcFace}) arcface_model = Model(inputs=arcface_model.input[0], outputs=arcface_model.layers[-3].output) arcface_features = arcface_model.predict(X_test, verbose=1) arcface_features /= np.linalg.norm(arcface_features, axis=1, keepdims=True) umap = UMAP(n_components=2) umap.fit(arcface_features) # UMAPはpickleで保存する(https://github.com/lmcinnes/umap/issues/178) pickle.dump(umap, open('umap.pkl', 'wb')) if __name__ == '__main__': main()
4. test.pyを以下のように変更(ここで学習済みのArcfaceとUmapを用いて2次元特徴空間にマッピングします)
import matplotlib.pyplot as plt import matplotlib.cm as cm import numpy as np import pickle from keras.datasets import mnist from keras.models import load_model, Model from metrics import * from scipy.sparse.csgraph import connected_components def main(): # dataset (X, y), (X_test, y_test) = mnist.load_data() X_test = X_test[:, :, :, np.newaxis].astype('float32') / 255 # feature extraction arcface_model = load_model('models/mnist_vgg8_arcface_30d/model.hdf5', custom_objects={'ArcFace': ArcFace}) arcface_model = Model(inputs=arcface_model.input[0], outputs=arcface_model.layers[-3].output) arcface_features = arcface_model.predict(X_test, verbose=1) arcface_features /= np.linalg.norm(arcface_features, axis=1, keepdims=True) umap = pickle.load((open('umap.pkl', 'rb'))) embedding = umap.transform(arcface_features) plt.scatter(embedding[:, 0], embedding[:, 1], c=y_test, cmap=cm.tab10) plt.colorbar() plt.show() if __name__ == '__main__': main()
準備が出来たらtrain_arcface.py ⇒ train_umap.py ⇒ test.pyの順番に走らせると以下のような画像が出力されます。 使う場面によってはクラス分類問題に落とし込むよりこのように視覚化した方がユーザーに理解されやすいケースがあるかもしれません。