旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Keras】fit_generatorに使うgeneratorの雛形メモ

クラス分類用のfit_generatorに使うgeneratorの雛形をメモします。
画像が格納されているフォルダが以下のような構造であることを前提とします。

  • トップフォルダ
    • class1フォルダ
      • class1に属する画像ファイル
    • class2フォルダ
      • class2に属する画像ファイル
    • 以下同様


また、opencvやデータ拡張ライブラリのalbumentationsに依存しています。
ni4muraano.hatenablog.com

data_loader.py

import os, random, cv2
import numpy as np
from albumentations import Compose
from albumentations.augmentations.transforms import HorizontalFlip, Normalize

class DataLoader(object):
    def __init__(self, data_folder, batch_size, input_shape, do_augmentation, gray_scale=False):
        self._file_paths = []
        self._annotations = []
        folders = os.listdir(data_folder)
        folders.sort()
        # 画像のパスとクラスIDを取得する
        for class_id, class_name in enumerate(folders):
            folder_path = data_folder + os.sep + class_name
            file_paths = [folder_path + os.sep + fn for fn in os.listdir(folder_path + os.sep)]
            self._file_paths += file_paths
            self._annotations += [class_id]*len(file_paths)
        # クラス数
        self._class_count = class_id + 1
        self._batch_size = batch_size
        self._input_shape = input_shape
        self._gray_scale = gray_scale
        if len(self._file_paths)%self._batch_size == 0:
            self.iterations = len(self._file_paths) // self._batch_size
            self._has_extra_data = False
        else:
            self.iterations = len(self._file_paths) // self._batch_size + 1
            self._has_extra_data = True
        self._compose = self._define_augment(input_shape, do_augmentation)

    def _define_augment(self, input_shape, do_augmentation):
        # mean, std, max_pixel_valueは適宜変更してください
        mean = (0.485*255, 0.456*255, 0.406*255)
        std = (0.229*255, 0.224*255, 0.225*255)
        normalize = Normalize(mean=mean, std=std, max_pixel_value=1)
        # データ拡張内容は適宜変更してください
        if do_augmentation:
            return Compose([normalize, HorizontalFlip(p=0.5)])
        else:
            return Compose([normalize])

    def get_data_loader(self):
        while True:
            file_paths, annotations = self._shuffle(self._file_paths, self._annotations)
            for iteration in range(self.iterations):
                if iteration == self.iterations - 1 and self._has_extra_data:
                    shape = (len(file_paths)%self._batch_size, self._input_shape[0],
                             self._input_shape[1], self._input_shape[2])
                else:
                    shape = (self._batch_size, self._input_shape[0], self._input_shape[1], self._input_shape[2])
                # バッチサイズ分のデータを取得する
                X = np.zeros(shape, dtype=np.float32)
                y = np.zeros((shape[0], self._class_count), dtype=np.float32)
                for i in range(X.shape[0]):
                    index = self._batch_size*iteration + i
                    if self._gray_scale:
                        image = cv2.imread(file_paths[index], cv2.IMREAD_GRAYSCALE)
                        image = image[:,:,np.newaxis]
                    else:
                        image = cv2.imread(file_paths[index])
                    image = cv2.resize(image, (self._input_shape[1], self._input_shape[0]))
                    image = image.astype(np.float32)
                    # データ拡張を実行する
                    X[i] = self._augment(image)
                    y[i, annotations[index]] = 1
                yield X, y

    def _shuffle(self, x, y):
        p = list(zip(x, y))
        random.shuffle(p)
        return zip(*p)

    def _augment(self, image):
        dict = {'image': image}
        augmented = self._compose(**dict)
        return augmented['image']


このdata_loader.pyは以下のように使います。

from data_loader import DataLoader

train_data_loader = DataLoader('train', batch_size, input_shape, do_augmentation=True)
train_generator = train_data_loader.get_data_loader()
h = model.fit_generator(train_generator, train_data_loader.iterations, epochs)