【Python】 MNIST手書き文字データを扱う
手書き文字のサンプルがダウンロードできる"THE MNIST DATABASE of handwritten digits"(MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges)ですが、gzファイルをダウンロードした後のデータ処理ロジックをメモします。mnist_handling.pyというファイルを作成し、以下を記述します。
import gzip import numpy as np def load_image_and_label(gz_image_file_path, gz_label_file_path): images = _load_image(gz_image_file_path) labels = _load_label(gz_label_file_path) return (images, labels) def _load_image(gz_file_path): with gzip.open(gz_file_path, 'rb') as file: content = file.read() # 手書き画像枚数 image_count = int.from_bytes(content[4:8], 'big') # 手書き画像行数 row = int.from_bytes(content[8:12], 'big') # 手書き画像列数 col = int.from_bytes(content[12:16], 'big') # 手書き画像データの読み込み images = np.frombuffer(content, np.uint8, -1, 16) # imagesのshapeは(60000, 28, 28, 1) images = images.reshape(image_count, row, col, 1) return images def _load_label(gz_file_path): with gzip.open(gz_file_path, 'rb') as file: content = file.read() label_count = int.from_bytes(content[4:8], 'big') labels = np.frombuffer(content, np.uint8, -1, 8) return labels
以下はmnist_handlingモジュールの利用例です。
import mnist_handling training_image_file = 'train-images-idx3-ubyte.gz' training_label_file = 'train-labels-idx1-ubyte.gz' training_images, training_labels = mnist_handling.load_image_and_label(training_image_file, training_label_file) test_image_file = 't10k-images-idx3-ubyte.gz' test_label_file = 't10k-labels-idx1-ubyte.gz' test_images, test_labels = mnist_handling.load_image_and_label(test_image_file, test_label_file)