【Python】 KerasでResNet等のショートカット構造を実装する
Kerasでは学習済みのResNetが利用できるため、ResNetを自分で作ることは無いと思います。ただ、ResNet以外にも下の写真のようなショートカット構造を持つネットワークがあり、これらを実装したい時にどのように作成するかをメモします。 単純なネットワークの場合、KerasではSequentialを生成して、レイヤーをaddしていくのが通常ですが(Sequentialモデルのガイド - Keras Documentation)、少し複雑なネットワークを作成する場合はFunctional APIを利用します(Functional APIのガイド - Keras Documentation)。 Functional APIはSequentialを利用するプログラミングと比較した場合は多少直観性に劣りますが、別段難しいものではありません。例えば下図のようにInput⇒Convolutionというシーケンスは以下のように書くことができます。
from keras.layers import Input, Convolution2D # 入力の形状を指定する(CIFAR-10のような32×32×3の画像を想定) in_ = Input((32, 32, 3)) # Input⇒Convolution2D in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_)
Input⇒Convolutionに更にConvolutionを接続するのも一行追加するだけです。
from keras.layers import Input, Convolution2D # 入力の形状を指定する(CIFAR-10のような32×32×3の画像を想定) in_ = Input((32, 32, 3)) # Input⇒Convolution2D in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_) # Input⇒Convolution2D⇒Convolution2D in_conv_conv = Convolution2D(10, 3, 3, border_mode='same')(in_conv)
Input⇒Convolution2Dをショートカットさせます。これはInput⇒Convolution2DとInput⇒Convolution2D⇒Convolution2Dをマージさせることで実現できます。keras.layersにはMergeとmergeがあるのですが、mergeを利用します。間違えてMergeをimportしないよう注意して下さい。
from keras.layers import Input, Convolution2D, merge # 入力の形状を知らせる(ここではCIFAR-10のような32×32×3としている) in_ = Input((32, 32, 3)) # Input⇒Convolution2D in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_) # Input⇒Convolution2D⇒Convolution2D in_conv_conv = Convolution2D(10, 3, 3, border_mode='same')(in_conv) # Input⇒Convolution2DとInput⇒Convolution2D⇒Convolution2Dをマージする merged = merge([in_conv_conv, in_conv], mode='sum')
以上でショートカット構造を実現できました。最後に、最初に示した図の構造を実現するためのコードを下記にまとめます。
from keras.layers import Input, Convolution2D, merge from keras.models import Model from keras.utils.visualize_util import plot # 入力の形状を知らせる(ここではCIFAR-10のような32×32×3としている) in_ = Input((32, 32, 3)) # Input⇒Convolution2D in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_) # Input⇒Convolution2D⇒Convolution2D in_conv_conv = Convolution2D(10, 3, 3, border_mode='same')(in_conv) # Input⇒Convolution2DとInput⇒Convolution2D⇒Convolution2Dをマージする merged = merge([in_conv_conv, in_conv], mode='sum') # 更にConvolution2Dを接続する merged_conv = Convolution2D(10, 3, 3, border_mode='same')(merged) # input, outputを指定してモデルを作成する model = Model(input=in_, output=merged_conv) # モデルを図示する plot(model, 'shortcut_structure_example.png')