【YOLOv11多模态融合改进】| BMVC 2024 MASAG 模块(多尺度自适应空间注意门):动态感受野与空间注意力增强多尺度目标检测精度
一、本文介绍
本文记录的是利用 M S A 2 N e t MSA^{2}Net MS A 2 N e t 中的 MASAG 模块改进 YOLOv11 的多模态融合部分 。
MASAG
(Multi - Scale Adaptive Spatial Attention Gate) 模块
通过动态调制空间注意力权重与多尺度感受野
,实现了
对跨层级特征图中局部细节与全局语义的智能聚合
。将其应用于
YOLOv11
的改进过程中,针对目标检测中
浅层边界特征
与
深层语义信息
的互补性需求,
增强对多尺度目标的特征表达能力
,提升复杂场景下的检测精度与边界定位准确性。
二、MASAG 介绍
MSA^2Net: Multi-scale Adaptive Attention-guided Network for Medical Image Segmentation
2.1 设计出发点
- 传统基于卷积神经网络(CNNs)的方法 难以有效捕捉长距离依赖,且受限于静态感受野 ;
- 基于Transformer的方法在 捕捉局部表示和上下文方面存在不足 ,计算效率也较低。
- 虽然已有多种改进方法,但大多只专注于改进编码器或解码器, 未对原始跳跃连接设计进行有效改变 。
因此,需要一种新的模块来优化跳跃连接,更好地融合局部和全局特征,提高医学图像分割的精度。
2.2 结构
- 多尺度特征融合 :该模块融合局部上下文提取和全局上下文提取,通过深度卷积、扩张卷积拓宽编码器高分辨率空间细节(X)的空间范围,利用通道池化获取解码器语义信息(G)中的广泛上下文信息,将两者整合形成综合特征图,公式为: U = C o n v 1 × 1 ( D W − D ( D W ( X ) ) ) + C o n v 1 × 1 ( [ P A v g ( G ) ; P M a x ( G ) ] ) U = Conv_{1×1}(DW - D(DW(X))) + Conv_{1×1}([P_{Avg}(G);P_{Max}(G)]) U = C o n v 1 × 1 ( D W − D ( D W ( X ))) + C o n v 1 × 1 ([ P A vg ( G ) ; P M a x ( G )])
- 空间选择 :将融合后的特征图(U)投影到两个通道,通过通道软max计算空间选择性权重,得到空间选择后的 X ′ X' X ′ 和 G ′ G' G ′ ,并使用两个残差连接优化梯度流和特征利用,公式包括: S W i = S ( C o n v 1 × 1 ( U ) ) ∀ i ∈ [ 1 , 2 ] SW_{i}=S(Conv_{1 × 1}(U)) \forall i \in[1,2] S W i = S ( C o n v 1 × 1 ( U )) ∀ i ∈ [ 1 , 2 ] X ′ = S W 1 ⊗ X + X , G ′ = S W 2 ⊗ G + G X'=SW_{1} \otimes X+X, G'=SW_{2} \otimes G+G X ′ = S W 1 ⊗ X + X , G ′ = S W 2 ⊗ G + G
- 空间交互和交叉调制 :此部分对 X ′ X' X ′ 和 G ′ G' G ′ 进行动态增强,使 X ′ X' X ′ 融入 G ′ G' G ′ 的全局上下文, G ′ G' G ′ 融入 X ′ X' X ′ 的详细上下文,最后将两者相乘融合,公式为: X ′ ′ = X ′ ⊗ σ ( G ′ ) , G ′ ′ = G ′ ⊗ σ ( X ′ ) X''=X' \otimes \sigma\left(G'\right), G''=G' \otimes \sigma\left(X'\right) X ′′ = X ′ ⊗ σ ( G ′ ) , G ′′ = G ′ ⊗ σ ( X ′ ) U ′ = X ′ ′ ⊗ G ′ ′ U'=X'' \otimes G'' U ′ = X ′′ ⊗ G ′′
- 重新校准 :对空间交互和交叉调制得到的融合特征图进行逐点卷积和sigmoid激活,生成注意力图,用其重新校准编码器的初始输入X,公式为: X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ′ ) ) ⊗ X ) X=Conv_{1 × 1}\left(\sigma\left(Conv_{1 × 1}\left(U'\right)\right) \otimes X\right) X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ′ ) ) ⊗ X )
2.3 优势
- 动态调整感受野 :能够根据输入动态调整感受野,突出空间相关特征,抑制无关背景细节,提高对不同大小、形状和结构的器官及病变的分割准确性 。
- 有效融合特征 :通过多尺度融合、空间交互等操作,实现局部和全局特征的有效融合,使模型在分割时能综合考虑详细信息和整体结构,在复杂医学图像中精准分割目标结构。
- 提升模型性能 :在多个医学数据集(如Synapse数据集和ISIC 2018数据集)上的实验表明,包含MASAG模块的(MSA {2}Net)方法在多项指标上优于或与当前最优方法相当,如在Synapse数据集上,(MSA {2}Net)的Dice相似系数(DSC)比2D版本的D - LKA高0.48%,比DAE - Former高2.21% ,充分证明了该模块对提升模型性能的有效性。
论文: https://arxiv.org/abs/2407.21640
源码: https://github.com/xmindflow/MSA-2Net
三、MASAG 的实现代码
MASAG 模块
的实现代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
def num_trainable_params(model):
nums = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
return nums
class GlobalExtraction(nn.Module):
def __init__(self,dim = None):
super().__init__()
self.avgpool = self.globalavgchannelpool
self.maxpool = self.globalmaxchannelpool
self.proj = nn.Sequential(
nn.Conv2d(2, 1, 1,1),
nn.BatchNorm2d(1)
)
def globalavgchannelpool(self, x):
x = x.mean(1, keepdim = True)
return x
def globalmaxchannelpool(self, x):
x = x.max(dim = 1, keepdim=True)[0]
return x
def forward(self, x):
x_ = x.clone()
x = self.avgpool(x)
x2 = self.maxpool(x_)
cat = torch.cat((x,x2), dim = 1)
proj = self.proj(cat)
return proj
class ContextExtraction(nn.Module):
def __init__(self, dim, reduction = None):
super().__init__()
self.reduction = 1 if reduction == None else 2
self.dconv = self.DepthWiseConv2dx2(dim)
self.proj = self.Proj(dim)
def DepthWiseConv2dx2(self, dim):
dconv = nn.Sequential(
nn.Conv2d(in_channels = dim,
out_channels = dim,
kernel_size = 3,
padding = 1,
groups = dim),
nn.BatchNorm2d(num_features = dim),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels = dim,
out_channels = dim,
kernel_size = 3,
padding = 2,
dilation = 2),
nn.BatchNorm2d(num_features = dim),
nn.ReLU(inplace = True)
)
return dconv
def Proj(self, dim):
proj = nn.Sequential(
nn.Conv2d(in_channels = dim,
out_channels = dim //self.reduction,
kernel_size = 1
),
nn.BatchNorm2d(num_features = dim//self.reduction)
)
return proj
def forward(self,x):
x = self.dconv(x)
x = self.proj(x)
return x
class MultiscaleFusion(nn.Module):
def __init__(self, dim):
super().__init__()
self.local= ContextExtraction(dim)
self.global_ = GlobalExtraction()
self.bn = nn.BatchNorm2d(num_features=dim)
def forward(self, x, g,):
x = self.local(x)
g = self.global_(g)
fuse = self.bn(x + g)
return fuse
class MultiScaleGatedAttn(nn.Module):
# Version 1
def __init__(self, dims):
super().__init__()
dim = min(dims)
if dims[0] != dims[1]:
self.conv1 = Conv(dims[0], dim)
self.conv2 = Conv(dims[1], dim)
self.multi = MultiscaleFusion(dim)
self.selection = nn.Conv2d(dim, 2,1)
self.proj = nn.Conv2d(dim, dim,1)
self.bn = nn.BatchNorm2d(dim)
self.bn_2 = nn.BatchNorm2d(dim)
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels=dim, out_channels=dim,
kernel_size=1, stride=1))
def forward(self, inputs):
x, g = inputs
if x.size(1) != g.size(1):
x = self.conv1(x)
g = self.conv2(g)
x_ = x.clone()
g_ = g.clone()
#stacked = torch.stack((x_, g_), dim = 1) # B, 2, C, H, W
multi = self.multi(x, g) # B, C, H, W
### Option 2 ###
multi = self.selection(multi) # B, num_path, H, W
attention_weights = F.softmax(multi, dim=1) # Shape: [B, 2, H, W]
#attention_weights = torch.sigmoid(multi)
A, B = attention_weights.split(1, dim=1) # Each will have shape [B, 1, H, W]
x_att = A.expand_as(x_) * x_ # Using expand_as to match the channel dimensions
g_att = B.expand_as(g_) * g_
x_att = x_att + x_
g_att = g_att + g_
## Bidirectional Interaction
x_sig = torch.sigmoid(x_att)
g_att_2 = x_sig * g_att
g_sig = torch.sigmoid(g_att)
x_att_2 = g_sig * x_att
interaction = x_att_2 * g_att_2
projected = torch.sigmoid(self.bn(self.proj(interaction)))
weighted = projected * x_
y = self.conv_block(weighted)
#y = self.bn_2(weighted + y)
y = self.bn_2(y)
return y
四、添加步骤
4.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
MultiScaleGatedAttn.py
,将
第三节
中的代码粘贴到此处
4.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .MultiScaleGatedAttn import *
4.3 修改三
在
ultralytics/nn/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
然后,在
parse_model函数
中添加如下代码:
elif m in {MultiScaleGatedAttn}:
c1 = [ch[x] for x in f]
c2 = min(c1)
args = [c1]
五、yaml模型文件
5.1 中期融合⭐
📌 此模型的修方法是将MASAG模块应用到YOLOv11的中期融合中。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
# YOLO11n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
- [1, 1, Conv, [64, 3, 2]] # 3-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 8-P4/16
- [-1, 2, C3k2, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32
- [-1, 2, C3k2, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 12
- [-1, 2, C2PSA, [1024]] # 13
- [2, 1, Conv, [64, 3, 2]] # 14-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 15-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 17-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 19-P4/16
- [-1, 2, C3k2, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 21-P5/32
- [-1, 2, C3k2, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 23
- [-1, 2, C2PSA, [1024]] # 24
- [[7, 18], 1, MultiScaleGatedAttn, []] # 25 cat backbone P3
- [[9, 20], 1, MultiScaleGatedAttn, []] # 26 cat backbone P4
- [[13, 24], 1, MultiScaleGatedAttn, []] # 27 cat backbone P5
# YOLO11n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 26], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C3k2, [512, False]] # 30
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 25], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C3k2, [256, False]] # 33 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 30], 1, Concat, [1]] # cat head P4
- [-1, 2, C3k2, [512, False]] # 36 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 27], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 39 (P5/32-large)
- [[33, 36, 39], 1, Detect, [nc]] # Detect(P3, P4, P5)
六、成功运行结果
打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。
**YOLOv11-mid-MSGA **:
YOLO11-mid-MSGA summary: 543 layers, 5,149,512 parameters, 5,149,496 gradients, 13.4 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
4 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
5 -1 1 6640 ultralytics.nn.modules.block.C3k2 [32, 64, 1, False, 0.25]
6 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
7 -1 1 26080 ultralytics.nn.modules.block.C3k2 [64, 128, 1, False, 0.25]
8 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
9 -1 1 87040 ultralytics.nn.modules.block.C3k2 [128, 128, 1, True]
10 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
11 -1 1 346112 ultralytics.nn.modules.block.C3k2 [256, 256, 1, True]
12 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
13 -1 1 249728 ultralytics.nn.modules.block.C2PSA [256, 256, 1]
14 2 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
15 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
16 -1 1 6640 ultralytics.nn.modules.block.C3k2 [32, 64, 1, False, 0.25]
17 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
18 -1 1 26080 ultralytics.nn.modules.block.C3k2 [64, 128, 1, False, 0.25]
19 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
20 -1 1 87040 ultralytics.nn.modules.block.C3k2 [128, 128, 1, True]
21 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
22 -1 1 346112 ultralytics.nn.modules.block.C3k2 [256, 256, 1, True]
23 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
24 -1 1 249728 ultralytics.nn.modules.block.C2PSA [256, 256, 1]
25 [7, 18] 1 200199 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[128, 128]]
26 [9, 20] 1 200199 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[128, 128]]
27 [13, 24] 1 793607 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[256, 256]]
28 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
29 [-1, 26] 1 0 ultralytics.nn.modules.conv.Concat [1]
30 -1 1 111296 ultralytics.nn.modules.block.C3k2 [384, 128, 1, False]
31 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
32 [-1, 25] 1 0 ultralytics.nn.modules.conv.Concat [1]
33 -1 1 32096 ultralytics.nn.modules.block.C3k2 [256, 64, 1, False]
34 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
35 [-1, 30] 1 0 ultralytics.nn.modules.conv.Concat [1]
36 -1 1 86720 ultralytics.nn.modules.block.C3k2 [192, 128, 1, False]
37 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
38 [-1, 27] 1 0 ultralytics.nn.modules.conv.Concat [1]
39 -1 1 378880 ultralytics.nn.modules.block.C3k2 [384, 256, 1, True]
40 [33, 36, 39] 1 430867 ultralytics.nn.modules.head.Detect [1, [64, 128, 256]]
YOLO11-mid-MSGA summary: 543 layers, 5,149,512 parameters, 5,149,496 gradients, 13.4 GFLOPs