【RT-DETR多模态融合改进】| BMVC 2024 MASAG 模块(多尺度自适应空间注意门):动态感受野与空间注意力增强多尺度融合精度
一、本文介绍
本文记录的是利用 M S A 2 N e t MSA^{2}Net MS A 2 N e t 中的 MASAG 模块改进 RT-DETR的多模态融合部分 。
MASAG
(Multi - Scale Adaptive Spatial Attention Gate) 模块
通过动态调制空间注意力权重与多尺度感受野
,实现了
对跨层级特征图中局部细节与全局语义的智能聚合
。将其应用于
YOLOv12
的改进过程中,针对目标检测中
浅层边界特征
与
深层语义信息
的互补性需求,
增强对多尺度目标的特征表达能力
,提升复杂场景下的检测精度与边界定位准确性。
二、MASAG 介绍
MSA^2Net: Multi-scale Adaptive Attention-guided Network for Medical Image Segmentation
2.1 设计出发点
- 传统基于卷积神经网络(CNNs)的方法 难以有效捕捉长距离依赖,且受限于静态感受野 ;
- 基于Transformer的方法在 捕捉局部表示和上下文方面存在不足 ,计算效率也较低。
- 虽然已有多种改进方法,但大多只专注于改进编码器或解码器, 未对原始跳跃连接设计进行有效改变 。
因此,需要一种新的模块来优化跳跃连接,更好地融合局部和全局特征,提高医学图像分割的精度。
2.2 结构
- 多尺度特征融合 :该模块融合局部上下文提取和全局上下文提取,通过深度卷积、扩张卷积拓宽编码器高分辨率空间细节(X)的空间范围,利用通道池化获取解码器语义信息(G)中的广泛上下文信息,将两者整合形成综合特征图,公式为: U = C o n v 1 × 1 ( D W − D ( D W ( X ) ) ) + C o n v 1 × 1 ( [ P A v g ( G ) ; P M a x ( G ) ] ) U = Conv_{1×1}(DW - D(DW(X))) + Conv_{1×1}([P_{Avg}(G);P_{Max}(G)]) U = C o n v 1 × 1 ( D W − D ( D W ( X ))) + C o n v 1 × 1 ([ P A vg ( G ) ; P M a x ( G )])
- 空间选择 :将融合后的特征图(U)投影到两个通道,通过通道软max计算空间选择性权重,得到空间选择后的 X ′ X' X ′ 和 G ′ G' G ′ ,并使用两个残差连接优化梯度流和特征利用,公式包括: S W i = S ( C o n v 1 × 1 ( U ) ) ∀ i ∈ [ 1 , 2 ] SW_{i}=S(Conv_{1 × 1}(U)) \forall i \in[1,2] S W i = S ( C o n v 1 × 1 ( U )) ∀ i ∈ [ 1 , 2 ] X ′ = S W 1 ⊗ X + X , G ′ = S W 2 ⊗ G + G X'=SW_{1} \otimes X+X, G'=SW_{2} \otimes G+G X ′ = S W 1 ⊗ X + X , G ′ = S W 2 ⊗ G + G
- 空间交互和交叉调制 :此部分对 X ′ X' X ′ 和 G ′ G' G ′ 进行动态增强,使 X ′ X' X ′ 融入 G ′ G' G ′ 的全局上下文, G ′ G' G ′ 融入 X ′ X' X ′ 的详细上下文,最后将两者相乘融合,公式为: X ′ ′ = X ′ ⊗ σ ( G ′ ) , G ′ ′ = G ′ ⊗ σ ( X ′ ) X''=X' \otimes \sigma\left(G'\right), G''=G' \otimes \sigma\left(X'\right) X ′′ = X ′ ⊗ σ ( G ′ ) , G ′′ = G ′ ⊗ σ ( X ′ ) U ′ = X ′ ′ ⊗ G ′ ′ U'=X'' \otimes G'' U ′ = X ′′ ⊗ G ′′
- 重新校准 :对空间交互和交叉调制得到的融合特征图进行逐点卷积和sigmoid激活,生成注意力图,用其重新校准编码器的初始输入X,公式为: X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ′ ) ) ⊗ X ) X=Conv_{1 × 1}\left(\sigma\left(Conv_{1 × 1}\left(U'\right)\right) \otimes X\right) X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ′ ) ) ⊗ X )
2.3 优势
- 动态调整感受野 :能够根据输入动态调整感受野,突出空间相关特征,抑制无关背景细节,提高对不同大小、形状和结构的器官及病变的分割准确性 。
- 有效融合特征 :通过多尺度融合、空间交互等操作,实现局部和全局特征的有效融合,使模型在分割时能综合考虑详细信息和整体结构,在复杂医学图像中精准分割目标结构。
- 提升模型性能 :在多个医学数据集(如Synapse数据集和ISIC 2018数据集)上的实验表明,包含MASAG模块的(MSA {2}Net)方法在多项指标上优于或与当前最优方法相当,如在Synapse数据集上,(MSA {2}Net)的Dice相似系数(DSC)比2D版本的D - LKA高0.48%,比DAE - Former高2.21% ,充分证明了该模块对提升模型性能的有效性。
论文: https://arxiv.org/abs/2407.21640
源码: https://github.com/xmindflow/MSA-2Net
三、MASAG 的实现代码
MASAG 模块
的实现代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
def num_trainable_params(model):
nums = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
return nums
class GlobalExtraction(nn.Module):
def __init__(self,dim = None):
super().__init__()
self.avgpool = self.globalavgchannelpool
self.maxpool = self.globalmaxchannelpool
self.proj = nn.Sequential(
nn.Conv2d(2, 1, 1,1),
nn.BatchNorm2d(1)
)
def globalavgchannelpool(self, x):
x = x.mean(1, keepdim = True)
return x
def globalmaxchannelpool(self, x):
x = x.max(dim = 1, keepdim=True)[0]
return x
def forward(self, x):
x_ = x.clone()
x = self.avgpool(x)
x2 = self.maxpool(x_)
cat = torch.cat((x,x2), dim = 1)
proj = self.proj(cat)
return proj
class ContextExtraction(nn.Module):
def __init__(self, dim, reduction = None):
super().__init__()
self.reduction = 1 if reduction == None else 2
self.dconv = self.DepthWiseConv2dx2(dim)
self.proj = self.Proj(dim)
def DepthWiseConv2dx2(self, dim):
dconv = nn.Sequential(
nn.Conv2d(in_channels = dim,
out_channels = dim,
kernel_size = 3,
padding = 1,
groups = dim),
nn.BatchNorm2d(num_features = dim),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels = dim,
out_channels = dim,
kernel_size = 3,
padding = 2,
dilation = 2),
nn.BatchNorm2d(num_features = dim),
nn.ReLU(inplace = True)
)
return dconv
def Proj(self, dim):
proj = nn.Sequential(
nn.Conv2d(in_channels = dim,
out_channels = dim //self.reduction,
kernel_size = 1
),
nn.BatchNorm2d(num_features = dim//self.reduction)
)
return proj
def forward(self,x):
x = self.dconv(x)
x = self.proj(x)
return x
class MultiscaleFusion(nn.Module):
def __init__(self, dim):
super().__init__()
self.local= ContextExtraction(dim)
self.global_ = GlobalExtraction()
self.bn = nn.BatchNorm2d(num_features=dim)
def forward(self, x, g,):
x = self.local(x)
g = self.global_(g)
fuse = self.bn(x + g)
return fuse
class MultiScaleGatedAttn(nn.Module):
# Version 1
def __init__(self, dims):
super().__init__()
dim = min(dims)
if dims[0] != dims[1]:
self.conv1 = Conv(dims[0], dim)
self.conv2 = Conv(dims[1], dim)
self.multi = MultiscaleFusion(dim)
self.selection = nn.Conv2d(dim, 2,1)
self.proj = nn.Conv2d(dim, dim,1)
self.bn = nn.BatchNorm2d(dim)
self.bn_2 = nn.BatchNorm2d(dim)
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels=dim, out_channels=dim,
kernel_size=1, stride=1))
def forward(self, inputs):
x, g = inputs
if x.size(1) != g.size(1):
x = self.conv1(x)
g = self.conv2(g)
x_ = x.clone()
g_ = g.clone()
#stacked = torch.stack((x_, g_), dim = 1) # B, 2, C, H, W
multi = self.multi(x, g) # B, C, H, W
### Option 2 ###
multi = self.selection(multi) # B, num_path, H, W
attention_weights = F.softmax(multi, dim=1) # Shape: [B, 2, H, W]
#attention_weights = torch.sigmoid(multi)
A, B = attention_weights.split(1, dim=1) # Each will have shape [B, 1, H, W]
x_att = A.expand_as(x_) * x_ # Using expand_as to match the channel dimensions
g_att = B.expand_as(g_) * g_
x_att = x_att + x_
g_att = g_att + g_
## Bidirectional Interaction
x_sig = torch.sigmoid(x_att)
g_att_2 = x_sig * g_att
g_sig = torch.sigmoid(g_att)
x_att_2 = g_sig * x_att
interaction = x_att_2 * g_att_2
projected = torch.sigmoid(self.bn(self.proj(interaction)))
weighted = projected * x_
y = self.conv_block(weighted)
#y = self.bn_2(weighted + y)
y = self.bn_2(y)
return y
四、添加步骤
4.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
MultiScaleGatedAttn.py
,将
第三节
中的代码粘贴到此处
4.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .MultiScaleGatedAttn import *
4.3 修改三
在
ultralytics/nn/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
然后,在
parse_model函数
中添加如下代码:
elif m in {MultiScaleGatedAttn}:
c1 = [ch[x] for x in f]
c2 = min(c1)
args = [c1]
五、yaml模型文件
5.1 中期融合⭐
📌 此模型的修方法是将MASAG模块应用到RT-DETR的中期融合中。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-ResNet50 object detection model with P3-P5 outputs.
# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
- [1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 3-P1
- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 4
- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 5
- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 6-P2
- [-1, 2, Blocks, [64, BasicBlock, 2, False]] # 7
- [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 8-P3
- [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 9-P4
- [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 10-P5
- [2, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 11-P1
- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 12
- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 13
- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 14-P2
- [-1, 2, Blocks, [64, BasicBlock, 2, False]] # 15
- [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 16-P3
- [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 17-P4
- [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 18-P5
- [[8, 16], 1, MultiScaleGatedAttn, []] # 19 cat backbone P3
- [[9, 17], 1, MultiScaleGatedAttn, []] # 20 cat backbone P4
- [[10, 18], 1, MultiScaleGatedAttn, []] # 21 cat backbone P5
head:
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 22 input_proj.2
- [-1, 1, AIFI, [1024, 8]]
- [-1, 1, Conv, [256, 1, 1]] # 24, Y5, lateral_convs.0
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 25
- [20, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 26 input_proj.1
- [[-2, -1], 1, Concat, [1]]
- [-1, 3, RepC3, [256, 0.5]] # 28, fpn_blocks.0
- [-1, 1, Conv, [256, 1, 1]] # 29, Y4, lateral_convs.1
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 30
- [19, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 31 input_proj.0
- [[-2, -1], 1, Concat, [1]] # 32 cat backbone P4
- [-1, 3, RepC3, [256, 0.5]] # X3 (33), fpn_blocks.1
- [-1, 1, Conv, [256, 3, 2]] # 34, downsample_convs.0
- [[-1, 29], 1, Concat, [1]] # 35 cat Y4
- [-1, 3, RepC3, [256, 0.5]] # F4 (36), pan_blocks.0
- [-1, 1, Conv, [256, 3, 2]] # 37, downsample_convs.1
- [[-1, 24], 1, Concat, [1]] # 38 cat Y5
- [-1, 3, RepC3, [256, 0.5]] # F5 (39), pan_blocks.1
- [[33, 36, 39], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]] # Detect(P3, P4, P5)
六、成功运行结果
打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。
**rtdetr-resnet18-mid-MSGA **:
rtdetr-resnet18-mid-MSGA summary: 550 layers, 35,456,745 parameters, 35,456,745 gradients, 99.9 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 960 ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
4 -1 1 9312 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
5 -1 1 18624 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
6 -1 1 0 torch.nn.modules.pooling.MaxPool2d [3, 2, 1]
7 -1 2 152512 ultralytics.nn.AddModules.ResNet.Blocks [64, 64, 2, 'BasicBlock', 2, False]
8 -1 2 526208 ultralytics.nn.AddModules.ResNet.Blocks [64, 128, 2, 'BasicBlock', 3, False]
9 -1 2 2100992 ultralytics.nn.AddModules.ResNet.Blocks [128, 256, 2, 'BasicBlock', 4, False]
10 -1 2 8396288 ultralytics.nn.AddModules.ResNet.Blocks [256, 512, 2, 'BasicBlock', 5, False]
11 2 1 960 ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
12 -1 1 9312 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
13 -1 1 18624 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
14 -1 1 0 torch.nn.modules.pooling.MaxPool2d [3, 2, 1]
15 -1 2 152512 ultralytics.nn.AddModules.ResNet.Blocks [64, 64, 2, 'BasicBlock', 2, False]
16 -1 2 526208 ultralytics.nn.AddModules.ResNet.Blocks [64, 128, 2, 'BasicBlock', 3, False]
17 -1 2 2100992 ultralytics.nn.AddModules.ResNet.Blocks [128, 256, 2, 'BasicBlock', 4, False]
18 -1 2 8396288 ultralytics.nn.AddModules.ResNet.Blocks [256, 512, 2, 'BasicBlock', 5, False]
19 [8, 16] 1 200199 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[128, 128]]
20 [9, 17] 1 793607 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[256, 256]]
21 [10, 18] 1 3160071 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[512, 512]]
22 -1 1 131584 ultralytics.nn.modules.conv.Conv [512, 256, 1, 1, None, 1, 1, False]
23 -1 1 789760 ultralytics.nn.modules.transformer.AIFI [256, 1024, 8]
24 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
25 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
26 20 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1, None, 1, 1, False]
27 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
28 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
29 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
30 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
31 19 1 33280 ultralytics.nn.modules.conv.Conv [128, 256, 1, 1, None, 1, 1, False]
32 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
33 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
34 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
35 [-1, 29] 1 0 ultralytics.nn.modules.conv.Concat [1]
36 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
37 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
38 [-1, 24] 1 0 ultralytics.nn.modules.conv.Concat [1]
39 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
40 [33, 36, 39] 1 3927956 ultralytics.nn.modules.head.RTDETRDecoder [9, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-resnet18-mid-MSGA summary: 550 layers, 35,456,745 parameters, 35,456,745 gradients, 99.9 GFLOPs