【Python】 Kerasで中間層を可視化する
Kerasを利用してネットワーク中間層を可視化する方法をメモします。プログラムでは学習済みのモデルと重みがある事を想定し、それらを読み取って一層目の中間層であるConvolution2Dの重みを可視化しています。
import numpy as np from keras.models import model_from_json from matplotlib import pyplot as plt with open('lenet.json', 'r') as file: model_json = file.read() model = model_from_json(model_json) model.load_weights('lenet_weights.hdf5') # Get 1st layer Convolution2D weights. # In this example, weights.shape is (6, 1, 5, 5) weights = model.layers[0].get_weights()[0].transpose(3,2,0,1) fig = plt.figure() for i, weight_3d in enumerate(weights): for j, weight_2d in enumerate(weight_3d): sub = fig.add_subplot(weights.shape[0], weight_3d.shape[0], i*weight_3d.shape[0]+j+1) sub.axis('off') sub.imshow(weight_2d, 'Greys') plt.show()