学习资源站

YOLOv11改进-检测头篇-利用AFPN辅助YOLOv11检测头进行特征融合识别(全网独家创新)

一、本文介绍

本文给大家带来的最新改进机制是利用 AFPN(渐近特征金字塔网络) 来优化检测头, AFPN的核心思想 是通过引入一种渐近的 特征融合 策略,将底层、高层和顶层的特征逐渐整合到目标检测过程中。这种渐近融合方式有助于减小不同层次特征之间的语义差距,提高特征融合效果,使得检测 模型 能更好地适应不同层次的语义信息。本文在AFPN的结构基础上,为了适配YOLOv11改进AFPN结构,同时将AFPN融合到YOLOv11中 (因为AFPN需要四个检测头,我们只有三个,下一篇文章我会出YOLOv11适配AFPN增加小目标检测头) 实现有效涨点。



二、AFPN基本框架原理

论文地址: 官方论文地址

代码地址: 官方代码地址


2.1 AFPN的基本原理

AFPN的核心思想是通过引入一种渐近的特征融合策略,将底层、高层和顶层的特征逐渐整合到目标检测过程中。这种渐近融合方式有助于减小不同层次特征之间的语义差距,提高特征融合效果,使得检测模型能更好地适应不同层次的语义信息。

主要改进机制:
1. 底层特征融合: AFPN通过引入底层特征的逐步融合,首先融合底层特征,接着深层特征,最后整合顶层特征。这种层级融合的方式有助于更好地利用不同层次的语义信息,提高检测 性能

2. 自适应空间融合: 引入自适应空间融合机制(ASFF),在多级特征融合过程中引入变化的空间权重,加强关键级别的重要性,同时抑制来自不同对象的矛盾信息的影响。这有助于提高检测性能,尤其在处理矛盾信息时更为有效。

3. 底层特征对齐: AFPN采用渐近融合的思想,使得不同层次的特征在融合过程中逐渐接近,减小它们之间的语义差距。通过底层特征的逐步整合,提高了特征融合的效果,使得模型更能理解和利用不同层次的信息。

个人总结: AFPN的灵感就像是搭积木一样,它不是一下子把所有的积木都放到一起,而是逐步地将不同层次的积木慢慢整合在一起。这样一来,我们可以更好地理解和利用每一层次的积木,从而构建一个更牢固的目标检测系统。同时,引入了一种智能的机制,能够根据不同情况调整注意力,更好地处理矛盾信息。

上面上AFPN的网络结构,可以看出从Backbone中提取出特征之后,将特征输入到AFPN中进行处理,然后它可以获得不同层级的特征进行融合,这也是它的主要思想质疑,同时将结果输入到检测头中进行预测。

(需要注意的是本文砍掉了最下面那一条线适应YOLOv8因为我们是三个检测头,下一篇文章我会出增加小目标检测头的然后四个头的yolov8改进,从而适应AFPN的结构)。


三、AFPNHead完整代码

核心代码的使用方式看章节四!

  1. import copy
  2. import math
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from ultralytics.nn.modules import DFL
  8. from ultralytics.nn.modules.conv import Conv
  9. from ultralytics.utils.tal import dist2bbox, make_anchors
  10. __all__ =['AFPNHead']
  11. def BasicConv(filter_in, filter_out, kernel_size, stride=1, pad=None):
  12. if not pad:
  13. pad = (kernel_size - 1) // 2 if kernel_size else 0
  14. else:
  15. pad = pad
  16. return nn.Sequential(OrderedDict([
  17. ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
  18. ("bn", nn.BatchNorm2d(filter_out)),
  19. ("relu", nn.ReLU(inplace=True)),
  20. ]))
  21. class BasicBlock(nn.Module):
  22. expansion = 1
  23. # CSDN Snu77
  24. def __init__(self, filter_in, filter_out):
  25. super(BasicBlock, self).__init__()
  26. self.conv1 = nn.Conv2d(filter_in, filter_out, 3, padding=1)
  27. self.bn1 = nn.BatchNorm2d(filter_out, momentum=0.1)
  28. self.relu = nn.ReLU(inplace=True)
  29. self.conv2 = nn.Conv2d(filter_out, filter_out, 3, padding=1)
  30. self.bn2 = nn.BatchNorm2d(filter_out, momentum=0.1)
  31. def forward(self, x):
  32. residual = x
  33. out = self.conv1(x)
  34. out = self.bn1(out)
  35. out = self.relu(out)
  36. out = self.conv2(out)
  37. out = self.bn2(out)
  38. out += residual
  39. out = self.relu(out)
  40. return out
  41. class Upsample(nn.Module):
  42. # CSDN Snu77
  43. def __init__(self, in_channels, out_channels, scale_factor=2):
  44. super(Upsample, self).__init__()
  45. self.upsample = nn.Sequential(
  46. BasicConv(in_channels, out_channels, 1),
  47. nn.Upsample(scale_factor=scale_factor, mode='bilinear')
  48. )
  49. def forward(self, x):
  50. x = self.upsample(x)
  51. return x
  52. class Downsample_x2(nn.Module):
  53. # CSDN Snu77
  54. def __init__(self, in_channels, out_channels):
  55. super(Downsample_x2, self).__init__()
  56. self.downsample = nn.Sequential(
  57. BasicConv(in_channels, out_channels, 2, 2, 0)
  58. )
  59. def forward(self, x, ):
  60. x = self.downsample(x)
  61. return x
  62. class Downsample_x4(nn.Module):
  63. def __init__(self, in_channels, out_channels):
  64. super(Downsample_x4, self).__init__()
  65. self.downsample = nn.Sequential(
  66. BasicConv(in_channels, out_channels, 4, 4, 0)
  67. )
  68. def forward(self, x, ):
  69. x = self.downsample(x)
  70. return x
  71. class Downsample_x8(nn.Module):
  72. def __init__(self, in_channels, out_channels):
  73. super(Downsample_x8, self).__init__()
  74. self.downsample = nn.Sequential(
  75. BasicConv(in_channels, out_channels, 8, 8, 0)
  76. )
  77. def forward(self, x, ):
  78. x = self.downsample(x)
  79. return x
  80. class ASFF_2(nn.Module):
  81. def __init__(self, inter_dim=512):
  82. super(ASFF_2, self).__init__()
  83. self.inter_dim = inter_dim
  84. compress_c = 8
  85. self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
  86. self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
  87. self.weight_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)
  88. self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
  89. def forward(self, input1, input2):
  90. level_1_weight_v = self.weight_level_1(input1)
  91. level_2_weight_v = self.weight_level_2(input2)
  92. levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
  93. levels_weight = self.weight_levels(levels_weight_v)
  94. levels_weight = F.softmax(levels_weight, dim=1)
  95. fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
  96. input2 * levels_weight[:, 1:2, :, :]
  97. out = self.conv(fused_out_reduced)
  98. return out
  99. class ASFF_3(nn.Module):
  100. def __init__(self, inter_dim=512):
  101. super(ASFF_3, self).__init__()
  102. self.inter_dim = inter_dim
  103. compress_c = 8
  104. self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
  105. self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
  106. self.weight_level_3 = BasicConv(self.inter_dim, compress_c, 1, 1)
  107. self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
  108. self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
  109. def forward(self, input1, input2, input3):
  110. level_1_weight_v = self.weight_level_1(input1)
  111. level_2_weight_v = self.weight_level_2(input2)
  112. level_3_weight_v = self.weight_level_3(input3)
  113. levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
  114. levels_weight = self.weight_levels(levels_weight_v)
  115. levels_weight = F.softmax(levels_weight, dim=1)
  116. fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
  117. input2 * levels_weight[:, 1:2, :, :] + \
  118. input3 * levels_weight[:, 2:, :, :]
  119. out = self.conv(fused_out_reduced)
  120. return out
  121. class ASFF_4(nn.Module):
  122. def __init__(self, inter_dim=512):
  123. super(ASFF_4, self).__init__()
  124. self.inter_dim = inter_dim
  125. compress_c = 8
  126. self.weight_level_0 = BasicConv(self.inter_dim, compress_c, 1, 1)
  127. self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
  128. self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
  129. self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
  130. self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
  131. def forward(self, input0, input1, input2):
  132. level_0_weight_v = self.weight_level_0(input0)
  133. level_1_weight_v = self.weight_level_1(input1)
  134. level_2_weight_v = self.weight_level_2(input2)
  135. levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
  136. levels_weight = self.weight_levels(levels_weight_v)
  137. levels_weight = F.softmax(levels_weight, dim=1)
  138. fused_out_reduced = input0 * levels_weight[:, 0:1, :, :] + \
  139. input1 * levels_weight[:, 1:2, :, :] + \
  140. input2 * levels_weight[:, 2:3, :, :]
  141. out = self.conv(fused_out_reduced)
  142. return out
  143. class BlockBody(nn.Module):
  144. def __init__(self, channels=[64, 128, 256, 512]):
  145. super(BlockBody, self).__init__()
  146. self.blocks_scalezero1 = nn.Sequential(
  147. BasicConv(channels[0], channels[0], 1),
  148. )
  149. self.blocks_scaleone1 = nn.Sequential(
  150. BasicConv(channels[1], channels[1], 1),
  151. )
  152. self.blocks_scaletwo1 = nn.Sequential(
  153. BasicConv(channels[2], channels[2], 1),
  154. )
  155. self.downsample_scalezero1_2 = Downsample_x2(channels[0], channels[1])
  156. self.upsample_scaleone1_2 = Upsample(channels[1], channels[0], scale_factor=2)
  157. self.asff_scalezero1 = ASFF_2(inter_dim=channels[0])
  158. self.asff_scaleone1 = ASFF_2(inter_dim=channels[1])
  159. self.blocks_scalezero2 = nn.Sequential(
  160. BasicBlock(channels[0], channels[0]),
  161. BasicBlock(channels[0], channels[0]),
  162. BasicBlock(channels[0], channels[0]),
  163. BasicBlock(channels[0], channels[0]),
  164. )
  165. self.blocks_scaleone2 = nn.Sequential(
  166. BasicBlock(channels[1], channels[1]),
  167. BasicBlock(channels[1], channels[1]),
  168. BasicBlock(channels[1], channels[1]),
  169. BasicBlock(channels[1], channels[1]),
  170. )
  171. self.downsample_scalezero2_2 = Downsample_x2(channels[0], channels[1])
  172. self.downsample_scalezero2_4 = Downsample_x4(channels[0], channels[2])
  173. self.downsample_scaleone2_2 = Downsample_x2(channels[1], channels[2])
  174. self.upsample_scaleone2_2 = Upsample(channels[1], channels[0], scale_factor=2)
  175. self.upsample_scaletwo2_2 = Upsample(channels[2], channels[1], scale_factor=2)
  176. self.upsample_scaletwo2_4 = Upsample(channels[2], channels[0], scale_factor=4)
  177. self.asff_scalezero2 = ASFF_3(inter_dim=channels[0])
  178. self.asff_scaleone2 = ASFF_3(inter_dim=channels[1])
  179. self.asff_scaletwo2 = ASFF_3(inter_dim=channels[2])
  180. self.blocks_scalezero3 = nn.Sequential(
  181. BasicBlock(channels[0], channels[0]),
  182. BasicBlock(channels[0], channels[0]),
  183. BasicBlock(channels[0], channels[0]),
  184. BasicBlock(channels[0], channels[0]),
  185. )
  186. self.blocks_scaleone3 = nn.Sequential(
  187. BasicBlock(channels[1], channels[1]),
  188. BasicBlock(channels[1], channels[1]),
  189. BasicBlock(channels[1], channels[1]),
  190. BasicBlock(channels[1], channels[1]),
  191. )
  192. self.blocks_scaletwo3 = nn.Sequential(
  193. BasicBlock(channels[2], channels[2]),
  194. BasicBlock(channels[2], channels[2]),
  195. BasicBlock(channels[2], channels[2]),
  196. BasicBlock(channels[2], channels[2]),
  197. )
  198. self.downsample_scalezero3_2 = Downsample_x2(channels[0], channels[1])
  199. self.downsample_scalezero3_4 = Downsample_x4(channels[0], channels[2])
  200. self.upsample_scaleone3_2 = Upsample(channels[1], channels[0], scale_factor=2)
  201. self.downsample_scaleone3_2 = Downsample_x2(channels[1], channels[2])
  202. self.upsample_scaletwo3_4 = Upsample(channels[2], channels[0], scale_factor=4)
  203. self.upsample_scaletwo3_2 = Upsample(channels[2], channels[1], scale_factor=2)
  204. self.asff_scalezero3 = ASFF_4(inter_dim=channels[0])
  205. self.asff_scaleone3 = ASFF_4(inter_dim=channels[1])
  206. self.asff_scaletwo3 = ASFF_4(inter_dim=channels[2])
  207. self.blocks_scalezero4 = nn.Sequential(
  208. BasicBlock(channels[0], channels[0]),
  209. BasicBlock(channels[0], channels[0]),
  210. BasicBlock(channels[0], channels[0]),
  211. BasicBlock(channels[0], channels[0]),
  212. )
  213. self.blocks_scaleone4 = nn.Sequential(
  214. BasicBlock(channels[1], channels[1]),
  215. BasicBlock(channels[1], channels[1]),
  216. BasicBlock(channels[1], channels[1]),
  217. BasicBlock(channels[1], channels[1]),
  218. )
  219. self.blocks_scaletwo4 = nn.Sequential(
  220. BasicBlock(channels[2], channels[2]),
  221. BasicBlock(channels[2], channels[2]),
  222. BasicBlock(channels[2], channels[2]),
  223. BasicBlock(channels[2], channels[2]),
  224. )
  225. def forward(self, x):
  226. x0, x1, x2 = x
  227. x0 = self.blocks_scalezero1(x0)
  228. x1 = self.blocks_scaleone1(x1)
  229. x2 = self.blocks_scaletwo1(x2)
  230. scalezero = self.asff_scalezero1(x0, self.upsample_scaleone1_2(x1))
  231. scaleone = self.asff_scaleone1(self.downsample_scalezero1_2(x0), x1)
  232. x0 = self.blocks_scalezero2(scalezero)
  233. x1 = self.blocks_scaleone2(scaleone)
  234. scalezero = self.asff_scalezero2(x0, self.upsample_scaleone2_2(x1), self.upsample_scaletwo2_4(x2))
  235. scaleone = self.asff_scaleone2(self.downsample_scalezero2_2(x0), x1, self.upsample_scaletwo2_2(x2))
  236. scaletwo = self.asff_scaletwo2(self.downsample_scalezero2_4(x0), self.downsample_scaleone2_2(x1), x2)
  237. x0 = self.blocks_scalezero3(scalezero)
  238. x1 = self.blocks_scaleone3(scaleone)
  239. x2 = self.blocks_scaletwo3(scaletwo)
  240. scalezero = self.asff_scalezero3(x0, self.upsample_scaleone3_2(x1), self.upsample_scaletwo3_4(x2))
  241. scaleone = self.asff_scaleone3(self.downsample_scalezero3_2(x0), x1, self.upsample_scaletwo3_2(x2))
  242. scaletwo = self.asff_scaletwo3(self.downsample_scalezero3_4(x0), self.downsample_scaleone3_2(x1), x2)
  243. scalezero = self.blocks_scalezero4(scalezero)
  244. scaleone = self.blocks_scaleone4(scaleone)
  245. scaletwo = self.blocks_scaletwo4(scaletwo)
  246. return scalezero, scaleone, scaletwo
  247. class AFPN(nn.Module):
  248. # CSDN Snu77
  249. def __init__(self,
  250. in_channels=[256, 512, 1024, 2048],
  251. out_channels=128):
  252. super(AFPN, self).__init__()
  253. self.fp16_enabled = False
  254. self.conv0 = BasicConv(in_channels[0], in_channels[0] // 8, 1)
  255. self.conv1 = BasicConv(in_channels[1], in_channels[1] // 8, 1)
  256. self.conv2 = BasicConv(in_channels[2], in_channels[2] // 8, 1)
  257. # self.conv3 = BasicConv(in_channels[3], in_channels[3] // 8, 1)
  258. self.body = nn.Sequential(
  259. BlockBody([in_channels[0] // 8, in_channels[1] // 8, in_channels[2] // 8])
  260. )
  261. self.conv00 = BasicConv(in_channels[0] // 8, out_channels, 1)
  262. self.conv11 = BasicConv(in_channels[1] // 8, out_channels, 1)
  263. self.conv22 = BasicConv(in_channels[2] // 8, out_channels, 1)
  264. # self.conv33 = BasicConv(in_channels[3] // 8, out_channels, 1)
  265. # init weight
  266. for m in self.modules():
  267. if isinstance(m, nn.Conv2d):
  268. nn.init.xavier_normal_(m.weight, gain=0.02)
  269. elif isinstance(m, nn.BatchNorm2d):
  270. torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
  271. torch.nn.init.constant_(m.bias.data, 0.0)
  272. def forward(self, x):
  273. x0, x1, x2 = x
  274. x0 = self.conv0(x0)
  275. x1 = self.conv1(x1)
  276. x2 = self.conv2(x2)
  277. # x3 = self.conv3(x3)
  278. out0, out1, out2 = self.body([x0, x1, x2])
  279. out0 = self.conv00(out0)
  280. out1 = self.conv11(out1)
  281. out2 = self.conv22(out2)
  282. return out0, out1, out2
  283. class DWConv(Conv):
  284. """Depth-wise convolution."""
  285. def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
  286. """Initialize Depth-wise convolution with given parameters."""
  287. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
  288. class AFPNHead(nn.Module):
  289. """YOLOv8 Detect head for detection models. CSDNSnu77"""
  290. dynamic = False # force grid reconstruction
  291. export = False # export mode
  292. end2end = False # end2end
  293. max_det = 300 # max_det
  294. shape = None
  295. anchors = torch.empty(0) # init
  296. strides = torch.empty(0) # init
  297. def __init__(self, nc=80, channel=256, ch=()):
  298. """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
  299. super().__init__()
  300. self.nc = nc # number of classes
  301. self.nl = len(ch) # number of detection layers
  302. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  303. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  304. self.stride = torch.zeros(self.nl) # strides computed during build
  305. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
  306. self.cv2 = nn.ModuleList(
  307. nn.Sequential(Conv(channel, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
  308. )
  309. self.cv3 = nn.ModuleList(
  310. nn.Sequential(
  311. nn.Sequential(DWConv(channel, channel, 3), Conv(channel, c3, 1)),
  312. nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
  313. nn.Conv2d(c3, self.nc, 1),
  314. )
  315. for x in ch
  316. )
  317. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  318. self.AFPN = AFPN(ch, channel)
  319. if self.end2end:
  320. self.one2one_cv2 = copy.deepcopy(self.cv2)
  321. self.one2one_cv3 = copy.deepcopy(self.cv3)
  322. def forward(self, x):
  323. x = list(self.AFPN(x))
  324. """Concatenates and returns predicted bounding boxes and class probabilities."""
  325. if self.end2end:
  326. return self.forward_end2end(x)
  327. for i in range(self.nl):
  328. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  329. if self.training: # Training path
  330. return x
  331. y = self._inference(x)
  332. return y if self.export else (y, x)
  333. def forward_end2end(self, x):
  334. """
  335. Performs forward pass of the v10Detect module.
  336. Args:
  337. x (tensor): Input tensor.
  338. Returns:
  339. (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
  340. If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
  341. """
  342. x_detach = [xi.detach() for xi in x]
  343. one2one = [
  344. torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
  345. ]
  346. for i in range(self.nl):
  347. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  348. if self.training: # Training path
  349. return {"one2many": x, "one2one": one2one}
  350. y = self._inference(one2one)
  351. y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
  352. return y if self.export else (y, {"one2many": x, "one2one": one2one})
  353. def _inference(self, x):
  354. """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
  355. # Inference path
  356. shape = x[0].shape # BCHW
  357. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  358. if self.dynamic or self.shape != shape:
  359. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  360. self.shape = shape
  361. if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
  362. box = x_cat[:, : self.reg_max * 4]
  363. cls = x_cat[:, self.reg_max * 4 :]
  364. else:
  365. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  366. if self.export and self.format in {"tflite", "edgetpu"}:
  367. # Precompute normalization factor to increase numerical stability
  368. # See https://github.com/ultralytics/ultralytics/issues/7371
  369. grid_h = shape[2]
  370. grid_w = shape[3]
  371. grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
  372. norm = self.strides / (self.stride[0] * grid_size)
  373. dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
  374. else:
  375. dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
  376. return torch.cat((dbox, cls.sigmoid()), 1)
  377. def bias_init(self):
  378. """Initialize Detect() biases, WARNING: requires stride availability."""
  379. m = self # self.model[-1] # Detect() module
  380. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  381. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  382. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  383. a[-1].bias.data[:] = 1.0 # box
  384. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  385. if self.end2end:
  386. for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
  387. a[-1].bias.data[:] = 1.0 # box
  388. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  389. def decode_bboxes(self, bboxes, anchors):
  390. """Decode bounding boxes."""
  391. return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
  392. @staticmethod
  393. def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
  394. """
  395. Post-processes YOLO model predictions.
  396. Args:
  397. preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
  398. format [x, y, w, h, class_probs].
  399. max_det (int): Maximum detections per image.
  400. nc (int, optional): Number of classes. Default: 80.
  401. Returns:
  402. (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
  403. dimension format [x, y, w, h, max_class_prob, class_index].
  404. """
  405. batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
  406. boxes, scores = preds.split([4, nc], dim=-1)
  407. index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
  408. boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
  409. scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
  410. scores, index = scores.flatten(1).topk(min(max_det, anchors))
  411. i = torch.arange(batch_size)[..., None] # batch indices
  412. return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
  413. if __name__ == "__main__":
  414. # Generating Sample image
  415. image1 = (1, 64, 32, 32)
  416. image2 = (1, 128, 16, 16)
  417. image3 = (1, 256, 8, 8)
  418. image1 = torch.rand(image1)
  419. image2 = torch.rand(image2)
  420. image3 = torch.rand(image3)
  421. image = [image1, image2, image3]
  422. channel = (64, 128, 256)
  423. # Model
  424. mobilenet_v1 = AFPNHead(nc=80, channel=128, ch=channel)
  425. out = mobilenet_v1(image)
  426. print(out)


四、手把手教你添加AFPNHead检测头

4.1 修改一

首先我们将上面的代码复制粘贴到' ultralytics /nn' 目录下新建一个py文件复制粘贴进去,具体名字自己来定,我这里起名为AFPNHead.py。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。

​​


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

​​


4.4 修改四

第四步我门找到如下文件'ultralytics/nn/tasks.py,找到如下的代码进行将检测头添加进去,这里给大家推荐个快速搜索的方法用ctrl+f然后搜索Detect然后就能快速查找了。

​​​


4.5 修改五


4.6 修改六

同理

​​​


4.7 修改七

这里有一些不一样,我们需要加一行代码

  1. else:
  2. return 'detect'

为啥呢不一样,因为这里的m在代码执行过程中会将你的代码自动转换为小写,所以直接else方便一点,以后出现一些其它分割或者其它的教程的时候在提供其它的修改教程。

​​​


4.8 修改八

同理.

​​​


到此就修改完成了,大家可以复制下面的yaml文件运行。


五、AFPNHead检测头的yaml文件

此版本的训练信息:YOLO11-AFPNHead summary: 788 layers, 2,707,033 parameters, 2,707,017 gradients, 6.8 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, Conv, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, Conv, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. # 64 可以更换为 128 256
  41. - [[16, 19, 22], 1, AFPNHead, [nc, 64]] # Detect(P3, P4, P5)


六、完美运行记录

最后提供一下完美运行的图片。


七、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~