旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】 KerasでResNet等のショートカット構造を実装する

Kerasでは学習済みのResNetが利用できるため、ResNetを自分で作ることは無いと思います。ただ、ResNet以外にも下の写真のようなショートカット構造を持つネットワークがあり、これらを実装したい時にどのように作成するかをメモします。
f:id:ni4muraano:20170325221400p:plain

単純なネットワークの場合、KerasではSequentialを生成して、レイヤーをaddしていくのが通常ですが(Sequentialモデルのガイド - Keras Documentation)、少し複雑なネットワークを作成する場合はFunctional APIを利用します(Functional APIのガイド - Keras Documentation)。
Functional APIはSequentialを利用するプログラミングと比較した場合は多少直観性に劣りますが、別段難しいものではありません。例えば下図のようにInput⇒Convolutionというシーケンスは以下のように書くことができます。

from keras.layers import Input, Convolution2D

# 入力の形状を指定する(CIFAR-10のような32×32×3の画像を想定)
in_ = Input((32, 32, 3))
# Input⇒Convolution2D
in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_)

f:id:ni4muraano:20170325223257p:plain

Input⇒Convolutionに更にConvolutionを接続するのも一行追加するだけです。

from keras.layers import Input, Convolution2D

# 入力の形状を指定する(CIFAR-10のような32×32×3の画像を想定)
in_ = Input((32, 32, 3))
# Input⇒Convolution2D
in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_)
# Input⇒Convolution2D⇒Convolution2D
in_conv_conv = Convolution2D(10, 3, 3, border_mode='same')(in_conv)

f:id:ni4muraano:20170325224113p:plain

Input⇒Convolution2Dをショートカットさせます。これはInput⇒Convolution2DとInput⇒Convolution2D⇒Convolution2Dをマージさせることで実現できます。keras.layersにはMergeとmergeがあるのですが、mergeを利用します。間違えてMergeをimportしないよう注意して下さい。

from keras.layers import Input, Convolution2D, merge

# 入力の形状を知らせる(ここではCIFAR-10のような32×32×3としている)
in_ = Input((32, 32, 3))
# Input⇒Convolution2D
in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_)
# Input⇒Convolution2D⇒Convolution2D
in_conv_conv = Convolution2D(10, 3, 3, border_mode='same')(in_conv)
# Input⇒Convolution2DとInput⇒Convolution2D⇒Convolution2Dをマージする
merged = merge([in_conv_conv, in_conv], mode='sum')

f:id:ni4muraano:20170325224731p:plain

以上でショートカット構造を実現できました。最後に、最初に示した図の構造を実現するためのコードを下記にまとめます。

from keras.layers import Input, Convolution2D, merge
from keras.models import Model
from keras.utils.visualize_util import plot

# 入力の形状を知らせる(ここではCIFAR-10のような32×32×3としている)
in_ = Input((32, 32, 3))
# Input⇒Convolution2D
in_conv = Convolution2D(10, 3, 3, border_mode='same')(in_)
# Input⇒Convolution2D⇒Convolution2D
in_conv_conv = Convolution2D(10, 3, 3, border_mode='same')(in_conv)
# Input⇒Convolution2DとInput⇒Convolution2D⇒Convolution2Dをマージする
merged = merge([in_conv_conv, in_conv], mode='sum')
# 更にConvolution2Dを接続する
merged_conv = Convolution2D(10, 3, 3, border_mode='same')(merged)
# input, outputを指定してモデルを作成する
model = Model(input=in_, output=merged_conv)
# モデルを図示する
plot(model, 'shortcut_structure_example.png')