旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】データ拡張手法Mixupの擬似コード

PyTorchのカスタムデータセットにmixupをどう入れ込むかの擬似コードメモです。

# これをDatasetの__get_item__に入れ込めば良い
def _apply_mixup(self, image1, label1, idx1, image_size):
    # mixする画像のインデックスを拾ってくる
    idx2 = self._get_pair_index(idx1)
    # 画像の準備
    image2 = cv2.imread(self._image_paths[idx2]).astype(np.float32)
    image2 = cv2.resize(image2, (image_size, image_size))
    image2 = normalize(image2)
    # ラベルの準備(アノテーションファイルは1,0,0,0のように所属するクラスが記されている)
    label2 = np.loadtxt(self._annotation_paths[idx2], dtype=np.float32, delimiter=',')
    # 混ぜる割合を決めて
    r = np.random.beta(self._alpha, self._alpha, 1)[0]
    # 画像、ラベルを混ぜる(クリップしないと範囲外になることがある)
    mixed_image = np.clip(r*image1 + (1 - r)*image2, 0, 1)
    mixed_label = np.clip(r*label1 + (1 - r)*label2, 0, 1)
    return mixed_image, mixed_label

# Datasetの__get_item__のidx以外のindexを取得する
def _get_pair_index(self, idx):    
    r = list(range(0, idx)) + list(range(idx+1, len(self._image_paths)))
    return random.choice(r)

mixupについての説明は以下。 qiita.com