YOLOv5改进系列(21)——替换主干网络之RepViT(清华 ICCV 2023|最新开源移动端ViT)
🚀一、RepViT介绍
- 论文题目:《RepViT: Revisiting Mobile CNN From ViT Perspective》
- 论文地址:https://arxiv.org/pdf/2307.09283.pdf
- 源码地址:代码地址

1.1 简介
近年来,与轻量级卷积神经网络(cnn)相比,轻量级视觉变压器(ViTs)在资源受限的移动设备上表现出了更高的性能和更低的延迟。这种改进通常归功于多头自注意模块,它使模型能够学习全局表示。然而,轻量级vit和轻量级cnn之间的架构差异还没有得到充分的研究。
本文重点探讨了在资源有限的移动设备上,通过重新审视轻量级卷积神经网络的设计,并整合轻量级 ViTs 的有效架构选择,来提升轻量级 CNNs 的性能,即RepViT。
大量的实验表明,RepViT优于现有的轻型ViT,并在各种视觉任务中表现出良好的延迟。在ImageNet上,RepViT在iPhone 12上以近1ms的延迟实现了超过80%的top-1精度,据我们所知,这是轻量级模型的第一次。
🚀二、具体添加方法
第①步:在common.py中添加RepViT模块
将以下代码复制粘贴到common.py文件的末尾
- class Conv2d_BN(torch.nn.Sequential):
- def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
- groups=1, bn_weight_init=1, resolution=-10000):
- super().__init__()
- self.add_module('c', torch.nn.Conv2d(
- a, b, ks, stride, pad, dilation, groups, bias=False))
- self.add_module('bn', torch.nn.BatchNorm2d(b))
- torch.nn.init.constant_(self.bn.weight, bn_weight_init)
- torch.nn.init.constant_(self.bn.bias, 0)
- @torch.no_grad()
- def fuse(self):
- c, bn = self._modules.values()
- w = bn.weight / (bn.running_var + bn.eps)**0.5
- w = c.weight * w[:, None, None, None]
- b = bn.bias - bn.running_mean * bn.weight / \
- (bn.running_var + bn.eps)**0.5
- m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
- 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
- device=c.weight.device)
- m.weight.data.copy_(w)
- m.bias.data.copy_(b)
- return m
- class Residual(torch.nn.Module):
- def __init__(self, m, drop=0.):
- super().__init__()
- self.m = m
- self.drop = drop
- def forward(self, x):
- if self.training and self.drop > 0:
- return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
- device=x.device).ge_(self.drop).div(1 - self.drop).detach()
- else:
- return x + self.m(x)
- @torch.no_grad()
- def fuse(self):
- if isinstance(self.m, Conv2d_BN):
- m = self.m.fuse()
- assert(m.groups == m.in_channels)
- identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
- identity = torch.nn.functional.pad(identity, [1,1,1,1])
- m.weight += identity.to(m.weight.device)
- return m
- elif isinstance(self.m, torch.nn.Conv2d):
- m = self.m
- assert(m.groups != m.in_channels)
- identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
- identity = torch.nn.functional.pad(identity, [1,1,1,1])
- m.weight += identity.to(m.weight.device)
- return m
- else:
- return self
- class RepVGGDW(torch.nn.Module):
- def __init__(self, ed) -> None:
- super().__init__()
- self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
- self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed)
- self.dim = ed
- def forward(self, x):
- return self.conv(x) + self.conv1(x) + x
- @torch.no_grad()
- def fuse(self):
- conv = self.conv.fuse()
- conv1 = self.conv1.fuse()
- conv_w = conv.weight
- conv_b = conv.bias
- conv1_w = conv1.weight
- conv1_b = conv1.bias
- conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
- identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
- final_conv_w = conv_w + conv1_w + identity
- final_conv_b = conv_b + conv1_b
- conv.weight.data.copy_(final_conv_w)
- conv.bias.data.copy_(final_conv_b)
- return conv
- class RepViTBlock(nn.Module):
- def __init__(self,in1, inp, hidden_dim, oup, kernel_size=3, stride=2, use_se=0, use_hs=0):
- super(RepViTBlock, self).__init__()
- assert stride in [1, 2]
- self.identity = stride == 1 and inp == oup
- print(inp)
- print(hidden_dim)
- print(oup)
- assert(hidden_dim == 2 * inp)
- if stride == 2:
- self.token_mixer = nn.Sequential(
- Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
- SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
- Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
- )
- self.channel_mixer = Residual(nn.Sequential(
- # pw
- Conv2d_BN(oup, 2 * oup, 1, 1, 0),
- nn.GELU() if use_hs else nn.GELU(),
- # pw-linear
- Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
- ))
- else:
- assert(self.identity)
- self.token_mixer = nn.Sequential(
- RepVGGDW(inp),
- SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
- )
- self.channel_mixer = Residual(nn.Sequential(
- # pw
- Conv2d_BN(inp, hidden_dim, 1, 1, 0),
- nn.GELU() if use_hs else nn.GELU(),
- # pw-linear
- Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
- ))
- def forward(self, x):
- return self.channel_mixer(self.token_mixer(x))
如下图所示:

第②步:修改yolo.py文件
首先找到yolo.py里面parse_model函数的这一行

加入 RepViTBlock这个模块

第③步:创建自定义的yaml文件
yaml文件完整代码:
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
- # Parameters
- nc: 1 # number of classes
- depth_multiple: 0.33 # model depth multiple
- width_multiple: 0.50 # layer channel multiple
- anchors:
- - [10,13, 16,30, 33,23] # P3/8
- - [30,61, 62,45, 59,119] # P4/16
- - [116,90, 156,198, 373,326] # P5/32
- # CSPNet-v5
- backbone:
- # [from, number, module, args]
- [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
- [-1, 3, C3, [128]],
- [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
- [-1, 6, C3, [256]],
- [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
- [-1, 9, C3, [512]],
- [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
- [-1, 3, C3, [1024]],
- [-1, 1, SPPF, [1024, 5]], # 9
- ]
- # YOLOv5 v6.0 head
- head:
- [[-1, 1, Conv, [512, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [[-1, 6], 1, Concat, [1]], # cat backbone P4
- [-1, 3, C3, [512, False]], # 13
- [-1, 1, Conv, [256, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [[-1, 4], 1, Concat, [1]], # cat backbone P3
- [-1, 3, C3, [256, False]], # 17 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]],
- [[-1, 14], 1, Concat, [1]], # cat head P4
- [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]],
- [[-1, 10], 1, Concat, [1]], # cat head P5
- [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
- [-1, 1, RepViTBlock, [1024,1024,512]], # 23 (P5/32-large)
- [[17, 20, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
- ]
第④步:验证是否加入成功
运行yolo.py

PS:
这个也是,对我的数据集没涨点,而且我没觉得参数量小了。( ╯#-_-)╯┴—┴