学习资源站

22-替换主干网络之RepViT(清华 ICCV 2023_最新开源移动端ViT)_改进yolov5系列_21t添加

YOLOv5改进系列(21)——替换主干网络之RepViT(清华 ICCV 2023|最新开源移动端ViT)

🚀一、RepViT介绍 


1.1 简介

近年来,与轻量级卷积神经网络(cnn)相比,轻量级视觉变压器(ViTs)在资源受限的移动设备上表现出了更高的性能和更低的延迟。这种改进通常归功于多头自注意模块,它使模型能够学习全局表示。然而,轻量级vit和轻量级cnn之间的架构差异还没有得到充分的研究。

本文重点探讨了在资源有限的移动设备上,通过重新审视轻量级卷积神经网络的设计,并整合轻量级 ViTs 的有效架构选择,来提升轻量级 CNNs 的性能,即RepViT

大量的实验表明,RepViT优于现有的轻型ViT,并在各种视觉任务中表现出良好的延迟。在ImageNet上,RepViT在iPhone 12上以近1ms的延迟实现了超过80%的top-1精度,据我们所知,这是轻量级模型的第一次。


🚀二、具体添加方法 

第①步:在common.py中添加RepViT模块

将以下代码复制粘贴到common.py文件的末尾

  1. class Conv2d_BN(torch.nn.Sequential):
  2. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
  3. groups=1, bn_weight_init=1, resolution=-10000):
  4. super().__init__()
  5. self.add_module('c', torch.nn.Conv2d(
  6. a, b, ks, stride, pad, dilation, groups, bias=False))
  7. self.add_module('bn', torch.nn.BatchNorm2d(b))
  8. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  9. torch.nn.init.constant_(self.bn.bias, 0)
  10. @torch.no_grad()
  11. def fuse(self):
  12. c, bn = self._modules.values()
  13. w = bn.weight / (bn.running_var + bn.eps)**0.5
  14. w = c.weight * w[:, None, None, None]
  15. b = bn.bias - bn.running_mean * bn.weight / \
  16. (bn.running_var + bn.eps)**0.5
  17. m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
  18. 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
  19. device=c.weight.device)
  20. m.weight.data.copy_(w)
  21. m.bias.data.copy_(b)
  22. return m
  23. class Residual(torch.nn.Module):
  24. def __init__(self, m, drop=0.):
  25. super().__init__()
  26. self.m = m
  27. self.drop = drop
  28. def forward(self, x):
  29. if self.training and self.drop > 0:
  30. return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
  31. device=x.device).ge_(self.drop).div(1 - self.drop).detach()
  32. else:
  33. return x + self.m(x)
  34. @torch.no_grad()
  35. def fuse(self):
  36. if isinstance(self.m, Conv2d_BN):
  37. m = self.m.fuse()
  38. assert(m.groups == m.in_channels)
  39. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  40. identity = torch.nn.functional.pad(identity, [1,1,1,1])
  41. m.weight += identity.to(m.weight.device)
  42. return m
  43. elif isinstance(self.m, torch.nn.Conv2d):
  44. m = self.m
  45. assert(m.groups != m.in_channels)
  46. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  47. identity = torch.nn.functional.pad(identity, [1,1,1,1])
  48. m.weight += identity.to(m.weight.device)
  49. return m
  50. else:
  51. return self
  52. class RepVGGDW(torch.nn.Module):
  53. def __init__(self, ed) -> None:
  54. super().__init__()
  55. self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
  56. self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed)
  57. self.dim = ed
  58. def forward(self, x):
  59. return self.conv(x) + self.conv1(x) + x
  60. @torch.no_grad()
  61. def fuse(self):
  62. conv = self.conv.fuse()
  63. conv1 = self.conv1.fuse()
  64. conv_w = conv.weight
  65. conv_b = conv.bias
  66. conv1_w = conv1.weight
  67. conv1_b = conv1.bias
  68. conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
  69. identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
  70. final_conv_w = conv_w + conv1_w + identity
  71. final_conv_b = conv_b + conv1_b
  72. conv.weight.data.copy_(final_conv_w)
  73. conv.bias.data.copy_(final_conv_b)
  74. return conv
  75. class RepViTBlock(nn.Module):
  76. def __init__(self,in1, inp, hidden_dim, oup, kernel_size=3, stride=2, use_se=0, use_hs=0):
  77. super(RepViTBlock, self).__init__()
  78. assert stride in [1, 2]
  79. self.identity = stride == 1 and inp == oup
  80. print(inp)
  81. print(hidden_dim)
  82. print(oup)
  83. assert(hidden_dim == 2 * inp)
  84. if stride == 2:
  85. self.token_mixer = nn.Sequential(
  86. Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
  87. SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
  88. Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
  89. )
  90. self.channel_mixer = Residual(nn.Sequential(
  91. # pw
  92. Conv2d_BN(oup, 2 * oup, 1, 1, 0),
  93. nn.GELU() if use_hs else nn.GELU(),
  94. # pw-linear
  95. Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
  96. ))
  97. else:
  98. assert(self.identity)
  99. self.token_mixer = nn.Sequential(
  100. RepVGGDW(inp),
  101. SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
  102. )
  103. self.channel_mixer = Residual(nn.Sequential(
  104. # pw
  105. Conv2d_BN(inp, hidden_dim, 1, 1, 0),
  106. nn.GELU() if use_hs else nn.GELU(),
  107. # pw-linear
  108. Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
  109. ))
  110. def forward(self, x):
  111. return self.channel_mixer(self.token_mixer(x))

如下图所示:


第②步:修改yolo.py文件

首先找到yolo.py里面parse_model函数的这一行

加入 RepViTBlock这个模块


 第③步:创建自定义的yaml文件  

yaml文件完整代码:

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 1 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.50 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # CSPNet-v5
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, C3, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 9
  23. ]
  24. # YOLOv5 v6.0 head
  25. head:
  26. [[-1, 1, Conv, [512, 1, 1]],
  27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  29. [-1, 3, C3, [512, False]], # 13
  30. [-1, 1, Conv, [256, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  33. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  34. [-1, 1, Conv, [256, 3, 2]],
  35. [[-1, 14], 1, Concat, [1]], # cat head P4
  36. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
  37. [-1, 1, Conv, [512, 3, 2]],
  38. [[-1, 10], 1, Concat, [1]], # cat head P5
  39. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  40. [-1, 1, RepViTBlock, [1024,1024,512]], # 23 (P5/32-large)
  41. [[17, 20, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  42. ]

第④步:验证是否加入成功

运行yolo.py  


PS: 

这个也是,对我的数据集没涨点,而且我没觉得参数量小了。( ╯#-_-)╯┴—┴