学习资源站

【YOLOv12多模态融合改进】_BMVC2024MASAG模块(多尺度自适应空间注意门)-动态感受野与空间注意力增强多尺度融合精度-

【YOLOv12多模态融合改进】| BMVC 2024 MASAG 模块(多尺度自适应空间注意门):动态感受野与空间注意力增强多尺度融合精度

一、本文介绍

本文记录的是利用 M S A 2 N e t MSA^{2}Net MS A 2 N e t 中的 MASAG 模块改进 YOLOv12 的多模态融合部分

MASAG (Multi - Scale Adaptive Spatial Attention Gate) 模块 通过动态调制空间注意力权重与多尺度感受野 ,实现了 对跨层级特征图中局部细节与全局语义的智能聚合 。将其应用于 YOLOv12 的改进过程中,针对目标检测中 浅层边界特征 深层语义信息 的互补性需求, 增强对多尺度目标的特征表达能力 ,提升复杂场景下的检测精度与边界定位准确性。



二、MASAG 介绍

MSA^2Net: Multi-scale Adaptive Attention-guided Network for Medical Image Segmentation

2.1 设计出发点

  1. 传统基于卷积神经网络(CNNs)的方法 难以有效捕捉长距离依赖,且受限于静态感受野
  2. 基于Transformer的方法在 捕捉局部表示和上下文方面存在不足 ,计算效率也较低。
  3. 虽然已有多种改进方法,但大多只专注于改进编码器或解码器, 未对原始跳跃连接设计进行有效改变

因此,需要一种新的模块来优化跳跃连接,更好地融合局部和全局特征,提高医学图像分割的精度。

2.2 结构

  • 多尺度特征融合 :该模块融合局部上下文提取和全局上下文提取,通过深度卷积、扩张卷积拓宽编码器高分辨率空间细节(X)的空间范围,利用通道池化获取解码器语义信息(G)中的广泛上下文信息,将两者整合形成综合特征图,公式为: U = C o n v 1 × 1 ( D W − D ( D W ( X ) ) ) + C o n v 1 × 1 ( [ P A v g ( G ) ; P M a x ( G ) ] ) U = Conv_{1×1}(DW - D(DW(X))) + Conv_{1×1}([P_{Avg}(G);P_{Max}(G)]) U = C o n v 1 × 1 ( D W D ( D W ( X ))) + C o n v 1 × 1 ([ P A vg ( G ) ; P M a x ( G )])
  • 空间选择 :将融合后的特征图(U)投影到两个通道,通过通道软max计算空间选择性权重,得到空间选择后的 X ′ X' X G ′ G' G ,并使用两个残差连接优化梯度流和特征利用,公式包括: S W i = S ( C o n v 1 × 1 ( U ) ) ∀ i ∈ [ 1 , 2 ] SW_{i}=S(Conv_{1 × 1}(U)) \forall i \in[1,2] S W i = S ( C o n v 1 × 1 ( U )) i [ 1 , 2 ] X ′ = S W 1 ⊗ X + X , G ′ = S W 2 ⊗ G + G X'=SW_{1} \otimes X+X, G'=SW_{2} \otimes G+G X = S W 1 X + X , G = S W 2 G + G
  • 空间交互和交叉调制 :此部分对 X ′ X' X G ′ G' G 进行动态增强,使 X ′ X' X 融入 G ′ G' G 的全局上下文, G ′ G' G 融入 X ′ X' X 的详细上下文,最后将两者相乘融合,公式为: X ′ ′ = X ′ ⊗ σ ( G ′ ) , G ′ ′ = G ′ ⊗ σ ( X ′ ) X''=X' \otimes \sigma\left(G'\right), G''=G' \otimes \sigma\left(X'\right) X ′′ = X σ ( G ) , G ′′ = G σ ( X ) U ′ = X ′ ′ ⊗ G ′ ′ U'=X'' \otimes G'' U = X ′′ G ′′
  • 重新校准 :对空间交互和交叉调制得到的融合特征图进行逐点卷积和sigmoid激活,生成注意力图,用其重新校准编码器的初始输入X,公式为: X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ′ ) ) ⊗ X ) X=Conv_{1 × 1}\left(\sigma\left(Conv_{1 × 1}\left(U'\right)\right) \otimes X\right) X = C o n v 1 × 1 ( σ ( C o n v 1 × 1 ( U ) ) X )

在这里插入图片描述

2.3 优势

  • 动态调整感受野 :能够根据输入动态调整感受野,突出空间相关特征,抑制无关背景细节,提高对不同大小、形状和结构的器官及病变的分割准确性 。
  • 有效融合特征 :通过多尺度融合、空间交互等操作,实现局部和全局特征的有效融合,使模型在分割时能综合考虑详细信息和整体结构,在复杂医学图像中精准分割目标结构。
  • 提升模型性能 :在多个医学数据集(如Synapse数据集和ISIC 2018数据集)上的实验表明,包含MASAG模块的(MSA {2}Net)方法在多项指标上优于或与当前最优方法相当,如在Synapse数据集上,(MSA {2}Net)的Dice相似系数(DSC)比2D版本的D - LKA高0.48%,比DAE - Former高2.21% ,充分证明了该模块对提升模型性能的有效性。

论文: https://arxiv.org/abs/2407.21640
源码: https://github.com/xmindflow/MSA-2Net

三、MASAG 的实现代码

MASAG 模块 的实现代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))

def num_trainable_params(model):
    nums = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
    return nums

class GlobalExtraction(nn.Module):
  def __init__(self,dim = None):
    super().__init__()
    self.avgpool = self.globalavgchannelpool
    self.maxpool = self.globalmaxchannelpool
    self.proj = nn.Sequential(
        nn.Conv2d(2, 1, 1,1),
        nn.BatchNorm2d(1)
    )
  def globalavgchannelpool(self, x):
    x = x.mean(1, keepdim = True)
    return x

  def globalmaxchannelpool(self, x):
    x = x.max(dim = 1, keepdim=True)[0]
    return x

  def forward(self, x):
    x_ = x.clone()
    x = self.avgpool(x)
    x2 = self.maxpool(x_)

    cat = torch.cat((x,x2), dim = 1)

    proj = self.proj(cat)
    return proj

class ContextExtraction(nn.Module):
  def __init__(self, dim, reduction = None):
    super().__init__()
    self.reduction = 1 if reduction == None else 2

    self.dconv = self.DepthWiseConv2dx2(dim)
    self.proj = self.Proj(dim)

  def DepthWiseConv2dx2(self, dim):
    dconv = nn.Sequential(
        nn.Conv2d(in_channels = dim,
              out_channels = dim,
              kernel_size = 3,
              padding = 1,
              groups = dim),
        nn.BatchNorm2d(num_features = dim),
        nn.ReLU(inplace = True),
        nn.Conv2d(in_channels = dim,
              out_channels = dim,
              kernel_size = 3,
              padding = 2,
              dilation = 2),
        nn.BatchNorm2d(num_features = dim),
        nn.ReLU(inplace = True)
    )
    return dconv

  def Proj(self, dim):
    proj = nn.Sequential(
        nn.Conv2d(in_channels = dim,
              out_channels = dim //self.reduction,
              kernel_size = 1
              ),
        nn.BatchNorm2d(num_features = dim//self.reduction)
    )
    return proj
  def forward(self,x):
    x = self.dconv(x)
    x = self.proj(x)
    return x

class MultiscaleFusion(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.local= ContextExtraction(dim)
    self.global_ = GlobalExtraction()
    self.bn = nn.BatchNorm2d(num_features=dim)

  def forward(self, x, g,):
    x = self.local(x)
    g = self.global_(g)

    fuse = self.bn(x + g)
    return fuse

class MultiScaleGatedAttn(nn.Module):
    # Version 1
  def __init__(self, dims):
    super().__init__()
    dim = min(dims)
    if dims[0] != dims[1]:
        self.conv1 = Conv(dims[0], dim)
        self.conv2 = Conv(dims[1], dim)
    self.multi = MultiscaleFusion(dim)
    self.selection = nn.Conv2d(dim, 2,1)
    self.proj = nn.Conv2d(dim, dim,1)
    self.bn = nn.BatchNorm2d(dim)
    self.bn_2 = nn.BatchNorm2d(dim)
    self.conv_block = nn.Sequential(
        nn.Conv2d(in_channels=dim, out_channels=dim,
                  kernel_size=1, stride=1))

  def forward(self, inputs):
    x, g = inputs
    if x.size(1) != g.size(1):
        x = self.conv1(x)
        g = self.conv2(g)
    x_ = x.clone()
    g_ = g.clone()

    #stacked = torch.stack((x_, g_), dim = 1) # B, 2, C, H, W

    multi = self.multi(x, g) # B, C, H, W

    ### Option 2 ###
    multi = self.selection(multi) # B, num_path, H, W

    attention_weights = F.softmax(multi, dim=1)  # Shape: [B, 2, H, W]
    #attention_weights = torch.sigmoid(multi)
    A, B = attention_weights.split(1, dim=1)  # Each will have shape [B, 1, H, W]

    x_att = A.expand_as(x_) * x_  # Using expand_as to match the channel dimensions
    g_att = B.expand_as(g_) * g_

    x_att = x_att + x_
    g_att = g_att + g_
    ## Bidirectional Interaction

    x_sig = torch.sigmoid(x_att)
    g_att_2 = x_sig * g_att

    g_sig = torch.sigmoid(g_att)
    x_att_2 = g_sig * x_att

    interaction = x_att_2 * g_att_2

    projected = torch.sigmoid(self.bn(self.proj(interaction)))

    weighted = projected * x_

    y = self.conv_block(weighted)

    #y = self.bn_2(weighted + y)
    y = self.bn_2(y)
    return y

四、添加步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 MultiScaleGatedAttn.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .MultiScaleGatedAttn import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

然后,在 parse_model函数 中添加如下代码:

        elif m in {MultiScaleGatedAttn}:
            c1 = [ch[x] for x in f]
            c2 = min(c1)
            args = [c1]

在这里插入图片描述


五、yaml模型文件

5.1 中期融合⭐

📌 此模型的修方法是将MASAG模块应用到YOLOv12的中期融合中。

# YOLOv12 🚀, AGPL-3.0 license
# YOLOv12 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# CFG file for YOLOv12-turbo

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov12n.yaml' will call yolov12.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 497 layers, 2,553,904 parameters, 2,553,888 gradients, 6.2 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 497 layers, 9,127,424 parameters, 9,127,408 gradients, 19.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 533 layers, 19,670,784 parameters, 19,670,768 gradients, 60.4 GFLOPs
  l: [1.00, 1.00, 512] # summary: 895 layers, 26,506,496 parameters, 26,506,480 gradients, 83.3 GFLOPs
  x: [1.00, 1.50, 512] # summary: 895 layers, 59,414,176 parameters, 59,414,160 gradients, 185.9 GFLOPs

# YOLO12 backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, Conv,  [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv,  [128, 3, 2, 1, 2]] # 4-P2/4
  - [-1, 2, C3k2,  [256, False, 0.25]]
  - [-1, 1, Conv,  [256, 3, 2, 1, 4]] # 6-P3/8
  - [-1, 2, C3k2,  [512, False, 0.25]]
  - [-1, 1, Conv,  [512, 3, 2]] # 8-P4/16
  - [-1, 4, A2C2f, [512, True, 4]]
  - [-1, 1, Conv,  [1024, 3, 2]] # 10-P5/32
  - [-1, 4, A2C2f, [1024, True, 1]] # 11

  - [2, 1, Conv,  [64, 3, 2]] # 12-P1/2
  - [-1, 1, Conv,  [128, 3, 2, 1, 2]] # 13-P2/4
  - [-1, 2, C3k2,  [256, False, 0.25]]
  - [-1, 1, Conv,  [256, 3, 2, 1, 4]] # 15-P3/8
  - [-1, 2, C3k2,  [512, False, 0.25]]
  - [-1, 1, Conv,  [512, 3, 2]] # 17-P4/16
  - [-1, 4, A2C2f, [512, True, 4]]
  - [-1, 1, Conv,  [1024, 3, 2]] # 19-P5/32
  - [-1, 4, A2C2f, [1024, True, 1]] # 20

  - [[7, 16], 1, MultiScaleGatedAttn, []]  # 21 cat backbone P3
  - [[9, 18], 1, MultiScaleGatedAttn, []]  # 22 cat backbone P4
  - [[11, 20], 1, MultiScaleGatedAttn, []]  # 23 cat backbone P5

# YOLO12 head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 22], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, A2C2f, [512, False, -1]] # 26

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 21], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, A2C2f, [256, False, -1]] # 29

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 26], 1, Concat, [1]] # cat head P4
  - [-1, 2, A2C2f, [512, False, -1]] # 32

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 23], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 35 (P5/32-large)

  - [[29, 32, 35], 1, Detect, [nc]] # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。

**YOLOv12-mid-MSGA **:

YOLOv12-mid-MSGA summary: 782 layers, 5,090,600 parameters, 5,090,584 gradients, 12.8 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      2368  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2, 1, 2]
  5                  -1  1      6640  ultralytics.nn.modules.block.C3k2            [32, 64, 1, False, 0.25]
  6                  -1  1      9344  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2, 1, 4]
  7                  -1  1     26080  ultralytics.nn.modules.block.C3k2            [64, 128, 1, False, 0.25]
  8                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
  9                  -1  2    180864  ultralytics.nn.AddModules.A2C2f.A2C2f        [128, 128, 2, True, 4]
 10                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 11                  -1  2    689408  ultralytics.nn.AddModules.A2C2f.A2C2f        [256, 256, 2, True, 1]
 12                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
 13                  -1  1      2368  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2, 1, 2]
 14                  -1  1      6640  ultralytics.nn.modules.block.C3k2            [32, 64, 1, False, 0.25]
 15                  -1  1      9344  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2, 1, 4]
 16                  -1  1     26080  ultralytics.nn.modules.block.C3k2            [64, 128, 1, False, 0.25]
 17                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 18                  -1  2    180864  ultralytics.nn.AddModules.A2C2f.A2C2f        [128, 128, 2, True, 4]
 19                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 20                  -1  2    689408  ultralytics.nn.AddModules.A2C2f.A2C2f        [256, 256, 2, True, 1]
 21             [7, 16]  1    200199  ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[128, 128]]
 22             [9, 18]  1    200199  ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[128, 128]]
 23            [11, 20]  1    793607  ultralytics.nn.AddModules.MultiScaleGatedAttn.MultiScaleGatedAttn[[256, 256]]
 24                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 25            [-1, 22]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 26                  -1  1     86912  ultralytics.nn.AddModules.A2C2f.A2C2f        [384, 128, 1, False, -1]
 27                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 28            [-1, 21]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 29                  -1  1     24000  ultralytics.nn.AddModules.A2C2f.A2C2f        [256, 64, 1, False, -1]
 30                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 31            [-1, 26]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 32                  -1  1     74624  ultralytics.nn.AddModules.A2C2f.A2C2f        [192, 128, 1, False, -1]
 33                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 34            [-1, 23]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 35                  -1  1    378880  ultralytics.nn.modules.block.C3k2            [384, 256, 1, True]
 36        [29, 32, 35]  1    430867  ultralytics.nn.modules.head.Detect           [1, [64, 128, 256]]
YOLOv12-mid-MSGA summary: 782 layers, 5,090,600 parameters, 5,090,584 gradients, 12.8 GFLOPs