学习资源站

RT-DETR改进策略【注意力机制篇】CVPR-2023FSAS基于频域的自注意力求解器结合频域计算和卷积操作降低噪声影响_fsas注意力模块-

RT-DETR改进策略【注意力机制篇】| CVPR-2023 FSAS 基于频域的自注意力求解器 结合频域计算和卷积操作 降低噪声影响

一、本文介绍

本文记录的是 利用 FSAS 模块优化 RT-DETR 的目标检测网络模型 FSAS 全称为: Frequency domain-based Self-Attention Solver ,其结合了 频域计算 的高效性和 卷积操作 的特性,有效地 降低了注意力计算的复杂度 。在加入到 RT-DETR 网络中, 提升图像特征的表示能力 ,特别是在 处理高分辨率图像时能够减少计算成本并降低噪声影响。



二、FSAS介绍

Efficient Frequency Domain-based Transformers for High-Quality Image Deblurring

2.1 出发点

  • 降低复杂度 :在视觉Transformer中,计算缩放点积注意力的空间复杂度和时间复杂度较高,分别为 O ( N 2 ) O(N^{2}) O ( N 2 ) O ( N 2 C ) O(N^{2}C) O ( N 2 C ) ,当图像分辨率和提取的补丁数量较大时,计算成本难以承受。虽然可以通过下采样操作或减少补丁数量来缓解,但会导致信息丢失。因此需要一种更高效的方法来估计注意力图,以降低复杂度。
  • 利用频域特性 :基于卷积定理,即空间域中两个信号的相关性或卷积等同于频域中它们的逐元素乘积。作者思考能否在频域中通过逐元素乘积操作而不是在空间域中计算矩阵乘法来有效地估计注意力图。

2.2 原理

2.2.1 基于卷积操作的转换

首先注意到缩放点积注意力计算中 Q K ⊤ QK^{\top} Q K 的每个元素是通过内积获得的,基于此,如果对 q i q_{i} q i 和所有的 k j k_{j} k j 分别应用重塑函数, Q K ⊤ QK^{\top} Q K 的所有第 i i i 列元素可以通过 卷积操作 获得。

2.2.2 频域转换与计算

根据卷积定理,进一步在 频域 中进行计算。通过对估计的特征 F q F_{q} F q F k F_{k} F k 应用 快速傅里叶变换(FFT) ,并在 频域 中估计它们的相关性 A = F − 1 ( F ( F q ) F ( F k ) ‾ ) A=\mathcal{F}^{-1}\left(\mathcal{F}\left(F_{q}\right) \overline{\mathcal{F}\left(F_{k}\right)}\right) A = F 1 ( F ( F q ) F ( F k ) ) 其中 F ( . ) \mathcal{F}(.) F ( . ) 表示 FFT F − 1 ( ⋅ ) \mathcal{F}^{-1}(\cdot) F 1 ( ) 表示 逆FFT F ( . ) ‾ \overline{\mathcal{F}(.)} F ( . ) 表示共轭转置操作。

最后通过 V a t t = L ( A ) F v V_{a t t}=\mathcal{L}(A) F_{v} V a tt = L ( A ) F v 估计聚合特征,其中 L ( ⋅ ) \mathcal{L}(\cdot) L ( ) 层归一化 操作。

2.3结构

图(b) 所示,详细网络结构如下:

  • 卷积层获取特征 :首先通过 1×1 点卷积 3×3 深度卷积 获得 F q F_{q} F q F k F_{k} F k F v F_{v} F v
  • 频域转换层 :对 F q F_{q} F q F k F_{k} F k 应用 FFT 进行 频域转换
  • 相关性计算层 :在频域中计算相关性 A A A
  • 特征聚合层 :通过 层归一化 和与 F v F_{v} F v 的计算得到 V a t t V_{a t t} V a tt
  • 输出层 :最后通过 X a t t = X + C o n v 1 × 1 ( V a t t ) X_{a t t}=X + Conv_{1×1}(V_{a t t}) X a tt = X + C o n v 1 × 1 ( V a tt ) 生成 FSAS 的输出特征,其中 C o n v 1 × 1 ( ⋅ ) Conv_{1×1}(\cdot) C o n v 1 × 1 ( ) 是1×1像素的卷积。

在这里插入图片描述

2.4 优势

  • 复杂度降低
    • 空间复杂度 FSAS 的空间复杂度降低到 O ( N ) O(N) O ( N ) ,相比原始缩放点积注意力的 O ( N 2 ) O(N^{2}) O ( N 2 ) 显著降低。
    • 时间复杂度 :时间复杂度降低到 O ( N l o g N ) O(N log N) O ( Nl o g N ) ,对于每个特征通道,大大减少了计算量。
  • 性能提升
    • 与空间域方法比较 :在GoPro数据集上的实验表明,与在空间域计算缩放点积注意力的方法(如Swin Transformer)相比, FSAS 能生成更好的去模糊结果。例如,空间域方法的PSNR值比FSAS + DFFN低0.27dB。
    • 与仅使用FFN的方法比较 :与仅使用FFN的基线方法相比,使用 FSAS 的PSNR值提高了0.42dB,能更好地去除模糊,边界恢复得更好。

论文: https://arxiv.org/pdf/2211.12250
源码: https://github.com/kkkls/FFTformer

三、FSAS的实现代码

FSAS 及其改进的实现代码如下:

import torch
import torch.nn as nn
import numbers
from einops import rearrange

from ultralytics.nn.modules.conv import LightConv 
 
def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')
 
def to_4d(x, h, w):
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
 
class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape) 
 
        assert len(normalized_shape) == 1  
 
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape
 
    def forward(self, x):
        mu = x.mean(-1, keepdim=True)  
        sigma = x.var(-1, keepdim=True, unbiased=False)  
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
 
class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)
 
        assert len(normalized_shape) == 1
 
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape
 
    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)  

        return x / torch.sqrt(sigma + 1e-5) * self.weight
 
class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)  
        else:
            self.body = WithBias_LayerNorm(dim) 

    def forward(self, x):
        h, w = x.shape[-2:] 
        return to_4d(self.body(to_3d(x)), h, w)
 
class FSAS(nn.Module):
    def __init__(self, dim, bias=False):
        super(FSAS, self).__init__()

        self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)

        self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)
 
        self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
 
        self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
 
        self.patch_size = 8  

    def forward(self, x):

        hidden = self.to_hidden(x)
 
        q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1)
 
        q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
                            patch2=self.patch_size)
        k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
                            patch2=self.patch_size)
 
        q_fft = torch.fft.rfft2(q_patch.float())
        k_fft = torch.fft.rfft2(k_patch.float())
 
        out = q_fft * k_fft
        out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
 
        out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
                        patch2=self.patch_size)
 
        out = self.norm(out)
        output = v * out
        output = self.project_out(output)
 
        return output
 
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
 
    default_act = nn.SiLU()  # default activation
 
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
 
    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))
 
    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))
 
class HGBlock_FSAS(nn.Module):
    """
    HG_Block of PPHGNetV2 with 2 convolutions and LightConv.

    https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
    """

    def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
        """Initializes a CSP Bottleneck with 1 convolution using specified input and output channels."""
        super().__init__()
        block = LightConv if lightconv else Conv
        self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
        self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act)  # squeeze conv
        self.ec = Conv(c2 // 2, c2, 1, 1, act=act)  # excitation conv
        self.add = shortcut and c1 == c2
        self.cv = FSAS(c2)
        
    def forward(self, x):
        """Forward pass of a PPHGNetV2 backbone layer."""
        y = [x]
        y.extend(m(y[-1]) for m in self.m)
        y = self.cv(self.ec(self.sc(torch.cat(y, 1))))
        return y + x if self.add else y


四、创新模块

4.1 改进点⭐

模块改进方法 :基于 FSAS模块 HGBlock 第五节讲解添加步骤 )。

第二种改进方法是对 RT-DETR 中的 HGBlock模块 进行改进,并将 FSAS 在加入到 HGBlock 模块中。

改进代码如下:

HGBlock 模块进行改进,加入 FSAS模块 ,并重命名为 HGBlock_FSAS

class HGBlock_FSAS(nn.Module):
    """
    HG_Block of PPHGNetV2 with 2 convolutions and LightConv.

    https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
    """

    def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
        """Initializes a CSP Bottleneck with 1 convolution using specified input and output channels."""
        super().__init__()
        block = LightConv if lightconv else Conv
        self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
        self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act)  # squeeze conv
        self.ec = Conv(c2 // 2, c2, 1, 1, act=act)  # excitation conv
        self.add = shortcut and c1 == c2
        self.cv = FSAS(c2)
        
    def forward(self, x):
        """Forward pass of a PPHGNetV2 backbone layer."""
        y = [x]
        y.extend(m(y[-1]) for m in self.m)
        y = self.cv(self.ec(self.sc(torch.cat(y, 1))))
        return y + x if self.add else y
 

在这里插入图片描述

注意❗:在 第五小节 中需要声明的模块名称为: HGBlock_FSAS


五、添加步骤

5.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 FSAS.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

5.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .FSAS import *

在这里插入图片描述

5.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

其次:在 parse_model函数 中注册 HGBlock_FSAS 模块

在这里插入图片描述

在这里插入图片描述


六、yaml模型文件

6.1 模型改进版本⭐

此处以 ultralytics/cfg/models/rt-detr/rtdetr-l.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 rtdetr-l-HGBlock_FSAS.yaml

rtdetr-l.yaml 中的内容复制到 rtdetr-l-HGBlock_FSAS.yaml 文件下,修改 nc 数量等于自己数据中目标的数量。

📌 模型的修改方法是将 骨干网络 中的 HGBlock模块 替换成 HGBlock_FSAS模块

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]] # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P4/16
  - [-1, 6, HGBlock_FSAS, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock_FSAS, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock_FSAS, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P5/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
  - [[-2, -1], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
  - [[-1, 17], 1, Concat, [1]] # cat Y4
  - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]] # cat Y5
  - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1

  - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)


七、成功运行结果

打印网络模型可以看到 HGBlock_FSAS 已经加入到模型中,并可以进行训练了。

rtdetr-l-HGBlock_FSAS

rtdetr-l-HGBlock_FSAS summary: 700 layers, 58,152,131 parameters, 58,152,131 gradients, 189.1 GFLOPs

                   from  n    params  module                                       arguments                     
  0                  -1  1     25248  ultralytics.nn.modules.block.HGStem          [3, 32, 48]                   
  1                  -1  6    155072  ultralytics.nn.modules.block.HGBlock         [48, 48, 128, 3, 6]           
  2                  -1  1      1408  ultralytics.nn.modules.conv.DWConv           [128, 128, 3, 2, 1, False]    
  3                  -1  6    839296  ultralytics.nn.modules.block.HGBlock         [128, 96, 512, 3, 6]          
  4                  -1  1      5632  ultralytics.nn.modules.conv.DWConv           [512, 512, 3, 2, 1, False]    
  5                  -1  6  10143360  ultralytics.nn.AddModules.FSAS.HGBlock_FSAS  [512, 192, 1024, 5, 6, True, False]
  6                  -1  6  10503808  ultralytics.nn.AddModules.FSAS.HGBlock_FSAS  [1024, 192, 1024, 5, 6, True, True]
  7                  -1  6  10503808  ultralytics.nn.AddModules.FSAS.HGBlock_FSAS  [1024, 192, 1024, 5, 6, True, True]
  8                  -1  1     11264  ultralytics.nn.modules.conv.DWConv           [1024, 1024, 3, 2, 1, False]  
  9                  -1  6   6708480  ultralytics.nn.modules.block.HGBlock         [1024, 384, 2048, 5, 6, True, False]
 10                  -1  1    524800  ultralytics.nn.modules.conv.Conv             [2048, 256, 1, 1, None, 1, 1, False]
 11                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]                
 12                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 14                   7  1    262656  ultralytics.nn.modules.conv.Conv             [1024, 256, 1, 1, None, 1, 1, False]
 15            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 16                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 17                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 18                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 19                   3  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 20            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 21                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 22                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 23            [-1, 17]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 24                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 25                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 26            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 27                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 28        [21, 24, 27]  1   7303907  ultralytics.nn.modules.head.RTDETRDecoder    [1, [256, 256, 256]]          
rtdetr-l-HGBlock_FSAS summary: 700 layers, 58,152,131 parameters, 58,152,131 gradients, 189.1 GFLOPs