学习资源站

【RT-DETR多模态融合改进】_PSFM,深层语义融合模块引入跨模态交叉注意力机制,动态建模不同模态特征的全局语义依赖关系-

【RT-DETR多模态融合改进】| PSFM,深层语义融合模块 引入跨模态交叉注意力机制,动态建模不同模态特征的全局语义依赖关系

一、本文介绍

本文记录的是利用 PSFM 模块改进 RT-DETR的多模态融合部分

PSFM模块(Profound Semantic Fusion Module,深层语义融合模块) 通过在特征提取网络的深层引入 跨模态交叉注意力机制,动态建模红外与可见光特征的全局语义依赖关系 。该模块以可见光语义为引导、红外热信号为补充,通过 双向注意力 计算实现“语义类别”与“热目标位置”的精准对齐,同时 捕捉长距离语义关联 增强融合特征的判别性与场景理解能力 ,为检测头 提供包含全局上下文的高层语义表示 ,从而提升模型在复杂场景下的目标检测准确率与语义推理鲁棒性。



二、PSFM模块介绍

Rethinking the necessity of image fusion in high-level vision tasks: A practical infrared and visible image fusion network based on progressive semantic injection and scene fidelity

2.1 设计目标

  • 解决高层语义融合的跨模态鸿沟 :红外图像的热信号与可见光图像的语义类别存在模态差异(如红外“热斑”对应可见光“行人”),需通过语义交互建立跨模态映射。
  • 增强全局语义一致性 :确保融合特征中目标的语义标签(如类别)与可见光图像一致,同时保留红外目标的空间位置信息,避免语义混淆(如将车辆热信号误判为行人)。
  • 支持复杂场景的语义推理 :通过全局上下文建模,捕捉长距离依赖关系(如“行人-道路”的空间关系),提升模型对复杂场景的理解能力。

2.2 结构原理:基于交叉注意力的全局语义交互

2.2.1 核心组件与流程

PSFM模块的架构如图所示,基于 跨模态交叉注意力机制 (Cross-Attention)实现深层特征的语义融合,主要步骤如下:

在这里插入图片描述

  1. 特征投影与维度变换
  • 对红外和可见光的深层特征分别进行投影,生成注意力机制所需的 键(Key) 、**值(Value)**矩阵:
    K x i = Reshape ( Conv K x ( F ^ x i ) ) , V x i = Reshape ( Conv V x ( F ^ x i ) ) K_{x}^{i} = \text{Reshape}\left(\text{Conv}_{K}^{x}\left(\hat{\mathcal{F}}_{x}^{i}\right)\right), \quad V_{x}^{i} = \text{Reshape}\left(\text{Conv}_{V}^{x}\left(\hat{\mathcal{F}}_{x}^{i}\right)\right) K x i = Reshape ( Conv K x ( F ^ x i ) ) , V x i = Reshape ( Conv V x ( F ^ x i ) )
    (其中 x ∈ { ir , vi } x \in \{\text{ir}, \text{vi}\} x { ir , vi } Conv K x \text{Conv}_{K}^{x} Conv K x Conv V x \text{Conv}_{V}^{x} Conv V x 为3×3卷积, Reshape \text{Reshape} Reshape 将特征图展开为 H W × C HW \times C H W × C 的矩阵,便于注意力计算)。
  1. 跨模态注意力计算
  • 以可见光特征为 查询(Query, Q) ,计算其与红外特征的键矩阵 K ir i K_{\text{ir}}^{i} K ir i 的注意力矩阵 A ir i \mathcal{A}_{\text{ir}}^{i} A ir i
    A ir i = Softmax ( Q i ⋅ ( K ir i ) T ) \mathcal{A}_{\text{ir}}^{i} = \text{Softmax}\left(Q^{i} \cdot (K_{\text{ir}}^{i})^{\text{T}}\right) A ir i = Softmax ( Q i ( K ir i ) T )
    A ir i \mathcal{A}_{\text{ir}}^{i} A ir i 表示可见光特征对红外特征的依赖程度,数值越大表明该区域越需要红外语义信息)。
  • 同理,计算红外特征对可见光特征的注意力矩阵 A vi i \mathcal{A}_{\text{vi}}^{i} A vi i ,实现双向语义交互。
  1. 全局语义特征聚合
  • 根据注意力矩阵加权聚合跨模态的 值矩阵(Value) ,生成包含全局上下文的特征:
    Attn ir i = A ir i ⋅ V ir i , Attn vi i = A vi i ⋅ V vi i \text{Attn}_{\text{ir}}^{i} = \mathcal{A}_{\text{ir}}^{i} \cdot V_{\text{ir}}^{i}, \quad \text{Attn}_{\text{vi}}^{i} = \mathcal{A}_{\text{vi}}^{i} \cdot V_{\text{vi}}^{i} Attn ir i = A ir i V ir i , Attn vi i = A vi i V vi i
  • 将聚合后的特征与原始特征相加,并在通道维度拼接,通过卷积层生成最终的融合特征:
    F f u i = Conv ( C ( F vi i + Reshape ( Attn ir i ) , F ir i + Reshape ( Attn vi i ) ) ) \mathcal{F}_{fu}^{i} = \text{Conv}\left(\mathcal{C}\left(\mathcal{F}_{\text{vi}}^{i} + \text{Reshape}(\text{Attn}_{\text{ir}}^{i}), \mathcal{F}_{\text{ir}}^{i} + \text{Reshape}(\text{Attn}_{\text{vi}}^{i})\right)\right) F f u i = Conv ( C ( F vi i + Reshape ( Attn ir i ) , F ir i + Reshape ( Attn vi i ) ) )
    C \mathcal{C} C 表示通道拼接,通过残差连接保留原始特征的语义信息,避免注意力机制导致的特征失真)。

3. 关键技术特点

  • 双向跨模态交互 :通过可见光与红外特征的双向注意力计算(Q分别来自可见光和红外),实现“以可见光语义为引导对齐红外目标”和“以红外热信号增强可见光语义”的双向优化。
  • 全局上下文建模 :注意力机制可捕捉特征图中任意位置的语义依赖关系(如远处行人与近处车辆的关联),解决传统卷积神经网络对长距离依赖建模能力不足的问题。
  • 轻量化设计 :仅通过少量卷积层和矩阵运算实现,参数增量小于5%,适用于深层特征的高效语义融合。

论文: https://www.sciencedirect.com/science/article/abs/pii/S1566253523001860
源码: https://github.com/Linfeng-Tang/PSFusion

三、PSFM的实现代码

PSFM 的实现代码如下:

import math
import torch.nn as nn
import torch

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))

class DWConv(Conv):
    """Depth-wise convolution."""

    def __init__(self, c1, c2, k=1, s=1, d=1, act=True):  # ch_in, ch_out, kernel, stride, dilation, activation
        """Initialize Depth-wise convolution with given parameters."""
        super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)

class DSConv(nn.Module):
    """Depthwise Separable Convolution"""
    def __init__(self, c1, c2, k=1, s=1, d=1, act=True) -> None:
        super().__init__()

        self.dwconv = DWConv(c1, c1, 3)
        self.pwconv = Conv(c1, c2, 1)

    def forward(self, x):
        return self.pwconv(self.dwconv(x))

class GEFM(nn.Module):
    def __init__(self, in_C, out_C):
        super(GEFM, self).__init__()
        self.RGB_K= DSConv(out_C, out_C, 3)
        self.RGB_V = DSConv(out_C, out_C, 3)
        self.Q = DSConv(in_C, out_C, 3)
        self.INF_K= DSConv(out_C, out_C, 3)
        self.INF_V = DSConv(out_C, out_C, 3)
        self.Second_reduce = DSConv(in_C, out_C, 3)
        self.gamma1 = nn.Parameter(torch.zeros(1))
        self.gamma2 = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, y):
        Q = self.Q(torch.cat([x,y], dim=1))
        RGB_K = self.RGB_K(x)
        RGB_V = self.RGB_V(x)
        m_batchsize, C, height, width = RGB_V.size()
        RGB_V = RGB_V.view(m_batchsize, -1, width*height)
        RGB_K = RGB_K.view(m_batchsize, -1, width*height).permute(0, 2, 1)
        RGB_Q = Q.view(m_batchsize, -1, width*height)
        RGB_mask = torch.bmm(RGB_K, RGB_Q)
        RGB_mask = self.softmax(RGB_mask)
        RGB_refine = torch.bmm(RGB_V, RGB_mask.permute(0, 2, 1))
        RGB_refine = RGB_refine.view(m_batchsize, -1, height,width)
        RGB_refine = self.gamma1*RGB_refine+y

        INF_K = self.INF_K(y)
        INF_V = self.INF_V(y)
        INF_V = INF_V.view(m_batchsize, -1, width*height)
        INF_K = INF_K.view(m_batchsize, -1, width*height).permute(0, 2, 1)
        INF_Q = Q.view(m_batchsize, -1, width*height)
        INF_mask = torch.bmm(INF_K, INF_Q)
        INF_mask = self.softmax(INF_mask)
        INF_refine = torch.bmm(INF_V, INF_mask.permute(0, 2, 1))
        INF_refine = INF_refine.view(m_batchsize, -1, height,width)
        INF_refine = self.gamma2 * INF_refine + x

        out = self.Second_reduce(torch.cat([RGB_refine, INF_refine], dim=1))
        return out

class DenseLayer(nn.Module):
    def __init__(self, in_C, out_C, down_factor=4, k=2):
        super(DenseLayer, self).__init__()
        self.k = k
        self.down_factor = down_factor
        mid_C = out_C // self.down_factor

        self.down = nn.Conv2d(in_C, mid_C, 1)

        self.denseblock = nn.ModuleList()
        for i in range(1, self.k + 1):
            self.denseblock.append(DSConv(mid_C * i, mid_C, 3))

        self.fuse = DSConv(in_C + mid_C, out_C, 3)

    def forward(self, in_feat):
        down_feats = self.down(in_feat)
        out_feats = []
        for i in self.denseblock:
            feats = i(torch.cat((*out_feats, down_feats), dim=1))
            out_feats.append(feats)

        feats = torch.cat((in_feat, feats), dim=1)
        return self.fuse(feats)

class PSFM(nn.Module):
    def __init__(self, Channel):
        super(PSFM, self).__init__()
        self.RGBobj = DenseLayer(Channel, Channel)
        self.Infobj = DenseLayer(Channel, Channel)
        self.obj_fuse = GEFM(Channel * 2, Channel)

    def forward(self, data):
        rgb, depth = data
        rgb_sum = self.RGBobj(rgb)
        Inf_sum = self.Infobj(depth)
        out = self.obj_fuse(rgb_sum,Inf_sum)
        return out

四、融合步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 PSFM.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .PSFM import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

其次:在 parse_model函数 中注册 PSFM 模块

在这里插入图片描述

        elif m in {PSFM}:
            c2 = ch[f[0]]
            args = [c2]

在这里插入图片描述

最后将 ultralytics/utils/torch_utils.py 中的 get_flops 函数中的 stride 指定为 640

在这里插入图片描述


五、yaml模型文件

5.1 中期融合⭐

📌 此模型的修方法是将原本的中期融合中的Concat融合部分换成PSFM,融合骨干部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-ResNet50 object detection model with P3-P5 outputs.

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 3-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 4
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 5
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 6-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 7
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 8-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 9-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 10-P5

  - [2, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 11-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 12
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 13
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 14-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 15
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 16-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 17-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 18-P5

  - [[8, 16], 1, PSFM, []]  # 19 cat backbone P3
  - [[9, 17], 1, PSFM, []]  # 20 cat backbone P4
  - [[10, 18], 1, PSFM, []]  # 21 cat backbone P5

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 22 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]]  # 24, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 25
  - [20, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 26 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256, 0.5]]  # 28, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]  # 29, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 30
  - [19, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 31 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 32 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]  # X3 (33), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]  # 34, downsample_convs.0
  - [[-1, 29], 1, Concat, [1]]  # 35 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]  # F4 (36), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]  # 37, downsample_convs.1
  - [[-1, 24], 1, Concat, [1]]  # 38 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]  # F5 (39), pan_blocks.1

  - [[33, 36, 39], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)

5.2 中-后期融合⭐

📌 此模型的修方法是将原本的中-后期融合中的Concat融合部分换成PSFM,融合FPN部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-ResNet50 object detection model with P3-P5 outputs.

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 3-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 4
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 5
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 6-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 7
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 8-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 9-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 10-P5

  - [2, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 11-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 12
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 13
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 14-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 15
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 16-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 17-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 18-P5

head:
  - [10, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 19 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]]  # 21, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 22
  - [9, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 23 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256, 0.5]]  # 25, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]  # 26, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 27
  - [8, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 28 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 29 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]  # X3 (30), fpn_blocks.1

  - [18, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 31 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]]  # 33, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 34
  - [17, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 35 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256, 0.5]]  # 37, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]  # 38, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 39
  - [16, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 40 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 41 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]  # X3 (42), fpn_blocks.1

  - [[21, 33], 1, PSFM, []]  # 43 cat backbone P3
  - [[26, 38], 1, PSFM, []]  # 44 cat backbone P4
  - [[30, 42], 1, PSFM, []]  # 45 cat backbone P5

  - [-1, 1, Conv, [256, 3, 2]]  # 46, downsample_convs.0
  - [[-1, 44], 1, Concat, [1]]  # 47 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]  # F4 (48), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]  # 49, downsample_convs.1
  - [[-1, 43], 1, Concat, [1]]  # 50 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]  # F5 (51), pan_blocks.1

  - [[45, 48, 51], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)

5.3 后期融合⭐

📌 此模型的修方法是将原本的后期融合中的Concat融合部分换成PSFM,融合颈部部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-ResNet50 object detection model with P3-P5 outputs.

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 3-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 4
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 5
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 6-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 7
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 8-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 9-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 10-P5

  - [2, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 11-P1
  - [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 12
  - [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 13
  - [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 14-P2

  - [-1, 2, Blocks, [64,  BasicBlock, 2, False]] # 15
  - [-1, 2, Blocks, [128, BasicBlock, 3, False]] # 16-P3
  - [-1, 2, Blocks, [256, BasicBlock, 4, False]] # 17-P4
  - [-1, 2, Blocks, [512, BasicBlock, 5, False]] # 18-P5

head:
  - [10, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 19 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]]  # 21, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 22
  - [9, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 23 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256, 0.5]]  # 25, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]  # 26, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 27
  - [8, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 28 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 29 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]  # X3 (30), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]  # 31, downsample_convs.0
  - [[-1, 26], 1, Concat, [1]]  # 32 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]  # F4 (33), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]  # 34, downsample_convs.1
  - [[-1, 21], 1, Concat, [1]]  # 35 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]  # F5 (36), pan_blocks.1

  - [18, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 37 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]]  # 39, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 40
  - [17, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 41 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256, 0.5]]  # 43, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]  # 44, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 45
  - [16, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 46 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 47 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]  # X3 (48), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]  # 49, downsample_convs.0
  - [[-1, 44], 1, Concat, [1]]  # 50 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]  # F4 (51), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]  # 52, downsample_convs.1
  - [[-1, 39], 1, Concat, [1]]  # 53 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]  # F5 (54), pan_blocks.1

  - [[30, 48], 1, PSFM, []]  # 55 cat backbone P3
  - [[33, 51], 1, PSFM, []]  # 56 cat backbone P4
  - [[36, 54], 1, PSFM, []]  # 57 cat backbone P5

  - [[55, 56, 57], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。

rtdetr-resnet18-mid-PSFM

rtdetr-resnet18-mid-PSFM summary: 758 layers, 35,351,450 parameters, 35,351,450 gradients, 100.5 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
  4                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
  5                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
  6                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
  7                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
  8                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
  9                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 10                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 11                   2  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
 12                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
 13                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
 14                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
 15                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
 16                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
 17                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 18                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 19             [8, 16]  1    205634  ultralytics.nn.AddModules.PSFM.PSFM          [128]
 20             [9, 17]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 21            [10, 18]  1   3058946  ultralytics.nn.AddModules.PSFM.PSFM          [512]
 22                  -1  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 23                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]
 24                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 25                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 26                  20  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
 27            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 28                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 29                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 30                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 31                  19  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 32            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 33                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 34                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 35            [-1, 29]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 36                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 37                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 38            [-1, 24]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 39                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 40        [33, 36, 39]  1   3927956  ultralytics.nn.modules.head.RTDETRDecoder    [9, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-resnet18-mid-PSFM summary: 758 layers, 35,351,450 parameters, 35,351,450 gradients, 100.5 GFLOPs

rtdetr-resnet18-mid-to-late-PSFM

rtdetr-resnet18-mid-to-late-PSFM summary: 866 layers, 36,123,482 parameters, 36,123,482 gradients, 118.2 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
  4                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
  5                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
  6                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
  7                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
  8                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
  9                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 10                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 11                   2  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
 12                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
 13                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
 14                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
 15                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
 16                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
 17                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 18                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 19                  10  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 20                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]
 21                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 22                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 23                   9  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
 24            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 25                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 26                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 27                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 28                   8  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 29            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 30                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 31                  18  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 32                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]
 33                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 34                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 35                  17  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
 36            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 37                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 38                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 39                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 40                  16  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 41            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 42                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 43            [21, 33]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 44            [26, 38]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 45            [30, 42]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 46                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 47            [-1, 44]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 48                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 49                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 50            [-1, 43]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 51                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 52        [45, 48, 51]  1   3927956  ultralytics.nn.modules.head.RTDETRDecoder    [9, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-resnet18-mid-to-late-PSFM summary: 866 layers, 36,123,482 parameters, 36,123,482 gradients, 118.2 GFLOPs

rtdetr-resnet18-late-PSFM

rtdetr-resnet18-late-PSFM summary: 950 layers, 38,619,994 parameters, 38,619,994 gradients, 123.2 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
  4                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
  5                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
  6                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
  7                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
  8                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
  9                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 10                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 11                   2  1       960  ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
 12                  -1  1      9312  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
 13                  -1  1     18624  ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
 14                  -1  1         0  torch.nn.modules.pooling.MaxPool2d           [3, 2, 1]
 15                  -1  2    152512  ultralytics.nn.AddModules.ResNet.Blocks      [64, 64, 2, 'BasicBlock', 2, False]
 16                  -1  2    526208  ultralytics.nn.AddModules.ResNet.Blocks      [64, 128, 2, 'BasicBlock', 3, False]
 17                  -1  2   2100992  ultralytics.nn.AddModules.ResNet.Blocks      [128, 256, 2, 'BasicBlock', 4, False]
 18                  -1  2   8396288  ultralytics.nn.AddModules.ResNet.Blocks      [256, 512, 2, 'BasicBlock', 5, False]
 19                  10  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 20                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]
 21                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 22                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 23                   9  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
 24            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 25                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 26                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 27                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 28                   8  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 29            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 30                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 31                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 32            [-1, 26]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 33                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 34                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 35            [-1, 21]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 36                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 37                  18  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 38                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]
 39                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 40                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 41                  17  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
 42            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 43                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 44                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]
 45                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 46                  16  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 47            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 48                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 49                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 50            [-1, 44]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 51                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 52                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 53            [-1, 39]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 54                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]
 55            [30, 48]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 56            [33, 51]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 57            [36, 54]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 58        [55, 56, 57]  1   3927956  ultralytics.nn.modules.head.RTDETRDecoder    [9, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-resnet18-late-PSFM summary: 950 layers, 38,619,994 parameters, 38,619,994 gradients, 123.2 GFLOPs