【Python】データ拡張手法Mixupの擬似コード
PyTorchのカスタムデータセットにmixupをどう入れ込むかの擬似コードメモです。
# これをDatasetの__get_item__に入れ込めば良い def _apply_mixup(self, image1, label1, idx1, image_size): # mixする画像のインデックスを拾ってくる idx2 = self._get_pair_index(idx1) # 画像の準備 image2 = cv2.imread(self._image_paths[idx2]).astype(np.float32) image2 = cv2.resize(image2, (image_size, image_size)) image2 = normalize(image2) # ラベルの準備(アノテーションファイルは1,0,0,0のように所属するクラスが記されている) label2 = np.loadtxt(self._annotation_paths[idx2], dtype=np.float32, delimiter=',') # 混ぜる割合を決めて r = np.random.beta(self._alpha, self._alpha, 1)[0] # 画像、ラベルを混ぜる(クリップしないと範囲外になることがある) mixed_image = np.clip(r*image1 + (1 - r)*image2, 0, 1) mixed_label = np.clip(r*label1 + (1 - r)*label2, 0, 1) return mixed_image, mixed_label # Datasetの__get_item__のidx以外のindexを取得する def _get_pair_index(self, idx): r = list(range(0, idx)) + list(range(idx+1, len(self._image_paths))) return random.choice(r)
mixupについての説明は以下。 qiita.com