学习资源站

RT-DETR改进策略【Neck】ASF-YOLO注意力尺度序列融合模块改进颈部网络,提高小目标检测精度_rtdetr小目标-

RT-DETR改进策略【Neck】| ASF-YOLO 注意力尺度序列融合模块改进颈部网络,提高小目标检测精度

一、本文介绍

本文记录的是 利用 ASF-YOLO 提出的颈部结构优化RT-DETR的目标检测网络模型 。将 RT-DETR 的颈部网络改进成 ASF-YOLO 的结构, 使模型能够有效的融合多尺度特征,捕获小目标精细信息,并根据注意力机制关注小目标相关特征,显著提高模型精度。



二、ASF-YOLO介绍

ASF-YOLO 是一种基于YOLO的新颖框架,结合了空间和尺度特征以实现准确和快速的分割。其中, 注意力尺度序列融合模块 的设计包含以下几个关键方面:

2.1 出发点

  • 解决小目标分割挑战 :细胞实例分割因细胞的小、密集、重叠以及边界模糊等特点,对分割精度要求高。传统基于CNN的方法及一些现有架构在处理此类小目标分割时存在不足,需要一种能更好融合多尺度特征并关注小目标相关信息的方法。
  • 优化YOLO架构 :尽管YOLO系列在实时实例分割中具有优势,但对于医学图像中的小目标(如细胞)分割,其架构可进一步优化。通过设计注意尺度序列融合模块,提升模型对不同尺度小目标的处理能力和分割性能。

2.2 原理

2.2.1 多尺度特征融合

  • SSFF模块 :通过对不同尺度的特征图(P3、P4、P5)进行归一化、上采样和堆叠,然后利用3D卷积将多尺度特征组合起来,从而能够在尺度空间表示中有效处理不同大小、方向和宽高比的目标,增强了模型对小目标尺度变化的鲁棒性。
  • TFE模块 :将大、中、小三种不同尺寸的特征图在空间维度上拼接,以捕获不同尺度下小目标的精细空间信息,克服了 FPN 在YOLOv5中无法充分利用金字塔特征图相关性的局限。

2.2.2 注意力机制

  • CPAM模块 :整合 SSFF TFE 模块的特征信息,通过通道注意力网络和位置注意力网络,分别捕获与小目标相关的有信息通道和细化空间定位,使模型能够自适应地调整对不同尺度小目标相关通道和空间位置的关注,从而提高检测和分割精度。

2.3 结构

2.3.1 SSFF模块结构

  • 首先对P4和P5特征层进行 1 × 1 1×1 1 × 1 卷积,将通道数变为256,再使用最近邻插值法调整其大小与P3层相同。
  • 然后使用unsqueeze方法增加特征层维度,从3D张量变为4D张量,并沿深度维度将4D特征图拼接形成3D特征图。
  • 最后使用3D卷积、3D批归一化和SiLU激活函数完成尺度序列特征提取。

在这里插入图片描述

2.3.2 TFE模块结构

  • 对于大尺寸特征图(Large),经卷积模块处理后调整通道数为1C,然后采用最大池化+平均池化的混合结构进行下采样。
  • 对于小尺寸特征图(Small),卷积模块调整通道数后使用最近邻插值法进行上采样。
  • 最后将大、中、小三种尺寸相同的特征图在通道维度上拼接输出。

在这里插入图片描述

2.3.3 CPAM模块结构

  • 包含通道注意力网络和位置注意力网络。通道注意力网络接收TFE模块输出的特征图,采用无维度缩减的注意力机制,通过考虑每个通道及其k最近邻来捕获非线性跨通道交互。
  • 位置注意力网络接收通道注意力机制输出与SSFF模块输出叠加后的特征图,通过在水平和垂直轴上进行池化、卷积、分裂等操作,提取每个细胞的关键位置信息。

在这里插入图片描述

2.4 优势

  • 提高分割精度 :通过 SSFF模块 有效融合多尺度特征, TFE模块 捕获小目标精细信息,以及 CPAM模块 的注意力机制关注小目标相关特征,显著提高了细胞实例分割的精度,在DSB2018和BCC数据集上均取得了优于其他先进方法的结果。
  • 增强模型鲁棒性 SSFF模块 对多尺度特征的融合方式使模型对不同条件下细胞图像中小目标的尺度变化具有更强的鲁棒性。
  • 平衡精度与速度 :在实现高精度分割的同时,保持了较快的推理速度,如在DSB2018数据集上达到了47.3 FPS的推理速度,满足实时处理的需求。

论文: https://arxiv.org/pdf/2312.06458
源码: https://github.com/mkang315/ASF-YOLO

三、ASF-YOLO的实现代码

ASF-YOLO模块 的实现代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    # Pad to 'same' shape outputs
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
    default_act = nn.SiLU()  # default activation
 
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
 
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
 
    def forward_fuse(self, x):
        return self.act(self.conv(x))

class Zoom_cat(nn.Module):
    def __init__(self):
        super().__init__()
        # self.conv_l_post_down = Conv(in_dim, 2*in_dim, 3, 1, 1)
 
    def forward(self, x):
        """l,m,s表示大中小三个尺度,最终会被整合到m这个尺度上"""
        l, m, s = x[0], x[1], x[2]
        tgt_size = m.shape[2:]
        l = F.adaptive_max_pool2d(l, tgt_size) + F.adaptive_avg_pool2d(l, tgt_size)
        # l = self.conv_l_post_down(l)
        # m = self.conv_m(m)
        # s = self.conv_s_pre_up(s)
        s = F.interpolate(s, m.shape[2:], mode='nearest')
        # s = self.conv_s_post_up(s)
        lms = torch.cat([l, m, s], dim=1)
        return lms

class ScalSeq(nn.Module):
    def __init__(self, inc, channel):
        super(ScalSeq, self).__init__()
        self.conv0 = Conv(inc[0], channel, 1)
        self.conv1 = Conv(inc[1], channel, 1)
        self.conv2 = Conv(inc[2], channel, 1)
        self.conv3d = nn.Conv3d(channel, channel, kernel_size=(1, 1, 1))
        self.bn = nn.BatchNorm3d(channel)
        self.act = nn.LeakyReLU(0.1)
        self.pool_3d = nn.MaxPool3d(kernel_size=(3, 1, 1))
 
    def forward(self, x):
        p3, p4, p5 = x[0], x[1], x[2]
        p3 = self.conv0(p3)
        p4_2 = self.conv1(p4)
        p4_2 = F.interpolate(p4_2, p3.size()[2:], mode='nearest')
        p5_2 = self.conv2(p5)
        p5_2 = F.interpolate(p5_2, p3.size()[2:], mode='nearest')
        p3_3d = torch.unsqueeze(p3, -3)
        p4_3d = torch.unsqueeze(p4_2, -3)
        p5_3d = torch.unsqueeze(p5_2, -3)
        combine = torch.cat([p3_3d, p4_3d, p5_3d], dim=2)
        conv_3d = self.conv3d(combine)
        bn = self.bn(conv_3d)
        act = self.act(bn)
        x = self.pool_3d(act)
        x = torch.squeeze(x, 2)
        return x

class Add(nn.Module):
    # Concatenate a list of tensors along dimension
    def __init__(self, ch=256):
        super().__init__()
 
    def forward(self, x):
        input1, input2 = x[0], x[1]
        x = input1 + input2
        return x

class channel_att(nn.Module):
    def __init__(self, channel, b=1, gamma=2):
        super(channel_att, self).__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
 
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        y = self.avg_pool(x)
        y = y.squeeze(-1)
        y = y.transpose(-1, -2)
        y = self.conv(y).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

class local_att(nn.Module):
    def __init__(self, channel, reduction=16):
        super(local_att, self).__init__()
 
        self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1,
                                  bias=False)
 
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(channel // reduction)
 
        self.F_h = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,
                             bias=False)
        self.F_w = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,
                             bias=False)
 
        self.sigmoid_h = nn.Sigmoid()
        self.sigmoid_w = nn.Sigmoid()
 
    def forward(self, x):
        _, _, h, w = x.size()
 
        x_h = torch.mean(x, dim=3, keepdim=True).permute(0, 1, 3, 2)
        x_w = torch.mean(x, dim=2, keepdim=True)
 
        x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))
 
        x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)
 
        s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
        s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
 
        out = x * s_h.expand_as(x) * s_w.expand_as(x)
        return out

class attention_model(nn.Module):
    # Concatenate a list of tensors along dimension
    def __init__(self, ch=256):
        super().__init__()
        self.channel_att = channel_att(ch)
        self.local_att = local_att(ch)
 
    def forward(self, x):
        input1, input2 = x[0], x[1]
        input1 = self.channel_att(input1)
        x = input1 + input2
        x = self.local_att(x)
        return x



四、添加步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 ASF.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .ASF import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

其次:在 parse_model函数 中的相同位置处添加如下代码

在这里插入图片描述

elif m is Zoom_cat:
    c2 = sum(ch[x] for x in f)
elif m is Add:
    c2 = ch[f[-1]]
elif m is ScalSeq:
    c1 = [ch[x] for x in f]
    c2 = make_divisible(args[0] * width, 8)
    args = [c1, c2]
elif m is attention_model:
    args = [ch[f[-1]]]

在这里插入图片描述


五、yaml模型文件

5.1 模型改进版本⭐

此处以 ultralytics/cfg/models/rt-detr/rtdetr-l.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 rtdetr-l-ASF.yaml

rtdetr-l.yaml 中的内容复制到 rtdetr-l-ASF.yaml 文件下,修改 nc 数量等于自己数据中目标的数量。

📌 模型的修改方法是将 颈部网络 进行修改。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]] # 3 stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P4/16
  - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P5/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]] # 11
  - [-1, 1, Conv, [256, 1, 1]]  # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 14 input_proj.1
  - [[-1, 7, -2], 1, Zoom_cat, []] # 15
  - [-1, 3, RepC3, [256]]  # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]   # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 18
  - [1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 19 input_proj.0
  - [[-1, 3, -2], 1, Zoom_cat, []]  # 20 cat backbone P4
  - [-1, 3, RepC3, [256]]    # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]   # 22, downsample_convs.0
  - [[-1, 17], 1, Concat, [1]]  # 23 cat Y4
  - [-1, 3, RepC3, [256]]    # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]   # 25, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]]  # 26 cat Y5
  - [-1, 3, RepC3, [256]]    # F5 (27), pan_blocks.1

  - [[3, 7, 10], 1, ScalSeq, [256]] # 28
  - [[21, -1], 1, Add, []] # 29
  # - [[21, -1], 1, asf_attention_model, []] # 29

  - [[28, 24, 27], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到颈部网络已经修改完成,并可以进行训练了。

rtdetr-l-ASF-AFS

rtdetr-l-ASF summary: 622 layers, 30,505,556 parameters, 30,505,556 gradients, 114.1 GFLOPs

                   from  n    params  module                                       arguments                     
  0                  -1  1     25248  ultralytics.nn.modules.block.HGStem          [3, 32, 48]                   
  1                  -1  6    155072  ultralytics.nn.modules.block.HGBlock         [48, 48, 128, 3, 6]           
  2                  -1  1      1408  ultralytics.nn.modules.conv.DWConv           [128, 128, 3, 2, 1, False]    
  3                  -1  6    839296  ultralytics.nn.modules.block.HGBlock         [128, 96, 512, 3, 6]          
  4                  -1  1      5632  ultralytics.nn.modules.conv.DWConv           [512, 512, 3, 2, 1, False]    
  5                  -1  6   1695360  ultralytics.nn.modules.block.HGBlock         [512, 192, 1024, 5, 6, True, False]
  6                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  7                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  8                  -1  1     11264  ultralytics.nn.modules.conv.DWConv           [1024, 1024, 3, 2, 1, False]  
  9                  -1  6   6708480  ultralytics.nn.modules.block.HGBlock         [1024, 384, 2048, 5, 6, True, False]
 10                  -1  1    524800  ultralytics.nn.modules.conv.Conv             [2048, 256, 1, 1, None, 1, 1, False]
 11                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]                
 12                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 14                   3  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 15         [-1, 7, -2]  1         0  ultralytics.nn.AddModules.ASF.Zoom_cat       []                            
 16                  -1  3   2756608  ultralytics.nn.modules.block.RepC3           [1536, 256, 3]                
 17                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 18                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 19                   1  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 20         [-1, 3, -2]  1         0  ultralytics.nn.AddModules.ASF.Zoom_cat       []                            
 21                  -1  3   2494464  ultralytics.nn.modules.block.RepC3           [1024, 256, 3]                
 22                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 23            [-1, 17]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 24                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 25                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 26            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 27                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 28          [3, 7, 10]  1    526592  ultralytics.nn.AddModules.ASF.ScalSeq        [[512, 1024, 256], 256]       
 29            [21, -1]  1         0  ultralytics.nn.AddModules.ASF.Add            []                            
 30        [28, 24, 27]  1   3917684  ultralytics.nn.modules.head.RTDETRDecoder    [1, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-l-ASF summary: 622 layers, 30,505,556 parameters, 30,505,556 gradients, 114.1 GFLOPs