RT-DETR改进策略【SPPF】| NeuralPS-2022 Focal Modulation : 使用焦点调制模块优化空间金字塔池化SPPF
一、本文介绍
本文记录的是
利用焦点调制模块Focal Modulation改进RT-DETR的方法研究
。
Focal Modulation
利用
深度可分离卷积层
实现的焦点语境化来编码从短到长范围的视觉语境,通过
门控聚合
有选择性地为每个查询标记收集语境到调制器中,并利用
逐元素仿射变换
将调制器注入查询,优化了对视觉任务中标记交互的建模能力,提高模型性能。
二、Focal Modulation结构详解
Focal Modulation Networks
2.1 设计出发点
- 对自注意力机制的思考 :自注意力(SA)机制在视觉任务中虽有优势,但存在计算复杂度高的问题,尤其是对于高分辨率输入。许多研究通过各种方法改进SA,但作者思考是否存在比SA更好的方式来建模输入相关的长程交互。
- 现有相关工作的启发 :一些研究通过在SA中增加卷积操作来捕捉长程依赖并兼顾局部结构,但作者希望探索一种全新的机制。受焦点注意力的启发,作者尝试先聚集每个查询周围的上下文,然后用聚集的上下文自适应地调制查询,从而提出Focal Modulation机制。
2.2 原理
2.2.1 从自注意力到焦点调制
- 自注意力(SA) :使用晚期聚合程序,先计算查询和目标之间的注意力分数,然后对上下文进行聚合。
- 焦点调制(Focal Modulation) :采用早期聚合程序,先在每个位置聚合上下文特征,然后查询与聚合后的特征进行交互。
2.2.2 上下文聚合
- 分层语境化(Hierarchical Contextualization) :通过一系列深度可分离卷积层,从局部到全局范围提取不同粒度级别的上下文,每层的输出通过线性层投影和激活函数得到。
- 门控聚合(Gated Aggregation) :根据查询内容,使用线性层获取空间和层级感知的门控权重,对不同粒度级别的上下文特征进行加权求和,得到单个特征图,再通过另一个线性层得到调制器。
2.2.3 焦点调制操作
- 在得到调制器后,通过查询投影函数和元素级乘法将调制器注入到查询中,实现焦点调制。
2.3 结构
- 网络架构 :使用与Swin和Focal Transformers相同的阶段布局和隐藏维度,但将SA模块替换为Focal Modulation模块。通过指定焦点级别数量和每个级别的内核大小来构建不同的Focal Modulation Network(FocalNet)变体。
-
模块组成
- 深度可分离卷积层 :用于分层语境化,提取不同层次的上下文特征。
- 线性层 :用于投影、获取门控权重以及生成调制器等操作。
2.4 优势
-
计算效率
- 参数数量 :整体可学习参数数量主要由几个线性投影和深度可分离卷积决定,相较于一些对比模型,模型大小可通过调整相关参数得到控制。
- 时间复杂度 :除了线性投影和深度可分离卷积层,元素级乘法对每个视觉标记引入的复杂度相对较低,相比Swin Transformer的窗口注意力和ViT的普通自注意力,具有一定优势。
-
性能优势
- 在多个任务上超越对比模型 :在图像分类、目标检测和语义分割等任务上,FocalNets始终显著优于SoTA SA相关模型(如Swin和Focal Transformers),在不同的数据集和评估指标上均有体现。
- 模型解释性强 :通过可视化调制器、门控权重等,可以直观地看到模型对不同区域的关注和信息聚合方式,为模型解释提供了新的途径。
论文: https://arxiv.org/pdf/2203.11926
源码: https://github.com/microsoft/FocalNet
三、FocalModulation模块的实现代码
FocalModulation
的实现代码如下:
import torch
import torch.nn as nn
class FocalModulation(nn.Module):
def __init__(self, dim, focal_window=3, focal_level=2, focal_factor=2, bias=True, proj_drop=0.,
use_postln_in_modulation=False, normalize_modulator=False):
super().__init__()
self.dim = dim
self.focal_window = focal_window
self.focal_level = focal_level
self.focal_factor = focal_factor
self.use_postln_in_modulation = use_postln_in_modulation
self.normalize_modulator = normalize_modulator
self.f_linear = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias)
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
self.act = nn.GELU()
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
self.proj_drop = nn.Dropout(proj_drop)
self.focal_layers = nn.ModuleList()
self.kernel_sizes = []
for k in range(self.focal_level):
kernel_size = self.focal_factor * k + self.focal_window
self.focal_layers.append(
nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1,
groups=dim, padding=kernel_size // 2, bias=False),
nn.GELU(),
)
)
self.kernel_sizes.append(kernel_size)
if self.use_postln_in_modulation:
self.ln = nn.LayerNorm(dim)
def forward(self, x):
"""
Args:
x: input features with shape of (B, H, W, C)
"""
C = x.shape[1]
# pre linear projection
x = self.f_linear(x).contiguous()
q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1)
# context aggreation
ctx_all = 0.0
for l in range(self.focal_level):
ctx = self.focal_layers[l](ctx)
ctx_all = ctx_all + ctx * gates[:, l:l + 1]
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
# normalize context
if self.normalize_modulator:
ctx_all = ctx_all / (self.focal_level + 1)
# focal modulation
x_out = q * self.h(ctx_all)
x_out = x_out.contiguous()
if self.use_postln_in_modulation:
x_out = self.ln(x_out)
# post linear porjection
x_out = self.proj(x_out)
x_out = self.proj_drop(x_out)
return x_out
四、修改步骤
4.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
FocalModulation.py
,将
第三节
中的代码粘贴到此处
4.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .FocalModulation import *
4.3 修改三
在
ultralytics/nn/modules/tasks.py
文件中,需要在两处位置添加各模块类名称。
① 首先:导入模块
② 接着,在此函数下添加如下代码:
elif m in {FocalModulation}:
args = [ch[f], *args]
至此就修改完成了,可以配置模型开始训练了
五、yaml模型文件
5.1 模型改进⭐
在代码配置完成后,配置模型的YAML文件。
此处以
ultralytics/cfg/models/rt-detr/rtdetr-l.yaml
为例,在同目录下创建一个用于自己数据集训练的模型文件
rtdetr-l-FocalModulation.yaml
。
将
rtdetr-lm.yaml
中的内容复制到
rtdetr-l-FocalModulation.yaml
文件下,修改
nc
数量等于自己数据中目标的数量。
📌 模型的修改方法是将
AIFI
替换成
FocalModulation
。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, HGStem, [32, 48]] # 0-P2/4
- [-1, 6, HGBlock, [48, 128, 3]] # stage 1
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
- [-1, 6, HGBlock, [96, 512, 3]] # stage 2
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
- [-1, 6, HGBlock, [192, 1024, 5, True, True]]
- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
head:
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
- [-1, 1, FocalModulation, []]
- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
- [[-2, -1], 1, Concat, [1]]
- [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
- [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
- [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
- [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
- [[-1, 17], 1, Concat, [1]] # cat Y4
- [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
- [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
- [[-1, 12], 1, Concat, [1]] # cat Y5
- [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
- [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
六、成功运行结果
分别打印网络模型可以看到
FocalModulation模块
已经加入到模型中,并可以进行训练了。
rtdetr-l-FocalModulation :
rtdetr-l-FocalModulation summary: 683 layers, 32,291,014 parameters, 32,291,014 gradients, 107.8 GFLOPs
from n params module arguments
0 -1 1 25248 ultralytics.nn.modules.block.HGStem [3, 32, 48]
1 -1 6 155072 ultralytics.nn.modules.block.HGBlock [48, 48, 128, 3, 6]
2 -1 1 1408 ultralytics.nn.modules.conv.DWConv [128, 128, 3, 2, 1, False]
3 -1 6 839296 ultralytics.nn.modules.block.HGBlock [128, 96, 512, 3, 6]
4 -1 1 5632 ultralytics.nn.modules.conv.DWConv [512, 512, 3, 2, 1, False]
5 -1 6 1695360 ultralytics.nn.modules.block.HGBlock [512, 192, 1024, 5, 6, True, False]
6 -1 6 2055808 ultralytics.nn.modules.block.HGBlock [1024, 192, 1024, 5, 6, True, True]
7 -1 6 2055808 ultralytics.nn.modules.block.HGBlock [1024, 192, 1024, 5, 6, True, True]
8 -1 1 11264 ultralytics.nn.modules.conv.DWConv [1024, 1024, 3, 2, 1, False]
9 -1 6 6708480 ultralytics.nn.modules.block.HGBlock [1024, 384, 2048, 5, 6, True, False]
10 -1 1 524800 ultralytics.nn.modules.conv.Conv [2048, 256, 1, 1, None, 1, 1, False]
11 -1 1 272643 ultralytics.nn.AddModules.FocalModulation.FocalModulation[256]
12 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
14 7 1 262656 ultralytics.nn.modules.conv.Conv [1024, 256, 1, 1, None, 1, 1, False]
15 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
16 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
17 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
18 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
19 3 1 131584 ultralytics.nn.modules.conv.Conv [512, 256, 1, 1, None, 1, 1, False]
20 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
21 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
22 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
23 [-1, 17] 1 0 ultralytics.nn.modules.conv.Concat [1]
24 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
25 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
26 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1]
27 -1 3 2232320 ultralytics.nn.modules.block.RepC3 [512, 256, 3]
28 [21, 24, 27] 1 7303907 ultralytics.nn.modules.head.RTDETRDecoder [1, [256, 256, 256]]
rtdetr-l-FocalModulation summary: 683 layers, 32,291,014 parameters, 32,291,014 gradients, 107.8 GFLOPs