【Python】 PyTorchで自前のロス関数を定義する
Kerasと違ってPyTorchで自前のロス関数を定義するのは大変かなと思ったのですが、Kerasとほぼ同じやり方で出来ました。
#1. ロス関数を定義して def dice_coef_loss(input, target): small_value = 1e-4 input_flattened = input.view(-1) target_flattened = target.view(-1) intersection = torch.sum(input_flattened * target_flattened) dice_coef = (2.0*intersection + small_value)/(torch.sum(input_flattened) + torch.sum(target_flattened) + small_value) return 1.0 - dice_coef #2. backwardするだけ(outputs, labelsは共にVariable) outputs = model(X) loss = dice_coef_loss(outputs, labels) loss.backward()