YOLOv5改进系列(22)——替换主干网络之MobileViTv1(一种轻量级的、通用的移动设备 ViT)
🚀一、MobileViT v1介绍
- 论文题目:《MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer》
- 论文地址:https://arxiv.org/abs/2110.02178
- 源码地址:GitHub - apple/ml-cvnets: CVNets: A library for training computer vision networks

(Apple(⊙o⊙) )
1.1 简介
MobileViT网络是由苹果公司提出了一种轻量级的、通用的移动设备 vision transformer,将CNN和ViT的优势相结合,提高了在移动视觉任务中的性能。
以往的研究主要集中在轻量级卷积神经网络和自注意力ViTs,其中CNN具有局部感知性,参数较少,ViTs具有全局感知性,但参数较多。然而,这些方法在移动视觉任务中存在一些问题,如性能不够理想、延迟较高等。
本篇论文提出了MobileViT的研究方法,将transformers作为卷积的方式进行全局信息处理,实现了轻量级和低延迟的移动视觉任务网络。
研究结果表明,MobileViT在不同任务和数据集上明显优于基于CNN和ViT的网络。在ImageNet-1k数据集上取得了最佳结果。
1.2 网络结构

上面那个图展示就是标准视觉ViT模型,下面就是今天要介绍的MobileViT的网络结构。
主要由MV2和MobileViTblock两个模块组成,下面我们来介绍下这两个模块:
(1)MV2
MV2就是MobileNet v2(直通车:)里面Inverted Residual Block,即下面的图所示的结构。

图中MV2是当stride等于1时的MV2结构,上图中标有向下箭头的MV2结构代表stride等于2的情况,即需要进行下采样。
(2)MobileViTblock

首先将特征图通过一个卷积层,卷积核大小为n×n,然后再通过一个卷积核大小为1×1的卷积层进行通道调整。
接着依次通过Unfold、Transformer、Fold结构进行全局特征建模,然后再通过一个卷积核大小为1×1的卷积层将通道调整为原始大小。
接着通过shortcut捷径分支与原始输入特征图按通道concat拼接。
最后再通过一个卷积核大小为n×n的卷积层进行特征融合得到最终的输出。
1.3 实验
(1)和CNN对比

(2)和ViT对比

(3)移动端目标检测

(4)移动端实例分割

(5)移动设备的性能

🚀二、具体添加方法
第①步:在common.py中添加MobileViTv1模块
将以下代码复制粘贴到common.py文件的末尾
- from einops import rearrange
- class TAttention(nn.Module):
- def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
- super().__init__()
- inner_dim = dim_head * heads
- project_out = not (heads == 1 and dim_head == dim)
- self.heads = heads
- self.scale = dim_head ** -0.5
- self.attend = nn.Softmax(dim=-1)
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, dim),
- nn.Dropout(dropout)
- ) if project_out else nn.Identity()
- def forward(self, x):
- qkv = self.to_qkv(x).chunk(3, dim=-1)
- q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
- dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
- attn = self.attend(dots)
- out = torch.matmul(attn, v)
- out = rearrange(out, 'b p h n d -> b p n (h d)')
- return self.to_out(out)
- class MoblieTrans(nn.Module):
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
- super().__init__()
- self.layers = nn.ModuleList([])
- for _ in range(depth):
- self.layers.append(nn.ModuleList([
- PreNorm(dim, TAttention(dim, heads, dim_head, dropout)),
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
- ]))
- def forward(self, x):
- for attn, ff in self.layers:
- x = attn(x) + x
- x = ff(x) + x
- return x
- class MV2B(nn.Module):
- def __init__(self, ch_in, ch_out, stride=1, expansion=4):
- super().__init__()
- self.stride = stride
- assert stride in [1, 2]
- hidden_dim = int(ch_in * expansion)
- self.use_res_connect = self.stride == 1 and ch_in == ch_out
- if expansion == 1:
- self.conv = nn.Sequential(
- # dw
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pw-linear
- nn.Conv2d(hidden_dim, ch_out, 1, 1, 0, bias=False),
- nn.BatchNorm2d(ch_out),
- )
- else:
- self.conv = nn.Sequential(
- nn.Conv2d(ch_in, hidden_dim, 1, 1, 0, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- nn.Conv2d(hidden_dim, ch_out, 1, 1, 0, bias=False),
- nn.BatchNorm2d(ch_out),
- )
- def forward(self, x):
- if self.use_res_connect:
- return x + self.conv(x)
- else:
- return self.conv(x)
- class MobileViT_Block(nn.Module):
- def __init__(self, ch_in, dim=64, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64 * 2), dropout=0.):
- super().__init__() # mg
- self.ph, self.pw = patch_size
- self.conv1 = conv_nxn_bn(ch_in, ch_in, kernel_size)
- self.conv2 = conv_1x1_bn(ch_in, dim)
- self.transformer = MoblieTrans(dim, depth, 4, 8, mlp_dim, dropout)
- self.conv3 = conv_1x1_bn(dim, ch_in)
- self.conv4 = conv_nxn_bn(2 * ch_in, ch_in, kernel_size)
- def forward(self, x):
- y = x.clone() # torch.Size for
- # Local representations
- x = self.conv1(x)
- x = self.conv2(x) # torch.Size for
- # Global representations
- _, _, h, w = x.shape
- x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
- x = self.transformer(x)
- x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
- pw=self.pw)
- x = self.conv3(x)
- x = torch.cat((x, y), 1)
- x = self.conv4(x)
- return x
- def conv_1x1_bn(ch_in, ch_out):
- return nn.Sequential(
- nn.Conv2d(ch_in, ch_out, 1, 1, 0, bias=False),
- nn.BatchNorm2d(ch_out),
- nn.SiLU()
- )
- def conv_nxn_bn(ch_in, ch_out, kernal_size=3, stride=1):
- return nn.Sequential(
- nn.Conv2d(ch_in, ch_out, kernal_size, stride, 1, bias=False),
- nn.BatchNorm2d(ch_out),
- nn.SiLU()
- )
- class PreNorm(nn.Module):
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
- def forward(self, x, **kwargs):
- return self.fn(self.norm(x), **kwargs)
- class FeedForward(nn.Module):
- def __init__(self, dim, hidden_dim, dropout=0.):
- super().__init__()
- self.net = nn.Sequential(
- nn.Linear(dim, hidden_dim),
- nn.SiLU(),
- nn.Dropout(dropout),
- nn.Linear(hidden_dim, dim),
- nn.Dropout(dropout)
- )
- def forward(self, x):
- return self.net(x)
第②步:修改yolo.py文件
再来修改yolo.py,在parse_model函数中找到 elif m is Concat: 语句,在其后面加上下面代码:
- # mobilevit v1
- elif m in [MobileViT_Block]:
- c1, c2 = ch[f], args[0]
- if c2 != no: # if not outputss
- c2 = make_divisible(c2 * gw, 8)
- args = [c1, c2, *args[1:]]
如下图所示:

第③步:创建自定义的yaml文件
yaml文件配置完整代码如下:
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
- # Parameters
- nc: 80 # number of classes
- depth_multiple: 0.33 # model depth multiple
- width_multiple: 0.50 # layer channel multiple
- anchors:
- - [10,13, 16,30, 33,23] # P3/8
- - [30,61, 62,45, 59,119] # P4/16
- - [116,90, 156,198, 373,326] # P5/32
- # YOLOv5 v6.0 backbone
- backbone:
- # [from, number, module, args]
- [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
- [-1, 3, C3, [128]],
- [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
- [-1, 6, C3, [256]],
- [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
- [-1, 9, C3, [512]],
- [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
- [-1, 3, C3, [1024]],
- [-1, 1, SPPF, [1024, 5]], # 9
- ]
- # YOLOv5 v6.0 head
- head:
- [[-1, 1, Conv, [512, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [[-1, 6], 1, Concat, [1]], # cat backbone P4
- [-1, 3, C3, [512, False]], # 13
- [-1, 1, Conv, [256, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [[-1, 4], 1, Concat, [1]], # cat backbone P3
- [-1, 3, C3, [256, False]], # 17 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]],
- [[-1, 14], 1, Concat, [1]], # cat head P4
- [-1, 3, MobileViT_Block, [512, False]], # 20 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]],
- [[-1, 10], 1, Concat, [1]], # cat head P5
- [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
- [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
- ]
第④步 验证是否加入成功
运行yolo.py

这样就OK啦!