旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】不均衡な2クラスセグメンテーション問題に適用するロス関数のメモ

この論文で不均衡な2クラスセグメンテーション問題に適用するロス関数が提案されていたのでメモします。ディープラーニングを使ったセグメンテーションでデータが極端に不均衡(例えば画像のほとんどが0で、1はちょっとだけ)の場合、工夫をしないと学習が上手くいかないのですが、論文ではロス関数の工夫によりこの問題を回避しようとしています。

下記の記事ではセグメンテーションのロス関数に以下のダイス係数を利用しました。

def dice_coef(y_true, y_pred):
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    intersection = K.sum(y_true * y_pred)
    return 2.0 * intersection / (K.sum(y_true) + K.sum(y_pred) + 1)

def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

ni4muraano.hatenablog.com

論文ではTversky loss functionという関数を提案しており、以下のようになります。ただこれどこかで見たと思ったらIOUの修正バージョンですね。

ALPHA = 0.3 # 0~1.0の値、Precision重視ならALPHAを大きくする
BETA = 1.0 - ALPHA # 0~1.0の値、Recall重視ならALPHAを小さくする

def tversky_index(y_true, y_pred):
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    intersection = K.sum(y_true * y_pred)
    false_positive = K.sum((1.0 - y_true) * y_pred)
    false_negative = K.sum(y_true * (1.0 - y_pred))
    return intersection / (intersection + ALPHA*false_positive + BETA*false_negative)

def tversky_loss(y_true, y_pred):
    return 1.0 - tversky_index(y_true, y_pred)