RT-DETR改进策略【SPPF】| SimSPPF,简化空间金字塔池化设计,提高计算效率
一、本文介绍
本文记录的是
基于SimSPPF模块的RT-DETR目标检测改进方法研究
。介绍了
SPP
,
SPPF
以及
SimSPPF
。
SimSPPF
的设计更加简化,计算效率更高。
二、空间金字塔池化
2.1 SPP
空间金字塔池化(Spatial Pyramid Pooling,SPP)
主要是为了解决输入图像尺寸不固定的问题。在传统的卷积神经网络中,通常要求输入图像具有固定的尺寸,这在实际应用中会带来很多不便。而空间金字塔池化能够对不同尺寸的输入图像产生固定长度的输出,使得网络可以接受任意尺寸的图像输入。
原理
- 对于输入的特征图,空间金字塔池化将其划分成不同尺度的子区域。例如,可以将特征图划分成多个不同大小的网格,形成多个层次的空间金字塔结构。
- 对每个子区域分别进行池化操作,例如最大池化或平均池化。这样可以得到不同尺度的特征表示。
- 将不同尺度的池化结果进行拼接,得到一个固定长度的特征向量。这个特征向量可以作为后续全连接层的输入。
特点
- 灵活性:可以处理任意尺寸的输入图像,无需对图像进行裁剪或缩放等预处理操作,保留了图像的原始信息。
- 多尺度特征提取:通过不同尺度的空间金字塔结构,能够提取图像在不同尺度下的特征,增强了网络对不同大小目标的适应能力。
- 减少过拟合:由于可以接受不同尺寸的图像输入,增加了数据的多样性,有助于减少过拟合的风险。
2.2 SPPF
SPPF(Spatial Pyramid Pooling - Fast)
是一种空间金字塔池化的改进版本。
原理
- 首先,将输入特征图划分成多个不同大小的区域。这些区域可以是固定大小的网格,也可以是根据特定规则划分的区域。
- 对每个区域进行快速池化操作,例如最大池化或平均池化。快速池化可以通过一些高效的算法实现,以减少计算时间。
- 将不同尺度的池化结果进行融合,可以通过拼接、相加或其他融合方式。这样得到的融合特征包含了多尺度的信息。
- 最后,将融合后的特征输出,作为后续网络层的输入。
特点
-
高效性:相比传统的空间金字塔池化,
SPPF在保持相似性能的同时,具有更高的计算效率。它通过对特征图进行更快速的池化操作,减少了计算量和处理时间。 -
多尺度特征融合:与
SPP一样,SPPF也能够提取多尺度的特征信息。它将输入特征图划分成不同大小的区域,并进行池化操作,然后将这些不同尺度的池化结果进行融合,得到更丰富的特征表示。 - 灵活性:可以很容易地集成到各种卷积神经网络架构中,适用于不同的任务和应用场景。
2.3 SimSPPF
SimSPPF(Simplified Spatial Pyramid Pooling - Fast)模块
是
YOLOv6
中提出的一种简化的空间金字塔池化模块,主要用于计算机视觉任务中的特征提取。以下是其设计原理及特点。
SimSPPF模块
由两个主要部分组成:
-
一系列卷积操作:包括一个初始的
SimConv卷积层用于将输入特征图进行初步处理,降低通道数为原来的一半。其中SimConv是一个自定义的卷积模块,包含卷积操作(nn.Conv2d)、批归一化(nn.BatchNorm2d)和ReLU激活函数。它的作用是对输入特征图进行卷积操作以提取特征,并通过批归一化来加速训练过程和提高模型的稳定性,ReLU激活函数则引入非线性,增强模型的表达能力。 -
多次最大池化和拼接操作:通过多次最大池化操作和拼接操作,实现对不同尺度特征的融合,最后再经过一个
SimConv卷积层将融合后的特征图转换为指定的输出通道数。
通过以上设计,
SimSPPF模块
能够有效地提取多尺度特征,并融合这些特征以增强模型对不同大小物体的识别能力。同时,简化的设计使得计算效率更高,适用于对实时性要求较高的计算机视觉任务。
论文: https://arxiv.org/abs/2209.02976
源码: https://github.com/meituan/YOLOv6
三、SimSPPF模块的实现代码
SimSPPF模块
的实现代码如下:
import warnings
import torch
import torch.nn as nn
class SimConv(nn.Module):
'''Normal Conv with ReLU activation'''
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class SimSPPF(nn.Module):
'''Simplified SPPF with ReLU activation'''
def __init__(self, in_channels, out_channels, kernel_size=5):
super().__init__()
c_ = in_channels // 2 # hidden channels
self.cv1 = SimConv(in_channels, c_, 1, 1)
self.cv2 = SimConv(c_ * 4, out_channels, 1, 1)
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
四、创新模块
4.1 改进点⭐
模块改进方法
:加入
SimSPPF模块
(
第五节讲解添加步骤
)。
SimSPPF模块
添加后如下:
五、添加步骤
5.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
SimSPPF.py
,将
第三节
中的代码粘贴到此处
5.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .SimSPPF import *
5.3 修改三
在
ultralytics/nn/modules/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
其次:在
parse_model函数
中注册
SimSPPF
模块:
六、yaml模型文件
6.1 模型改进版本
此处以
ultralytics/cfg/models/rt-detr/rtdetr-l.yaml
为例,在同目录下创建一个用于自己数据集训练的模型文件
rtdetr-l-SimSPPF.yaml
。
将
rtdetr-l.yaml
中的内容复制到
rtdetr-l-SimSPPF.yaml
文件下,修改
nc
数量等于自己数据中目标的数量。
📌 模型的修改方法是将
AIFI
替换成
SimSPPF模块
。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, HGStem, [32, 48]] # 0-P2/4
- [-1, 6, HGBlock, [48, 128, 3]] # stage 1
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
- [-1, 6, HGBlock, [96, 512, 3]] # stage 2
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
- [-1, 6, HGBlock, [192, 1024, 5, True, True]]
- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
head:
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
- [-1, 1, SimSPPF, [1024, 5]]
- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
- [[-2, -1], 1, Concat, [1]]
- [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
- [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
- [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
- [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
- [[-1, 17], 1, Concat, [1]] # cat Y4
- [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
- [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
- [[-1, 12], 1, Concat, [1]] # cat Y5
- [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
- [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
七、成功运行结果
打印网络模型可以看到
SimSPPF
已经加入到模型中,并可以进行训练了。
rtdetr-l-SimSPPF :
rtdetr-l-SimSPPF summary: 680 layers, 32,774,339 parameters, 32,774,339 gradients, 108.2 GFLOPs
from n params module arguments
0 -1 1 25248 ultralytics.nn.modules.block.HGStem [3, 32, 48]
1 -1 6 155072 ultralytics.nn.modules.block.HGBlock [48, 48, 128, 3, 6]
2 -1 1 1408 ultralytics.nn.modules.conv.DWConv [128, 128, 3, 2, 1, False]
3 -1 6 839296 ultralytics.nn.modules.block.HGBlock [128, 96, 512, 3, 6]
4 -1 1 5632 ultralytics.nn.modules.conv.DWConv [512, 512, 3, 2, 1, False]
5 -1 6 1695360 ultralytics.nn.modules.block.HGBlock [512, 192, 1024, 5, 6, True, False]
6 -1 6 2055808 ultralytics.nn.modules.block.HGBlock [1024, 192, 1024, 5, 6, True, True]
7 -1 6 2055808 ultralytics.nn.modules.block.HGBlock [1024, 192, 1024, 5, 6, True, True]
8 -1 1 11264 ultralytics.nn.modules.conv.DWConv [1024, 1024, 3, 2, 1, False]
9 -1 6 6708480 ultralytics.nn.modules.block.HGBlock [1024, 384, 2048, 5, 6, True, False]
10 -1 1 524800 ultralytics.nn.modules.conv.Conv [2048, 256, 1, 1, None, 1, 1, False]
11 -1 1 559360 ultralytics.nn.AddModules.SimSPPF.SimSPPF [256, 1024]
12 -1 1 262656 ultralytics.nn.modules.conv.Conv [1024, 256, 1, 1]
13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
14 7 1 262656 ultralytics.nn.modules.conv.Conv [1024, 256, 1, 1, None, 1, 1, False]
15 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
16 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
17 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
18 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
19 3 1 131584 ultralytics.nn.modules.conv.Conv [512, 256, 1, 1, None, 1, 1, False]
20 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
21 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
22 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
23 [-1, 17] 1 0 ultralytics.nn.modules.conv.Concat [1]
24 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
25 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
26 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1]
27 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
28 [21, 24, 27] 1 7303907 ultralytics.nn.modules.head.RTDETRDecoder [1, [256, 256, 256]]
rtdetr-l-SimSPPF summary: 680 layers, 32,774,339 parameters, 32,774,339 gradients, 108.2 GFLOPs