【YOLOv13多模态融合改进】| BMVC 2024 MASAG 模块(多尺度自适应空间注意门):动态感受野与空间注意力增强多尺度融合精度
一、本文介绍
本文记录的是利用 M S A 2 N e t MSA^{2}Net MS A 2 N e t 中的 MASAG 模块改进 YOLOv13 的多模态融合部分 。
MASAG
(Multi - Scale Adaptive Spatial Attention Gate) 模块
通过动态调制空间注意力权重与多尺度感受野
,实现了
对跨层级特征图中局部细节与全局语义的智能聚合
。将其应用于
YOLOv13
的改进过程中,针对目标检测中
浅层边界特征
与
深层语义信息
的互补性需求,
增强对多尺度目标的特征表达能力
,提升复杂场景下的检测精度与边界定位准确性。
二、MASAG 介绍
MSA^2Net: Multi-scale Adaptive Attention-guided Network for Medical Image Segmentation
2.1 设计出发点
- 传统基于卷积神经网络(CNNs)的方法 难以有效捕捉长距离依赖,且受限于静态感受野 ;
- 基于Transformer的方法在 捕捉局部表示和上下文方面存在不足 ,计算效率也较低。
- 虽然已有多种改进方法,但大多只专注于改进编码器或解码器, 未对原始跳跃连接设计进行有效改变 。
因此,需要一种新的模块来优化跳跃连接,更好地融合局部和全局特征,提高医学图像分割的精度。
2.2 结构
- 多尺度特征融合 :该模块融合局部上下文提取和全局上下文提取,通过深度卷积、扩张卷积拓宽编码器高分辨率空间细节(X)的空间范围,利用通道池化获取解码器语义信息(G)中的广泛上下文信息,将两者整合形成综合特征图,公式为: U = C o n v 1 × 1 ( D W − D ( D W ( X ) ) ) + C o n v 1 × 1 ( [ P A v g ( G ) ; P M a x ( G ) ] ) U = Conv_{1×1}(DW - D(DW(X))) + Conv_{1×1}([P_{Avg}(G);P_{Max}(G)]) U = C o n v 1 × 1 ( D W − D ( D W ( X ))) + C o n v 1 × 1 ([ P A vg ( G ) ; P M a x ( G )])
- 空间选择 :将融合后的特征图(U)投影到两个通道,通过通道软max计算空间选择性权重,得到空间选择后的 X ′ X' X ′ 和 G ′ G' G ′ ,并使用两个残差连接优化梯度流和特征利用,公式包括: S W i = S ( C o n v 1 × 1 ( U ) ) ∀ i ∈ [ 1 , 2 ] SW_{i}=S(Conv_{1 × 1}(U)) \forall i \in[1,2] S W i = S ( C o n v 1 × 1 ( U )) ∀ i ∈ [ 1 , 2 ] X ′ = S W 1 ⊗ X + X , G ′ = S W 2 ⊗ G + G X'=SW_{1} \otimes X+X, G'=SW_{2} \otimes G+G X ′ = S W 1 ⊗ X + X , G ′ = S W 2 ⊗ G + G
- 空间交互和交叉调制 :此部分对 X ′ X' X ′ 和 G ′ G' G ′ 进行动态增强,使 X ′ X' X ′ 融入 G ′ G' G ′ 的全局上下文, G ′ G' G ′ 融入 X ′ X' X ′ 的详细上下文,最后将两者相乘融合,公式为: X ′ ′ = X ′ ⊗ σ ( G ′ ) , G ′ ′ = G ′ ⊗ σ ( X ′ ) X''=X' \otimes \sigma\left(G'\right), G''=G' \otimes \sigma\left(X'\right) X ′′ = X ′ ⊗ σ ( G ′ ) , G ′′ = G ′ ⊗ σ ( X ′ ) U ′ = X ′ ′ ⊗ G ′ ′ U'=X'' \otimes G'' U ′ = X ′′ ⊗ G ′′
- 重新校准 :对空间交互和交叉调制得到的融合特征图进行逐点卷积和sigmoid激活,生成注意力图,用其重新校准编码器的初始输入X,公式为: X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ′ ) ) ⊗ X ) X=Conv_{1 × 1}\left(\sigma\left(Conv_{1 × 1}\left(U'\right)\right) \otimes X\right) X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ′ ) ) ⊗ X )
2.3 优势
- 动态调整感受野 :能够根据输入动态调整感受野,突出空间相关特征,抑制无关背景细节,提高对不同大小、形状和结构的器官及病变的分割准确性 。
- 有效融合特征 :通过多尺度融合、空间交互等操作,实现局部和全局特征的有效融合,使模型在分割时能综合考虑详细信息和整体结构,在复杂医学图像中精准分割目标结构。
- 提升模型性能 :在多个医学数据集(如Synapse数据集和ISIC 2018数据集)上的实验表明,包含MASAG模块的(MSA {2}Net)方法在多项指标上优于或与当前最优方法相当,如在Synapse数据集上,(MSA {2}Net)的Dice相似系数(DSC)比2D版本的D - LKA高0.48%,比DAE - Former高2.21% ,充分证明了该模块对提升模型性能的有效性。
论文: https://arxiv.org/abs/2407.21640
源码: https://github.com/xmindflow/MSA-2Net
三、MASAG 的实现代码
MASAG 模块
的实现代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
def num_trainable_params(model):
nums = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
return nums
class GlobalExtraction(nn.Module):
def __init__(self,dim = None):
super().__init__()
self.avgpool = self.globalavgchannelpool
self.maxpool = self.globalmaxchannelpool
self.proj = nn.Sequential(
nn.Conv2d(2, 1, 1,1),
nn.BatchNorm2d(1)
)
def globalavgchannelpool(self, x):
x = x.mean(1, keepdim = True)
return x
def globalmaxchannelpool(self, x):
x = x.max(dim = 1, keepdim=True)[0]
return x
def forward(self, x):
x_ = x.clone()
x = self.avgpool(x)
x2 = self.maxpool(x_)
cat = torch.cat((x,x2), dim = 1)
proj = self.proj(cat)
return proj
class ContextExtraction(nn.Module):
def __init__(self, dim, reduction = None):
super().__init__()
self.reduction = 1 if reduction == None else 2
self.dconv = self.DepthWiseConv2dx2(dim)
self.proj = self.Proj(dim)
def DepthWiseConv2dx2(self, dim):
dconv = nn.Sequential(
nn.Conv2d(in_channels = dim,
out_channels = dim,
kernel_size = 3,
padding = 1,
groups = dim),
nn.BatchNorm2d(num_features = dim),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels = dim,
out_channels = dim,
kernel_size = 3,
padding = 2,
dilation = 2),
nn.BatchNorm2d(num_features = dim),
nn.ReLU(inplace = True)
)
return dconv
def Proj(self, dim):
proj = nn.Sequential(
nn.Conv2d(in_channels = dim,
out_channels = dim //self.reduction,
kernel_size = 1
),
nn.BatchNorm2d(num_features = dim//self.reduction)
)
return proj
def forward(self,x):
x = self.dconv(x)
x = self.proj(x)
return x
class MultiscaleFusion(nn.Module):
def __init__(self, dim):
super().__init__()
self.local= ContextExtraction(dim)
self.global_ = GlobalExtraction()
self.bn = nn.BatchNorm2d(num_features=dim)
def forward(self, x, g,):
x = self.local(x)
g = self.global_(g)
fuse = self.bn(x + g)
return fuse
class MultiScaleGatedAttn(nn.Module):
# Version 1
def __init__(self, dims):
super().__init__()
dim = min(dims)
if dims[0] != dims[1]:
self.conv1 = Conv(dims[0], dim)
self.conv2 = Conv(dims[1], dim)
self.multi = MultiscaleFusion(dim)
self.selection = nn.Conv2d(dim, 2,1)
self.proj = nn.Conv2d(dim, dim,1)
self.bn = nn.BatchNorm2d(dim)
self.bn_2 = nn.BatchNorm2d(dim)
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels=dim, out_channels=dim,
kernel_size=1, stride=1))
def forward(self, inputs):
x, g = inputs
if x.size(1) != g.size(1):
x = self.conv1(x)
g = self.conv2(g)
x_ = x.clone()
g_ = g.clone()
#stacked = torch.stack((x_, g_), dim = 1) # B, 2, C, H, W
multi = self.multi(x, g) # B, C, H, W
### Option 2 ###
multi = self.selection(multi) # B, num_path, H, W
attention_weights = F.softmax(multi, dim=1) # Shape: [B, 2, H, W]
#attention_weights = torch.sigmoid(multi)
A, B = attention_weights.split(1, dim=1) # Each will have shape [B, 1, H, W]
x_att = A.expand_as(x_) * x_ # Using expand_as to match the channel dimensions
g_att = B.expand_as(g_) * g_
x_att = x_att + x_
g_att = g_att + g_
## Bidirectional Interaction
x_sig = torch.sigmoid(x_att)
g_att_2 = x_sig * g_att
g_sig = torch.sigmoid(g_att)
x_att_2 = g_sig * x_att
interaction = x_att_2 * g_att_2
projected = torch.sigmoid(self.bn(self.proj(interaction)))
weighted = projected * x_
y = self.conv_block(weighted)
#y = self.bn_2(weighted + y)
y = self.bn_2(y)
return y
四、添加步骤
4.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
MultiScaleGatedAttn.py
,将
第三节
中的代码粘贴到此处
4.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .MultiScaleGatedAttn import *
4.3 修改三
在
ultralytics/nn/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
然后,在
parse_model函数
中添加如下代码:
elif m in {MultiScaleGatedAttn}:
c1 = [ch[x] for x in f]
c2 = min(c1)
args = [c1]
五、yaml模型文件
5.1 中期融合⭐
📌 此模型的修方法是将MASAG模块应用到YOLOv13的中期融合中。
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov13n.yaml' will call yolov13.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # Nano
s: [0.50, 0.50, 1024] # Small
l: [1.00, 1.00, 512] # Large
x: [1.00, 1.50, 512] # Extra Large
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
- [1, 1, Conv, [64, 3, 2]] # 3-P1/2
- [-1, 1, Conv, [128, 3, 2, 1, 2]] # 4-P2/4
- [-1, 2, DSC3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2, 1, 4]] # 6-P3/8
- [-1, 2, DSC3k2, [512, True]]
- [-1, 1, DSConv, [512, 3, 2]] # 8-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, DSConv, [1024, 3, 2]] # 10-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 11
- [2, 1, Conv, [64, 3, 2]] # 12-P1/2
- [-1, 1, Conv, [128, 3, 2, 1, 2]] # 13-P2/4
- [-1, 2, DSC3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2, 1, 4]] # 15-P3/8
- [-1, 2, DSC3k2, [512, True]]
- [-1, 1, DSConv, [512, 3, 2]] # 17-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, DSConv, [1024, 3, 2]] # 19-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 20
- [[7, 16], 1, MultiScaleGatedAttn, []] # 21 cat backbone P3
- [[9, 18], 1, MultiScaleGatedAttn, []] # 22 cat backbone P4
- [[11, 20], 1, MultiScaleGatedAttn, []] # 23 cat backbone P5
head:
- [[21, 22, 23], 2, HyperACE, [512, 8, True, True, 0.5, 1, "both"]]
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [ 24, 1, DownsampleConv, []]
- [[22, 24], 1, FullPAD_Tunnel, []] # 27
- [[21, 25], 1, FullPAD_Tunnel, []] # 28
- [[23, 26], 1, FullPAD_Tunnel, []] # 29
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 27], 1, Concat, [1]] # cat backbone P4
- [-1, 2, DSC3k2, [512, True]] # 32
- [[-1, 24], 1, FullPAD_Tunnel, []] # 33
- [32, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 28], 1, Concat, [1]] # cat backbone P3
- [-1, 2, DSC3k2, [256, True]] # 36
- [25, 1, Conv, [256, 1, 1]]
- [[36, 37], 1, FullPAD_Tunnel, []] # 38
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 33], 1, Concat, [1]] # cat head P4
- [-1, 2, DSC3k2, [512, True]] # 41
- [[-1, 24], 1, FullPAD_Tunnel, []]
- [41, 1, Conv, [512, 3, 2]]
- [[-1, 29], 1, Concat, [1]] # cat head P5
- [-1, 2, DSC3k2, [1024,True]] # 45 (P5/32-large)
- [[-1, 26], 1, FullPAD_Tunnel, []]
- [[38, 42, 46], 1, Detect, [nc]] # Detect(P3, P4, P5)
六、成功运行结果
打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。
**YOLOv13-mid-MSGA **:
YOLOv13-mid-MSGA summary: 989 layers, 4,742,503 parameters, 4,742,487 gradients, 13.7 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
4 -1 1 2368 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2, 1, 2]
5 -1 1 5792 ultralytics.nn.modules.block.DSC3k2 [32, 64, 1, False, 0.25]
6 -1 1 9344 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2, 1, 4]
7 -1 1 74368 ultralytics.nn.modules.block.DSC3k2 [64, 128, 1, True]
8 -1 1 17792 ultralytics.nn.modules.conv.DSConv [128, 128, 3, 2]
9 -1 2 180864 ultralytics.nn.AddModules.A2C2f.A2C2f [128, 128, 2, True, 4]
10 -1 1 34432 ultralytics.nn.modules.conv.DSConv [128, 256, 3, 2]
11 -1 2 689408 ultralytics.nn.AddModules.A2C2f.A2C2f [256, 256, 2, True, 1]
12 2 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
13 -1 1 2368 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2, 1, 2]
14 -1 1 5792 ultralytics.nn.modules.block.DSC3k2 [32, 64, 1, False, 0.25]
15 -1 1 9344 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2, 1, 4]
16 -1 1 74368 ultralytics.nn.modules.block.DSC3k2 [64, 128, 1, True]
17 -1 1 17792 ultralytics.nn.modules.conv.DSConv [128, 128, 3, 2]
18 -1 2 180864 ultralytics.nn.AddModules.A2C2f.A2C2f [128, 128, 2, True, 4]
19 -1 1 34432 ultralytics.nn.modules.conv.DSConv [128, 256, 3, 2]
20 -1 2 689408 ultralytics.nn.AddModules.A2C2f.A2C2f [256, 256, 2, True, 1]
21 [7, 16] 1 200199 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[128, 128]]
22 [9, 18] 1 200199 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[128, 128]]
23 [11, 20] 1 793607 ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[256, 256]]
24 [21, 22, 23] 1 273536 ultralytics.nn.modules.block.HyperACE [128, 128, 1, 4, True, True, 0.5, 1, 'both']
25 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
26 24 1 33280 ultralytics.nn.modules.block.DownsampleConv [128]
27 [22, 24] 1 1 ultralytics.nn.modules.block.FullPAD_Tunnel []
28 [21, 25] 1 1 ultralytics.nn.modules.block.FullPAD_Tunnel []
29 [23, 26] 1 1 ultralytics.nn.modules.block.FullPAD_Tunnel []
30 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
31 [-1, 27] 1 0 ultralytics.nn.modules.conv.Concat [1]
32 -1 1 115328 ultralytics.nn.modules.block.DSC3k2 [384, 128, 1, True]
33 [-1, 24] 1 1 ultralytics.nn.modules.block.FullPAD_Tunnel []
34 32 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
35 [-1, 28] 1 0 ultralytics.nn.modules.conv.Concat [1]
36 -1 1 35136 ultralytics.nn.modules.block.DSC3k2 [256, 64, 1, True]
37 25 1 8320 ultralytics.nn.modules.conv.Conv [128, 64, 1, 1]
38 [36, 37] 1 1 ultralytics.nn.modules.block.FullPAD_Tunnel []
39 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
40 [-1, 33] 1 0 ultralytics.nn.modules.conv.Concat [1]
41 -1 1 90752 ultralytics.nn.modules.block.DSC3k2 [192, 128, 1, True]
42 [-1, 24] 1 1 ultralytics.nn.modules.block.FullPAD_Tunnel []
43 41 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
44 [-1, 29] 1 0 ultralytics.nn.modules.conv.Concat [1]
45 -1 1 345344 ultralytics.nn.modules.block.DSC3k2 [384, 256, 1, True]
46 [-1, 26] 1 1 ultralytics.nn.modules.block.FullPAD_Tunnel []
47 [38, 42, 46] 1 432427 ultralytics.nn.modules.head.Detect [9, [64, 128, 256]]
YOLOv13-mid-MSGA summary: 989 layers, 4,742,503 parameters, 4,742,487 gradients, 13.7 GFLOPs