【PyTorch】マルチラベル問題で使われているFocalLossを見つけたのでメモ
マルチラベル+不均衡データを扱うのでマルチラベル問題で利用されているFocalLossの実装を探したのですが見つけました。感謝!
import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, gamma=2): super(FocalLoss, self).__init__() self.gamma = gamma def forward(self, input, target): target = target.float() # BCELossWithLogits max_val = (-input).clamp(min=0) loss = input - input * target + max_val + \ ((-max_val).exp() + (-input - max_val).exp()).log() invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0)) loss = (invprobs * self.gamma).exp() * loss if len(loss.size()) == 2: loss = loss.sum(dim=1) return loss.mean()