学习资源站

【RT-DETR多模态融合改进】_CVPR2024MFM(ModulationFusionModule,调制融合模块)-动态特征加权融合,突出关键特征抑制冗余_动态特征加权融合方法-

【RT-DETR多模态融合改进】| CVPR 2024 MFM(Modulation Fusion Module,调制融合模块):动态特征加权融合,突出关键特征抑制冗余

一、本文介绍

本文记录的是利用 DCMPNet中的 MFM 模块改进 RT-DETR 的多模态融合部分

MFM 模块通过 动态调制特征融合 过程,实现了 对多尺度、跨层级特征的智能聚合 。将其应用于 YOLOv12 的改进过程中,针对目标检测中 边界特征 语义信息 的互补性需求, 缓解网络中浅层细节与深层语义融合不足的问题



二、MFM介绍

Depth Information Assisted Collaborative Mutual Promotion Network for Single Image Dehazing

2.1 设计出发点

在图像去雾网络中,不同层级和类型的特征包含着互补的信息(如浅层的纹理细节与深层的语义结构)。

传统的特征融合方法(如简单相加或拼接)难以动态适应不同特征的重要性差异,可能导致关键信息被稀释或次要信息过度增强。 MFM模块的核心目标是通过动态调整特征融合权重,增强网络对关键特征的敏感度,提升特征表示能力 ,从而优化去雾结果的细节恢复和结构一致性。

2.2 模块结构

  1. 输入特征
    接收来自不同路径的特征图(如编码器的输出特征与解码器的中间特征),例如:

    • F ^ l e g m 1 \hat{F}_{legm}^{1} F ^ l e g m 1 :来自编码器的局部-全局特征融合结果
    • F r c 1 F_{rc}^{1} F rc 1 :经过3×3卷积处理的浅层特征。
  2. 权重生成组件

    • 全局平均池化(GAP) :对输入特征进行全局上下文感知,压缩空间维度以提取全局统计信息。
    • 多层感知机(MLP) :通过非线性变换生成初步的权重向量。
    • Softmax归一化 :将权重向量归一化为概率分布,得到系数矩阵 A r , c 1 A_{r,c}^{1} A r , c 1 ,表示各通道/空间位置特征在融合中的重要性。
  3. 特征调制与融合

    • 特征加权 :利用系数矩阵 A r , c 1 A_{r,c}^{1} A r , c 1 对输入特征进行逐元素相乘( ⊙ \odot ),突出关键特征并抑制冗余信息:
      F ~ r c 1 = A r , c 1 ⊙ F ^ l e g m 1 + A r , c 1 ⊙ F r c 1 \tilde{F}_{rc}^{1} = A_{r,c}^{1} \odot \hat{F}_{legm}^{1} + A_{r,c}^{1} \odot F_{rc}^{1} F ~ rc 1 = A r , c 1 F ^ l e g m 1 + A r , c 1 F rc 1
    • 特征拼接与卷积 :将调制后的特征拼接后,通过卷积层进一步融合跨通道信息,输出最终的融合特征。

在这里插入图片描述

2.3 模块特点

  1. 动态特征加权
    通过自适应学习的权重矩阵,MFM模块能够根据输入内容动态调整不同特征的贡献度。例如,在去雾任务中,针对雾霾残留较多的区域,模块会增强对应的深层语义特征;而对于纹理丰富的细节区域,则强化浅层的局部特征。

  2. 跨层级特征交互
    融合编码器的深层语义特征与解码器的浅层细节特征,缓解传统U型网络中浅层特征在跨层传输时的“稀释”问题,提升图像结构的稳定性和细节的清晰度。

论文: https://openaccess.thecvf.com/content/CVPR2024/papers/Zhang_Depth_Information_Assisted_Collaborative_Mutual_Promotion_Network_for_Single_Image_CVPR_2024_paper.pdf
源码: https://github.com/zhoushen1/DCMPNet

三、MFM的实现代码

MFM模块 的实现代码如下:

import torch
import torch.nn as nn

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))

class MFM(nn.Module):
    def __init__(self, inc, dim, reduction=8):
        super(MFM, self).__init__()

        self.height = len(inc)
        d = max(int(dim/reduction), 4)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, d, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(d, dim * self.height, 1, bias=False)
        )

        self.softmax = nn.Softmax(dim=1)

        self.conv1x1 = nn.ModuleList([])
        for i in inc:
            if i != dim:
                self.conv1x1.append(Conv(i, dim, 1))
            else:
                self.conv1x1.append(nn.Identity())

    def forward(self, in_feats_):
        in_feats = []
        for idx, layer in enumerate(self.conv1x1):
            in_feats.append(layer(in_feats_[idx]))

        B, C, H, W = in_feats[0].shape

        in_feats = torch.cat(in_feats, dim=1)
        in_feats = in_feats.view(B, self.height, C, H, W)

        feats_sum = torch.sum(in_feats, dim=1)
        attn = self.mlp(self.avg_pool(feats_sum))
        attn = self.softmax(attn.view(B, self.height, C, 1, 1))

        out = torch.sum(in_feats*attn, dim=1)
        return out

四、添加步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 MFM.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .MFM import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

然后,在 parse_model函数 中添加如下代码:

        elif m in {MFM}:
            if args[0] == 'head_channel':
                args[0] = d[args[0]]
            c1 = [ch[x] for x in f]
            c2 = make_divisible(min(args[0], max_channels) * width, 8)
            args = [c1, c2, *args[1:]]

在这里插入图片描述


五、yaml模型文件

5.1 中期融合⭐

📌 此模型的修方法是将MFM模块应用到RT-DETR的中期融合中。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-ResNet50 object detection model with P3-P5 outputs.

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 3-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 4
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 5
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 6-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 7
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 8-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 9-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 10-P5

  - [2, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 11-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 12
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 13
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 14-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 15
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 16-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 17-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 18-P5

  - [[8, 16], 1, MFM, [128]]  # 19 cat backbone P3
  - [[9, 17], 1, MFM, [256]]  # 20 cat backbone P4
  - [[10, 18], 1, MFM, [512]]  # 21 cat backbone P5

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 22 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]]  # 24, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 25
  - [20, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 26 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256, 0.5]]  # 28, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]  # 29, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 30
  - [19, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 31 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 32 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]  # X3 (33), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]  # 34, downsample_convs.0
  - [[-1, 29], 1, Concat, [1]]  # 35 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]  # F4 (36), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]  # 37, downsample_convs.1
  - [[-1, 24], 1, Concat, [1]]  # 38 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]  # F5 (39), pan_blocks.1

  - [[33, 36, 39], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。

rtdetr-resnet18-mid-MFM

rtdetr-resnet18-mid-MFM summary: 508 layers, 31,431,892 parameters, 31,431,892 gradients, 92.3 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
  4                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
  5                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
  6                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
  7                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
  8                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
  9                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 10                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 11                   2  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
 12                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
 13                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
 14                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
 15                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
 16                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
 17                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 18                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 19             [8, 16]  1      6144  ultralytics.nn.AddModules.MFM.MFM            [[128, 128], 128]
 20             [9, 17]  1     24576  ultralytics.nn.AddModules.MFM.MFM            [[256, 256], 256]
 21            [10, 18]  1     98304  ultralytics.nn.AddModules.MFM.MFM            [[512, 512], 512]
 22                  -1  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 23                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]
 24                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 25                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 26                  20  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
 27            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 28                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 29                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 30                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 31                  19  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 32            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 33                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 34                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 35            [-1, 29]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 36                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 37                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 38            [-1, 24]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 39                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 40        [33, 36, 39]  1   3927956  ultralytics.nn.modules.head.RTDETRDecoder    [9, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-resnet18-mid-MFM summary: 508 layers, 31,431,892 parameters, 31,431,892 gradients, 92.3 GFLOPs