学习资源站

RT-DETR改进策略【SPPF】NeuralPS-2022FocalModulation使用焦点调制模块优化空间金字塔池化SPPF_focalmodulation-

RT-DETR改进策略【SPPF】| NeuralPS-2022 Focal Modulation : 使用焦点调制模块优化空间金字塔池化SPPF

一、本文介绍

本文记录的是 利用焦点调制模块Focal Modulation改进RT-DETR的方法研究 Focal Modulation 利用 深度可分离卷积层 实现的焦点语境化来编码从短到长范围的视觉语境,通过 门控聚合 有选择性地为每个查询标记收集语境到调制器中,并利用 逐元素仿射变换 将调制器注入查询,优化了对视觉任务中标记交互的建模能力,提高模型性能。



二、Focal Modulation结构详解

Focal Modulation Networks

2.1 设计出发点

  • 对自注意力机制的思考 :自注意力(SA)机制在视觉任务中虽有优势,但存在计算复杂度高的问题,尤其是对于高分辨率输入。许多研究通过各种方法改进SA,但作者思考是否存在比SA更好的方式来建模输入相关的长程交互。
  • 现有相关工作的启发 :一些研究通过在SA中增加卷积操作来捕捉长程依赖并兼顾局部结构,但作者希望探索一种全新的机制。受焦点注意力的启发,作者尝试先聚集每个查询周围的上下文,然后用聚集的上下文自适应地调制查询,从而提出Focal Modulation机制。

2.2 原理

2.2.1 从自注意力到焦点调制

  • 自注意力(SA) :使用晚期聚合程序,先计算查询和目标之间的注意力分数,然后对上下文进行聚合。
  • 焦点调制(Focal Modulation) :采用早期聚合程序,先在每个位置聚合上下文特征,然后查询与聚合后的特征进行交互。

在这里插入图片描述

2.2.2 上下文聚合

  • 分层语境化(Hierarchical Contextualization) :通过一系列深度可分离卷积层,从局部到全局范围提取不同粒度级别的上下文,每层的输出通过线性层投影和激活函数得到。
  • 门控聚合(Gated Aggregation) :根据查询内容,使用线性层获取空间和层级感知的门控权重,对不同粒度级别的上下文特征进行加权求和,得到单个特征图,再通过另一个线性层得到调制器。

在这里插入图片描述

2.2.3 焦点调制操作

  • 在得到调制器后,通过查询投影函数和元素级乘法将调制器注入到查询中,实现焦点调制。

2.3 结构

  • 网络架构 :使用与Swin和Focal Transformers相同的阶段布局和隐藏维度,但将SA模块替换为Focal Modulation模块。通过指定焦点级别数量和每个级别的内核大小来构建不同的Focal Modulation Network(FocalNet)变体。
  • 模块组成
    • 深度可分离卷积层 :用于分层语境化,提取不同层次的上下文特征。
    • 线性层 :用于投影、获取门控权重以及生成调制器等操作。

2.4 优势

  • 计算效率
    • 参数数量 :整体可学习参数数量主要由几个线性投影和深度可分离卷积决定,相较于一些对比模型,模型大小可通过调整相关参数得到控制。
    • 时间复杂度 :除了线性投影和深度可分离卷积层,元素级乘法对每个视觉标记引入的复杂度相对较低,相比Swin Transformer的窗口注意力和ViT的普通自注意力,具有一定优势。
  • 性能优势
    • 在多个任务上超越对比模型 :在图像分类、目标检测和语义分割等任务上,FocalNets始终显著优于SoTA SA相关模型(如Swin和Focal Transformers),在不同的数据集和评估指标上均有体现。
    • 模型解释性强 :通过可视化调制器、门控权重等,可以直观地看到模型对不同区域的关注和信息聚合方式,为模型解释提供了新的途径。

论文: https://arxiv.org/pdf/2203.11926
源码: https://github.com/microsoft/FocalNet

三、FocalModulation模块的实现代码

FocalModulation 的实现代码如下:

import torch
import torch.nn as nn

class FocalModulation(nn.Module):
    def __init__(self, dim, focal_window=3, focal_level=2, focal_factor=2, bias=True, proj_drop=0.,
                 use_postln_in_modulation=False, normalize_modulator=False):
        super().__init__()
 
        self.dim = dim
        self.focal_window = focal_window
        self.focal_level = focal_level
        self.focal_factor = focal_factor
        self.use_postln_in_modulation = use_postln_in_modulation
        self.normalize_modulator = normalize_modulator
 
        self.f_linear = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias)
        self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
 
        self.act = nn.GELU()
        self.proj = nn.Conv2d(dim, dim, kernel_size=1)
        self.proj_drop = nn.Dropout(proj_drop)
        self.focal_layers = nn.ModuleList()
 
        self.kernel_sizes = []
        for k in range(self.focal_level):
            kernel_size = self.focal_factor * k + self.focal_window
            self.focal_layers.append(
                nn.Sequential(
                    nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1,
                              groups=dim, padding=kernel_size // 2, bias=False),
                    nn.GELU(),
                )
            )
            self.kernel_sizes.append(kernel_size)
        if self.use_postln_in_modulation:
            self.ln = nn.LayerNorm(dim)
 
    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, H, W, C)
        """
        C = x.shape[1]
 
        # pre linear projection
        x = self.f_linear(x).contiguous()
        q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1)
 
        # context aggreation
        ctx_all = 0.0
        for l in range(self.focal_level):
            ctx = self.focal_layers[l](ctx)
            ctx_all = ctx_all + ctx * gates[:, l:l + 1]
        ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
        ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
 
        # normalize context
        if self.normalize_modulator:
            ctx_all = ctx_all / (self.focal_level + 1)
 
        # focal modulation
        x_out = q * self.h(ctx_all)
        x_out = x_out.contiguous()
        if self.use_postln_in_modulation:
            x_out = self.ln(x_out)
 
        # post linear porjection
        x_out = self.proj(x_out)
        x_out = self.proj_drop(x_out)
        return x_out


四、修改步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 FocalModulation.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .FocalModulation import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

① 首先:导入模块

在这里插入图片描述

② 接着,在此函数下添加如下代码:

elif m in {FocalModulation}:
     args = [ch[f], *args] 

在这里插入图片描述

至此就修改完成了,可以配置模型开始训练了


五、yaml模型文件

5.1 模型改进⭐

在代码配置完成后,配置模型的YAML文件。

此处以 ultralytics/cfg/models/rt-detr/rtdetr-l.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 rtdetr-l-FocalModulation.yaml

rtdetr-lm.yaml 中的内容复制到 rtdetr-l-FocalModulation.yaml 文件下,修改 nc 数量等于自己数据中目标的数量。

📌 模型的修改方法是将 AIFI 替换成 FocalModulation

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]] # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
  - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
  - [-1, 1, FocalModulation, []]
  - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
  - [[-2, -1], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
  - [[-1, 17], 1, Concat, [1]] # cat Y4
  - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]] # cat Y5
  - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1

  - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)


六、成功运行结果

分别打印网络模型可以看到 FocalModulation模块 已经加入到模型中,并可以进行训练了。

rtdetr-l-FocalModulation

rtdetr-l-FocalModulation summary: 683 layers, 32,291,014 parameters, 32,291,014 gradients, 107.8 GFLOPs

                  from  n    params  module                                       arguments                     
  0                  -1  1     25248  ultralytics.nn.modules.block.HGStem          [3, 32, 48]                   
  1                  -1  6    155072  ultralytics.nn.modules.block.HGBlock         [48, 48, 128, 3, 6]           
  2                  -1  1      1408  ultralytics.nn.modules.conv.DWConv           [128, 128, 3, 2, 1, False]    
  3                  -1  6    839296  ultralytics.nn.modules.block.HGBlock         [128, 96, 512, 3, 6]          
  4                  -1  1      5632  ultralytics.nn.modules.conv.DWConv           [512, 512, 3, 2, 1, False]    
  5                  -1  6   1695360  ultralytics.nn.modules.block.HGBlock         [512, 192, 1024, 5, 6, True, False]
  6                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  7                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  8                  -1  1     11264  ultralytics.nn.modules.conv.DWConv           [1024, 1024, 3, 2, 1, False]  
  9                  -1  6   6708480  ultralytics.nn.modules.block.HGBlock         [1024, 384, 2048, 5, 6, True, False]
 10                  -1  1    524800  ultralytics.nn.modules.conv.Conv             [2048, 256, 1, 1, None, 1, 1, False]
 11                  -1  1    272643  ultralytics.nn.AddModules.FocalModulation.FocalModulation[256]                         
 12                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 14                   7  1    262656  ultralytics.nn.modules.conv.Conv             [1024, 256, 1, 1, None, 1, 1, False]
 15            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 16                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 17                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 18                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 19                   3  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 20            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 21                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 22                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 23            [-1, 17]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 24                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 25                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 26            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 27                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 28        [21, 24, 27]  1   7303907  ultralytics.nn.modules.head.RTDETRDecoder    [1, [256, 256, 256]]          
rtdetr-l-FocalModulation summary: 683 layers, 32,291,014 parameters, 32,291,014 gradients, 107.8 GFLOPs