学习资源站

YOLOv11改进-主干_Backbone篇-RepViT从视觉变换器(ViT)的视角重新审视CNN的目标检测网络(适配yolov11全系列)

一、本文介绍

本文给大家来的改进机制是 RepViT ,用其替换我们整个主干网络,其是今年最新推出的主干网络,其主要思想是将轻量级视觉 变换器 (ViT)的设计原则应用于传统的轻量级 卷积神经网络 (CNN)。我将其替换整个YOLOv11的Backbone,实现了大幅度涨点。我对修改后的网络(我用的最轻量的版本),在一个包含1000张图片包含大中小的检测目标的数据集上(共有20+类别),进行训练测试, 发现所有的目标上均有一定程度的涨点效果 ,下面我会附上基础版本和修改版本的训练对比图。

(本文内容可根据yolov11的N、S、M、L、X进行二次缩放,轻量化更上一层)。



二、RepViT基本原理

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转


RepViT: Revisiting Mobile CNN From ViT Perspective 这篇论文探讨了如何改进轻量级卷积 神经网络 (CNN)以提高其在移动设备上的性能和效率。作者们发现,虽然轻量级视觉变换器( ViT )因其能够学习全局表示而表现出色,但轻量级CNN和轻量级ViT之间的架构差异尚未得到充分研究。因此,他们通过整合轻量级ViT的高效架构设计,逐步改进标准轻量级CNN(特别是MobileNetV3),从而创造了一系列全新的纯CNN模型,称为RepViT。这些模型在各种视觉任务上表现出色,比现有的轻量级ViT更高效。

其主要的改进机制包括:

  1. 结构性重组 :通过结构性重组(Structural Re-parameterization, SR),引入多分支拓扑结构,以提高训练时的性能。

  2. 扩展比率调整 :调整卷积层中的扩展比率,以减少参数冗余和延迟,同时提高网络宽度以增强 模型 性能。

  3. 宏观设计优化 :对网络的宏观架构进行优化,包括早期卷积层的设计、更深的下采样层、简化的分类器,以及整体阶段比例的调整。

  4. 微观设计调整 :在微观架构层面进行优化,包括卷积核大小的选择和压缩激励(SE)层的最佳放置。

这些创新机制共同推动了轻量级CNN的性能和效率,使其更适合在移动设备上使用,下面的是官方论文中的结构图,我们对其进行简单的分析。

这张图片是论文中的图3,展示了RepViT架构的总览。RepViT有四个阶段,输入图像的分辨率依次为

每个阶段的通道维度用 Ci​ 表示,批处理大小用 B 表示。

  • Stem :用于预处理输入图像的模块。
  • Stage1-4 :每个阶段由多个RepViTBlock组成,以及一个可选的RepViTSEBlock,包含深度可分离卷积(3x3DW),1x1卷积,压缩激励模块(SE)和前馈网络(FFN)。每个阶段通过下采样减少空间维度。
  • Pooling :全局平均池化层,用于减少特征图的空间维度。
  • FC :全连接层,用于最终的类别预测。

总结: 大家可以将RepViT看成是MobileNet系列的改进版本


三、RepViT的核心代码

下面的代码是整个RepViT的核心代码,其中有个版本,对应的GFLOPs也不相同,使用方式看章节四。

  1. from symbol import factor
  2. import torch.nn as nn
  3. from timm.models.layers import SqueezeExcite
  4. import torch
  5. __all__ = ['repvit_m0_6','repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']
  6. def _make_divisible(v, divisor, min_value=None):
  7. """
  8. This function is taken from the original tf repo.
  9. It ensures that all layers have a channel number that is divisible by 8
  10. It can be seen here:
  11. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  12. :param v:
  13. :param divisor:
  14. :param min_value:
  15. :return:
  16. """
  17. if min_value is None:
  18. min_value = divisor
  19. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  20. # Make sure that round down does not go down by more than 10%.
  21. if new_v < 0.9 * v:
  22. new_v += divisor
  23. return new_v
  24. class Conv2d_BN(torch.nn.Sequential):
  25. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
  26. groups=1, bn_weight_init=1, resolution=-10000):
  27. super().__init__()
  28. self.add_module('c', torch.nn.Conv2d(
  29. a, b, ks, stride, pad, dilation, groups, bias=False))
  30. self.add_module('bn', torch.nn.BatchNorm2d(b))
  31. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  32. torch.nn.init.constant_(self.bn.bias, 0)
  33. @torch.no_grad()
  34. def fuse_self(self):
  35. c, bn = self._modules.values()
  36. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  37. w = c.weight * w[:, None, None, None]
  38. b = bn.bias - bn.running_mean * bn.weight / \
  39. (bn.running_var + bn.eps) ** 0.5
  40. m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
  41. 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation,
  42. groups=self.c.groups,
  43. device=c.weight.device)
  44. m.weight.data.copy_(w)
  45. m.bias.data.copy_(b)
  46. return m
  47. class Residual(torch.nn.Module):
  48. def __init__(self, m, drop=0.):
  49. super().__init__()
  50. self.m = m
  51. self.drop = drop
  52. def forward(self, x):
  53. if self.training and self.drop > 0:
  54. return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
  55. device=x.device).ge_(self.drop).div(1 - self.drop).detach()
  56. else:
  57. return x + self.m(x)
  58. @torch.no_grad()
  59. def fuse_self(self):
  60. if isinstance(self.m, Conv2d_BN):
  61. m = self.m.fuse_self()
  62. assert (m.groups == m.in_channels)
  63. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  64. identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
  65. m.weight += identity.to(m.weight.device)
  66. return m
  67. elif isinstance(self.m, torch.nn.Conv2d):
  68. m = self.m
  69. assert (m.groups != m.in_channels)
  70. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  71. identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
  72. m.weight += identity.to(m.weight.device)
  73. return m
  74. else:
  75. return self
  76. class RepVGGDW(torch.nn.Module):
  77. def __init__(self, ed) -> None:
  78. super().__init__()
  79. self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
  80. self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
  81. self.dim = ed
  82. self.bn = torch.nn.BatchNorm2d(ed)
  83. def forward(self, x):
  84. return self.bn((self.conv(x) + self.conv1(x)) + x)
  85. @torch.no_grad()
  86. def fuse_self(self):
  87. conv = self.conv.fuse_self()
  88. conv1 = self.conv1
  89. conv_w = conv.weight
  90. conv_b = conv.bias
  91. conv1_w = conv1.weight
  92. conv1_b = conv1.bias
  93. conv1_w = torch.nn.functional.pad(conv1_w, [1, 1, 1, 1])
  94. identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device),
  95. [1, 1, 1, 1])
  96. final_conv_w = conv_w + conv1_w + identity
  97. final_conv_b = conv_b + conv1_b
  98. conv.weight.data.copy_(final_conv_w)
  99. conv.bias.data.copy_(final_conv_b)
  100. bn = self.bn
  101. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  102. w = conv.weight * w[:, None, None, None]
  103. b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
  104. (bn.running_var + bn.eps) ** 0.5
  105. conv.weight.data.copy_(w)
  106. conv.bias.data.copy_(b)
  107. return conv
  108. class RepViTBlock(nn.Module):
  109. def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
  110. super(RepViTBlock, self).__init__()
  111. assert stride in [1, 2]
  112. self.identity = stride == 1 and inp == oup
  113. assert (hidden_dim == 2 * inp)
  114. if stride == 2:
  115. self.token_mixer = nn.Sequential(
  116. Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
  117. SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
  118. Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
  119. )
  120. self.channel_mixer = Residual(nn.Sequential(
  121. # pw
  122. Conv2d_BN(oup, 2 * oup, 1, 1, 0),
  123. nn.GELU() if use_hs else nn.GELU(),
  124. # pw-linear
  125. Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
  126. ))
  127. else:
  128. self.token_mixer = nn.Sequential(
  129. RepVGGDW(inp),
  130. SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
  131. )
  132. self.channel_mixer = Residual(nn.Sequential(
  133. # pw
  134. Conv2d_BN(inp, hidden_dim, 1, 1, 0),
  135. nn.GELU() if use_hs else nn.GELU(),
  136. # pw-linear
  137. Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
  138. ))
  139. def forward(self, x):
  140. return self.channel_mixer(self.token_mixer(x))
  141. class RepViT(nn.Module):
  142. def __init__(self, cfgs, factor):
  143. super(RepViT, self).__init__()
  144. # setting of inverted residual blocks
  145. cfgs = [sublist[:2] + [_make_divisible(int(sublist[2] * factor) , 8)] + sublist[3:] for sublist in cfgs]
  146. self.cfgs = cfgs
  147. # building first layer
  148. input_channel = self.cfgs[0][2]
  149. patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
  150. Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
  151. layers = [patch_embed]
  152. # building inverted residual blocks
  153. block = RepViTBlock
  154. for k, t, c, use_se, use_hs, s in self.cfgs:
  155. output_channel = _make_divisible(c , 8)
  156. exp_size = _make_divisible(input_channel * t, 8)
  157. layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
  158. input_channel = output_channel
  159. self.features = nn.ModuleList(layers)
  160. self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
  161. def forward(self, x):
  162. # x = self.features(x
  163. results = [None, None, None, None]
  164. temp = None
  165. i = None
  166. for index, f in enumerate(self.features):
  167. x = f(x)
  168. if index == 0:
  169. temp = x.size(1)
  170. i = 0
  171. elif x.size(1) == temp:
  172. results[i] = x
  173. else:
  174. temp = x.size(1)
  175. i = i + 1
  176. return results
  177. def repvit_m0_6(factor):
  178. """
  179. Constructs a MobileNetV3-Large model
  180. """
  181. cfgs = [
  182. [3, 2, 40, 1, 0, 1],
  183. [3, 2, 40, 0, 0, 1],
  184. [3, 2, 80, 0, 0, 2],
  185. [3, 2, 80, 1, 0, 1],
  186. [3, 2, 80, 0, 0, 1],
  187. [3, 2, 160, 0, 1, 2],
  188. [3, 2, 160, 1, 1, 1],
  189. [3, 2, 160, 0, 1, 1],
  190. [3, 2, 160, 1, 1, 1],
  191. [3, 2, 160, 0, 1, 1],
  192. [3, 2, 160, 1, 1, 1],
  193. [3, 2, 160, 0, 1, 1],
  194. [3, 2, 160, 1, 1, 1],
  195. [3, 2, 160, 0, 1, 1],
  196. [3, 2, 160, 0, 1, 1],
  197. [3, 2, 320, 0, 1, 2],
  198. [3, 2, 320, 1, 1, 1],
  199. ]
  200. model = RepViT(cfgs, factor)
  201. return model
  202. def repvit_m0_9(factor):
  203. """
  204. Constructs a MobileNetV3-Large model
  205. """
  206. cfgs = [
  207. # k, t, c, SE, HS, s
  208. [3, 2, 48, 1, 0, 1],
  209. [3, 2, 48, 0, 0, 1],
  210. [3, 2, 48, 0, 0, 1],
  211. [3, 2, 96, 0, 0, 2],
  212. [3, 2, 96, 1, 0, 1],
  213. [3, 2, 96, 0, 0, 1],
  214. [3, 2, 96, 0, 0, 1],
  215. [3, 2, 192, 0, 1, 2],
  216. [3, 2, 192, 1, 1, 1],
  217. [3, 2, 192, 0, 1, 1],
  218. [3, 2, 192, 1, 1, 1],
  219. [3, 2, 192, 0, 1, 1],
  220. [3, 2, 192, 1, 1, 1],
  221. [3, 2, 192, 0, 1, 1],
  222. [3, 2, 192, 1, 1, 1],
  223. [3, 2, 192, 0, 1, 1],
  224. [3, 2, 192, 1, 1, 1],
  225. [3, 2, 192, 0, 1, 1],
  226. [3, 2, 192, 1, 1, 1],
  227. [3, 2, 192, 0, 1, 1],
  228. [3, 2, 192, 1, 1, 1],
  229. [3, 2, 192, 0, 1, 1],
  230. [3, 2, 192, 0, 1, 1],
  231. [3, 2, 384, 0, 1, 2],
  232. [3, 2, 384, 1, 1, 1],
  233. [3, 2, 384, 0, 1, 1]
  234. ]
  235. model = RepViT(cfgs, factor)
  236. return model
  237. def repvit_m1_0(factor):
  238. """
  239. Constructs a MobileNetV3-Large model
  240. """
  241. cfgs = [
  242. # k, t, c, SE, HS, s
  243. [3, 2, 56, 1, 0, 1],
  244. [3, 2, 56, 0, 0, 1],
  245. [3, 2, 56, 0, 0, 1],
  246. [3, 2, 112, 0, 0, 2],
  247. [3, 2, 112, 1, 0, 1],
  248. [3, 2, 112, 0, 0, 1],
  249. [3, 2, 112, 0, 0, 1],
  250. [3, 2, 224, 0, 1, 2],
  251. [3, 2, 224, 1, 1, 1],
  252. [3, 2, 224, 0, 1, 1],
  253. [3, 2, 224, 1, 1, 1],
  254. [3, 2, 224, 0, 1, 1],
  255. [3, 2, 224, 1, 1, 1],
  256. [3, 2, 224, 0, 1, 1],
  257. [3, 2, 224, 1, 1, 1],
  258. [3, 2, 224, 0, 1, 1],
  259. [3, 2, 224, 1, 1, 1],
  260. [3, 2, 224, 0, 1, 1],
  261. [3, 2, 224, 1, 1, 1],
  262. [3, 2, 224, 0, 1, 1],
  263. [3, 2, 224, 1, 1, 1],
  264. [3, 2, 224, 0, 1, 1],
  265. [3, 2, 224, 0, 1, 1],
  266. [3, 2, 448, 0, 1, 2],
  267. [3, 2, 448, 1, 1, 1],
  268. [3, 2, 448, 0, 1, 1]
  269. ]
  270. model = RepViT(cfgs,factor=factor)
  271. return model
  272. def repvit_m1_1(factor):
  273. """
  274. Constructs a MobileNetV3-Large model
  275. """
  276. cfgs = [
  277. # k, t, c, SE, HS, s
  278. [3, 2, 64, 1, 0, 1],
  279. [3, 2, 64, 0, 0, 1],
  280. [3, 2, 64, 0, 0, 1],
  281. [3, 2, 128, 0, 0, 2],
  282. [3, 2, 128, 1, 0, 1],
  283. [3, 2, 128, 0, 0, 1],
  284. [3, 2, 128, 0, 0, 1],
  285. [3, 2, 256, 0, 1, 2],
  286. [3, 2, 256, 1, 1, 1],
  287. [3, 2, 256, 0, 1, 1],
  288. [3, 2, 256, 1, 1, 1],
  289. [3, 2, 256, 0, 1, 1],
  290. [3, 2, 256, 1, 1, 1],
  291. [3, 2, 256, 0, 1, 1],
  292. [3, 2, 256, 1, 1, 1],
  293. [3, 2, 256, 0, 1, 1],
  294. [3, 2, 256, 1, 1, 1],
  295. [3, 2, 256, 0, 1, 1],
  296. [3, 2, 256, 1, 1, 1],
  297. [3, 2, 256, 0, 1, 1],
  298. [3, 2, 256, 0, 1, 1],
  299. [3, 2, 512, 0, 1, 2],
  300. [3, 2, 512, 1, 1, 1],
  301. [3, 2, 512, 0, 1, 1]
  302. ]
  303. model = RepViT(cfgs,factor=factor)
  304. return model
  305. def repvit_m1_5(factor):
  306. """
  307. Constructs a MobileNetV3-Large model
  308. """
  309. cfgs = [
  310. # k, t, c, SE, HS, s
  311. [3, 2, 64, 1, 0, 1],
  312. [3, 2, 64, 0, 0, 1],
  313. [3, 2, 64, 1, 0, 1],
  314. [3, 2, 64, 0, 0, 1],
  315. [3, 2, 64, 0, 0, 1],
  316. [3, 2, 128, 0, 0, 2],
  317. [3, 2, 128, 1, 0, 1],
  318. [3, 2, 128, 0, 0, 1],
  319. [3, 2, 128, 1, 0, 1],
  320. [3, 2, 128, 0, 0, 1],
  321. [3, 2, 128, 0, 0, 1],
  322. [3, 2, 256, 0, 1, 2],
  323. [3, 2, 256, 1, 1, 1],
  324. [3, 2, 256, 0, 1, 1],
  325. [3, 2, 256, 1, 1, 1],
  326. [3, 2, 256, 0, 1, 1],
  327. [3, 2, 256, 1, 1, 1],
  328. [3, 2, 256, 0, 1, 1],
  329. [3, 2, 256, 1, 1, 1],
  330. [3, 2, 256, 0, 1, 1],
  331. [3, 2, 256, 1, 1, 1],
  332. [3, 2, 256, 0, 1, 1],
  333. [3, 2, 256, 1, 1, 1],
  334. [3, 2, 256, 0, 1, 1],
  335. [3, 2, 256, 1, 1, 1],
  336. [3, 2, 256, 0, 1, 1],
  337. [3, 2, 256, 1, 1, 1],
  338. [3, 2, 256, 0, 1, 1],
  339. [3, 2, 256, 1, 1, 1],
  340. [3, 2, 256, 0, 1, 1],
  341. [3, 2, 256, 1, 1, 1],
  342. [3, 2, 256, 0, 1, 1],
  343. [3, 2, 256, 1, 1, 1],
  344. [3, 2, 256, 0, 1, 1],
  345. [3, 2, 256, 1, 1, 1],
  346. [3, 2, 256, 0, 1, 1],
  347. [3, 2, 256, 0, 1, 1],
  348. [3, 2, 512, 0, 1, 2],
  349. [3, 2, 512, 1, 1, 1],
  350. [3, 2, 512, 0, 1, 1],
  351. [3, 2, 512, 1, 1, 1],
  352. [3, 2, 512, 0, 1, 1]
  353. ]
  354. model = RepViT(cfgs,factor=factor)
  355. return model
  356. def repvit_m2_3(factor):
  357. """
  358. Constructs a MobileNetV3-Large model
  359. """
  360. cfgs = [
  361. # k, t, c, SE, HS, s
  362. [3, 2, 80, 1, 0, 1],
  363. [3, 2, 80, 0, 0, 1],
  364. [3, 2, 80, 1, 0, 1],
  365. [3, 2, 80, 0, 0, 1],
  366. [3, 2, 80, 1, 0, 1],
  367. [3, 2, 80, 0, 0, 1],
  368. [3, 2, 80, 0, 0, 1],
  369. [3, 2, 160, 0, 0, 2],
  370. [3, 2, 160, 1, 0, 1],
  371. [3, 2, 160, 0, 0, 1],
  372. [3, 2, 160, 1, 0, 1],
  373. [3, 2, 160, 0, 0, 1],
  374. [3, 2, 160, 1, 0, 1],
  375. [3, 2, 160, 0, 0, 1],
  376. [3, 2, 160, 0, 0, 1],
  377. [3, 2, 320, 0, 1, 2],
  378. [3, 2, 320, 1, 1, 1],
  379. [3, 2, 320, 0, 1, 1],
  380. [3, 2, 320, 1, 1, 1],
  381. [3, 2, 320, 0, 1, 1],
  382. [3, 2, 320, 1, 1, 1],
  383. [3, 2, 320, 0, 1, 1],
  384. [3, 2, 320, 1, 1, 1],
  385. [3, 2, 320, 0, 1, 1],
  386. [3, 2, 320, 1, 1, 1],
  387. [3, 2, 320, 0, 1, 1],
  388. [3, 2, 320, 1, 1, 1],
  389. [3, 2, 320, 0, 1, 1],
  390. [3, 2, 320, 1, 1, 1],
  391. [3, 2, 320, 0, 1, 1],
  392. [3, 2, 320, 1, 1, 1],
  393. [3, 2, 320, 0, 1, 1],
  394. [3, 2, 320, 1, 1, 1],
  395. [3, 2, 320, 0, 1, 1],
  396. [3, 2, 320, 1, 1, 1],
  397. [3, 2, 320, 0, 1, 1],
  398. [3, 2, 320, 1, 1, 1],
  399. [3, 2, 320, 0, 1, 1],
  400. [3, 2, 320, 1, 1, 1],
  401. [3, 2, 320, 0, 1, 1],
  402. [3, 2, 320, 1, 1, 1],
  403. [3, 2, 320, 0, 1, 1],
  404. [3, 2, 320, 1, 1, 1],
  405. [3, 2, 320, 0, 1, 1],
  406. [3, 2, 320, 1, 1, 1],
  407. [3, 2, 320, 0, 1, 1],
  408. [3, 2, 320, 1, 1, 1],
  409. [3, 2, 320, 0, 1, 1],
  410. [3, 2, 320, 1, 1, 1],
  411. [3, 2, 320, 0, 1, 1],
  412. # [3, 2, 320, 1, 1, 1],
  413. # [3, 2, 320, 0, 1, 1],
  414. [3, 2, 320, 0, 1, 1],
  415. [3, 2, 640, 0, 1, 2],
  416. [3, 2, 640, 1, 1, 1],
  417. [3, 2, 640, 0, 1, 1],
  418. # [3, 2, 640, 1, 1, 1],
  419. # [3, 2, 640, 0, 1, 1]
  420. ]
  421. model = RepViT(cfgs,factor=factor)
  422. return model
  423. if __name__ == '__main__':
  424. model = repvit_m0_6(factor=0.25)
  425. inputs = torch.randn((1, 3, 640, 640))
  426. for i in model(inputs):
  427. print(i.size())


四、手把手教你添加RepViT网络结构

4.1 修改一

第一步还是建立文件,我们找到如下ultralytics/nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹( 用群内的文件的话已经有了无需新建) !然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四

添加如下两行代码!!!


4.5 修改五

找到七百多行大概把具体看图片,按照图片来修改就行,添加红框内的部分,注意没有()只是 函数 名。

  1. elif m in {自行添加对应的模型即可,下面都是一样的}:
  2. m = m(*args)
  3. c2 = m.width_list # 返回通道列表
  4. backbone = True


4.6 修改六

下面的两个红框内都是需要改动的。

  1. if isinstance(c2, list):
  2. m_ = m
  3. m_.backbone = True
  4. else:
  5. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  6. t = str(m)[8:-2].replace('__main__.', '') # module type
  7. m.np = sum(x.numel() for x in m_.parameters()) # number params
  8. m_.i, m_.f, m_.type = i + 4 if backbone else i, f, t # attach index, 'from' index, type


4.7 修改七

如下的也需要修改,全部按照我的来。

代码如下把原先的代码替换了即可。

  1. if verbose:
  2. LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
  3. save.extend(x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  4. layers.append(m_)
  5. if i == 0:
  6. ch = []
  7. if isinstance(c2, list):
  8. ch.extend(c2)
  9. if len(c2) != 5:
  10. ch.insert(0, 0)
  11. else:
  12. ch.append(c2)


4.8 修改八

修改八和前面的都不太一样,需要修改前向传播中的一个部分, 已经离开了parse_model方法了。

可以在图片中开代码行数,没有离开task.py文件都是同一个文件。 同时这个部分有好几个前向传播都很相似,大家不要看错了, 是70多行左右的!!!,同时我后面提供了代码,大家直接复制粘贴即可,有时间我针对这里会出一个视频。

​​

代码如下->

  1. def _predict_once(self, x, profile=False, visualize=False, embed=None):
  2. """
  3. Perform a forward pass through the network.
  4. Args:
  5. x (torch.Tensor): The input tensor to the model.
  6. profile (bool): Print the computation time of each layer if True, defaults to False.
  7. visualize (bool): Save the feature maps of the model if True, defaults to False.
  8. embed (list, optional): A list of feature vectors/embeddings to return.
  9. Returns:
  10. (torch.Tensor): The last output of the model.
  11. """
  12. y, dt, embeddings = [], [], [] # outputs
  13. for m in self.model:
  14. if m.f != -1: # if not from previous layer
  15. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  16. if profile:
  17. self._profile_one_layer(m, x, dt)
  18. if hasattr(m, 'backbone'):
  19. x = m(x)
  20. if len(x) != 5: # 0 - 5
  21. x.insert(0, None)
  22. for index, i in enumerate(x):
  23. if index in self.save:
  24. y.append(i)
  25. else:
  26. y.append(None)
  27. x = x[-1] # 最后一个输出传给下一层
  28. else:
  29. x = m(x) # run
  30. y.append(x if m.i in self.save else None) # save output
  31. if visualize:
  32. feature_visualization(x, m.type, m.i, save_dir=visualize)
  33. if embed and m.i in embed:
  34. embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
  35. if m.i == max(embed):
  36. return torch.unbind(torch.cat(embeddings, 1), dim=0)
  37. return x

到这里就完成了修改部分,但是这里面细节很多,大家千万要注意不要替换多余的代码,导致报错,也不要拉下任何一部,都会导致运行失败,而且报错很难排查!!!很难排查!!!


注意!!! 额外的修改!

关注我的其实都知道,我大部分的修改都是一样的,这个网络需要额外的修改一步,就是s一个参数,将下面的s改为640!!!即可完美运行!!


打印计算量问题解决方案

我们找到如下文件'ultralytics/utils/torch_utils.py'按照如下的图片进行修改,否则容易打印不出来计算量。


注意事项!!!

如果大家在验证的时候报错形状不匹配的错误可以固定 验证集 的图片尺寸,方法如下 ->

找到下面这个文件ultralytics/ models /yolo/detect/train.py然后其中有一个类是DetectionTrainer class中的build_dataset函数中的一个参数rect=mode == 'val'改为rect=False


五、RepViT的yaml文件

复制如下yaml文件进行运行!!!


5.1 RepViT 的yaml文件版本1

此版本训练信息:YOLO11-RepViT summary: 559 layers, 2,118,115 parameters, 2,118,099 gradients, 5.4 GFLOPs

使用说明:# 下面 [-1, 1, LSKNet, [0.25]] 参数位置的0.25是通道放缩的系数, YOLOv11N是0.25 YOLOv11S是0.5 YOLOv11M是1. YOLOv11l是1 YOLOv11是1.5大家根据自己训练的YOLO版本设定即可.

# 本文支持版本有 __all__ = ['repvit_m0_6','repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # 下面 [-1, 1, repvit_m0_6, [0.25]] 参数位置的0.25是通道放缩的系数, YOLOv11N是0.25 YOLOv11S是0.5 YOLOv11M是1. YOLOv11l是1 YOLOv111.5大家根据自己训练的YOLO版本设定即可.
  13. # 本文支持版本有 __all__ = ['repvit_m0_6','repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']
  14. # YOLO11n backbone
  15. backbone:
  16. # [from, repeats, module, args]
  17. - [-1, 1, repvit_m0_6, [0.5]] # 0-4 P1/2 这里是四层大家不要被yaml文件限制住了思维,不会画图进群看视频.
  18. - [-1, 1, SPPF, [1024, 5]] # 5
  19. - [-1, 2, C2PSA, [1024]] # 6
  20. # YOLO11n head
  21. head:
  22. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  23. - [[-1, 3], 1, Concat, [1]] # cat backbone P4
  24. - [-1, 2, C3k2, [512, False]] # 9
  25. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  26. - [[-1, 2], 1, Concat, [1]] # cat backbone P3
  27. - [-1, 2, C3k2, [256, False]] # 12 (P3/8-small)
  28. - [-1, 1, Conv, [256, 3, 2]]
  29. - [[-1, 9], 1, Concat, [1]] # cat head P4
  30. - [-1, 2, C3k2, [512, False]] # 15 (P4/16-medium)
  31. - [-1, 1, Conv, [512, 3, 2]]
  32. - [[-1, 6], 1, Concat, [1]] # cat head P5
  33. - [-1, 2, C3k2, [1024, True]] # 18 (P5/32-large)
  34. - [[12, 15, 18], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.2 训练文件

  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. from ultralytics import YOLO
  4. if __name__ == '__main__':
  5. model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
  6. # model.load('yolov8n.pt') # loading pretrain weights
  7. model.train(data=r'替换数据集yaml文件地址',
  8. # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
  9. cache=False,
  10. imgsz=640,
  11. epochs=150,
  12. single_cls=False, # 是否是单类别检测
  13. batch=4,
  14. close_mosaic=10,
  15. workers=0,
  16. device='0',
  17. optimizer='SGD', # using SGD
  18. # resume='', # 如过想续训就设置last.pt的地址
  19. amp=False, # 如果出现训练损失为Nan可以关闭amp
  20. project='runs/train',
  21. name='exp',
  22. )


六、成功运行记录

下面是成功运行的截图,已经完成了有1个epochs的训练,图片太大截不全第2个epochs,这里改完之后打印出了点问题,但是不影响任何功能,后期我找时间修复一下这个问题。

​​


七、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充, 目前本专栏免费阅读(暂时,大家尽早关注不迷路~) ,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

​​