旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】 MNIST手書き文字データを扱う

手書き文字のサンプルがダウンロードできる"THE MNIST DATABASE of handwritten digits"(MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges)ですが、gzファイルをダウンロードした後のデータ処理ロジックをメモします。mnist_handling.pyというファイルを作成し、以下を記述します。

import gzip
import numpy as np

def load_image_and_label(gz_image_file_path, gz_label_file_path):
    images = _load_image(gz_image_file_path)
    labels = _load_label(gz_label_file_path)
    return (images, labels)    

def _load_image(gz_file_path):
    with gzip.open(gz_file_path, 'rb') as file:
        content = file.read()
        # 手書き画像枚数
        image_count = int.from_bytes(content[4:8], 'big')
        # 手書き画像行数
        row = int.from_bytes(content[8:12], 'big')
        # 手書き画像列数
        col = int.from_bytes(content[12:16], 'big')
        # 手書き画像データの読み込み
        images = np.frombuffer(content, np.uint8, -1, 16)
    # imagesのshapeは(60000, 28, 28, 1)
    images = images.reshape(image_count, row, col, 1)
    return images

def _load_label(gz_file_path):
    with gzip.open(gz_file_path, 'rb') as file:
        content = file.read()
        label_count = int.from_bytes(content[4:8], 'big')
        labels = np.frombuffer(content, np.uint8, -1, 8)
    return labels

以下はmnist_handlingモジュールの利用例です。

import mnist_handling

training_image_file = 'train-images-idx3-ubyte.gz'
training_label_file = 'train-labels-idx1-ubyte.gz'
training_images, training_labels = mnist_handling.load_image_and_label(training_image_file, training_label_file)
    
test_image_file = 't10k-images-idx3-ubyte.gz'
test_label_file = 't10k-labels-idx1-ubyte.gz'
test_images, test_labels = mnist_handling.load_image_and_label(test_image_file, test_label_file)