学习资源站

YOLOv11改进-检测头篇-辅助特征融合模块ASFF改进yolov11检测头(适配YOLOv11版本,全网独家创新)

一、本文介绍

本文给大家带来的最新改进机制是利用 辅助特征融合模块ASFF改进yolov11检测头 形成新的检测头,其主要创新是引入了一种自适应的空间 特征融合 方式,有效地过滤掉冲突信息,从而增强了尺度不变性。经过我的实验验证,修改后的检测头在所有的检测目标上均有大幅度的涨点效果,此版本为三头版本, 后期我会在该检测头的基础上进行二次创新形成四头版本的Detect_ASFF助力小目标检测 ,本文的检测头非常推荐大家使用 (本文的代码本是我个人创新)



二、 ASFF 的基本框架原理

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转


ASFF(自适应空间特征融合)方法针对单次对象检测任务提出,解决了不同特征尺度间的一致性问题。其主要创新是引入了一种自适应的空间特征融合方式,有效地过滤掉冲突信息,从而增强了尺度不变性。研究表明,将ASFF应用于 YOLOv3 可以显著提高在MS COCO数据集上的检测 性能 ,实现了速度与准确性的平衡。ASFF方法可以通过反向传播进行训练,与模型无关,并且引入的计算开销很小,使其成为现有对象检测框架的一种实用增强。

ASFF的创新点主要包括:

1. 自适应空间特征融合:提出了一种新的金字塔特征融合策略,能够空间过滤冲突信息,压制不同尺度特征间的不一致性。

2. 改善尺度不变性:通过ASFF策略,显著提升了特征的尺度不变性,有助于提高对象检测的准确性。

3. 低推理开销:在提升检测性能的同时,几乎不增加额外的推理开销。

这些创新使ASFF成为单次对象检测领域的一个重要进展,特别是对处理不同尺度对象的能力的提升, 所以将其对于一些单一尺度检测的Neck适合是不适用的大家需要注意这一点

这张图片展示了自适应空间特征融合(ASFF)机制的工作原理,它是用于单次对象检测的。在这种结构中,不同层级的特征(表示为不同颜色的层)首先通过各自的步幅(stride)进行下采样或上采样,以便所有特征具有相同的空间维度。

- Level 1、Level 2和Level 3指的是特征金字塔中不同层级的特征,每个层级都有不同的空间分辨率。
- ASFF-1、ASFF-2和ASFF-3表示应用了ASFF机制的不同层级的特征融合。
- 在ASFF-3的放大部分,我们可以看到来自其他层级的特征(x1→3、x2→3)被调整到与第三层(x3→3)相同的尺寸,然后它们通过学习到的权重图进行加权融合,生成最终用于预测的融合特征( y^3 ​)。

通过这种方式,ASFF能够在每个空间位置自适应地选择最有用的特征,以提高检测的准确性。这种方法允许模型根据每个特定位置和尺度的上下文,灵活地决定哪些特征层级对最终预测最为重要。


三、ASFFHead的核心代码

现在是三头的检测版本,后期我会出四头的增加小目标检测层的版本给大家,其使用方式看章节四。

  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. from ultralytics.utils.tal import dist2bbox, make_anchors
  5. import math
  6. import torch.nn.functional as F
  7. __all__ = ['ASFFHead']
  8. def autopad(k, p=None, d=1): # kernel, padding, dilation
  9. """Pad to 'same' shape outputs."""
  10. if d > 1:
  11. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  12. if p is None:
  13. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  14. return p
  15. class Conv(nn.Module):
  16. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  17. default_act = nn.SiLU() # default activation
  18. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  19. """Initialize Conv layer with given arguments including activation."""
  20. super().__init__()
  21. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  22. self.bn = nn.BatchNorm2d(c2)
  23. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  24. def forward(self, x):
  25. """Apply convolution, batch normalization and activation to input tensor."""
  26. return self.act(self.bn(self.conv(x)))
  27. def forward_fuse(self, x):
  28. """Perform transposed convolution of 2D data."""
  29. return self.act(self.conv(x))
  30. class DFL(nn.Module):
  31. """
  32. Integral module of Distribution Focal Loss (DFL) CSDN:Snu77 .
  33. Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
  34. """
  35. def __init__(self, c1=16):
  36. """Initialize a convolutional layer with a given number of input channels."""
  37. super().__init__()
  38. self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
  39. x = torch.arange(c1, dtype=torch.float)
  40. self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
  41. self.c1 = c1
  42. def forward(self, x):
  43. """Applies a transformer layer on input tensor 'x' and returns a tensor."""
  44. b, c, a = x.shape # batch, channels, anchors
  45. return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
  46. # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
  47. class ASFFV5(nn.Module):
  48. def __init__(self, level, ch, multiplier=1, rfb=False, vis=False, act_cfg=True):
  49. """
  50. CSDN:Snu77
  51. """
  52. super(ASFFV5, self).__init__()
  53. self.level = level
  54. self.dim = [int(ch[2] * multiplier), int(ch[1] * multiplier),
  55. int(ch[0] * multiplier)]
  56. # print(self.dim)
  57. self.inter_dim = self.dim[self.level]
  58. if level == 0:
  59. self.stride_level_1 = Conv(int(ch[1] * multiplier), self.inter_dim, 3, 2)
  60. self.stride_level_2 = Conv(int(ch[0] * multiplier), self.inter_dim, 3, 2)
  61. self.expand = Conv(self.inter_dim, int(
  62. ch[2] * multiplier), 3, 1)
  63. elif level == 1:
  64. self.compress_level_0 = Conv(
  65. int(ch[2] * multiplier), self.inter_dim, 1, 1)
  66. self.stride_level_2 = Conv(
  67. int(ch[0] * multiplier), self.inter_dim, 3, 2)
  68. self.expand = Conv(self.inter_dim, int(ch[1] * multiplier), 3, 1)
  69. elif level == 2:
  70. self.compress_level_0 = Conv(
  71. int(ch[2] * multiplier), self.inter_dim, 1, 1)
  72. self.compress_level_1 = Conv(
  73. int(ch[1] * multiplier), self.inter_dim, 1, 1)
  74. self.expand = Conv(self.inter_dim, int(
  75. ch[0] * multiplier), 3, 1)
  76. # when adding rfb, we use half number of channels to save memory
  77. compress_c = 8 if rfb else 16
  78. self.weight_level_0 = Conv(
  79. self.inter_dim, compress_c, 1, 1)
  80. self.weight_level_1 = Conv(
  81. self.inter_dim, compress_c, 1, 1)
  82. self.weight_level_2 = Conv(
  83. self.inter_dim, compress_c, 1, 1)
  84. self.weight_levels = Conv(
  85. compress_c * 3, 3, 1, 1)
  86. self.vis = vis
  87. def forward(self, x): # l,m,s
  88. """
  89. # 128, 256, 512
  90. 512, 256, 128
  91. from small -> large
  92. """
  93. x_level_0 = x[2] # l
  94. x_level_1 = x[1] # m
  95. x_level_2 = x[0] # s
  96. # print('x_level_0: ', x_level_0.shape)
  97. # print('x_level_1: ', x_level_1.shape)
  98. # print('x_level_2: ', x_level_2.shape)
  99. if self.level == 0:
  100. level_0_resized = x_level_0
  101. level_1_resized = self.stride_level_1(x_level_1)
  102. level_2_downsampled_inter = F.max_pool2d(
  103. x_level_2, 3, stride=2, padding=1)
  104. level_2_resized = self.stride_level_2(level_2_downsampled_inter)
  105. elif self.level == 1:
  106. level_0_compressed = self.compress_level_0(x_level_0)
  107. level_0_resized = F.interpolate(
  108. level_0_compressed, scale_factor=2, mode='nearest')
  109. level_1_resized = x_level_1
  110. level_2_resized = self.stride_level_2(x_level_2)
  111. elif self.level == 2:
  112. level_0_compressed = self.compress_level_0(x_level_0)
  113. level_0_resized = F.interpolate(
  114. level_0_compressed, scale_factor=4, mode='nearest')
  115. x_level_1_compressed = self.compress_level_1(x_level_1)
  116. level_1_resized = F.interpolate(
  117. x_level_1_compressed, scale_factor=2, mode='nearest')
  118. level_2_resized = x_level_2
  119. # print('level: {}, l1_resized: {}, l2_resized: {}'.format(self.level,
  120. # level_1_resized.shape, level_2_resized.shape))
  121. level_0_weight_v = self.weight_level_0(level_0_resized)
  122. level_1_weight_v = self.weight_level_1(level_1_resized)
  123. level_2_weight_v = self.weight_level_2(level_2_resized)
  124. # print('level_0_weight_v: ', level_0_weight_v.shape)
  125. # print('level_1_weight_v: ', level_1_weight_v.shape)
  126. # print('level_2_weight_v: ', level_2_weight_v.shape)
  127. levels_weight_v = torch.cat(
  128. (level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
  129. levels_weight = self.weight_levels(levels_weight_v)
  130. levels_weight = F.softmax(levels_weight, dim=1)
  131. fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
  132. level_1_resized * levels_weight[:, 1:2, :, :] + \
  133. level_2_resized * levels_weight[:, 2:, :, :]
  134. out = self.expand(fused_out_reduced)
  135. if self.vis:
  136. return out, levels_weight, fused_out_reduced.sum(dim=1)
  137. else:
  138. return out
  139. class DWConv(Conv):
  140. """Depth-wise convolution."""
  141. def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
  142. """Initialize Depth-wise convolution with given parameters."""
  143. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
  144. class ASFFHead(nn.Module):
  145. """YOLOv8 Detect head for detection models. CSDNSnu77"""
  146. dynamic = False # force grid reconstruction
  147. export = False # export mode
  148. end2end = False # end2end
  149. max_det = 300 # max_det
  150. shape = None
  151. anchors = torch.empty(0) # init
  152. strides = torch.empty(0) # init
  153. def __init__(self, nc=80, ch=(), multiplier=1, rfb=False):
  154. """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
  155. super().__init__()
  156. self.nc = nc # number of classes
  157. self.nl = len(ch) # number of detection layers
  158. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  159. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  160. self.stride = torch.zeros(self.nl) # strides computed during build
  161. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
  162. self.cv2 = nn.ModuleList(
  163. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
  164. )
  165. self.cv3 = nn.ModuleList(
  166. nn.Sequential(
  167. nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
  168. nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
  169. nn.Conv2d(c3, self.nc, 1),
  170. )
  171. for x in ch
  172. )
  173. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  174. self.l0_fusion = ASFFV5(level=0, ch=ch, multiplier=multiplier, rfb=rfb)
  175. self.l1_fusion = ASFFV5(level=1, ch=ch, multiplier=multiplier, rfb=rfb)
  176. self.l2_fusion = ASFFV5(level=2, ch=ch, multiplier=multiplier, rfb=rfb)
  177. if self.end2end:
  178. self.one2one_cv2 = copy.deepcopy(self.cv2)
  179. self.one2one_cv3 = copy.deepcopy(self.cv3)
  180. def forward(self, x):
  181. x1 = self.l0_fusion(x)
  182. x2 = self.l1_fusion(x)
  183. x3 = self.l2_fusion(x)
  184. x = [x3, x2, x1]
  185. """Concatenates and returns predicted bounding boxes and class probabilities."""
  186. if self.end2end:
  187. return self.forward_end2end(x)
  188. for i in range(self.nl):
  189. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  190. if self.training: # Training path
  191. return x
  192. y = self._inference(x)
  193. return y if self.export else (y, x)
  194. def forward_end2end(self, x):
  195. """
  196. Performs forward pass of the v10Detect module.
  197. Args:
  198. x (tensor): Input tensor.
  199. Returns:
  200. (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
  201. If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
  202. """
  203. x_detach = [xi.detach() for xi in x]
  204. one2one = [
  205. torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
  206. ]
  207. for i in range(self.nl):
  208. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  209. if self.training: # Training path
  210. return {"one2many": x, "one2one": one2one}
  211. y = self._inference(one2one)
  212. y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
  213. return y if self.export else (y, {"one2many": x, "one2one": one2one})
  214. def _inference(self, x):
  215. """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
  216. # Inference path
  217. shape = x[0].shape # BCHW
  218. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  219. if self.dynamic or self.shape != shape:
  220. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  221. self.shape = shape
  222. if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
  223. box = x_cat[:, : self.reg_max * 4]
  224. cls = x_cat[:, self.reg_max * 4 :]
  225. else:
  226. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  227. if self.export and self.format in {"tflite", "edgetpu"}:
  228. # Precompute normalization factor to increase numerical stability
  229. # See https://github.com/ultralytics/ultralytics/issues/7371
  230. grid_h = shape[2]
  231. grid_w = shape[3]
  232. grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
  233. norm = self.strides / (self.stride[0] * grid_size)
  234. dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
  235. else:
  236. dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
  237. return torch.cat((dbox, cls.sigmoid()), 1)
  238. def bias_init(self):
  239. """Initialize Detect() biases, WARNING: requires stride availability."""
  240. m = self # self.model[-1] # Detect() module
  241. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  242. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  243. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  244. a[-1].bias.data[:] = 1.0 # box
  245. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  246. if self.end2end:
  247. for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
  248. a[-1].bias.data[:] = 1.0 # box
  249. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  250. def decode_bboxes(self, bboxes, anchors):
  251. """Decode bounding boxes."""
  252. return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
  253. @staticmethod
  254. def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
  255. """
  256. Post-processes YOLO model predictions.
  257. Args:
  258. preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
  259. format [x, y, w, h, class_probs].
  260. max_det (int): Maximum detections per image.
  261. nc (int, optional): Number of classes. Default: 80.
  262. Returns:
  263. (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
  264. dimension format [x, y, w, h, max_class_prob, class_index].
  265. """
  266. batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
  267. boxes, scores = preds.split([4, nc], dim=-1)
  268. index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
  269. boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
  270. scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
  271. scores, index = scores.flatten(1).topk(min(max_det, anchors))
  272. i = torch.arange(batch_size)[..., None] # batch indices
  273. return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
  274. if __name__ == "__main__":
  275. # Generating Sample image
  276. image1 = (1, 64, 32, 32)
  277. image2 = (1, 128, 16, 16)
  278. image3 = (1, 256, 8, 8)
  279. image1 = torch.rand(image1)
  280. image2 = torch.rand(image2)
  281. image3 = torch.rand(image3)
  282. image = [image1, image2, image3]
  283. channel = (64, 128, 256)
  284. # Model
  285. mobilenet_v1 = ASFFHead(nc=80, ch=channel)
  286. out = mobilenet_v1(image)
  287. print(out)


四、手把手教你添加ASFFHead检测头

4.1 修改一

首先我们将上面的代码复制粘贴到' ultralytics /nn' 目录下新建一个py文件复制粘贴进去,具体名字自己来定,我这里起名为ASFFHead.py。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。

​​


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

​​


4.4 修改四

第四步我门找到如下文件'ultralytics/nn/tasks.py,找到如下的代码进行将检测头添加进去,这里给大家推荐个快速搜索的方法用ctrl+f然后搜索 Detect 然后就能快速查找了。

​​​


4.5 修改五

按照红框修改.


4.6 修改六

同理

​​​


4.7 修改七

这里有一些不一样,我们需要加一行代码

  1. else:
  2. return 'detect'

为啥呢不一样,因为这里的m在代码执行过程中会将你的代码自动转换为小写,所以直接else方便一点,以后出现一些其它分割或者其它的教程的时候在提供其它的修改教程。

​​​


4.8 修改八

同理.

​​​


到此就修改完成了,大家可以复制下面的yaml文件运行。


五、ASFFHead检测头的yaml文件

这个代码的yaml文件和正常的对比也需要修改一下, 如下->

此版本训练信息:YOLO11-ASFFHead summary: 386 layers, 3,962,549 parameters, 3,962,533 gradients, 8.6 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, Conv, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, Conv, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, ASFFHead, [nc]] # Detect(P3, P4, P5)


六、完美运行记录

最后提供一下完美运行的图片。

​​


七、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

​​