旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【PyTorch】マルチラベル問題で使われているFocalLossを見つけたのでメモ

マルチラベル+不均衡データを扱うのでマルチラベル問題で利用されているFocalLossの実装を探したのですが見つけました。感謝!

import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super(FocalLoss, self).__init__()
        self.gamma = gamma

    def forward(self, input, target):
        target = target.float()

        # BCELossWithLogits
        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
               ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        if len(loss.size()) == 2:
            loss = loss.sum(dim=1)
        return loss.mean()

becominghuman.ai