学习资源站

YOLOv11改进-检测头篇-利用渐近特征金字塔网络AFPN改进yolov11检测头融合P2网络结构(让小目标无所遁形,全网独家首发)

一、本文介绍

本文给大家带来的最新改进机制是利用今年新推出的 AFPN(渐近特征金字塔网络) 来优化检测头, AFPN的核心思想 是通过引入一种渐近的 特征融合 策略,将底层、高层和顶层的特征逐渐整合到目标检测过程中。这种渐近融合方式有助于减小不同层次特征之间的语义差距,提高特征融合效果,使得检测 模型 能更好地适应不同层次的语义信息。之前答应大家说出一个四头版本的Detect_FPN本文就是该检测头,利用该检测头实现暴力涨点, 让小目标无所遁形。

同时欢迎大家订阅本专栏,本专栏每周更新3-5篇最新机制,更有包含我所有改进的文件和交流群提供给大家。



二、AFPN基本框架原理

论文地址: 官方论文地址

代码地址: 官方代码地址


AFPN的核心思想是通过引入一种渐近的特征融合策略 ,将底层、高层和顶层的特征逐渐整合到 目标检测 过程中。这种渐近融合方式有助于减小不同层次特征之间的语义差距,提高特征融合效果,使得检测模型能更好地适应不同层次的语义信息。

主要改进机制:
1. 底层特征融合: AFPN通过引入底层特征的逐步融合,首先融合底层特征,接着深层特征,最后整合顶层特征。这种层级融合的方式有助于更好地利用不同层次的语义信息,提高检测 性能

2. 自适应空间融合: 引入自适应空间融合机制(ASFF),在多级特征融合过程中引入变化的空间权重,加强关键级别的重要性,同时抑制来自不同对象的矛盾信息的影响。这有助于提高检测性能,尤其在处理矛盾信息时更为有效。

3. 底层特征对齐: AFPN采用渐近融合的思想,使得不同层次的特征在融合过程中逐渐接近,减小它们之间的语义差距。通过底层特征的逐步整合,提高了特征融合的效果,使得模型更能理解和利用不同层次的信息。

个人总结: AFPN的灵感就像是搭积木一样,它不是一下子把所有的积木都放到一起,而是逐步地将不同层次的积木慢慢整合在一起。这样一来,我们可以更好地理解和利用每一层次的积木,从而构建一个更牢固的目标检测系统。同时,引入了一种智能的机制,能够根据不同情况调整注意力,更好地处理矛盾信息。

上面是AFPN的网络结构,可以看出从Backbone中提取出特征之后,将特征输入到AFPN中进行处理,然后它可以获得不同层级的特征进行融合,这也是它的主要思想之一,同时将结果输入到检测头中进行预测。


三、AFPN4Head完整代码

使用方法看章节四!

  1. import copy
  2. import math
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from ultralytics.nn.modules import DFL
  8. from ultralytics.nn.modules.conv import Conv
  9. from ultralytics.utils.tal import dist2bbox, make_anchors
  10. __all__ =['AFPN4Head']
  11. def BasicConv(filter_in, filter_out, kernel_size, stride=1, pad=None):
  12. if not pad:
  13. pad = (kernel_size - 1) // 2 if kernel_size else 0
  14. else:
  15. pad = pad
  16. return nn.Sequential(OrderedDict([
  17. ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
  18. ("bn", nn.BatchNorm2d(filter_out)),
  19. ("relu", nn.ReLU(inplace=True)),
  20. ]))
  21. class BasicBlock(nn.Module):
  22. expansion = 1
  23. def __init__(self, filter_in, filter_out):
  24. super(BasicBlock, self).__init__()
  25. self.conv1 = nn.Conv2d(filter_in, filter_out, 3, padding=1)
  26. self.bn1 = nn.BatchNorm2d(filter_out, momentum=0.1)
  27. self.relu = nn.ReLU(inplace=True)
  28. self.conv2 = nn.Conv2d(filter_out, filter_out, 3, padding=1)
  29. self.bn2 = nn.BatchNorm2d(filter_out, momentum=0.1)
  30. def forward(self, x):
  31. residual = x
  32. out = self.conv1(x)
  33. out = self.bn1(out)
  34. out = self.relu(out)
  35. out = self.conv2(out)
  36. out = self.bn2(out)
  37. out += residual
  38. out = self.relu(out)
  39. return out
  40. class Upsample(nn.Module):
  41. def __init__(self, in_channels, out_channels, scale_factor=2):
  42. super(Upsample, self).__init__()
  43. self.upsample = nn.Sequential(
  44. BasicConv(in_channels, out_channels, 1),
  45. nn.Upsample(scale_factor=scale_factor, mode='bilinear')
  46. )
  47. def forward(self, x):
  48. x = self.upsample(x)
  49. return x
  50. class Downsample_x2(nn.Module):
  51. def __init__(self, in_channels, out_channels):
  52. super(Downsample_x2, self).__init__()
  53. self.downsample = nn.Sequential(
  54. BasicConv(in_channels, out_channels, 2, 2, 0)
  55. )
  56. def forward(self, x, ):
  57. x = self.downsample(x)
  58. return x
  59. class Downsample_x4(nn.Module):
  60. def __init__(self, in_channels, out_channels):
  61. super(Downsample_x4, self).__init__()
  62. self.downsample = nn.Sequential(
  63. BasicConv(in_channels, out_channels, 4, 4, 0)
  64. )
  65. def forward(self, x, ):
  66. x = self.downsample(x)
  67. return x
  68. class Downsample_x8(nn.Module):
  69. def __init__(self, in_channels, out_channels):
  70. super(Downsample_x8, self).__init__()
  71. self.downsample = nn.Sequential(
  72. BasicConv(in_channels, out_channels, 8, 8, 0)
  73. )
  74. def forward(self, x, ):
  75. x = self.downsample(x)
  76. return x
  77. class ASFF_2(nn.Module):
  78. def __init__(self, inter_dim=512):
  79. super(ASFF_2, self).__init__()
  80. self.inter_dim = inter_dim
  81. compress_c = 8
  82. self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
  83. self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
  84. self.weight_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)
  85. self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
  86. def forward(self, input1, input2):
  87. level_1_weight_v = self.weight_level_1(input1)
  88. level_2_weight_v = self.weight_level_2(input2)
  89. levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
  90. levels_weight = self.weight_levels(levels_weight_v)
  91. levels_weight = F.softmax(levels_weight, dim=1)
  92. fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
  93. input2 * levels_weight[:, 1:2, :, :]
  94. out = self.conv(fused_out_reduced)
  95. return out
  96. class ASFF_3(nn.Module):
  97. def __init__(self, inter_dim=512):
  98. super(ASFF_3, self).__init__()
  99. self.inter_dim = inter_dim
  100. compress_c = 8
  101. self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
  102. self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
  103. self.weight_level_3 = BasicConv(self.inter_dim, compress_c, 1, 1)
  104. self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
  105. self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
  106. def forward(self, input1, input2, input3):
  107. level_1_weight_v = self.weight_level_1(input1)
  108. level_2_weight_v = self.weight_level_2(input2)
  109. level_3_weight_v = self.weight_level_3(input3)
  110. levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
  111. levels_weight = self.weight_levels(levels_weight_v)
  112. levels_weight = F.softmax(levels_weight, dim=1)
  113. fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
  114. input2 * levels_weight[:, 1:2, :, :] + \
  115. input3 * levels_weight[:, 2:, :, :]
  116. out = self.conv(fused_out_reduced)
  117. return out
  118. class ASFF_4(nn.Module):
  119. def __init__(self, inter_dim=512):
  120. super(ASFF_4, self).__init__()
  121. self.inter_dim = inter_dim
  122. compress_c = 8
  123. self.weight_level_0 = BasicConv(self.inter_dim, compress_c, 1, 1)
  124. self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
  125. self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
  126. self.weight_level_3 = BasicConv(self.inter_dim, compress_c, 1, 1)
  127. self.weight_levels = nn.Conv2d(compress_c * 4, 4, kernel_size=1, stride=1, padding=0)
  128. self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
  129. def forward(self, input0, input1, input2, input3):
  130. level_0_weight_v = self.weight_level_0(input0)
  131. level_1_weight_v = self.weight_level_1(input1)
  132. level_2_weight_v = self.weight_level_2(input2)
  133. level_3_weight_v = self.weight_level_3(input3)
  134. levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
  135. levels_weight = self.weight_levels(levels_weight_v)
  136. levels_weight = F.softmax(levels_weight, dim=1)
  137. fused_out_reduced = input0 * levels_weight[:, 0:1, :, :] + \
  138. input1 * levels_weight[:, 1:2, :, :] + \
  139. input2 * levels_weight[:, 2:3, :, :] + \
  140. input3 * levels_weight[:, 3:, :, :]
  141. out = self.conv(fused_out_reduced)
  142. return out
  143. class BlockBody(nn.Module):
  144. def __init__(self, channels=[64, 128, 256, 512]):
  145. super(BlockBody, self).__init__()
  146. self.blocks_scalezero1 = nn.Sequential(
  147. BasicConv(channels[0], channels[0], 1),
  148. )
  149. self.blocks_scaleone1 = nn.Sequential(
  150. BasicConv(channels[1], channels[1], 1),
  151. )
  152. self.blocks_scaletwo1 = nn.Sequential(
  153. BasicConv(channels[2], channels[2], 1),
  154. )
  155. self.blocks_scalethree1 = nn.Sequential(
  156. BasicConv(channels[3], channels[3], 1),
  157. )
  158. self.downsample_scalezero1_2 = Downsample_x2(channels[0], channels[1])
  159. self.upsample_scaleone1_2 = Upsample(channels[1], channels[0], scale_factor=2)
  160. self.asff_scalezero1 = ASFF_2(inter_dim=channels[0])
  161. self.asff_scaleone1 = ASFF_2(inter_dim=channels[1])
  162. self.blocks_scalezero2 = nn.Sequential(
  163. BasicBlock(channels[0], channels[0]),
  164. BasicBlock(channels[0], channels[0]),
  165. BasicBlock(channels[0], channels[0]),
  166. BasicBlock(channels[0], channels[0]),
  167. )
  168. self.blocks_scaleone2 = nn.Sequential(
  169. BasicBlock(channels[1], channels[1]),
  170. BasicBlock(channels[1], channels[1]),
  171. BasicBlock(channels[1], channels[1]),
  172. BasicBlock(channels[1], channels[1]),
  173. )
  174. self.downsample_scalezero2_2 = Downsample_x2(channels[0], channels[1])
  175. self.downsample_scalezero2_4 = Downsample_x4(channels[0], channels[2])
  176. self.downsample_scaleone2_2 = Downsample_x2(channels[1], channels[2])
  177. self.upsample_scaleone2_2 = Upsample(channels[1], channels[0], scale_factor=2)
  178. self.upsample_scaletwo2_2 = Upsample(channels[2], channels[1], scale_factor=2)
  179. self.upsample_scaletwo2_4 = Upsample(channels[2], channels[0], scale_factor=4)
  180. self.asff_scalezero2 = ASFF_3(inter_dim=channels[0])
  181. self.asff_scaleone2 = ASFF_3(inter_dim=channels[1])
  182. self.asff_scaletwo2 = ASFF_3(inter_dim=channels[2])
  183. self.blocks_scalezero3 = nn.Sequential(
  184. BasicBlock(channels[0], channels[0]),
  185. BasicBlock(channels[0], channels[0]),
  186. BasicBlock(channels[0], channels[0]),
  187. BasicBlock(channels[0], channels[0]),
  188. )
  189. self.blocks_scaleone3 = nn.Sequential(
  190. BasicBlock(channels[1], channels[1]),
  191. BasicBlock(channels[1], channels[1]),
  192. BasicBlock(channels[1], channels[1]),
  193. BasicBlock(channels[1], channels[1]),
  194. )
  195. self.blocks_scaletwo3 = nn.Sequential(
  196. BasicBlock(channels[2], channels[2]),
  197. BasicBlock(channels[2], channels[2]),
  198. BasicBlock(channels[2], channels[2]),
  199. BasicBlock(channels[2], channels[2]),
  200. )
  201. self.downsample_scalezero3_2 = Downsample_x2(channels[0], channels[1])
  202. self.downsample_scalezero3_4 = Downsample_x4(channels[0], channels[2])
  203. self.downsample_scalezero3_8 = Downsample_x8(channels[0], channels[3])
  204. self.upsample_scaleone3_2 = Upsample(channels[1], channels[0], scale_factor=2)
  205. self.downsample_scaleone3_2 = Downsample_x2(channels[1], channels[2])
  206. self.downsample_scaleone3_4 = Downsample_x4(channels[1], channels[3])
  207. self.upsample_scaletwo3_4 = Upsample(channels[2], channels[0], scale_factor=4)
  208. self.upsample_scaletwo3_2 = Upsample(channels[2], channels[1], scale_factor=2)
  209. self.downsample_scaletwo3_2 = Downsample_x2(channels[2], channels[3])
  210. self.upsample_scalethree3_8 = Upsample(channels[3], channels[0], scale_factor=8)
  211. self.upsample_scalethree3_4 = Upsample(channels[3], channels[1], scale_factor=4)
  212. self.upsample_scalethree3_2 = Upsample(channels[3], channels[2], scale_factor=2)
  213. self.asff_scalezero3 = ASFF_4(inter_dim=channels[0])
  214. self.asff_scaleone3 = ASFF_4(inter_dim=channels[1])
  215. self.asff_scaletwo3 = ASFF_4(inter_dim=channels[2])
  216. self.asff_scalethree3 = ASFF_4(inter_dim=channels[3])
  217. self.blocks_scalezero4 = nn.Sequential(
  218. BasicBlock(channels[0], channels[0]),
  219. BasicBlock(channels[0], channels[0]),
  220. BasicBlock(channels[0], channels[0]),
  221. BasicBlock(channels[0], channels[0]),
  222. )
  223. self.blocks_scaleone4 = nn.Sequential(
  224. BasicBlock(channels[1], channels[1]),
  225. BasicBlock(channels[1], channels[1]),
  226. BasicBlock(channels[1], channels[1]),
  227. BasicBlock(channels[1], channels[1]),
  228. )
  229. self.blocks_scaletwo4 = nn.Sequential(
  230. BasicBlock(channels[2], channels[2]),
  231. BasicBlock(channels[2], channels[2]),
  232. BasicBlock(channels[2], channels[2]),
  233. BasicBlock(channels[2], channels[2]),
  234. )
  235. self.blocks_scalethree4 = nn.Sequential(
  236. BasicBlock(channels[3], channels[3]),
  237. BasicBlock(channels[3], channels[3]),
  238. BasicBlock(channels[3], channels[3]),
  239. BasicBlock(channels[3], channels[3]),
  240. )
  241. def forward(self, x):
  242. x0, x1, x2, x3 = x
  243. x0 = self.blocks_scalezero1(x0)
  244. x1 = self.blocks_scaleone1(x1)
  245. x2 = self.blocks_scaletwo1(x2)
  246. x3 = self.blocks_scalethree1(x3)
  247. scalezero = self.asff_scalezero1(x0, self.upsample_scaleone1_2(x1))
  248. scaleone = self.asff_scaleone1(self.downsample_scalezero1_2(x0), x1)
  249. x0 = self.blocks_scalezero2(scalezero)
  250. x1 = self.blocks_scaleone2(scaleone)
  251. scalezero = self.asff_scalezero2(x0, self.upsample_scaleone2_2(x1), self.upsample_scaletwo2_4(x2))
  252. scaleone = self.asff_scaleone2(self.downsample_scalezero2_2(x0), x1, self.upsample_scaletwo2_2(x2))
  253. scaletwo = self.asff_scaletwo2(self.downsample_scalezero2_4(x0), self.downsample_scaleone2_2(x1), x2)
  254. x0 = self.blocks_scalezero3(scalezero)
  255. x1 = self.blocks_scaleone3(scaleone)
  256. x2 = self.blocks_scaletwo3(scaletwo)
  257. scalezero = self.asff_scalezero3(x0, self.upsample_scaleone3_2(x1), self.upsample_scaletwo3_4(x2), self.upsample_scalethree3_8(x3))
  258. scaleone = self.asff_scaleone3(self.downsample_scalezero3_2(x0), x1, self.upsample_scaletwo3_2(x2), self.upsample_scalethree3_4(x3))
  259. scaletwo = self.asff_scaletwo3(self.downsample_scalezero3_4(x0), self.downsample_scaleone3_2(x1), x2, self.upsample_scalethree3_2(x3))
  260. scalethree = self.asff_scalethree3(self.downsample_scalezero3_8(x0), self.downsample_scaleone3_4(x1), self.downsample_scaletwo3_2(x2), x3)
  261. scalezero = self.blocks_scalezero4(scalezero)
  262. scaleone = self.blocks_scaleone4(scaleone)
  263. scaletwo = self.blocks_scaletwo4(scaletwo)
  264. scalethree = self.blocks_scalethree4(scalethree)
  265. return scalezero, scaleone, scaletwo, scalethree
  266. class AFPN(nn.Module):
  267. def __init__(self,
  268. in_channels=[256, 512, 1024, 2048],
  269. out_channels=128):
  270. super(AFPN, self).__init__()
  271. self.fp16_enabled = False
  272. self.conv0 = BasicConv(in_channels[0], in_channels[0] // 8, 1)
  273. self.conv1 = BasicConv(in_channels[1], in_channels[1] // 8, 1)
  274. self.conv2 = BasicConv(in_channels[2], in_channels[2] // 8, 1)
  275. self.conv3 = BasicConv(in_channels[3], in_channels[3] // 8, 1)
  276. self.body = nn.Sequential(
  277. BlockBody([in_channels[0] // 8, in_channels[1] // 8, in_channels[2] // 8, in_channels[3] // 8])
  278. )
  279. self.conv00 = BasicConv(in_channels[0] // 8, out_channels, 1)
  280. self.conv11 = BasicConv(in_channels[1] // 8, out_channels, 1)
  281. self.conv22 = BasicConv(in_channels[2] // 8, out_channels, 1)
  282. self.conv33 = BasicConv(in_channels[3] // 8, out_channels, 1)
  283. self.conv44 = nn.MaxPool2d(kernel_size=1, stride=2)
  284. # init weight
  285. for m in self.modules():
  286. if isinstance(m, nn.Conv2d):
  287. nn.init.xavier_normal_(m.weight, gain=0.02)
  288. elif isinstance(m, nn.BatchNorm2d):
  289. torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
  290. torch.nn.init.constant_(m.bias.data, 0.0)
  291. def forward(self, x):
  292. x0, x1, x2, x3 = x
  293. x0 = self.conv0(x0)
  294. x1 = self.conv1(x1)
  295. x2 = self.conv2(x2)
  296. x3 = self.conv3(x3)
  297. out0, out1, out2, out3 = self.body([x0, x1, x2, x3])
  298. out0 = self.conv00(out0)
  299. out1 = self.conv11(out1)
  300. out2 = self.conv22(out2)
  301. out3 = self.conv33(out3)
  302. return out0, out1, out2, out3
  303. class DWConv(Conv):
  304. """Depth-wise convolution."""
  305. def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
  306. """Initialize Depth-wise convolution with given parameters."""
  307. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
  308. class AFPN4Head(nn.Module):
  309. """YOLOv8 Detect head for detection models. CSDNSnu77"""
  310. dynamic = False # force grid reconstruction
  311. export = False # export mode
  312. end2end = False # end2end
  313. max_det = 300 # max_det
  314. shape = None
  315. anchors = torch.empty(0) # init
  316. strides = torch.empty(0) # init
  317. def __init__(self, nc=80, channel=256, ch=()):
  318. """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
  319. super().__init__()
  320. self.nc = nc # number of classes
  321. self.nl = len(ch) # number of detection layers
  322. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  323. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  324. self.stride = torch.zeros(self.nl) # strides computed during build
  325. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
  326. self.cv2 = nn.ModuleList(
  327. nn.Sequential(Conv(channel, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
  328. )
  329. self.cv3 = nn.ModuleList(
  330. nn.Sequential(
  331. nn.Sequential(DWConv(channel, channel, 3), Conv(channel, c3, 1)),
  332. nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
  333. nn.Conv2d(c3, self.nc, 1),
  334. )
  335. for x in ch
  336. )
  337. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  338. self.AFPN = AFPN(ch, channel) # 4
  339. if self.end2end:
  340. self.one2one_cv2 = copy.deepcopy(self.cv2)
  341. self.one2one_cv3 = copy.deepcopy(self.cv3)
  342. def forward(self, x):
  343. x = list(self.AFPN(x))
  344. """Concatenates and returns predicted bounding boxes and class probabilities."""
  345. if self.end2end:
  346. return self.forward_end2end(x)
  347. for i in range(self.nl):
  348. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  349. if self.training: # Training path
  350. return x
  351. y = self._inference(x)
  352. return y if self.export else (y, x)
  353. def forward_end2end(self, x):
  354. """
  355. Performs forward pass of the v10Detect module.
  356. Args:
  357. x (tensor): Input tensor.
  358. Returns:
  359. (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
  360. If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
  361. """
  362. x_detach = [xi.detach() for xi in x]
  363. one2one = [
  364. torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
  365. ]
  366. for i in range(self.nl):
  367. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  368. if self.training: # Training path
  369. return {"one2many": x, "one2one": one2one}
  370. y = self._inference(one2one)
  371. y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
  372. return y if self.export else (y, {"one2many": x, "one2one": one2one})
  373. def _inference(self, x):
  374. """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
  375. # Inference path
  376. shape = x[0].shape # BCHW
  377. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  378. if self.dynamic or self.shape != shape:
  379. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  380. self.shape = shape
  381. if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
  382. box = x_cat[:, : self.reg_max * 4]
  383. cls = x_cat[:, self.reg_max * 4 :]
  384. else:
  385. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  386. if self.export and self.format in {"tflite", "edgetpu"}:
  387. # Precompute normalization factor to increase numerical stability
  388. # See https://github.com/ultralytics/ultralytics/issues/7371
  389. grid_h = shape[2]
  390. grid_w = shape[3]
  391. grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
  392. norm = self.strides / (self.stride[0] * grid_size)
  393. dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
  394. else:
  395. dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
  396. return torch.cat((dbox, cls.sigmoid()), 1)
  397. def bias_init(self):
  398. """Initialize Detect() biases, WARNING: requires stride availability."""
  399. m = self # self.model[-1] # Detect() module
  400. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  401. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  402. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  403. a[-1].bias.data[:] = 1.0 # box
  404. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  405. if self.end2end:
  406. for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
  407. a[-1].bias.data[:] = 1.0 # box
  408. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  409. def decode_bboxes(self, bboxes, anchors):
  410. """Decode bounding boxes."""
  411. return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
  412. @staticmethod
  413. def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
  414. """
  415. Post-processes YOLO model predictions.
  416. Args:
  417. preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
  418. format [x, y, w, h, class_probs].
  419. max_det (int): Maximum detections per image.
  420. nc (int, optional): Number of classes. Default: 80.
  421. Returns:
  422. (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
  423. dimension format [x, y, w, h, max_class_prob, class_index].
  424. """
  425. batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
  426. boxes, scores = preds.split([4, nc], dim=-1)
  427. index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
  428. boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
  429. scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
  430. scores, index = scores.flatten(1).topk(min(max_det, anchors))
  431. i = torch.arange(batch_size)[..., None] # batch indices
  432. return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)


四、手把手教你添加AFPN4Head检测头

4.1 修改一

首先我们将上面的代码复制粘贴到' ultralytics /nn' 目录下新建一个py文件复制粘贴进去,具体名字自己来定,我这里起名为AFPN4Head.py。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。

​​


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

​​


4.4 修改四

第四步我门找到如下文件'ultralytics/nn/tasks.py,找到如下的代码进行将检测头添加进去,这里给大家推荐个快速搜索的方法用ctrl+f然后搜索Detect然后就能快速查找了。

​​​


4.5 修改五


4.6 修改六

同理

​​​


4.7 修改七

这里有一些不一样,我们需要加一行代码

  1. else:
  2. return 'detect'

为啥呢不一样,因为这里的m在代码执行过程中会将你的代码自动转换为小写,所以直接else方便一点,以后出现一些其它分割或者其它的教程的时候在提供其它的修改教程。

​​​


4.8 修改八

同理.

​​​


到此就修改完成了,大家可以复制下面的yaml文件运行。


五、AFPN4Head检测头的yaml文件

此版本训练信息:YOLO11-AFPN4Head summary: 960 layers, 2,715,185 parameters, 2,715,169 gradients, 12.0 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLOv11.0-p2 head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  35. - [[-1, 2], 1, Concat, [1]] # cat backbone P2
  36. - [-1, 2, C3k2, [128, False]] # 19 (P2/4-xsmall) # 小目标可以尝试将这里的False设置为True.
  37. - [-1, 1, Conv, [128, 3, 2]]
  38. - [[-1, 16], 1, Concat, [1]] # cat head P3
  39. - [-1, 2, C3k2, [256, False]] # 22 (P3/8-small)
  40. - [-1, 1, Conv, [256, 3, 2]]
  41. - [[-1, 13], 1, Concat, [1]] # cat head P4
  42. - [-1, 2, C3k2, [512, False]] # 25 (P4/16-medium)
  43. - [-1, 1, Conv, [512, 3, 2]]
  44. - [[-1, 10], 1, Concat, [1]] # cat head P5
  45. - [-1, 2, C3k2, [1024, True]] # 28 (P5/32-large)
  46. # 64 可以尝试设置为32 128 256
  47. - [[19, 22, 25, 28], 1, AFPN4Head, [nc, 64]] # Detect(P2, P3, P4, P5)


六、完美运行记录

最后提供一下完美运行的图片。

​​


七、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~