【Python】CIFAR-10画像データを扱う
AlexNet等の実験でCIFAR-10(CIFAR-10 and CIFAR-100 datasets)を利用することがあります。幸いKerasにはデフォルトでCIFAR-10の画像データを取り込む関数があるのですが、諸事情によりこれが利用できなかったのでCIFAR-10を扱うための関数を作成しました。
def load_image_and_label(pickled_files): import numpy as np # Each file contains 10000 images IMAGE_COUNT_PER_FILE = 10000 # Image shape is 32x32x3 ROW = 32 COL = 32 DIM = 3 whole_images = np.empty((IMAGE_COUNT_PER_FILE*len(pickled_files), ROW, COL, DIM)) whole_labels = np.empty(IMAGE_COUNT_PER_FILE*len(pickled_files)) for i, pickled_file in enumerate(pickled_files): dict = _unpickle(pickled_file) images = dict['data'].reshape(IMAGE_COUNT_PER_FILE, DIM, ROW, COL).transpose(0, 2, 3, 1) whole_images[i*IMAGE_COUNT_PER_FILE:(i + 1)*IMAGE_COUNT_PER_FILE, :, :, :] = images labels = dict['labels'] whole_labels[i*IMAGE_COUNT_PER_FILE:(i + 1)*IMAGE_COUNT_PER_FILE] = labels return (whole_images, whole_labels) def _unpickle(pickled_file): import pickle with open(pickled_file, 'rb') as file: # You'll have an error without "encoding='latin1'" dict = pickle.load(file, encoding='latin1') return dict
上記をcifar10_handling.pyとして保存します。次にこの関数を利用する方法ですが、CIFAR-10のサイトからダウンロードできる"cifar-10-batches-py.tar.gz"を解凍しな中に入っている"data_batch_1"等のファイルになります。
import cifar10_handling (X_train, y_train) = cifar10_handling.load_image_and_label(['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']) (X_test, y_test) = cifar10_handling.load_image_and_label(['test_batch'])
あとは正規化等の前処理を加えれば実験に使えるデータとなります。