学习资源站

25-替换主干网络之MobileViTv3(移动端轻量化网络的进一步升级)_yolov5 vit

YOLOv5改进系列(24)——替换主干网络之MobileViTv3(移动端轻量化网络的进一步升级)

🚀一、MobileViT v3介绍    

        在之前的研究中,CNN模型足够轻量化但是精准度有待提高ViT模型具有较好的识别能力但是模型参数量大,计算复杂,都不满足移动端实时高效检测的需求。

       MobileViT模型是2021年苹果公司提出的基于轻量化的ViT模型,该模型既具备ViT模型准确检测的优越性能也具备CNN模型的轻量化优点,能极大程度上减少模型参数,对移动端友好,具备部署于移动设备的可能性。 

      MobileViT v3是该公司2022年9月推出的第3个版本,该模型相较于初始版本有以下四个改进:

  • 首先,将3×3卷积层替换为1×1卷积层;
  • 第二,将局部表示块和全局表示块的特征融合在一起;
  • 第三,在生成MobileViT Block输出之前,在融合块中添加输入特征作为最后一步
  • 最后,在局部表示块中,将普通的3×3卷积层替换为深度3×3卷积层

MobileViTv3网络结构图如下所示:

模型预测处理过程如下:

(1)将输入图像连接3×3标准卷积并做2倍下采样;之后通过5个MV2模块(如图(a)所示),其中步长为1的MV2模块进行特征提取,步长为2的MV2模块做2倍下采样;

(2)将得到的特征图间隔传入MobileViTV3 Block(如图(b)所示)和步长为2的MV2模块

(3)接着使用3×3标准卷积进行通道压缩;

(4)最后进行全局平均池化来获取预测结果 。


MobileViT v3 Block模块是MobileViT v3核心部分,由局部表征模块全局表征模块融合模块三部分组成,具体介绍如下:

(1)局部表征模块 

对于输入的特征图像X\in R^{H\times W\times C}HW为图像的高、宽;C为输入特征图像的通道) ,首先采用一个3×3的深度卷积层和1×1的卷积层得到输出X_{L}\in R^{H\times W\times d}d为输出特征图像的通道)。经过局部表征模块可以将特征图像X的局部空间信息映射到特定的维度d中。

(2)全局表征模块

将局部表征模块中的输出X_{L}作为全局表征模块的输入,然后将X_{L}展平为无重叠的图片块X_{U}\in R^{P\times N\times d}P为展平图片块的大小,P=w\times hwh分别为展平图片块的宽、高,N为图片块的个数,N=W\times H/P)。随后,将每个图片块上相同位置上的像素p\in \left \{ 1,...,P \right \}送入Transformer模块中进行编码,得到X_{G}\in R^{P\times N\times d}。 为了避免图片块之间的信息丢失,在全局表征模块的最后会将编码后的所有图片块重组还原,得到X_{F}\in R^{H\times W\times d},然后输出到下一个模块。 

(3)融合模块

X_{F}送入特征融合模块中,使其映射到一个低维空间,得到X^{'}\in R^{H\times ×W\times ×C}。 然后通过拼接操作,将局部表征模块的输出X_{L}X^{'}拼接起来,得到X^{''}\in R^{H\times ×W\times ×2C}。接着采用1×1的卷积层来融合局部特征X^{'}与全局特征X^{''},得到X^{*}\in R^{H\times ×W\times ×2C}。最后把刚获取的X^{*}和原始X进行相加操作,得到输出Y\in R^{H\times W\times C}

该过程需要伪代码可以私聊我~


🚀二、具体添加方法 

第①步:在common.py中添加MobileViT v3模块

首先,定义卷积层。

分为1×1卷积层和n×n(n=3)卷积层

  1. def conv_1x1_bn(inp, oup):
  2. return nn.Sequential(
  3. nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
  4. nn.BatchNorm2d(oup),
  5. nn.SiLU()
  6. )
  7. def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
  8. return nn.Sequential(
  9. nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
  10. nn.BatchNorm2d(oup),
  11. nn.SiLU()
  12. )

 接着,构造ViT模块。

Transformer Encoder模块中编码

  1. class PreNorm(nn.Module):
  2. def __init__(self, dim, fn):
  3. super().__init__()
  4. self.norm = nn.LayerNorm(dim)
  5. self.fn = fn # mg
  6. def forward(self, x, **kwargs):
  7. return self.fn(self.norm(x), **kwargs)
  8. class Attention(nn.Module):
  9. def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
  10. super().__init__()
  11. inner_dim = dim_head * heads
  12. project_out = not (heads == 1 and dim_head == dim)
  13. self.heads = heads
  14. self.scale = dim_head ** -0.5
  15. self.attend = nn.Softmax(dim = -1)
  16. self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
  17. self.to_out = nn.Sequential(
  18. nn.Linear(inner_dim, dim),
  19. nn.Dropout(dropout)# mg
  20. ) if project_out else nn.Identity()
  21. def forward(self, x):
  22. qkv = self.to_qkv(x).chunk(3, dim=-1)
  23. q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)
  24. dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
  25. attn = self.attend(dots)
  26. out = torch.matmul(attn, v)
  27. out = rearrange(out, 'b p h n d -> b p n (h d)')
  28. return self.to_out(out)
  29. class FeedForward(nn.Module):
  30. def __init__(self, dim, hidden_dim, dropout=0.):
  31. super().__init__()
  32. self.net = nn.Sequential(
  33. nn.Linear(dim, hidden_dim),
  34. nn.SiLU(),
  35. nn.Dropout(dropout),
  36. nn.Linear(hidden_dim, dim),
  37. nn.Dropout(dropout)
  38. )
  39. def forward(self, x):
  40. return self.net(x)
  41. class MBTransformer(nn.Module):
  42. def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
  43. super().__init__()
  44. self.layers = nn.ModuleList([])
  45. for _ in range(depth):
  46. self.layers.append(nn.ModuleList([
  47. PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
  48. PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
  49. ]))
  50. def forward(self, x):
  51. for attn, ff in self.layers:
  52. x = attn(x) + x
  53. x = ff(x) + x
  54. return x

 然后,MV2模块。

分为stride=1和stride=2两种。

  1. class MV2Block(nn.Module):
  2. def __init__(self, inp, oup, stride=1, expansion=4):
  3. super().__init__()
  4. self.stride = stride
  5. assert stride in [1, 2]
  6. hidden_dim = int(inp * expansion)
  7. self.use_res_connect = self.stride == 1 and inp == oup
  8. if expansion == 1:
  9. self.conv = nn.Sequential(
  10. # dw
  11. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  12. nn.BatchNorm2d(hidden_dim),
  13. nn.SiLU(),
  14. # pw-linear
  15. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  16. nn.BatchNorm2d(oup),
  17. )
  18. else:
  19. self.conv = nn.Sequential(
  20. # pw
  21. nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
  22. nn.BatchNorm2d(hidden_dim),
  23. nn.SiLU(),
  24. # dw
  25. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  26. nn.BatchNorm2d(hidden_dim),
  27. nn.SiLU(),
  28. # pw-linear
  29. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  30. nn.BatchNorm2d(oup),
  31. )
  32. def forward(self, x):
  33. if self.use_res_connect:
  34. return x + self.conv(x)
  35. else:
  36. return self.conv(x)

 最后,核心模块 MobileViTv3_block。

介绍部分看上面就行~

  1. class MobileViTv3_block(nn.Module):
  2. def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
  3. super().__init__()
  4. self.ph, self.pw = patch_size
  5. self.mv01 = MV2Block(channel, channel)
  6. self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
  7. self.conv3 = conv_1x1_bn(dim, channel)
  8. self.conv2 = conv_1x1_bn(channel, dim)
  9. self.transformer = MBTransformer(dim, depth, 4, 8, mlp_dim, dropout)
  10. self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
  11. def forward(self, x):
  12. y = x.clone()
  13. x = self.conv1(x)
  14. x = self.conv2(x)
  15. z = x.clone()
  16. _, _, h, w = x.shape
  17. x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
  18. x = self.transformer(x)
  19. x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
  20. x = self.conv3(x)
  21. x = torch.cat((x, z), 1)
  22. x = self.conv4(x)
  23. x = x + y
  24. x = self.mv01(x)
  25. return x

以下是完整代码:

将以下代码复制粘贴到common.py文件的末尾

  1. from einops import rearrange
  2. def conv_1x1_bn(inp, oup):
  3. return nn.Sequential(
  4. nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
  5. nn.BatchNorm2d(oup),
  6. nn.SiLU()
  7. )
  8. def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
  9. return nn.Sequential(
  10. nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
  11. nn.BatchNorm2d(oup),
  12. nn.SiLU()
  13. )
  14. class PreNorm(nn.Module):
  15. def __init__(self, dim, fn):
  16. super().__init__()
  17. self.norm = nn.LayerNorm(dim)
  18. self.fn = fn # mg
  19. def forward(self, x, **kwargs):
  20. return self.fn(self.norm(x), **kwargs)
  21. class Attention(nn.Module):
  22. def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
  23. super().__init__()
  24. inner_dim = dim_head * heads
  25. project_out = not (heads == 1 and dim_head == dim)
  26. self.heads = heads
  27. self.scale = dim_head ** -0.5
  28. self.attend = nn.Softmax(dim = -1)
  29. self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
  30. self.to_out = nn.Sequential(
  31. nn.Linear(inner_dim, dim),
  32. nn.Dropout(dropout)# mg
  33. ) if project_out else nn.Identity()
  34. def forward(self, x):
  35. qkv = self.to_qkv(x).chunk(3, dim=-1)
  36. q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)
  37. dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
  38. attn = self.attend(dots)
  39. out = torch.matmul(attn, v)
  40. out = rearrange(out, 'b p h n d -> b p n (h d)')
  41. return self.to_out(out)
  42. class FeedForward(nn.Module):
  43. def __init__(self, dim, hidden_dim, dropout=0.):
  44. super().__init__()
  45. self.net = nn.Sequential(
  46. nn.Linear(dim, hidden_dim),
  47. nn.SiLU(),
  48. nn.Dropout(dropout),
  49. nn.Linear(hidden_dim, dim),
  50. nn.Dropout(dropout)
  51. )
  52. def forward(self, x):
  53. return self.net(x)
  54. class MBTransformer(nn.Module):
  55. def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
  56. super().__init__()
  57. self.layers = nn.ModuleList([])
  58. for _ in range(depth):
  59. self.layers.append(nn.ModuleList([
  60. PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
  61. PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
  62. ]))
  63. def forward(self, x):
  64. for attn, ff in self.layers:
  65. x = attn(x) + x
  66. x = ff(x) + x
  67. return x
  68. class MV2Block(nn.Module):
  69. def __init__(self, inp, oup, stride=1, expansion=4):
  70. super().__init__()
  71. self.stride = stride
  72. assert stride in [1, 2]
  73. hidden_dim = int(inp * expansion)
  74. self.use_res_connect = self.stride == 1 and inp == oup
  75. if expansion == 1:
  76. self.conv = nn.Sequential(
  77. # dw
  78. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  79. nn.BatchNorm2d(hidden_dim),
  80. nn.SiLU(),
  81. # pw-linear
  82. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  83. nn.BatchNorm2d(oup),
  84. )
  85. else:
  86. self.conv = nn.Sequential(
  87. # pw
  88. nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
  89. nn.BatchNorm2d(hidden_dim),
  90. nn.SiLU(),
  91. # dw
  92. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  93. nn.BatchNorm2d(hidden_dim),
  94. nn.SiLU(),
  95. # pw-linear
  96. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  97. nn.BatchNorm2d(oup),
  98. )
  99. def forward(self, x):
  100. if self.use_res_connect:
  101. return x + self.conv(x)
  102. else:
  103. return self.conv(x)
  104. class MobileViTv3_block(nn.Module):
  105. def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
  106. super().__init__()
  107. self.ph, self.pw = patch_size
  108. self.mv01 = MV2Block(channel, channel)
  109. self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
  110. self.conv3 = conv_1x1_bn(dim, channel)
  111. self.conv2 = conv_1x1_bn(channel, dim)
  112. self.transformer = MBTransformer(dim, depth, 4, 8, mlp_dim, dropout)
  113. self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
  114. def forward(self, x):
  115. y = x.clone()
  116. x = self.conv1(x)
  117. x = self.conv2(x)
  118. z = x.clone()
  119. _, _, h, w = x.shape
  120. x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
  121. x = self.transformer(x)
  122. x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
  123. x = self.conv3(x)
  124. x = torch.cat((x, z), 1)
  125. x = self.conv4(x)
  126. x = x + y
  127. x = self.mv01(x)
  128. return x

第②步:修改yolo.py文件

再来修改yolo.py,在parse_model函数中找到 elif m is Concat: 语句,在其后面加上下面代码:

  1. elif m in [MobileViTv3_block]:
  2. c1, c2 = ch[f], args[0]
  3. if c2 != no:
  4. c2 = make_divisible(c2 * gw, 8)
  5. args = [c1, c2]
  6. if m in [MobileViTv3_block]:
  7. args.insert(2, n)
  8. n = 1

 第③步:创建自定义的yaml文件  

yaml文件配置完整代码如下:

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.50 # layer channel iscyy multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, MobileViTv3_block, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 9
  23. ]
  24. # YOLOv5 v6.0 head
  25. head:
  26. [[-1, 1, Conv, [512, 1, 1]],
  27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  29. [-1, 3, C3, [512, False]], # 13
  30. [-1, 1, Conv, [256, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  33. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  34. [-1, 1, Conv, [256, 3, 2]],
  35. [[-1, 14], 1, Concat, [1]], # cat head P4
  36. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
  37. [-1, 1, Conv, [512, 3, 2]],
  38. [[-1, 10], 1, Concat, [1]], # cat head P5
  39. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  40. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  41. ]

第④步 验证是否加入成功

运行yolo.py

 这样就OK啦~