学习资源站

【YOLOv8多模态融合改进】_PSFM,深层语义融合模块引入跨模态交叉注意力机制,动态建模不同模态特征的全局语义依赖关系_profoundsemanticfusionmodule-

【YOLOv8多模态融合改进】| PSFM,深层语义融合模块 引入跨模态交叉注意力机制,动态建模不同模态特征的全局语义依赖关系

一、本文介绍

本文记录的是利用 PSFM 模块改进 YOLOv8 的多模态融合部分

PSFM模块(Profound Semantic Fusion Module,深层语义融合模块) 通过在特征提取网络的深层引入 跨模态交叉注意力机制,动态建模红外与可见光特征的全局语义依赖关系 。该模块以可见光语义为引导、红外热信号为补充,通过 双向注意力 计算实现“语义类别”与“热目标位置”的精准对齐,同时 捕捉长距离语义关联 增强融合特征的判别性与场景理解能力 ,为检测头 提供包含全局上下文的高层语义表示 ,从而提升模型在复杂场景下的目标检测准确率与语义推理鲁棒性。



二、PSFM模块介绍

Rethinking the necessity of image fusion in high-level vision tasks: A practical infrared and visible image fusion network based on progressive semantic injection and scene fidelity

2.1 设计目标

  • 解决高层语义融合的跨模态鸿沟 :红外图像的热信号与可见光图像的语义类别存在模态差异(如红外“热斑”对应可见光“行人”),需通过语义交互建立跨模态映射。
  • 增强全局语义一致性 :确保融合特征中目标的语义标签(如类别)与可见光图像一致,同时保留红外目标的空间位置信息,避免语义混淆(如将车辆热信号误判为行人)。
  • 支持复杂场景的语义推理 :通过全局上下文建模,捕捉长距离依赖关系(如“行人-道路”的空间关系),提升模型对复杂场景的理解能力。

2.2 结构原理:基于交叉注意力的全局语义交互

2.2.1 核心组件与流程

PSFM模块的架构如图所示,基于 跨模态交叉注意力机制 (Cross-Attention)实现深层特征的语义融合,主要步骤如下:

在这里插入图片描述

  1. 特征投影与维度变换
  • 对红外和可见光的深层特征分别进行投影,生成注意力机制所需的 键(Key) 值(Value) 矩阵:
    K x i = Reshape ( Conv K x ( F ^ x i ) ) , V x i = Reshape ( Conv V x ( F ^ x i ) ) K_{x}^{i} = \text{Reshape}\left(\text{Conv}_{K}^{x}\left(\hat{\mathcal{F}}_{x}^{i}\right)\right), \quad V_{x}^{i} = \text{Reshape}\left(\text{Conv}_{V}^{x}\left(\hat{\mathcal{F}}_{x}^{i}\right)\right) K x i = Reshape ( Conv K x ( F ^ x i ) ) , V x i = Reshape ( Conv V x ( F ^ x i ) )
    (其中 x ∈ { ir , vi } x \in \{\text{ir}, \text{vi}\} x { ir , vi } Conv K x \text{Conv}_{K}^{x} Conv K x Conv V x \text{Conv}_{V}^{x} Conv V x 为3×3卷积, Reshape \text{Reshape} Reshape 将特征图展开为 H W × C HW \times C H W × C 的矩阵,便于注意力计算)。
  1. 跨模态注意力计算
  • 以可见光特征为 查询(Query, Q) ,计算其与红外特征的键矩阵 K ir i K_{\text{ir}}^{i} K ir i 的注意力矩阵 A ir i \mathcal{A}_{\text{ir}}^{i} A ir i
    A ir i = Softmax ( Q i ⋅ ( K ir i ) T ) \mathcal{A}_{\text{ir}}^{i} = \text{Softmax}\left(Q^{i} \cdot (K_{\text{ir}}^{i})^{\text{T}}\right) A ir i = Softmax ( Q i ( K ir i ) T )
    A ir i \mathcal{A}_{\text{ir}}^{i} A ir i 表示可见光特征对红外特征的依赖程度,数值越大表明该区域越需要红外语义信息)。
  • 同理,计算红外特征对可见光特征的注意力矩阵 A vi i \mathcal{A}_{\text{vi}}^{i} A vi i ,实现双向语义交互。
  1. 全局语义特征聚合
  • 根据注意力矩阵加权聚合跨模态的 值矩阵(Value) ,生成包含全局上下文的特征:
    Attn ir i = A ir i ⋅ V ir i , Attn vi i = A vi i ⋅ V vi i \text{Attn}_{\text{ir}}^{i} = \mathcal{A}_{\text{ir}}^{i} \cdot V_{\text{ir}}^{i}, \quad \text{Attn}_{\text{vi}}^{i} = \mathcal{A}_{\text{vi}}^{i} \cdot V_{\text{vi}}^{i} Attn ir i = A ir i V ir i , Attn vi i = A vi i V vi i
  • 将聚合后的特征与原始特征相加,并在通道维度拼接,通过卷积层生成最终的融合特征:
    F f u i = Conv ( C ( F vi i + Reshape ( Attn ir i ) , F ir i + Reshape ( Attn vi i ) ) ) \mathcal{F}_{fu}^{i} = \text{Conv}\left(\mathcal{C}\left(\mathcal{F}_{\text{vi}}^{i} + \text{Reshape}(\text{Attn}_{\text{ir}}^{i}), \mathcal{F}_{\text{ir}}^{i} + \text{Reshape}(\text{Attn}_{\text{vi}}^{i})\right)\right) F f u i = Conv ( C ( F vi i + Reshape ( Attn ir i ) , F ir i + Reshape ( Attn vi i ) ) )
    C \mathcal{C} C 表示通道拼接,通过残差连接保留原始特征的语义信息,避免注意力机制导致的特征失真)。

3. 关键技术特点

  • 双向跨模态交互 :通过可见光与红外特征的双向注意力计算(Q分别来自可见光和红外),实现“以可见光语义为引导对齐红外目标”和“以红外热信号增强可见光语义”的双向优化。
  • 全局上下文建模 :注意力机制可捕捉特征图中任意位置的语义依赖关系(如远处行人与近处车辆的关联),解决传统卷积神经网络对长距离依赖建模能力不足的问题。
  • 轻量化设计 :仅通过少量卷积层和矩阵运算实现,参数增量小于5%,适用于深层特征的高效语义融合。

论文: https://www.sciencedirect.com/science/article/abs/pii/S1566253523001860
源码: https://github.com/Linfeng-Tang/PSFusion

三、PSFM的实现代码

PSFM 的实现代码如下:

import math
import torch.nn as nn
import torch

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))

class DWConv(Conv):
    """Depth-wise convolution."""

    def __init__(self, c1, c2, k=1, s=1, d=1, act=True):  # ch_in, ch_out, kernel, stride, dilation, activation
        """Initialize Depth-wise convolution with given parameters."""
        super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)

class DSConv(nn.Module):
    """Depthwise Separable Convolution"""
    def __init__(self, c1, c2, k=1, s=1, d=1, act=True) -> None:
        super().__init__()

        self.dwconv = DWConv(c1, c1, 3)
        self.pwconv = Conv(c1, c2, 1)

    def forward(self, x):
        return self.pwconv(self.dwconv(x))

class GEFM(nn.Module):
    def __init__(self, in_C, out_C):
        super(GEFM, self).__init__()
        self.RGB_K= DSConv(out_C, out_C, 3)
        self.RGB_V = DSConv(out_C, out_C, 3)
        self.Q = DSConv(in_C, out_C, 3)
        self.INF_K= DSConv(out_C, out_C, 3)
        self.INF_V = DSConv(out_C, out_C, 3)
        self.Second_reduce = DSConv(in_C, out_C, 3)
        self.gamma1 = nn.Parameter(torch.zeros(1))
        self.gamma2 = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, y):
        Q = self.Q(torch.cat([x,y], dim=1))
        RGB_K = self.RGB_K(x)
        RGB_V = self.RGB_V(x)
        m_batchsize, C, height, width = RGB_V.size()
        RGB_V = RGB_V.view(m_batchsize, -1, width*height)
        RGB_K = RGB_K.view(m_batchsize, -1, width*height).permute(0, 2, 1)
        RGB_Q = Q.view(m_batchsize, -1, width*height)
        RGB_mask = torch.bmm(RGB_K, RGB_Q)
        RGB_mask = self.softmax(RGB_mask)
        RGB_refine = torch.bmm(RGB_V, RGB_mask.permute(0, 2, 1))
        RGB_refine = RGB_refine.view(m_batchsize, -1, height,width)
        RGB_refine = self.gamma1*RGB_refine+y

        INF_K = self.INF_K(y)
        INF_V = self.INF_V(y)
        INF_V = INF_V.view(m_batchsize, -1, width*height)
        INF_K = INF_K.view(m_batchsize, -1, width*height).permute(0, 2, 1)
        INF_Q = Q.view(m_batchsize, -1, width*height)
        INF_mask = torch.bmm(INF_K, INF_Q)
        INF_mask = self.softmax(INF_mask)
        INF_refine = torch.bmm(INF_V, INF_mask.permute(0, 2, 1))
        INF_refine = INF_refine.view(m_batchsize, -1, height,width)
        INF_refine = self.gamma2 * INF_refine + x

        out = self.Second_reduce(torch.cat([RGB_refine, INF_refine], dim=1))
        return out

class DenseLayer(nn.Module):
    def __init__(self, in_C, out_C, down_factor=4, k=2):
        super(DenseLayer, self).__init__()
        self.k = k
        self.down_factor = down_factor
        mid_C = out_C // self.down_factor

        self.down = nn.Conv2d(in_C, mid_C, 1)

        self.denseblock = nn.ModuleList()
        for i in range(1, self.k + 1):
            self.denseblock.append(DSConv(mid_C * i, mid_C, 3))

        self.fuse = DSConv(in_C + mid_C, out_C, 3)

    def forward(self, in_feat):
        down_feats = self.down(in_feat)
        out_feats = []
        for i in self.denseblock:
            feats = i(torch.cat((*out_feats, down_feats), dim=1))
            out_feats.append(feats)

        feats = torch.cat((in_feat, feats), dim=1)
        return self.fuse(feats)

class PSFM(nn.Module):
    def __init__(self, Channel):
        super(PSFM, self).__init__()
        self.RGBobj = DenseLayer(Channel, Channel)
        self.Infobj = DenseLayer(Channel, Channel)
        self.obj_fuse = GEFM(Channel * 2, Channel)

    def forward(self, data):
        rgb, depth = data
        rgb_sum = self.RGBobj(rgb)
        Inf_sum = self.Infobj(depth)
        out = self.obj_fuse(rgb_sum,Inf_sum)
        return out

四、融合步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 PSFM.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .PSFM import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

其次:在 parse_model函数 中注册 PSFM 模块

在这里插入图片描述

        elif m in {PSFM}:
            c2 = ch[f[0]]
            args = [c2]

在这里插入图片描述

最后将 ultralytics/utils/torch_utils.py 中的 get_flops 函数中的 stride 指定为 640

在这里插入图片描述


五、yaml模型文件

5.1 中期融合⭐

📌 此模型的修方法是将原本的中期融合中的Concat融合部分换成PSFM,融合骨干部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
ch: 6
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
   n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
   s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
   m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
   l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
   x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, Conv, [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 8-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32
  - [-1, 3, C2f, [1024, True]]

  - [2, 1, Conv, [64, 3, 2]] # 12-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 13-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 15-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 17-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 19-P5/32
  - [-1, 3, C2f, [1024, True]]

  - [[7, 16], 1, PSFM, []]  # 21 cat backbone P3
  - [[9, 18], 1, PSFM, []]  # 22 cat backbone P4
  - [[11, 20], 1, PSFM, []]  # 23 cat backbone P5

  - [-1, 1, SPPF, [1024, 5]] # 24

 # YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 22], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 27

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 21], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 30 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 27], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 33 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 24], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 36 (P5/32-large)

  - [[30, 33, 36], 1, Detect, [nc]]  # Detect(P3, P4, P5)

5.2 中-后期融合⭐

📌 此模型的修方法是将原本的中-后期融合中的Concat融合部分换成PSFM,融合FPN部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
ch: 6
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
   n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
   s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
   m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
   l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
   x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, Conv, [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 8-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 12

  - [2, 1, Conv, [64, 3, 2]] # 13-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 14-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 16-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 18-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 20-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 22

 # YOLOv8.0n head
head:
  - [12, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 9], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 25

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 7], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 28 (P3/8-small)

  - [22, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 19], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 31

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 17], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 34 (P3/8-small)

  - [ [ 12, 22 ], 1, PSFM, [] ]  # cat head P3  35
  - [ [ 25, 31 ], 1, PSFM, [] ]  # cat head P4  36
  - [ [ 28, 34 ], 1, PSFM, [] ]  # cat head P5  37

  - [37, 1, Conv, [256, 3, 2]]
  - [[-1, 36], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 40 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 35], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 43 (P5/32-large)

  - [[37, 40, 43], 1, Detect, [nc]]  # Detect(P3, P4, P5)

5.3 后期融合⭐

📌 此模型的修方法是将原本的后期融合中的Concat融合部分换成PSFM,融合颈部部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
ch: 6
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
   n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
   s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
   m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
   l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
   x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, Conv, [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 8-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 12

  - [2, 1, Conv, [64, 3, 2]] # 13-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 14-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 16-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 18-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 20-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 22

 # YOLOv8.0n head
head:
  - [12, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 9], 1, Concat, [1] ]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 25

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[ -1, 7], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 28 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 25], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 31 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 34 (P5/32-large)

  - [22, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 19], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 37

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[ -1, 17 ], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 40 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 37], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 43 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 22], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 46 (P5/32-large)

  - [[28, 40], 1, PSFM, []]  # cat head P5  47
  - [[31, 43], 1, PSFM, []]  # cat head P5  48
  - [[34, 46], 1, PSFM, []]  # cat head P5  49

  - [[47, 48, 49], 1, Detect, [nc]]  # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。

YOLOv8-mid-PSFM

YOLOv8-mid-PSFM summary: 621 layers, 4,844,313 parameters, 4,844,297 gradients, 12.6 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  5                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  6                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  7                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  8                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  9                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 10                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 11                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 12                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
 13                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
 14                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
 15                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
 16                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
 17                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
 18                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 19                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 20                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 21             [7, 16]  1     56226  ultralytics.nn.AddModules.PSFM.PSFM          [64]
 22             [9, 18]  1    205634  ultralytics.nn.AddModules.PSFM.PSFM          [128]
 23            [11, 20]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 24                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 25                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 26            [-1, 22]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 27                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 28                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 29            [-1, 21]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 30                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 31                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 32            [-1, 27]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 33                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 34                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 35            [-1, 24]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 36                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 37        [30, 33, 36]  1    430867  ultralytics.nn.modules.head.Detect           [1, [64, 128, 256]]
YOLOv8-mid-PSFM summary: 621 layers, 4,844,313 parameters, 4,844,297 gradients, 12.6 GFLOPs

YOLOv8-mid-to-late-PSFM

YOLOv8-mid-to-late-PSFM summary: 663 layers, 5,194,393 parameters, 5,194,377 gradients, 13.7 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  5                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  6                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  7                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  8                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  9                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 10                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 11                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 12                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 13                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
 14                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
 15                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
 16                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
 17                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
 18                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
 19                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 20                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 21                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 22                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 23                  12  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 24             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 25                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 26                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 27             [-1, 7]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 28                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 29                  22  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 30            [-1, 19]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 31                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 32                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 33            [-1, 17]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 34                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 35            [12, 22]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 36            [25, 31]  1    205634  ultralytics.nn.AddModules.PSFM.PSFM          [128]
 37            [28, 34]  1     56226  ultralytics.nn.AddModules.PSFM.PSFM          [64]
 38                  37  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 39            [-1, 36]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 40                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 41                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 42            [-1, 35]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 43                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 44        [37, 40, 43]  1    430867  ultralytics.nn.modules.head.Detect           [1, [64, 128, 256]]
YOLOv8-mid-to-late-PSFM summary: 663 layers, 5,194,393 parameters, 5,194,377 gradients, 13.7 GFLOPs

YOLOv8-late-PSFM

YOLOv8-late-PSFM summary: 701 layers, 5,995,801 parameters, 5,995,785 gradients, 14.7 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  5                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  6                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  7                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  8                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  9                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 10                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 11                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 12                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 13                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
 14                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
 15                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
 16                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
 17                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
 18                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
 19                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 20                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 21                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 22                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 23                  12  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 24             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 25                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 26                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 27             [-1, 7]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 28                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 29                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 30            [-1, 25]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 31                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 32                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 33            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 34                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 35                  22  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 36            [-1, 19]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 37                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 38                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 39            [-1, 17]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 40                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 41                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 42            [-1, 37]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 43                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 44                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 45            [-1, 22]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 46                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 47            [28, 40]  1     56226  ultralytics.nn.AddModules.PSFM.PSFM          [64]
 48            [31, 43]  1    205634  ultralytics.nn.AddModules.PSFM.PSFM          [128]
 49            [34, 46]  1    784002  ultralytics.nn.AddModules.PSFM.PSFM          [256]
 50        [47, 48, 49]  1    430867  ultralytics.nn.modules.head.Detect           [1, [64, 128, 256]]
YOLOv8-late-PSFM summary: 701 layers, 5,995,801 parameters, 5,995,785 gradients, 14.7 GFLOPs