逻辑回归以及损失函数

发布于:2025-09-02 ⋅ 阅读:(17) ⋅ 点赞:(0)

什么是逻辑回归?

逻辑回归是一种用于二分类(多分类)问题的概率线性模型,常用于预测某样本属于某一类别的概率。

逻辑回归的基本思想

逻辑回归试图通过一个线性函数预测自变量x与因变量y的关系,再通过sigmoid函数将线性结果z映射为概率分布。

模型表示:

z=w^Tx+b

逻辑函数(sigmoid):

\sigma (z)=\frac{1}{1+e^{-z}}

输出概率:

P(y=1|x)=\sigma (z) = \frac{1}{1+e^{-(w^Tx+b)}}

P(y=0|x)=1-P(y=1|x)


损失函数

二分类

逻辑回归的目标是让预测概率尽可能贴近真实标签。
采用对数损失函数,即交叉熵损失:

L=-[ylog(\hat{y})+(1-y)log(1-\hat{y})]

  • y 为真实标签(0或1)
  • \hat{y}=P(y=1|x)是预测概率

对N个样本的平均损失:

Loss = -\frac{1}{N}\sum_{i=1}^{N}[y^{(i)}log(\hat{y}^{(i)})+(1-y^{(i)})log(1-\hat{y}^{(i)})]

多分类

多分类问题中,假设有 C 个类别,每个样本的标签y是一个整数(0~C-1),\hat{y}是长度为C的概率分布(softmax输出):

Loss = -\frac{1}{N}\sum_{i=1}^{N}log(p_{i,y_i})

p_{i,y_i}:是第i个样本对其真实类别y_i的预测概率。


逻辑回归损失函数代码实现

二分类

import numpy as np

def sigmoid(z):
    return 1/(1+np.exp(-z))

def cross_entropy_loss(pred, gt):
    '''
    :param pred: shape(N,),预测为某类的概率
    :param gt: shape(N,),真实标签,0 or 1
    '''
    eps = 1e-15
    pred = np.clip(pred, eps, 1-eps)  # 将pred限制在[eps,1-eps]中,超过区间上界的值会被设为上界,低于下界的值会被设为下界

    loss = -np.mean(gt*np.log(pred)+(1-gt)*np.log(1-pred))
    return loss

gt = np.array([1,0,1,0])
pred_logits = np.array([2,-1,0,-2])
pred = sigmoid(pred_logits)
loss = cross_entropy_loss(pred, gt)
print(loss)

多分类

import numpy as np

def sigmoid(z):
    return 1 / (1+np.exp(-z))

def cross_entropy_loss(pred,gt):
    '''
    :param pred: [BS,classes]  每行是各类概率(softmax的输出)
    :param gt: [BS,]  每个元素是类别编号
    :return:
    '''
    eps = 1e-15
    pred = np.clip(pred, eps, 1-eps)
    N = pred.shape[0]  # bs  几个样本
    correct_class_probs = pred[np.arange(N),gt]  # 第几个样本对应的gt真实类别的预测概率
    loss = -np.mean(np.log(correct_class_probs))

    return loss

def softmax(z):
    z = z-np.max(z)
    exp_z = np.exp(z)
    return exp_z/np.sum(exp_z)

gt = np.array([2,0,1,2]) # 4个样本,3分类
pred_logits = np.array([
                        [1.2,0.2,2.1],
                        [2.0,1.1,4.0],
                        [0.5,1.5,1.8],
                        [1.1,0.3,2.5]])
pred_probs = softmax(pred_logits)
loss = cross_entropy_loss(pred_probs,gt)
print(loss)