PyTorch 损失函数

损失函数(Loss Function)衡量模型预测值与真实值之间的差距,是神经网络训练的核心指引——优化器通过最小化损失函数来更新模型参数。

PyTorch 在 torch.nn 模块中内置了十余种常用损失函数,覆盖分类、回归、排序等主要任务类型。


1. 损失函数基础

基本用法

所有 PyTorch 损失函数都是 nn.Module 的子类,使用方式统一:

实例

import torch
import torch.nn as nn

# 1. 实例化损失函数
criterion = nn.CrossEntropyLoss()

# 2. 计算损失(预测值在前,真实值在后)
loss = criterion(predictions, targets)

# 3. 反向传播
loss.backward()

预测值的形态约定

不同损失函数对输入的形态要求不同,这是初学者最容易出错的地方:

损失函数预测值(input)形态标签(target)形态
CrossEntropyLoss(N, C) 原始 logits(N,) 整数类别索引
BCELoss(N,) 经过 Sigmoid 的概率(N,) 0/1 浮点数
BCEWithLogitsLoss(N,) 原始 logits(N,) 0/1 浮点数
MSELoss(N,) 任意实数(N,) 任意实数
NLLLoss(N, C) 经过 log_softmax 的概率(N,) 整数类别索引

N = batch size,C = 类别数


2. 分类任务损失函数

2.1 CrossEntropyLoss 交叉熵损失

最常用的多分类损失函数,内部自动执行 Softmax + 对数 + 负号,无需手动对模型输出做 Softmax。

数学公式:

Loss = -sum(y_c * log(p_c))

其中 p_c = exp(x_c) / sum_j exp(x_j) 是 Softmax 输出。

实例

import torch
import torch.nn as nn

criterion = nn.CrossEntropyLoss()

# 模型输出:原始 logits,shape (batch_size, num_classes)
# 不需要提前做 Softmax!
predictions = torch.tensor([
    [2.0, 0.5, 0.3],   # 样本1,最可能是类别 0
    [0.1, 3.0, 0.2],   # 样本2,最可能是类别 1
    [0.2, 0.1, 4.0],   # 样本3,最可能是类别 2
])

# 标签:整数类别索引,shape (batch_size,)
targets = torch.tensor([0, 1, 2])

loss = criterion(predictions, targets)
print(f"Loss: {loss.item():.4f}")  # Loss: 0.1763

支持软标签(Label Smoothing):

实例

# 标签平滑,缓解过拟合,常用于图像分类竞赛
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# 也支持直接传入软标签(概率分布)
soft_targets = torch.tensor([
    [0.9, 0.05, 0.05],
    [0.05, 0.9, 0.05],
])
predictions = torch.randn(2, 3)
loss = criterion(predictions, soft_targets)

适用场景:多分类(猫/狗/鸟)、图像分类、文本分类等所有多分类任务。


2.2 BCELoss 二元交叉熵损失

专用于二分类多标签分类任务,输入必须是经过 Sigmoid 处理后的概率值(0~1)。

数学公式:

Loss = -[y * log(p) + (1-y) * log(1-p)]

实例

criterion = nn.BCELoss()

# 模型输出必须先经过 Sigmoid,取值范围 (0, 1)
raw_output = torch.tensor([2.0, -1.0, 0.5, -3.0])
predictions = torch.sigmoid(raw_output)   # [0.88, 0.27, 0.62, 0.05]

# 标签:浮点型 0.0 或 1.0
targets = torch.tensor([1.0, 0.0, 1.0, 0.0])

loss = criterion(predictions, targets)
print(f"Loss: {loss.item():.4f}")  # Loss: 0.2824

# 多标签分类(每个样本可属于多个类别)
# predictions shape: (batch_size, num_labels)
predictions_ml = torch.sigmoid(torch.randn(4, 5))
targets_ml     = torch.randint(0, 2, (4, 5)).float()
loss_ml = criterion(predictions_ml, targets_ml)

BCELoss 要求输入在 (0, 1) 范围内,传入原始 logits 会导致数值不稳定甚至 NaN。推荐使用下方的 BCEWithLogitsLoss


2.3 BCEWithLogitsLoss

BCELoss 的改进版,内部自动执行 Sigmoid,数值更稳定,推荐优先使用。

实例

criterion = nn.BCEWithLogitsLoss()

# 直接传入原始 logits,无需手动 Sigmoid
predictions = torch.tensor([2.0, -1.0, 0.5, -3.0])
targets     = torch.tensor([1.0,  0.0, 1.0,  0.0])

loss = criterion(predictions, targets)
print(f"Loss: {loss.item():.4f}")

# 等价于(但数值稳定性更好):
# loss = BCELoss(Sigmoid(predictions), targets)

带正样本权重(处理类别不平衡):

实例

# pos_weight:正样本权重,值越大越关注正样本
# 例如负样本是正样本的 10 倍,设 pos_weight=10
pos_weight = torch.tensor([10.0])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

适用场景:二分类(垃圾邮件判断)、多标签分类(文章多标签打标)、目标检测(前景/背景判断)。


2.4 NLLLoss 负对数似然损失

需要手动对模型输出做 log_softmax,灵活性更高。CrossEntropyLoss = LogSoftmax + NLLLoss

实例

criterion = nn.NLLLoss()

# 必须先手动做 log_softmax
raw_output   = torch.randn(4, 3)   # (batch, num_classes)
log_probs    = torch.log_softmax(raw_output, dim=1)

targets = torch.tensor([0, 2, 1, 0])
loss = criterion(log_probs, targets)

使用场景:需要在中间步骤使用 log 概率时(如 CTC、Beam Search);其他情况优先用 CrossEntropyLoss


3. 回归任务损失函数

3.1 MSELoss 均方误差

最经典的回归损失,对大误差非常敏感(因为平方会放大大误差的影响)。

数学公式:

MSELoss = (1/N) * sum((y_i - y_hat_i)^2)

实例

criterion = nn.MSELoss()

predictions = torch.tensor([2.5, 0.5, 2.0, 8.0])
targets     = torch.tensor([3.0, -0.5, 2.0, 7.0])

loss = criterion(predictions, targets)
print(f"MSE Loss: {loss.item():.4f}")  # MSE Loss: 0.3750

# 手动验证
manual = ((predictions - targets) ** 2).mean()
print(f"手动计算: {manual.item():.4f}")  # 0.3750

适用场景:房价预测、温度预测等连续值回归,数据中没有明显离群点时效果好。


3.2 L1Loss 平均绝对误差

离群点(outlier)更鲁棒,因为取绝对值而非平方,大误差不会被过度放大。

数学公式:

L1Loss = (1/N) * sum(|y_i - y_hat_i|)

实例

criterion = nn.L1Loss()

predictions = torch.tensor([2.5, 0.5, 2.0, 8.0])
targets     = torch.tensor([3.0, -0.5, 2.0, 7.0])

loss = criterion(predictions, targets)
print(f"L1 Loss: {loss.item():.4f}")  # L1 Loss: 0.5000

3.3 SmoothL1Loss Huber 损失

融合 MSE 和 L1 的优点:误差小时用 MSE(平滑,梯度稳定),误差大时用 L1(抗离群点)。目标检测(Faster R-CNN)中的标准损失。

数学公式:

SmoothL1(x) = 0.5*x^2 if |x| < 1, else |x| - 0.5

实例

criterion = nn.SmoothL1Loss()

predictions = torch.tensor([2.5, 0.5, 2.0, 8.0])
targets     = torch.tensor([3.0, -0.5, 2.0, 7.0])

loss = criterion(predictions, targets)
print(f"SmoothL1 Loss: {loss.item():.4f}")

三种回归损失对比:

实例

import torch
import torch.nn as nn

predictions = torch.tensor([0.0, 1.0, 5.0, 10.0])  # 模拟不同大小的误差
targets     = torch.zeros(4)

for name, fn in [("MSELoss", nn.MSELoss(reduction='none')),
                 ("L1Loss",  nn.L1Loss(reduction='none')),
                 ("SmoothL1",nn.SmoothL1Loss(reduction='none'))]:
    losses = fn(predictions, targets)
    print(f"{name:12s}: {[f'{l:.2f}' for l in losses.tolist()]}")
输出示例:
MSELoss     : ['0.00', '1.00', '25.00', '100.00']   # 大误差被平方放大
L1Loss      : ['0.00', '1.00',  '5.00',  '10.00']   # 线性增长
SmoothL1    : ['0.00', '0.50',  '4.50',   '9.50']   # 中间值

适用场景:目标检测的边框回归、深度估计等,同时存在小误差和离群点的回归任务。


4. 进阶损失函数

4.1 HuberLoss

SmoothL1Loss 的通用版,允许自定义切换阈值 delta(默认 1.0)。

实例

# delta 控制 MSE 和 L1 的切换点
criterion = nn.HuberLoss(delta=1.5)

predictions = torch.randn(10)
targets     = torch.randn(10)
loss = criterion(predictions, targets)

4.2 KLDivLoss KL 散度

衡量两个概率分布的差异,常用于知识蒸馏变分自编码器(VAE)

数学公式:

KL(P || Q) = sum(P(i) * log(P(i) / Q(i)))

实例

criterion = nn.KLDivLoss(reduction='batchmean')

# input 必须是 log 概率(对数概率),target 是普通概率
log_predictions = torch.log_softmax(torch.randn(4, 5), dim=1)
targets         = torch.softmax(torch.randn(4, 5), dim=1)

loss = criterion(log_predictions, targets)
print(f"KL Div Loss: {loss.item():.4f}")

知识蒸馏中的典型用法:

实例

temperature = 4.0  # 温度系数,越大越软化

# 教师模型输出
teacher_logits = torch.randn(32, 10)
# 学生模型输出
student_logits = torch.randn(32, 10)

soft_labels = torch.softmax(teacher_logits / temperature, dim=1)
soft_preds  = torch.log_softmax(student_logits / temperature, dim=1)

distill_loss = nn.KLDivLoss(reduction='batchmean')(soft_preds, soft_labels)
distill_loss *= temperature ** 2  # 还原梯度量级

4.3 MarginRankingLoss 排序损失

判断两个输入的相对顺序,常用于排序学习相似度学习

实例

criterion = nn.MarginRankingLoss(margin=0.5)

# x1 应该比 x2 更靠近 target(y=1 表示 x1 > x2)
x1 = torch.tensor([0.8, 0.3, 0.6])
x2 = torch.tensor([0.2, 0.7, 0.5])
y  = torch.tensor([1.0, -1.0, 1.0])  # 1: x1>x2, -1: x1<x2

loss = criterion(x1, x2, y)

4.4 TripletMarginLoss 三元组损失

用于度量学习,要求 anchor(锚点)positive(同类) 的距离小于与 negative(异类) 的距离。

实例

criterion = nn.TripletMarginLoss(margin=1.0)

# 每个向量维度为 embedding_dim
anchor   = torch.randn(32, 128)   # 锚点样本
positive = torch.randn(32, 128)   # 同类样本
negative = torch.randn(32, 128)   # 异类样本

loss = criterion(anchor, positive, negative)
# 目标:dist(anchor, positive) + margin < dist(anchor, negative)

适用场景:人脸识别、图像检索、Few-Shot Learning。


4.5 CTCLoss 序列标注损失

用于输入与输出长度不对齐的序列任务,如语音识别(声学序列 -> 文字序列)、手写识别。

实例

criterion = nn.CTCLoss(blank=0)  # blank 标签索引

# log_probs: (T, N, C) T=时间步, N=batch, C=类别数
T, N, C = 50, 4, 20
log_probs    = torch.log_softmax(torch.randn(T, N, C), dim=2)
targets      = torch.randint(1, C, (N * 10,))   # 拼接的目标序列
input_lengths  = torch.full((N,), T, dtype=torch.long)
target_lengths = torch.full((N,), 10, dtype=torch.long)

loss = criterion(log_probs, targets, input_lengths, target_lengths)

5. reduction 参数详解

所有损失函数都支持 reduction 参数,控制如何对样本损失进行汇总:

实例

predictions = torch.tensor([1.0, 2.0, 3.0, 4.0])
targets     = torch.tensor([1.5, 2.5, 2.0, 5.0])

# 单个样本的误差: [0.25, 0.25, 1.00, 1.00]

# mean(默认):对所有样本取平均
loss_mean = nn.MSELoss(reduction='mean')(predictions, targets)
print(f"mean:  {loss_mean.item():.4f}")   # 0.6250

# sum:对所有样本求和
loss_sum  = nn.MSELoss(reduction='sum')(predictions, targets)
print(f"sum:   {loss_sum.item():.4f}")    # 2.5000

# none:返回每个样本的独立损失(常用于加权)
loss_none = nn.MSELoss(reduction='none')(predictions, targets)
print(f"none:  {loss_none.tolist()}")     # [0.25, 0.25, 1.0, 1.0]

reduction='none' 的实际应用——对不同样本加权:

实例

# 给误差大的样本更高权重(焦点损失思路)
per_sample_loss = nn.MSELoss(reduction='none')(predictions, targets)
weights = torch.tensor([1.0, 1.0, 2.0, 2.0])   # 手动设置权重
weighted_loss = (per_sample_loss * weights).mean()

6. 类别权重与样本权重

类别权重(处理类别不平衡)

当数据集中某些类别样本极少时,可以给少数类更高的权重:

实例

# 假设 3 分类,类别 0 有 1000 个,类别 1 有 100 个,类别 2 有 50 个
# 权重与频率成反比
class_counts = torch.tensor([1000.0, 100.0, 50.0])
weights = 1.0 / class_counts
weights = weights / weights.sum() * len(weights)  # 归一化

criterion = nn.CrossEntropyLoss(weight=weights)

忽略特定标签

在语义分割等任务中,常需要忽略边界像素(标签为 255):

实例

# ignore_index:计算损失时忽略该标签
criterion = nn.CrossEntropyLoss(ignore_index=255)

# 语义分割场景
predictions = torch.randn(2, 21, 256, 256)   # (N, C, H, W)
targets     = torch.randint(0, 22, (2, 256, 256))
targets[targets == 21] = 255                  # 边界标注为 255

loss = criterion(predictions, targets)

7. 自定义损失函数

当内置损失函数无法满足需求时,可以通过两种方式自定义:

方式一:函数式(简单)

实例

import torch
import torch.nn.functional as F

def focal_loss(predictions, targets, gamma=2.0, alpha=0.25):
    """
    Focal Loss:解决目标检测中正负样本严重不平衡问题
    对容易分类的样本降低权重,让模型聚焦于困难样本
    """

    ce_loss = F.cross_entropy(predictions, targets, reduction='none')
    pt = torch.exp(-ce_loss)                          # 预测正确的概率
    focal_weight = alpha * (1 - pt) ** gamma          # 难样本权重更高
    return (focal_weight * ce_loss).mean()


# 使用
predictions = torch.randn(8, 10)
targets     = torch.randint(0, 10, (8,))
loss = focal_loss(predictions, targets)

方式二:继承 nn.Module(推荐)

实例

import torch
import torch.nn as nn

class DiceLoss(nn.Module):
    """
    Dice Loss:常用于图像分割,直接优化 Dice 系数
    对类别不平衡(如小目标分割)比 CrossEntropy 更鲁棒
    """

    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, predictions, targets):
        # predictions: (N, C, H, W) -> 经过 sigmoid 的概率
        # targets:     (N, C, H, W) -> one-hot 编码标签
        predictions = torch.sigmoid(predictions)

        # 展平到 (N, -1)
        pred_flat   = predictions.view(predictions.size(0), -1)
        target_flat = targets.view(targets.size(0), -1).float()

        intersection = (pred_flat * target_flat).sum(dim=1)
        dice = (2.0 * intersection + self.smooth) / (
            pred_flat.sum(dim=1) + target_flat.sum(dim=1) + self.smooth
        )
        return 1 - dice.mean()


class CombinedLoss(nn.Module):
    """
    组合损失:CrossEntropy + Dice,兼顾像素级分类和区域重叠
    图像分割常用组合
    """

    def __init__(self, ce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.ce_weight   = ce_weight
        self.dice_weight = dice_weight
        self.ce   = nn.CrossEntropyLoss()
        self.dice = DiceLoss()

    def forward(self, predictions, targets):
        return (self.ce_weight   * self.ce(predictions, targets) +
                self.dice_weight * self.dice(predictions, targets))


# 使用
criterion = CombinedLoss(ce_weight=0.4, dice_weight=0.6)

8. 损失函数选择指南

按任务类型选择

任务类型推荐损失函数备注
多分类CrossEntropyLoss最通用,优先选择
多分类(类别不平衡)CrossEntropyLoss(weight=...)给少数类加权
多分类(标签噪声大)CrossEntropyLoss(label_smoothing=0.1)防止过拟合
二分类BCEWithLogitsLoss比 BCELoss 更稳定
多标签分类BCEWithLogitsLoss每个标签独立判断
目标检测(分类头)CrossEntropyLoss / Focal Loss正负样本不平衡时用 Focal
目标检测(回归头)SmoothL1Loss / GIoULoss标准做法
普通回归MSELoss无离群点时首选
含离群点的回归HuberLoss / SmoothL1Loss鲁棒回归
图像分割CrossEntropyLoss + DiceLoss组合使用效果更好
语音识别CTCLoss序列对齐
度量学习 / 人脸识别TripletMarginLoss学习特征空间距离
知识蒸馏KLDivLoss学习软标签分布

常见误用与注意事项

实例

# 错误:CrossEntropyLoss 之前手动做了 Softmax
output = torch.softmax(model(x), dim=1)   # 多余的 softmax
loss   = nn.CrossEntropyLoss()(output, targets)  # 内部还会再做一次

# 正确:直接传入 logits
output = model(x)  # 原始 logits
loss   = nn.CrossEntropyLoss()(output, targets)


# 错误:BCELoss 传入未 sigmoid 的原始值
loss = nn.BCELoss()(model(x), targets)    # 可能超出 [0,1],数值不稳定

# 正确:使用 BCEWithLogitsLoss
loss = nn.BCEWithLogitsLoss()(model(x), targets)


# 错误:标签类型不匹配(整型 vs 浮点型)
targets = torch.tensor([1, 0, 1])                     # int 型
loss    = nn.BCEWithLogitsLoss()(preds, targets)       # 报错!

# 正确:BCEWithLogitsLoss 需要 float 标签
targets = torch.tensor([1.0, 0.0, 1.0])               # float 型
loss    = nn.BCEWithLogitsLoss()(preds, targets)       # 正确


# 错误:loss 没有 .item(),导致计算图不断积累,显存溢出
total_loss += loss      # loss 是张量,持有计算图

# 正确:用 .item() 取出标量
total_loss += loss.item()

完整训练示例

实例

import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 模型 & 损失函数 & 优化器
model     = MyModel().to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    model.train()
    total_loss, correct = 0.0, 0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)          # 原始 logits

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)    # .item() 取标量
        correct    += (outputs.argmax(1) == labels).sum().item()

    avg_loss = total_loss / len(train_loader.dataset)
    accuracy = correct / len(train_loader.dataset)
    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | Acc: {accuracy:.4f}")