旅行好きなソフトエンジニアの備忘録

プログラミングや技術関連のメモを始めました

【Python】 PyTorchで自前のロス関数を定義する

Kerasと違ってPyTorchで自前のロス関数を定義するのは大変かなと思ったのですが、Kerasとほぼ同じやり方で出来ました。

#1. ロス関数を定義して
def dice_coef_loss(input, target):
    small_value = 1e-4

    input_flattened = input.view(-1)
    target_flattened = target.view(-1)
    intersection = torch.sum(input_flattened * target_flattened)
    dice_coef = (2.0*intersection + small_value)/(torch.sum(input_flattened) + torch.sum(target_flattened) + small_value)
    return 1.0 - dice_coef

#2. backwardするだけ(outputs, labelsは共にVariable)
outputs = model(X)
loss = dice_coef_loss(outputs, labels)
loss.backward()