旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】 Kerasで中間層を可視化する

Kerasを利用してネットワーク中間層を可視化する方法をメモします。プログラムでは学習済みのモデルと重みがある事を想定し、それらを読み取って一層目の中間層であるConvolution2Dの重みを可視化しています。

import numpy as np
from keras.models import model_from_json
from matplotlib import pyplot as plt

with open('lenet.json', 'r') as file:
    model_json = file.read()
model = model_from_json(model_json)
model.load_weights('lenet_weights.hdf5')

# Get 1st layer Convolution2D weights.
# In this example, weights.shape is (6, 1, 5, 5)
weights = model.layers[0].get_weights()[0].transpose(3,2,0,1)
fig = plt.figure()
for i, weight_3d in enumerate(weights):
    for j, weight_2d in enumerate(weight_3d):
        sub = fig.add_subplot(weights.shape[0], weight_3d.shape[0], i*weight_3d.shape[0]+j+1)
        sub.axis('off')
        sub.imshow(weight_2d, 'Greys')
plt.show()

f:id:ni4muraano:20170216202753j:plain