学习资源站

RT-DETR改进策略【注意力机制篇】添加SE、CBAM、ECA、CA、SwinTransformer等注意力和多头注意力机制_rtdetr注意力改进-

RT-DETR改进策略【注意力机制篇】| 添加SE、CBAM、ECA、CA、Swin Transformer等注意力和多头注意力机制

前言

这篇文章带来一个经典注意力模块的汇总,虽然有些模块已经发布很久了,但后续的注意力模块也都是在此基础之上进行改进的,对于初学者来说还是有必要去学习了解一下,以加深对模块,模型的理解。



一、为什么要引入注意力机制?

来源 :注意力机制的设计灵感来源于人类视觉系统。当我们在观察外界事物时,会自动将注意力集中在重要或感兴趣的区域,而忽略无关信息。计算机视觉中的注意力机制就是在试图模拟这一过程,以提高模型的感知和理解能力。

问题 :随着图像数据量的增加,模型需要处理的信息量也随之增大。传统的卷积神经网络在处理大量数据时可能会遇到信息过载的问题,导致性能下降。注意力机制通过有选择地关注重要信息,帮助模型在海量数据中筛选出关键内容,从而提高检测精度。

好处

  • 注意力机制能够赋予输入数据的不同部分以不同的权重,使模型更加关注重要的特征信息。
  • 通过生成热力图,显示模型在做出决策时关注的具体区域,有助于更好地理解模型的决策过程,增强模型的可解释性。
  • 注意力机制使模型在处理不同数据集和任务时能够更灵活地调整其关注点。有助于提升模型的泛化能力,使其在面对新数据集或新任务时仍能保持较高的性能水平。

除了能够提升性能外,其最主要的还是其即插即用的特性,无论模块放在什么地方,都可以运行查看训练效果,更方便炼丹成功~

二、SE

2.1 SE的原理

通道注意力模块 关注于网络中每个通道的重要性, 通过为每个通道分配不同的权重,使得网络能够更加关注那些对任务更为关键的通道特征,从而提高模型的性能 。其中主要涉及 Squeeze Excitation 两个操作。

  • Squeeze操作 :通过 全局平均池化 将每个通道的特征图压缩为一个实数。
  • Excitation操作 :利用两个 全连接层 (先降维后升维)和一个 ReLU激活函数 来学习通道间的依赖关系,并通过 sigmoid函数 生成权重向量。
  • Scale操作 :将学习到的通道权重与原始特征图进行逐通道相乘,实现特征的重标定。
    在这里插入图片描述

论文: https://arxiv.org/abs/1709.01507
源码: https://github.com/hujie-frank/SENet

2.2 SE的实现代码

import torch.nn as nn

# SE
class SE(nn.Module):
    def __init__(self, c1, ratio=16):
        super(SE, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)

注意❗:在 7.2、7.3小节 中的 __init__.py tasks.py 文件中需要声明的模块名称为: SE


三、CBAM

3.1 CBAM的原理

CBAM注意力模块 通道注意力模块 空间注意力模块 两部分组成。 它通过顺序地应用通道注意力和空间注意力,使得网络能够自适应地关注到输入特征图中最重要的通道和空间位置,从而提高模型的表征能力

在这里插入图片描述

  • 通道注意力模块(CAM)
    此部分的操作步骤与 SE通道注意力 模块的步骤一致。

  • 空间注意力模块(SAM)

    • 特征提取 :在通道注意力模块处理后的特征图上,分别进行基于通道维度的 最大池化 平均池化 操作,以生成两个新的特征图。
    • 特征融合 :将两个池化后的特征图在通道维度上进行 拼接 (concatenate),然后通过一个卷积层进行特征融合,生成空间注意力权重图。
    • 特征增强 :将空间注意力权重图与原始特征图进行 逐元素相乘 ,实现特征的增强,使得模型能够关注到更重要的空间位置信息。

在这里插入图片描述

论文: https://arxiv.org/abs/1807.06521
源码: https://github.com/Jongchan/attention-module

3.2 CBAM的实现代码

import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        # 2*h*w
        x = self.conv(x)
        # 1*h*w
        return self.sigmoid(x)

class CBAM(nn.Module):
    def __init__(self, c1, c2, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

注意❗:在 7.2、7.3小节 中的 __init__.py tasks.py 文件中需要声明的模块名称为: CBAM


四、ECA

4.1 ECA的原理

ECA注意力模块 的核心思想是在不增加过多计算成本和参数的情况下, 通过引入一种有效的通道注意力机制,来增强网络对关键特征的关注能力 。它避免了通道注意力机制中可能存在的降维操作带来的性能损失,通过一种自适应的跨通道交互策略来实现通道权重的生成。步骤如下:

  • 特征压缩 :这里和 SE注意力 中的 Squeeze操作 一致,省略啦。
  • 特征学习 ECA 使用一维卷积来代替 SE注意力 中的全连接层,来学习通道间的依赖关系。这里的一维卷积核大小是自适应的,与通道维度成正比,以确保不同通道数的特征图都能有效地进行跨通道交互。通过一维卷积, ECA 能够直接捕获局部跨通道交互信息,而无需进行复杂的降维和升维操作。
  • 特征重标定 :这里也和 SE注意力 中的 Scale操作 一致。

在这里插入图片描述

论文: https://arxiv.org/abs/1910.03151
源码: https://github.com/BangguWu/ECANet

4.2 ECA的实现代码

import torch
import torch.nn as nn

class ECA(nn.Module):
    def __init__(self, c1, c2, k_size=3):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

注意❗:在 7.2、7.3小节 中的 __init__.py tasks.py 文件中需要声明的模块名称为: ECA


五、CA

5.1 CA的原理

CA注意力模块 的核心思想是将位置信息嵌入到通道注意力中,以更精确地捕捉到图像中的空间分布特征。与普通的通道注意力机制不同的是,CA不仅关注通道间的依赖关系, 还通过引入坐标信息来增强模型对空间细节的敏感度 。步骤如下:

  • 特征分解 :输入特征图通常具有C(通道数)、H(高度)、W(宽度)三个维度。CA首先通过两个并行的 全局平均池化 操作,分别沿垂直(高度)和水平(宽度)方向聚合输入特征,生成两个包含方向特定信息的特征图。这两个特征图分别捕捉了高度和宽度方向上的空间信息。
  • 特征编码 :将两个池化后的特征图在通道维度上 拼接 ,并通过一个 1x1的二维卷积层 来融合和转换特征。然后,对卷积后的特征图进行 批量归一化 非线性激活
  • 注意力图生成 :将批量归一化和激活后的特征图分裂为两个特征图,分别对应于高度和宽度方向。接着,通过另外两个 1x1的二维卷积层 分别处理这两个特征图,并应用 Sigmoid激活函数 ,生成两个注意力图。这两个注意力图分别沿宽度和高度方向对输入特征图进行重标定。
  • 特征重加权 :将通过Sigmoid激活的注意力图与原始的输入特征图相乘,以重新加权原始特征。这样,重要的特征会被放大,而不重要的特征则会减弱,从而增强模型对关键信息的关注能力。

论文: https://arxiv.org/abs/2103.02907v1
源码: https://github.com/houqb/CoordAttention

5.2 CA的实现代码

import torch
import torch.nn as nn

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        # c*1*W
        x_h = self.pool_h(x)
        # c*H*1
        # C*1*h
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        # C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out
        

注意❗:在 7.2、7.3小节 中的 __init__.py tasks.py 文件中需要声明的模块名称为: CoordAtt


六、Swin Transformer

6.1 Swin Transformer的原理

Swin Transformer 通过分层设计结合多个等级的窗口划分来降低计算复杂度,并提出位移窗口使相邻的窗口之间进行交互,从而达到全局建模的能力 。在 Swin Transformer 模型中最重要的是模块是 窗口多头自注意力(W-MSA) 移动窗口多头自注意力(SW-MSA) ,用于自注意力的计算。

  • 窗口多头自注意力(W-MSA)
    • 划分窗口 :将特征图划分为多个固定大小的窗口。
    • 自注意力计算 :在每个窗口内独立计算多头自注意力,此时计算复杂度与窗口内的小块数量成线性关系,从而降低了整体计算复杂度。
    • 输出 :得到每个窗口内的自注意力特征图。
  • 移动窗口多头自注意力(SW-MSA)
    • 窗口移动 :在W-MSA之后,通过移动窗口的方式改变窗口的划分,使得相邻的窗口之间能够产生交互。
    • 自注意力计算 :在新的窗口划分下再次计算多头自注意力。
    • 输出 :得到移动窗口后的自注意力特征图。

在这里插入图片描述

论文: https://arxiv.org/abs/2103.14030
源码: https://github.com/microsoft/Swin-Transformer

6.2 Swin Transformer的实现代码

class WindowAttention(nn.Module):

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):

        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # print(attn.dtype, v.dtype)
        try:
            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        except:
            #print(attn.dtype, v.dtype)
            x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):

    B, H, W, C = x.shape
    assert H % window_size == 0, 'feature map h and w can not divide by window size'
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class SwinTransformerLayer(nn.Module):

    def __init__(self, dim, num_heads, window_size=8, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        # if min(self.input_resolution) <= self.window_size:
        #     # if window size is larger than input resolution, we don't partition windows
        #     self.shift_size = 0
        #     self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def create_mask(self, H, W):
        # calculate attention mask for SW-MSA
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        return attn_mask

    def forward(self, x):
        # reshape x[b c h w] to x[b l c]
        _, _, H_, W_ = x.shape

        Padding = False
        if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
            Padding = True
            # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
            pad_r = (self.window_size - W_ % self.window_size) % self.window_size
            pad_b = (self.window_size - H_ % self.window_size) % self.window_size
            x = F.pad(x, (0, pad_r, 0, pad_b))

        # print('2', x.shape)
        B, C, H, W = x.shape
        L = H * W
        x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)  # b, L, c

        # create mask from init to forward
        if self.shift_size > 0:
            attn_mask = self.create_mask(H, W).to(x.device)
        else:
            attn_mask = None

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W)  # b c h w

        if Padding:
            x = x[:, :, :H_, :W_]  # reverse padding

        return x

class SwinTransformerBlock(nn.Module):
    def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
        super().__init__()
        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)

        # remove input_resolution
        self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])

    def forward(self, x):
        if self.conv is not None:
            x = self.conv(x)
        x = self.blocks(x)
        return x

class STCSPA(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPA, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        y1 = self.m(self.cv1(x))
        y2 = self.cv2(x)
        return self.cv3(torch.cat((y1, y2), dim=1))

class STCSPB(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPB, self).__init__()
        c_ = int(c2)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        x1 = self.cv1(x)
        y1 = self.m(x1)
        y2 = self.cv2(x1)
        return self.cv3(torch.cat((y1, y2), dim=1))

class STCSPC(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPC, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(c_, c_, 1, 1)
        self.cv4 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        y1 = self.cv3(self.m(self.cv1(x)))
        y2 = self.cv2(x)
        return self.cv4(torch.cat((y1, y2), dim=1))

注意❗:在 7.2、7.3小节 中的 __init__.py tasks.py 文件中需要声明的模块名称为: STCSPA, STCSPB, STCSPC ,在模型中使用哪个选哪个就行。


七、添加步骤

此处在模型配置中以 SE通道注意力 为例,列举的其他注意力模块添加步骤与此 完全一致

1. 修改ultralytics/nn/modules/block.py

此处需要修改的文件是 ultralytics/nn/modules/block.py

block.py中定义了网络结构的通用模块 ,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

SE 添加后如下:

在这里插入图片描述

2. 修改ultralytics/nn/modules/ init .py

此处需要修改的文件是 ultralytics/nn/modules/__init__.py

__init__.py 文件中定义了所有模块的初始化,我们只需要将 block.py 中的新的模块命添加到对应的函数即可。

SE block.py 中实现,所有要添加在:

from .block import (
    C1,
    C2,
    ...
    SE
)

在这里插入图片描述

3. 修改ultralytics/nn/modules/tasks.py

tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:在函数声明中引入 SE

在这里插入图片描述

在这里插入图片描述

其次:在 parse_model函数 中注册 SE模块

在这里插入图片描述

在这里插入图片描述


八、yaml模型文件

在代码配置完成后,配置模型的YAML文件。

此处以 ultralytics/cfg/models/rt-detr/rt-detr-l.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 rt-detr-l-SE.yaml

rt-detr-l.yaml 中的内容复制到 rt-detr-l-SE.yaml 文件下,修改 nc 数量等于自己数据中目标的数量。
在骨干网络的深层中添加 SE模块 只需要填入一个参数,通道数,和前一层通道数一致 还需要注意的是,由于PAN+FPN的颈部模型结构存在,层之间的匹配也要记得修改,维度要匹配上

📌 放在此处的目的是让网络能够学习到更深层的语义信息,因为此时特征图尺寸小,包含全局信息。若是希望网络能够更加关注局部信息,可尝试将注意力模块添加到网络的浅层。

📌 当然由于其即插即用的特性,加在哪里都是可以的,但是想要真的有效,还需要根据模型结构,数据集特性等多方面因素,多做实验进行验证。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]] # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P4/16
  - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P5/32
  - [-1, 1, SE, [1024]] # stage 4
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
  - [[-2, -1], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
  - [[-1, 18], 1, Concat, [1]] # cat Y4
  - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
  - [[-1, 13], 1, Concat, [1]] # cat Y5
  - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1

  - [[22, 25, 28], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)


九、成功运行结果

打印网络模型可以看到 SE模块 已经加入到模型中,并可以进行训练了。

其他模块如: CBAM、ECA、CA、Swin Transformer 这些模块和 SE 的添加步骤完全一致。

并且对于这里未提到的注意力模块的添加步骤也是一样的,只要加入模块代码,并将其添加到模型中即可。

                   from  n    params  module                                       arguments                     
  0                  -1  1     25248  ultralytics.nn.modules.block.HGStem          [3, 32, 48]                   
  1                  -1  6    155072  ultralytics.nn.modules.block.HGBlock         [48, 48, 128, 3, 6]           
  2                  -1  1      1408  ultralytics.nn.modules.conv.DWConv           [128, 128, 3, 2, 1, False]    
  3                  -1  6    839296  ultralytics.nn.modules.block.HGBlock         [128, 96, 512, 3, 6]          
  4                  -1  1      5632  ultralytics.nn.modules.conv.DWConv           [512, 512, 3, 2, 1, False]    
  5                  -1  6   1695360  ultralytics.nn.modules.block.HGBlock         [512, 192, 1024, 5, 6, True, False]
  6                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  7                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  8                  -1  1     11264  ultralytics.nn.modules.conv.DWConv           [1024, 1024, 3, 2, 1, False]  
  9                  -1  1      2048  ultralytics.nn.AddModules.SE.SE              [1024, 1024]                  
 10                  -1  6   6708480  ultralytics.nn.modules.block.HGBlock         [1024, 384, 2048, 5, 6, True, False]
 11                  -1  1    524800  ultralytics.nn.modules.conv.Conv             [2048, 256, 1, 1, None, 1, 1, False]
 12                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]                
 13                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 15                   7  1    262656  ultralytics.nn.modules.conv.Conv             [1024, 256, 1, 1, None, 1, 1, False]
 16            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 17                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 18                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 19                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 20                   3  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 21            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 22                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 23                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 24            [-1, 18]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 25                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 26                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 27            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 28                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 29        [22, 25, 28]  1   7303907  ultralytics.nn.modules.head.RTDETRDecoder    [1, [256, 256, 256]]          
rtdetr-l-SE summary: 687 layers, 32,810,179 parameters, 32,810,179 gradients, 108.0 GFLOPs