文章目录
前言
为了加深对Focal Loss理解,本文提供了一个简单的手写Demo。
1、FocalLoss
介绍FocalLoss的文章已经很多了,这里简单提一下:
1.1.公式定义
Focal Loss 的公式如下:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
;根据真实标签 y y y 的不同,Focal Loss 可以分为两种情况:
1) 当真实标签 y = 1 y = 1 y=1 时,公式变为:
FL ( p ) = − α ( 1 − p ) γ log ( p ) \text{FL}(p) = -\alpha (1 - p)^{\gamma} \log(p) FL(p)=−α(1−p)γlog(p)
2) 当真实标签 y = 0 y = 0 y=0 时,公式变为:
FL ( p ) = − ( 1 − α ) p γ log ( 1 − p ) \text{FL}(p) = -(1 - \alpha) p^{\gamma} \log(1 - p) FL(p)=−(1−α)pγlog(1−p)
Focal Loss 的完整公式可以写为:
FL ( y , p ) = − [ y ⋅ α ( 1 − p ) γ log ( p ) + ( 1 − y ) ⋅ ( 1 − α ) p γ log ( 1 − p ) ] \text{FL}(y, p) = -\left[ y \cdot \alpha (1 - p)^{\gamma} \log(p) + (1 - y) \cdot (1 - \alpha) p^{\gamma} \log(1 - p) \right] FL(y,p)=−[y⋅α(1−p)γlog(p)+(1−y)⋅(1−α)pγlog(1−p)]
其中 p p p表示经过sigmoid的预测值。本文实现的是完整版的公式,而且没有引入额外的封装函数。
2、代码
import torch
import torch.nn as nn
import torch.nn.functional as F
# focal_loss = pos_loss + neg_loss
# if y == 1: pos_loss = -|1-p|^gamma * log(p)
# if y == 0: neg_loss = -|0-p|^gamma * log(1-p)
class FocalLoss(nn.Module):
def __init__(self,alpha=0.25,gamma=2.0,reduce='sum'):
super(FocalLoss,self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduce = reduce
def forward(self,classifications,targets):
alpha = self.alpha
gamma = self.gamma
classifications = classifications.view(-1)
p = torch.sigmoid(classifications)
targets = targets.view(-1)
# 获取pos 和 neg 的索引
pos_idx = torch.nonzero(targets==1).view(-1)
neg_idx = torch.nonzero(targets==0).view(-1)
# step1: cpt pos loss
pos_loss = -(1-p[pos_idx]).abs() ** gamma * torch.log(p[pos_idx])
# step2: cpt neg loss
neg_loss = -(0-p[neg_idx]).abs() ** gamma * torch.log(1-p[neg_idx])
loss = torch.cat((pos_loss, neg_loss), dim=0)
# targets 也需要重新排序 来跟loss值对应
concat_idx = torch.cat((pos_idx, neg_idx), dim=0)
targets = targets[concat_idx]
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if self.reduce=='sum':
loss = loss.sum()
elif self.reduce=='mean':
loss = loss.mean()
else:
raise ValueError('reduce type is wrong!')
return loss
# ---test unit --- #
def main():
# single cls focal loss
focal_loss = FocalLoss()
pred = torch.FloatTensor([0.1,0.9,0.2,0.8,0.7]) # nb_anchors :5
tgt = torch.FloatTensor([0,1,0,1,1]) # neg:0 pos:1 ; no ignore
loss = focal_loss(pred, tgt)
print('loss:', loss)
总结
本文只是简单实现了一个二分类的FocalLoss,旨在加深读者对其理解。欢迎批评指正。