学习资源站

23-替换主干网络之MobileViTv1(一种轻量级的、通用的移动设备 ViT)_yolov5 backbone 更改为mobilevit

YOLOv5改进系列(22)——替换主干网络之MobileViTv1(一种轻量级的、通用的移动设备 ViT)

🚀一、MobileViT v1介绍  

 (Apple(⊙o⊙) )


1.1 简介

MobileViT网络是由苹果公司提出了一种轻量级的通用的移动设备 vision transformer,将CNN和ViT的优势相结合,提高了在移动视觉任务中的性能。

以往的研究主要集中在轻量级卷积神经网络自注意力ViTs,其中CNN具有局部感知性,参数较少,ViTs具有全局感知性,但参数较多。然而,这些方法在移动视觉任务中存在一些问题,如性能不够理想、延迟较高等。

本篇论文提出了MobileViT的研究方法,将transformers作为卷积的方式进行全局信息处理,实现了轻量级和低延迟的移动视觉任务网络。

研究结果表明,MobileViT在不同任务和数据集上明显优于基于CNN和ViT的网络。在ImageNet-1k数据集上取得了最佳结果。


1.2 网络结构

上面那个图展示就是标准视觉ViT模型,下面就是今天要介绍的MobileViT的网络结构。

主要由MV2MobileViTblock两个模块组成,下面我们来介绍下这两个模块:

(1)MV2 

MV2就是MobileNet v2(直通车:)里面Inverted Residual Block,即下面的图所示的结构。

图中MV2是当stride等于1时的MV2结构,上图中标有向下箭头的MV2结构代表stride等于2的情况,即需要进行下采样。


(2)MobileViTblock

首先将特征图通过一个卷积层,卷积核大小为n×n,然后再通过一个卷积核大小为1×1的卷积层进行通道调整。

接着依次通过Unfold、Transformer、Fold结构进行全局特征建模,然后再通过一个卷积核大小为1×1的卷积层将通道调整为原始大小。

接着通过shortcut捷径分支与原始输入特征图按通道concat拼接。

最后再通过一个卷积核大小为n×n的卷积层进行特征融合得到最终的输出。


1.3 实验

(1)和CNN对比

(2)和ViT对比

(3)移动端目标检测 

(4)移动端实例分割 

(5)移动设备的性能 


🚀二、具体添加方法 

第①步:在common.py中添加MobileViTv1模块

将以下代码复制粘贴到common.py文件的末尾

  1. from einops import rearrange
  2. class TAttention(nn.Module):
  3. def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
  4. super().__init__()
  5. inner_dim = dim_head * heads
  6. project_out = not (heads == 1 and dim_head == dim)
  7. self.heads = heads
  8. self.scale = dim_head ** -0.5
  9. self.attend = nn.Softmax(dim=-1)
  10. self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
  11. self.to_out = nn.Sequential(
  12. nn.Linear(inner_dim, dim),
  13. nn.Dropout(dropout)
  14. ) if project_out else nn.Identity()
  15. def forward(self, x):
  16. qkv = self.to_qkv(x).chunk(3, dim=-1)
  17. q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
  18. dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
  19. attn = self.attend(dots)
  20. out = torch.matmul(attn, v)
  21. out = rearrange(out, 'b p h n d -> b p n (h d)')
  22. return self.to_out(out)
  23. class MoblieTrans(nn.Module):
  24. def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
  25. super().__init__()
  26. self.layers = nn.ModuleList([])
  27. for _ in range(depth):
  28. self.layers.append(nn.ModuleList([
  29. PreNorm(dim, TAttention(dim, heads, dim_head, dropout)),
  30. PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
  31. ]))
  32. def forward(self, x):
  33. for attn, ff in self.layers:
  34. x = attn(x) + x
  35. x = ff(x) + x
  36. return x
  37. class MV2B(nn.Module):
  38. def __init__(self, ch_in, ch_out, stride=1, expansion=4):
  39. super().__init__()
  40. self.stride = stride
  41. assert stride in [1, 2]
  42. hidden_dim = int(ch_in * expansion)
  43. self.use_res_connect = self.stride == 1 and ch_in == ch_out
  44. if expansion == 1:
  45. self.conv = nn.Sequential(
  46. # dw
  47. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  48. nn.BatchNorm2d(hidden_dim),
  49. nn.SiLU(),
  50. # pw-linear
  51. nn.Conv2d(hidden_dim, ch_out, 1, 1, 0, bias=False),
  52. nn.BatchNorm2d(ch_out),
  53. )
  54. else:
  55. self.conv = nn.Sequential(
  56. nn.Conv2d(ch_in, hidden_dim, 1, 1, 0, bias=False),
  57. nn.BatchNorm2d(hidden_dim),
  58. nn.SiLU(),
  59. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  60. nn.BatchNorm2d(hidden_dim),
  61. nn.SiLU(),
  62. nn.Conv2d(hidden_dim, ch_out, 1, 1, 0, bias=False),
  63. nn.BatchNorm2d(ch_out),
  64. )
  65. def forward(self, x):
  66. if self.use_res_connect:
  67. return x + self.conv(x)
  68. else:
  69. return self.conv(x)
  70. class MobileViT_Block(nn.Module):
  71. def __init__(self, ch_in, dim=64, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64 * 2), dropout=0.):
  72. super().__init__() # mg
  73. self.ph, self.pw = patch_size
  74. self.conv1 = conv_nxn_bn(ch_in, ch_in, kernel_size)
  75. self.conv2 = conv_1x1_bn(ch_in, dim)
  76. self.transformer = MoblieTrans(dim, depth, 4, 8, mlp_dim, dropout)
  77. self.conv3 = conv_1x1_bn(dim, ch_in)
  78. self.conv4 = conv_nxn_bn(2 * ch_in, ch_in, kernel_size)
  79. def forward(self, x):
  80. y = x.clone() # torch.Size for
  81. # Local representations
  82. x = self.conv1(x)
  83. x = self.conv2(x) # torch.Size for
  84. # Global representations
  85. _, _, h, w = x.shape
  86. x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
  87. x = self.transformer(x)
  88. x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
  89. pw=self.pw)
  90. x = self.conv3(x)
  91. x = torch.cat((x, y), 1)
  92. x = self.conv4(x)
  93. return x
  94. def conv_1x1_bn(ch_in, ch_out):
  95. return nn.Sequential(
  96. nn.Conv2d(ch_in, ch_out, 1, 1, 0, bias=False),
  97. nn.BatchNorm2d(ch_out),
  98. nn.SiLU()
  99. )
  100. def conv_nxn_bn(ch_in, ch_out, kernal_size=3, stride=1):
  101. return nn.Sequential(
  102. nn.Conv2d(ch_in, ch_out, kernal_size, stride, 1, bias=False),
  103. nn.BatchNorm2d(ch_out),
  104. nn.SiLU()
  105. )
  106. class PreNorm(nn.Module):
  107. def __init__(self, dim, fn):
  108. super().__init__()
  109. self.norm = nn.LayerNorm(dim)
  110. self.fn = fn
  111. def forward(self, x, **kwargs):
  112. return self.fn(self.norm(x), **kwargs)
  113. class FeedForward(nn.Module):
  114. def __init__(self, dim, hidden_dim, dropout=0.):
  115. super().__init__()
  116. self.net = nn.Sequential(
  117. nn.Linear(dim, hidden_dim),
  118. nn.SiLU(),
  119. nn.Dropout(dropout),
  120. nn.Linear(hidden_dim, dim),
  121. nn.Dropout(dropout)
  122. )
  123. def forward(self, x):
  124. return self.net(x)

第②步:修改yolo.py文件

再来修改yolo.py,在parse_model函数中找到 elif m is Concat: 语句,在其后面加上下面代码:

  1. # mobilevit v1
  2. elif m in [MobileViT_Block]:
  3. c1, c2 = ch[f], args[0]
  4. if c2 != no: # if not outputss
  5. c2 = make_divisible(c2 * gw, 8)
  6. args = [c1, c2, *args[1:]]

如下图所示:


第③步:创建自定义的yaml文件  

yaml文件配置完整代码如下:

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.50 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, C3, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 9
  23. ]
  24. # YOLOv5 v6.0 head
  25. head:
  26. [[-1, 1, Conv, [512, 1, 1]],
  27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  29. [-1, 3, C3, [512, False]], # 13
  30. [-1, 1, Conv, [256, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  33. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  34. [-1, 1, Conv, [256, 3, 2]],
  35. [[-1, 14], 1, Concat, [1]], # cat head P4
  36. [-1, 3, MobileViT_Block, [512, False]], # 20 (P4/16-medium)
  37. [-1, 1, Conv, [512, 3, 2]],
  38. [[-1, 10], 1, Concat, [1]], # cat head P5
  39. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  40. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  41. ]

第④步 验证是否加入成功

运行yolo.py

这样就OK啦!