旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】CIFAR-10画像データを扱う

AlexNet等の実験でCIFAR-10(CIFAR-10 and CIFAR-100 datasets)を利用することがあります。幸いKerasにはデフォルトでCIFAR-10の画像データを取り込む関数があるのですが、諸事情によりこれが利用できなかったのでCIFAR-10を扱うための関数を作成しました。

def load_image_and_label(pickled_files):
    import numpy as np

    # Each file contains 10000 images
    IMAGE_COUNT_PER_FILE = 10000
    # Image shape is 32x32x3
    ROW = 32
    COL = 32
    DIM = 3
    whole_images = np.empty((IMAGE_COUNT_PER_FILE*len(pickled_files), ROW, COL, DIM))
    whole_labels = np.empty(IMAGE_COUNT_PER_FILE*len(pickled_files))
    for i, pickled_file in enumerate(pickled_files):
        dict = _unpickle(pickled_file)
        images = dict['data'].reshape(IMAGE_COUNT_PER_FILE, DIM, ROW, COL).transpose(0, 2, 3, 1)
        whole_images[i*IMAGE_COUNT_PER_FILE:(i + 1)*IMAGE_COUNT_PER_FILE, :, :, :] = images
        labels = dict['labels']
        whole_labels[i*IMAGE_COUNT_PER_FILE:(i + 1)*IMAGE_COUNT_PER_FILE] = labels
    return (whole_images, whole_labels)

def _unpickle(pickled_file):
    import pickle

    with open(pickled_file, 'rb') as file:
        # You'll have an error without "encoding='latin1'"
        dict = pickle.load(file, encoding='latin1')
    return dict

上記をcifar10_handling.pyとして保存します。次にこの関数を利用する方法ですが、CIFAR-10のサイトからダウンロードできる"cifar-10-batches-py.tar.gz"を解凍しな中に入っている"data_batch_1"等のファイルになります。

import cifar10_handling

(X_train, y_train) = cifar10_handling.load_image_and_label(['data_batch_1',
                                                            'data_batch_2',
                                                            'data_batch_3',
                                                            'data_batch_4',
                                                            'data_batch_5'])
(X_test, y_test) = cifar10_handling.load_image_and_label(['test_batch'])

あとは正規化等の前処理を加えれば実験に使えるデータとなります。