手撕FocalLoss

发布于:2025-02-23 ⋅ 阅读:(19) ⋅ 点赞:(0)


前言

 为了加深对Focal Loss理解,本文提供了一个简单的手写Demo。

1、FocalLoss

 介绍FocalLoss的文章已经很多了,这里简单提一下:

1.1.公式定义

 Focal Loss 的公式如下:

FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t) FL(pt)=αt(1pt)γlog(pt)

 ;根据真实标签 y y y 的不同,Focal Loss 可以分为两种情况:

 1) 当真实标签 y = 1 y = 1 y=1 时,公式变为:

FL ( p ) = − α ( 1 − p ) γ log ⁡ ( p ) \text{FL}(p) = -\alpha (1 - p)^{\gamma} \log(p) FL(p)=α(1p)γlog(p)

 2) 当真实标签 y = 0 y = 0 y=0 时,公式变为:

FL ( p ) = − ( 1 − α ) p γ log ⁡ ( 1 − p ) \text{FL}(p) = -(1 - \alpha) p^{\gamma} \log(1 - p) FL(p)=(1α)pγlog(1p)

 Focal Loss 的完整公式可以写为:

FL ( y , p ) = − [ y ⋅ α ( 1 − p ) γ log ⁡ ( p ) + ( 1 − y ) ⋅ ( 1 − α ) p γ log ⁡ ( 1 − p ) ] \text{FL}(y, p) = -\left[ y \cdot \alpha (1 - p)^{\gamma} \log(p) + (1 - y) \cdot (1 - \alpha) p^{\gamma} \log(1 - p) \right] FL(y,p)=[yα(1p)γlog(p)+(1y)(1α)pγlog(1p)]

其中 p p p表示经过sigmoid的预测值。本文实现的是完整版的公式,而且没有引入额外的封装函数。

2、代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# focal_loss = pos_loss + neg_loss 
# if y == 1: pos_loss = -|1-p|^gamma * log(p)  
# if y == 0: neg_loss = -|0-p|^gamma * log(1-p)
class FocalLoss(nn.Module):
    def __init__(self,alpha=0.25,gamma=2.0,reduce='sum'):
        super(FocalLoss,self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self,classifications,targets):
        alpha = self.alpha
        gamma = self.gamma
        classifications = classifications.view(-1)
        p = torch.sigmoid(classifications)
        targets = targets.view(-1)
        # 获取pos 和 neg 的索引
        pos_idx = torch.nonzero(targets==1).view(-1)
        neg_idx = torch.nonzero(targets==0).view(-1)
        # step1: cpt pos loss       
        pos_loss = -(1-p[pos_idx]).abs() ** gamma * torch.log(p[pos_idx])
        # step2: cpt neg loss 
        neg_loss = -(0-p[neg_idx]).abs() ** gamma * torch.log(1-p[neg_idx])
        loss = torch.cat((pos_loss, neg_loss), dim=0)
        # targets 也需要重新排序 来跟loss值对应 
        concat_idx = torch.cat((pos_idx, neg_idx), dim=0)
        targets = targets[concat_idx]
        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            loss = alpha_t * loss
        if self.reduce=='sum':
            loss = loss.sum()
        elif self.reduce=='mean':
            loss = loss.mean()
        else:
            raise ValueError('reduce type is wrong!')
        return loss


# ---test unit --- #
def main():
    # single cls focal loss 
    focal_loss = FocalLoss()
    pred = torch.FloatTensor([0.1,0.9,0.2,0.8,0.7]) # nb_anchors :5
    tgt  = torch.FloatTensor([0,1,0,1,1])           # neg:0 pos:1 ; no ignore
    loss = focal_loss(pred, tgt)
    print('loss:', loss) 

总结

 本文只是简单实现了一个二分类的FocalLoss,旨在加深读者对其理解。欢迎批评指正。


网站公告

今日签到

点亮在社区的每一天
去签到