旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】画像データ拡張ライブラリAlbumentationsを使ってみる

PyTorch版のYOLO v3を作っている人がいたので試してみようと思っています。 github.com

ただ、Trainにデータ拡張が入っていないのでデータ拡張ロジックを追加したいと思ったところ、 Albumentationsというライブラリを見つけました。 github.com

物体検出やセグメンテーションにも利用可能そうなので早速試してみました。 使い方は以下を実施すれば良さげです。

  1. Composeを作って、Composeの中に実施したいデータ拡張を記述
  2. Composeに画像、ラベル、クラスIDを含むディクショナリを投入

以下はComposeを作るコードになります。

from albumentations import Compose
from albumentations.augmentations.transforms import Resize, HorizontalFlip, RandomSizedCrop, HueSaturationValue

def get_compose(crop_min_max, image_height, image_width, hue_shift, saturation_shift, value_shift):
    # Resize image to (image_height, image_width) with 100% probability
    # Flip LR with 50% probability
    # Crop image and resize image to (image_height, image_width) with 100% probability
    # Change HSV from -hue_shift to +hue_shift and so on with 100% probability
    # Format 'pascal_voc' means label is given like [x_min, y_min, x_max, y_max]
    return Compose([Resize(image_height, image_width, p=1.0),
                     HorizontalFlip(p=0.5),
                     RandomSizedCrop(crop_min_max, image_height, image_width, p=1.0),
                     HueSaturationValue(hue_shift, saturation_shift, value_shift, p=1.0)],
                    bbox_params={'format':'pascal_voc', 'label_fields':['category_id']})

このComposeは以下のように使います。

# Image size for YOLO
image_size = 416
# Crop 80 - 100% of image
crop_min = image_size*80//100
crop_max = image_size
crop_min_max = (crop_min, crop_max)
# HSV shift limits
hue_shift = 10
saturation_shift = 10
value_shift = 10
# Get compose
compose = get_compose(crop_min_max, image_size, image_size, hue_shift, saturation_shift, value_shift)
# image: numpy array like return value of cv2.imread
# labels: bounding box lists like [[366.7, 80.84, 132.8, 181.84], [5.66, 138.95, 147.09, 164.88]]
# classes: class of each bounding box like [0, 1]
annotation = {'image': image, 'bboxes': labels, 'category_id': classes}
# Do augmentation
augmented = compose(**annotation)
augmented_image = augmented['image']
augmented_labels = augmented['bboxes']

f:id:ni4muraano:20181120235803j:plain:w300 f:id:ni4muraano:20181121000404j:plain:w200

他にも様々な拡張が用意されているっぽい。API Referenceはここ