学习资源站

YOLOv11改进-损失函数改进篇-QualityFocalLoss质量焦点损失(含代码加详细修改教程)

一、本文介绍

本文给大家带来的改进机制是 QualityFocalLoss ,其是一种CLS分类 损失函数 ,它的主要创新是 将目标的定位质量(如边界框与真实对象的重叠度量,例如IoU得分)直接融合到分类损失中 ,形成一个联合表示。这种方法能够解决传统 目标检测 中分类与定位任务之间存在的不一致性问题。QFL通过为每个类别的得分赋予根据定位质量调整的权重,使得检测 模型 在训练过程中能够更加关注那些难以定位或分类的样本。

在开始之前给大家推荐一下我的专栏,本专栏每周更新3-10篇最新前沿机制 | 包括二次创新全网无重复,以及融合改进(大家拿到之后添加另外一个改进机制在你的 数据集 上实现涨点即可撰写论文),还有各种前沿顶会改进机制 |,更有包含我所有附赠的文件(文件内集成我所有的改进机制全部注册完毕可以直接运行)和交流群和视频讲解提供给大家。

欢迎大家订阅我的专栏一起学习YOLO!


目录

一、本文介绍

二、Quality Focal Loss原理

2.1 Quality Focal Loss的基本原理

2.2 联合表示法

2.3 连续标签支持

2.4 动态调整难度

三、核心代码

三、使用方式

3.1 修改一

3.2 修改二

3.3 修改三

四 、本文总结


二、Quality Focal Loss原理

论文地址: 官方论文地址

代码地址: 官方代码地址


2.1 Quality Focal Loss的基本原理

Quality Focal Loss (QFL) 是一种用于目标检测的改进损失函数。它的主要创新是 将目标的定位质量(如边界框与真实对象的重叠度量,例如IoU得分)直接融合到分类损失中 ,形成一个联合表示。这种方法能够解决传统目标检测中分类与定位任务之间存在的不一致性问题。QFL通过为每个类别的得分赋予根据定位质量调整的权重,使得检测模型在训练过程中能够更加关注那些难以定位或分类的样本。

Quality Focal Loss(QFL)的基本原理 可以分为以下几个要点:

1. 联合表示法: QFL将定位质量(如 IoU 分数)与分类得分融合为一个联合表示,这种表示在训练和推理过程中保持一致性,有助于解决在训练和测试阶段对质量估计和分类得分使用不一致的问题。

2. 连续标签支持: 传统的Focal Loss仅支持离散的{0, 1}标签,而QFL扩展了这一概念,支持连续的标签(如IoU分数),从0到1的浮点数,更好地反映了实际数据中的情况。

3. 动态调整难度: QFL通过动态调整损失函数,使得模型在训练过程中更多地关注难以分类或定位的样本,从而提高模型的整体 性能


2.2 联合表示法

联合表示法 是Quality Focal Loss中的核心概念,它将 分类得分和定位质量(例如IoU得分)整合到单一的预测向量中 。这种表示方法解决了传统目标检测方法中训练和推理阶段质量估计与分类评分分离使用的不一致性问题。具体来说,它允许模型在预测分类的同时,估计每个检测框的定位质量,从而在非最大抑制(NMS)处理中提供更准确的排序得分,改善检测性能。

下图展示了传统方法(Existing Work)和我们的方法之间在分类和定位质量估计方面的不同表示形式的 对比:

图(a) 中,即现有工作,训练和测试阶段分别独立处理分类得分、 边界框 回归和IoU/centerness得分,这导致了训练和推理之间的不一致性。

而在 图(b) 中,我们的方法则将分类得分和IoU得分结合为一个联合表示,即在训练和测试时都使用的分类与IoU联合得分。这种联合表示提高了训练和推理之间的一致性,与Quality Focal Loss的基本原理中提到的“联合表示法”紧密相关。在Quality Focal Loss中,通过这种方式,模型能够在训练过程中考虑到每个样本的定位质量,使得损失函数能够更加关注那些定位或分类困难的样本。


2.3 连续标签支持

连续标签支持 是指在 质量焦点损失(Quality Focal Loss, QFL) 中,分类的输出标签不再是传统的0或1(如在one-hot编码中),而是可以 取任意在0到1之间的连续值 。这些连续值代表了目标定位的质量,通常是指与真实边界框的交并比(IoU)。通过这种方式,QFL可以直接在损失函数中整合定位质量,使得损失函数能够对定位不准确的样本施加更大的权重,从而激励模型学习更准确地预测边界框。

下面这张图 比较了传统目标检测方法和提出的广义焦点损失(GFL)方法之间的差异:

在传统方法(Existing Work)中,分类分支使用one-hot标签进行正类和负类的区分,而回归分支则采用Dirac delta 分布进行边界框的预测。相比之下,GFL方法引入了质量焦点损失(QFL)和分布焦点损失(DFL)。QFL通过软one-hot标签(IoU标签)进行学习,这些标签反映了边界框的定位质量。同时,DFL使用一般分布 P_{x} 来模拟边界框位置的概率分布。


2.4 动态调整难度

动态调整难度 是质量焦点损失(Quality Focal Loss, QFL)的一个特点,它允许模型在训练过程中更多地关注那些难以分类或定位的样本。这是通过调整损失函数中的一个参数来实现的,该参数会增加对模型预测不确定性较高的样本的损失值,使得模型更加集中于这些难以预测的样本上。这种方法旨在提高模型对困难样本的敏感性,帮助模型更加精确地进行分类和定位,尤其是在面对复杂或模糊的检测场景时。

下面这张图展示了Quality Focal Loss (QFL)在不同β参数下的损失曲线,以及不同分布对于相同积分目标的表示和实际边界框回归目标的分布直方图:

图(a) 展示的QFL损失曲线与基本原理中的“动态调整难度”相关,因为它展示了如何通过调整β参数来调节模型对于难以预测样本的关注度。

图(b) 展示了不同的概率分布如何针对相同的积分目标(即回归目标)进行调整,这与“表示任意分布的边界框位置”的原理有关。

图(c) 则是实际数据集中回归目标的分布,这有助于我们理解和验证QFL在实际应用中的效果。


三、核心代码

使用方式看章节四

  1. from .tal import bbox2dist
  2. import torch.nn.functional as F
  3. import math
  4. class QualityfocalLoss(nn.Module):
  5. def __init__(self, beta=2.0):
  6. super().__init__()
  7. self.beta = beta
  8. def forward(self, pred_score, gt_score, gt_target_pos_mask):
  9. # negatives are supervised by 0 quality score
  10. pred_sigmoid = pred_score.sigmoid()
  11. scale_factor = pred_sigmoid
  12. zerolabel = scale_factor.new_zeros(pred_score.shape)
  13. with torch.cuda.amp.autocast(enabled=False):
  14. loss = F.binary_cross_entropy_with_logits(pred_score, zerolabel, reduction='none') * scale_factor.pow(
  15. self.beta)
  16. scale_factor = gt_score[gt_target_pos_mask] - pred_sigmoid[gt_target_pos_mask]
  17. with torch.cuda.amp.autocast(enabled=False):
  18. loss[gt_target_pos_mask] = F.binary_cross_entropy_with_logits(pred_score[gt_target_pos_mask],
  19. gt_score[gt_target_pos_mask],
  20. reduction='none') * scale_factor.abs().pow(
  21. self.beta)
  22. return loss
  23. class SlideLoss(nn.Module):
  24. def __init__(self, loss_fcn):
  25. super(SlideLoss, self).__init__()
  26. self.loss_fcn = loss_fcn
  27. self.reduction = loss_fcn.reduction
  28. self.loss_fcn.reduction = 'none' # required to apply SL to each element
  29. def forward(self, pred, true, auto_iou=0.5):
  30. loss = self.loss_fcn(pred, true)
  31. if auto_iou < 0.2:
  32. auto_iou = 0.2
  33. b1 = true <= auto_iou - 0.1
  34. a1 = 1.0
  35. b2 = (true > (auto_iou - 0.1)) & (true < auto_iou)
  36. a2 = math.exp(1.0 - auto_iou)
  37. b3 = true >= auto_iou
  38. a3 = torch.exp(-(true - 1.0))
  39. modulating_weight = a1 * b1 + a2 * b2 + a3 * b3
  40. loss *= modulating_weight
  41. if self.reduction == 'mean':
  42. return loss.mean()
  43. elif self.reduction == 'sum':
  44. return loss.sum()
  45. else: # 'none'
  46. return loss
  47. class Focal_Loss(nn.Module):
  48. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  49. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  50. super().__init__()
  51. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  52. self.gamma = gamma
  53. self.alpha = alpha
  54. self.reduction = loss_fcn.reduction
  55. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  56. def forward(self, pred, true):
  57. loss = self.loss_fcn(pred, true)
  58. # p_t = torch.exp(-loss)
  59. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  60. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  61. pred_prob = torch.sigmoid(pred) # prob from logits
  62. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  63. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  64. modulating_factor = (1.0 - p_t) ** self.gamma
  65. loss *= alpha_factor * modulating_factor
  66. if self.reduction == 'mean':
  67. return loss.mean()
  68. elif self.reduction == 'sum':
  69. return loss.sum()
  70. else: # 'none'
  71. return loss
  72. def reduce_loss(loss, reduction):
  73. """Reduce loss as specified.
  74. Args:
  75. loss (Tensor): Elementwise loss tensor.
  76. reduction (str): Options are "none", "mean" and "sum".
  77. Return:
  78. Tensor: Reduced loss tensor.
  79. """
  80. reduction_enum = F._Reduction.get_enum(reduction)
  81. # none: 0, elementwise_mean:1, sum: 2
  82. if reduction_enum == 0:
  83. return loss
  84. elif reduction_enum == 1:
  85. return loss.mean()
  86. elif reduction_enum == 2:
  87. return loss.sum()
  88. def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
  89. """Apply element-wise weight and reduce loss.
  90. Args:
  91. loss (Tensor): Element-wise loss.
  92. weight (Tensor): Element-wise weights.
  93. reduction (str): Same as built-in losses of PyTorch.
  94. avg_factor (float): Avarage factor when computing the mean of losses.
  95. Returns:
  96. Tensor: Processed loss values.
  97. """
  98. # if weight is specified, apply element-wise weight
  99. if weight is not None:
  100. loss = loss * weight
  101. # if avg_factor is not specified, just reduce the loss
  102. if avg_factor is None:
  103. loss = reduce_loss(loss, reduction)
  104. else:
  105. # if reduction is mean, then average the loss by avg_factor
  106. if reduction == 'mean':
  107. loss = loss.sum() / avg_factor
  108. # if reduction is 'none', then do nothing, otherwise raise an error
  109. elif reduction != 'none':
  110. raise ValueError('avg_factor can not be used with reduction="sum"')
  111. return loss
  112. def varifocal_loss(pred,
  113. target,
  114. weight=None,
  115. alpha=0.75,
  116. gamma=2.0,
  117. iou_weighted=True,
  118. reduction='mean',
  119. avg_factor=None):
  120. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  121. Args:
  122. pred (torch.Tensor): The prediction with shape (N, C), C is the
  123. number of classes
  124. target (torch.Tensor): The learning target of the iou-aware
  125. classification score with shape (N, C), C is the number of classes.
  126. weight (torch.Tensor, optional): The weight of loss for each
  127. prediction. Defaults to None.
  128. alpha (float, optional): A balance factor for the negative part of
  129. Varifocal Loss, which is different from the alpha of Focal Loss.
  130. Defaults to 0.75.
  131. gamma (float, optional): The gamma for calculating the modulating
  132. factor. Defaults to 2.0.
  133. iou_weighted (bool, optional): Whether to weight the loss of the
  134. positive example with the iou target. Defaults to True.
  135. reduction (str, optional): The method used to reduce the loss into
  136. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  137. "sum".
  138. avg_factor (int, optional): Average factor that is used to average
  139. the loss. Defaults to None.
  140. """
  141. # pred and target should be of the same size
  142. assert pred.size() == target.size()
  143. pred_sigmoid = pred.sigmoid()
  144. target = target.type_as(pred)
  145. if iou_weighted:
  146. focal_weight = target * (target > 0.0).float() + \
  147. alpha * (pred_sigmoid - target).abs().pow(gamma) * \
  148. (target <= 0.0).float()
  149. else:
  150. focal_weight = (target > 0.0).float() + \
  151. alpha * (pred_sigmoid - target).abs().pow(gamma) * \
  152. (target <= 0.0).float()
  153. loss = F.binary_cross_entropy_with_logits(
  154. pred, target, reduction='none') * focal_weight
  155. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  156. return loss
  157. class Vari_focalLoss(nn.Module):
  158. def __init__(self,
  159. use_sigmoid=True,
  160. alpha=0.75,
  161. gamma=2.0,
  162. iou_weighted=True,
  163. reduction='sum',
  164. loss_weight=1.0):
  165. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  166. Args:
  167. use_sigmoid (bool, optional): Whether the prediction is
  168. used for sigmoid or softmax. Defaults to True.
  169. alpha (float, optional): A balance factor for the negative part of
  170. Varifocal Loss, which is different from the alpha of Focal
  171. Loss. Defaults to 0.75.
  172. gamma (float, optional): The gamma for calculating the modulating
  173. factor. Defaults to 2.0.
  174. iou_weighted (bool, optional): Whether to weight the loss of the
  175. positive examples with the iou target. Defaults to True.
  176. reduction (str, optional): The method used to reduce the loss into
  177. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  178. "sum".
  179. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  180. """
  181. super(Vari_focalLoss, self).__init__()
  182. assert use_sigmoid is True, \
  183. 'Only sigmoid varifocal loss supported now.'
  184. assert alpha >= 0.0
  185. self.use_sigmoid = use_sigmoid
  186. self.alpha = alpha
  187. self.gamma = gamma
  188. self.iou_weighted = iou_weighted
  189. self.reduction = reduction
  190. self.loss_weight = loss_weight
  191. def forward(self,
  192. pred,
  193. target,
  194. weight=None,
  195. avg_factor=None,
  196. reduction_override=None):
  197. """Forward function.
  198. Args:
  199. pred (torch.Tensor): The prediction.
  200. target (torch.Tensor): The learning target of the prediction.
  201. weight (torch.Tensor, optional): The weight of loss for each
  202. prediction. Defaults to None.
  203. avg_factor (int, optional): Average factor that is used to average
  204. the loss. Defaults to None.
  205. reduction_override (str, optional): The reduction method used to
  206. override the original reduction method of the loss.
  207. Options are "none", "mean" and "sum".
  208. Returns:
  209. torch.Tensor: The calculated loss
  210. """
  211. assert reduction_override in (None, 'none', 'mean', 'sum')
  212. reduction = (
  213. reduction_override if reduction_override else self.reduction)
  214. if self.use_sigmoid:
  215. loss_cls = self.loss_weight * varifocal_loss(
  216. pred,
  217. target,
  218. weight,
  219. alpha=self.alpha,
  220. gamma=self.gamma,
  221. iou_weighted=self.iou_weighted,
  222. reduction=reduction,
  223. avg_factor=avg_factor)
  224. else:
  225. raise NotImplementedError
  226. return loss_cls


三、使用方式

3.1 修改一

找到' ultralytics /utils/loss.py'文件,将上面的代码复制在文件的开头 (注意是在模块导入的后方)


3.2 修改二

按照图示进行修改即可。

  1. "下面的代码注释掉就是正常的损失函数,如果不注释使用的就是使用对应的损失失函数"
  2. # self.bce = Focal_Loss(nn.BCEWithLogitsLoss(reduction='none')) # Focal
  3. # self.bce = Vari_focalLoss() # VFLoss
  4. # self.bce = SlideLoss(nn.BCEWithLogitsLoss(reduction='none')) # SlideLoss
  5. # self.bce = QualityfocalLoss() # 目前仅持者目标检测需要注意!


3.3 修改三

按照图片进行修改,代码在图片的下方,完全替换call 内的代码即可。

替换代码再次处。

  1. def __call__(self, preds, batch):
  2. """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
  3. loss = torch.zeros(3, device=self.device) # box, cls, dfl
  4. feats = preds[1] if isinstance(preds, tuple) else preds
  5. pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
  6. (self.reg_max * 4, self.nc), 1
  7. )
  8. pred_scores = pred_scores.permute(0, 2, 1).contiguous()
  9. pred_distri = pred_distri.permute(0, 2, 1).contiguous()
  10. dtype = pred_scores.dtype
  11. batch_size = pred_scores.shape[0]
  12. imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
  13. anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
  14. # Targets
  15. targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
  16. targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
  17. gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
  18. mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
  19. # pboxes
  20. pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
  21. target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(
  22. pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
  23. anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
  24. target_scores_sum = max(target_scores.sum(), 1)
  25. # Cls loss
  26. # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
  27. if isinstance(self.bce, (nn.BCEWithLogitsLoss, Vari_focalLoss, Focal_Loss)):
  28. loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE VFLoss Focal
  29. elif isinstance(self.bce, SlideLoss):
  30. if fg_mask.sum():
  31. auto_iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True).mean()
  32. else:
  33. auto_iou = 0.1
  34. loss[1] = self.bce(pred_scores, target_scores.to(dtype), auto_iou).sum() / target_scores_sum # SlideLoss
  35. elif isinstance(self.bce, QualityfocalLoss):
  36. if fg_mask.sum():
  37. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  38. # 10.0x Faster than torch.one_hot
  39. targets_onehot = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  40. dtype=torch.int64,
  41. device=target_labels.device) # (b, h*w, 80)
  42. targets_onehot.scatter_(2, target_labels.unsqueeze(-1), 1)
  43. cls_iou_targets = pos_ious * targets_onehot
  44. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  45. targets_onehot_pos = torch.where(fg_scores_mask > 0, targets_onehot, 0)
  46. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  47. else:
  48. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  49. dtype=torch.int64,
  50. device=target_labels.device) # (b, h*w, 80)
  51. targets_onehot_pos = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  52. dtype=torch.int64,
  53. device=target_labels.device) # (b, h*w, 80)
  54. loss[1] = self.bce(pred_scores, cls_iou_targets.to(dtype), targets_onehot_pos.to(torch.bool)).sum() / max(
  55. fg_mask.sum(), 1)
  56. else:
  57. loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # 确保有损失可用
  58. # Bbox loss
  59. if fg_mask.sum():
  60. target_bboxes /= stride_tensor
  61. loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
  62. target_scores_sum, fg_mask)
  63. loss[0] *= self.hyp.box # box gain
  64. loss[1] *= self.hyp.cls # cls gain
  65. loss[2] *= self.hyp.dfl # dfl gain
  66. return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)


四 、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

希望大家阅读完以后可以给文章点点赞和评论支持一下这样购买专栏的人越多群内人越多大家交流的机会就更多了。