YOLOv5改进系列(24)——替换主干网络之MobileViTv3(移动端轻量化网络的进一步升级)
🚀一、MobileViT v3介绍
- 论文题目:《MOBILEVITV3: MOBILE-FRIENDLY VISION TRANS-
FORMER WITH SIMPLE AND EFFECTIVE FUSION OF LOCAL, GLOBAL AND INPUT FEATURES》- 论文原文: https://arxiv.org/abs/2209.15159
- 源码地址:https://github.com/micronDLA/MobileViTv3
- 论文精读:
在之前的研究中,CNN模型足够轻量化但是精准度有待提高,ViT模型具有较好的识别能力但是模型参数量大,计算复杂,都不能满足移动端实时高效检测的需求。
MobileViT模型是2021年苹果公司提出的基于轻量化的ViT模型,该模型既具备ViT模型准确检测的优越性能,也具备CNN模型的轻量化优点,能极大程度上减少模型参数,对移动端友好,具备部署于移动设备的可能性。
MobileViT v3是该公司2022年9月推出的第3个版本,该模型相较于初始版本有以下四个改进:
- 首先,将3×3卷积层替换为1×1卷积层;
- 第二,将局部表示块和全局表示块的特征融合在一起;
- 第三,在生成MobileViT Block输出之前,在融合块中添加输入特征作为最后一步;
- 最后,在局部表示块中,将普通的3×3卷积层替换为深度3×3卷积层。
MobileViTv3网络结构图如下所示:

模型预测处理过程如下:
(1)将输入图像连接3×3标准卷积并做2倍下采样;之后通过5个MV2模块(如图(a)所示),其中步长为1的MV2模块进行特征提取,步长为2的MV2模块做2倍下采样;
(2)将得到的特征图间隔传入MobileViTV3 Block(如图(b)所示)和步长为2的MV2模块;
(3)接着使用3×3标准卷积进行通道压缩;
(4)最后进行全局平均池化来获取预测结果 。
MobileViT v3 Block模块是MobileViT v3核心部分,由局部表征模块、全局表征模块、融合模块三部分组成,具体介绍如下:
(1)局部表征模块
对于输入的特征图像
(
、
为图像的高、宽;
为输入特征图像的通道) ,首先采用一个3×3的深度卷积层和1×1的卷积层得到输出
(
为输出特征图像的通道)。经过局部表征模块可以将特征图像
的局部空间信息映射到特定的维度
中。
(2)全局表征模块
将局部表征模块中的输出
作为全局表征模块的输入,然后将
展平为无重叠的图片块
(
为展平图片块的大小,
,
、
分别为展平图片块的宽、高,
为图片块的个数,
)。随后,将每个图片块上相同位置上的像素
送入Transformer模块中进行编码,得到
。 为了避免图片块之间的信息丢失,在全局表征模块的最后会将编码后的所有图片块重组还原,得到
,然后输出到下一个模块。
(3)融合模块
将
送入特征融合模块中,使其映射到一个低维空间,得到
。 然后通过拼接操作,将局部表征模块的输出
与
拼接起来,得到
。接着采用1×1的卷积层来融合局部特征
与全局特征
,得到
。最后把刚获取的
和原始
进行相加操作,得到输出
。
该过程需要伪代码可以私聊我~
🚀二、具体添加方法
第①步:在common.py中添加MobileViT v3模块
首先,定义卷积层。
分为1×1卷积层和n×n(n=3)卷积层
- def conv_1x1_bn(inp, oup):
- return nn.Sequential(
- nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- nn.SiLU()
- )
- def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
- return nn.Sequential(
- nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
- nn.BatchNorm2d(oup),
- nn.SiLU()
- )
接着,构造ViT模块。
Transformer Encoder模块中编码
- class PreNorm(nn.Module):
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn # mg
- def forward(self, x, **kwargs):
- return self.fn(self.norm(x), **kwargs)
- class Attention(nn.Module):
- def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
- super().__init__()
- inner_dim = dim_head * heads
- project_out = not (heads == 1 and dim_head == dim)
- self.heads = heads
- self.scale = dim_head ** -0.5
- self.attend = nn.Softmax(dim = -1)
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, dim),
- nn.Dropout(dropout)# mg
- ) if project_out else nn.Identity()
- def forward(self, x):
- qkv = self.to_qkv(x).chunk(3, dim=-1)
- q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)
- dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
- attn = self.attend(dots)
- out = torch.matmul(attn, v)
- out = rearrange(out, 'b p h n d -> b p n (h d)')
- return self.to_out(out)
- class FeedForward(nn.Module):
- def __init__(self, dim, hidden_dim, dropout=0.):
- super().__init__()
- self.net = nn.Sequential(
- nn.Linear(dim, hidden_dim),
- nn.SiLU(),
- nn.Dropout(dropout),
- nn.Linear(hidden_dim, dim),
- nn.Dropout(dropout)
- )
- def forward(self, x):
- return self.net(x)
- class MBTransformer(nn.Module):
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
- super().__init__()
- self.layers = nn.ModuleList([])
- for _ in range(depth):
- self.layers.append(nn.ModuleList([
- PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
- ]))
- def forward(self, x):
- for attn, ff in self.layers:
- x = attn(x) + x
- x = ff(x) + x
- return x
然后,MV2模块。

分为stride=1和stride=2两种。
- class MV2Block(nn.Module):
- def __init__(self, inp, oup, stride=1, expansion=4):
- super().__init__()
- self.stride = stride
- assert stride in [1, 2]
- hidden_dim = int(inp * expansion)
- self.use_res_connect = self.stride == 1 and inp == oup
- if expansion == 1:
- self.conv = nn.Sequential(
- # dw
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pw-linear
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- else:
- self.conv = nn.Sequential(
- # pw
- nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # dw
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pw-linear
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- def forward(self, x):
- if self.use_res_connect:
- return x + self.conv(x)
- else:
- return self.conv(x)
最后,核心模块 MobileViTv3_block。
介绍部分看上面就行~
- class MobileViTv3_block(nn.Module):
- def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
- super().__init__()
- self.ph, self.pw = patch_size
- self.mv01 = MV2Block(channel, channel)
- self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
- self.conv3 = conv_1x1_bn(dim, channel)
- self.conv2 = conv_1x1_bn(channel, dim)
- self.transformer = MBTransformer(dim, depth, 4, 8, mlp_dim, dropout)
- self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
- def forward(self, x):
- y = x.clone()
- x = self.conv1(x)
- x = self.conv2(x)
- z = x.clone()
- _, _, h, w = x.shape
- x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
- x = self.transformer(x)
- x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
- x = self.conv3(x)
- x = torch.cat((x, z), 1)
- x = self.conv4(x)
- x = x + y
- x = self.mv01(x)
- return x
以下是完整代码:
将以下代码复制粘贴到common.py文件的末尾
- from einops import rearrange
- def conv_1x1_bn(inp, oup):
- return nn.Sequential(
- nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- nn.SiLU()
- )
- def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
- return nn.Sequential(
- nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
- nn.BatchNorm2d(oup),
- nn.SiLU()
- )
- class PreNorm(nn.Module):
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn # mg
- def forward(self, x, **kwargs):
- return self.fn(self.norm(x), **kwargs)
- class Attention(nn.Module):
- def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
- super().__init__()
- inner_dim = dim_head * heads
- project_out = not (heads == 1 and dim_head == dim)
- self.heads = heads
- self.scale = dim_head ** -0.5
- self.attend = nn.Softmax(dim = -1)
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, dim),
- nn.Dropout(dropout)# mg
- ) if project_out else nn.Identity()
- def forward(self, x):
- qkv = self.to_qkv(x).chunk(3, dim=-1)
- q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)
- dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
- attn = self.attend(dots)
- out = torch.matmul(attn, v)
- out = rearrange(out, 'b p h n d -> b p n (h d)')
- return self.to_out(out)
- class FeedForward(nn.Module):
- def __init__(self, dim, hidden_dim, dropout=0.):
- super().__init__()
- self.net = nn.Sequential(
- nn.Linear(dim, hidden_dim),
- nn.SiLU(),
- nn.Dropout(dropout),
- nn.Linear(hidden_dim, dim),
- nn.Dropout(dropout)
- )
- def forward(self, x):
- return self.net(x)
- class MBTransformer(nn.Module):
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
- super().__init__()
- self.layers = nn.ModuleList([])
- for _ in range(depth):
- self.layers.append(nn.ModuleList([
- PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
- ]))
- def forward(self, x):
- for attn, ff in self.layers:
- x = attn(x) + x
- x = ff(x) + x
- return x
- class MV2Block(nn.Module):
- def __init__(self, inp, oup, stride=1, expansion=4):
- super().__init__()
- self.stride = stride
- assert stride in [1, 2]
- hidden_dim = int(inp * expansion)
- self.use_res_connect = self.stride == 1 and inp == oup
- if expansion == 1:
- self.conv = nn.Sequential(
- # dw
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pw-linear
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- else:
- self.conv = nn.Sequential(
- # pw
- nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # dw
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pw-linear
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- def forward(self, x):
- if self.use_res_connect:
- return x + self.conv(x)
- else:
- return self.conv(x)
- class MobileViTv3_block(nn.Module):
- def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
- super().__init__()
- self.ph, self.pw = patch_size
- self.mv01 = MV2Block(channel, channel)
- self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
- self.conv3 = conv_1x1_bn(dim, channel)
- self.conv2 = conv_1x1_bn(channel, dim)
- self.transformer = MBTransformer(dim, depth, 4, 8, mlp_dim, dropout)
- self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
- def forward(self, x):
- y = x.clone()
- x = self.conv1(x)
- x = self.conv2(x)
- z = x.clone()
- _, _, h, w = x.shape
- x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
- x = self.transformer(x)
- x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
- x = self.conv3(x)
- x = torch.cat((x, z), 1)
- x = self.conv4(x)
- x = x + y
- x = self.mv01(x)
- return x
第②步:修改yolo.py文件
再来修改yolo.py,在parse_model函数中找到 elif m is Concat: 语句,在其后面加上下面代码:
- elif m in [MobileViTv3_block]:
- c1, c2 = ch[f], args[0]
- if c2 != no:
- c2 = make_divisible(c2 * gw, 8)
- args = [c1, c2]
- if m in [MobileViTv3_block]:
- args.insert(2, n)
- n = 1
第③步:创建自定义的yaml文件
yaml文件配置完整代码如下:
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
- # Parameters
- nc: 80 # number of classes
- depth_multiple: 0.33 # model depth multiple
- width_multiple: 0.50 # layer channel iscyy multiple
- anchors:
- - [10,13, 16,30, 33,23] # P3/8
- - [30,61, 62,45, 59,119] # P4/16
- - [116,90, 156,198, 373,326] # P5/32
- # YOLOv5 v6.0 backbone
- backbone:
- # [from, number, module, args]
- [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
- [-1, 3, C3, [128]],
- [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
- [-1, 6, MobileViTv3_block, [256]],
- [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
- [-1, 9, C3, [512]],
- [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
- [-1, 3, C3, [1024]],
- [-1, 1, SPPF, [1024, 5]], # 9
- ]
- # YOLOv5 v6.0 head
- head:
- [[-1, 1, Conv, [512, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [[-1, 6], 1, Concat, [1]], # cat backbone P4
- [-1, 3, C3, [512, False]], # 13
- [-1, 1, Conv, [256, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [[-1, 4], 1, Concat, [1]], # cat backbone P3
- [-1, 3, C3, [256, False]], # 17 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]],
- [[-1, 14], 1, Concat, [1]], # cat head P4
- [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]],
- [[-1, 10], 1, Concat, [1]], # cat head P5
- [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
- [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
- ]
第④步 验证是否加入成功
运行yolo.py

这样就OK啦~