学习资源站

【YOLOv12多模态融合改进】_CFT-跨模态融合Transformer_利用Transformer的自注意力机制,解决跨模态融合中的长距离依赖和全局信息整合问题-

【YOLOv12多模态融合改进】| CFT:跨模态融合Transformer | 利用Transformer的自注意力机制,解决跨模态融合中的长距离依赖和全局信息整合问题

一、本文介绍

本文记录的是 利用 CFT 模块改进 YOLOv12 的多模态目标检测网络模型

CFT(Cross-Modality Fusion Transformer) 的设计出发点在于 解决传统多模态检测中跨模态特征融合不充分的问题 ,即当 不同模态 数据需协同检测时,基于CNN的方法因局部卷积的局限性, 难以捕捉长距离依赖和全局模态间的互补信息 ,导致复杂光照、遮挡等场景下检测精度不足。

本文利用 CFT 模块,将多模态特征序列拼接后自动学习模态内与模态间的交互权重, 在特征提取阶段整合全局上下文信息,增强对不同模态互补特征的利用能力 ,从而提升模型在多模态场景下的检测鲁棒性与准确性。



二、CFT模块介绍

Cross-Modality Fusion Transformer for Multispectral Object Detection

2.1 设计出发点

传统基于CNN的多光谱目标检测方法主要通过卷积操作进行特征融合,但卷积的局部感受野限制了其捕捉长距离依赖和全局上下文信息的能力,难以充分挖掘RGB与热成像(Thermal)模态间的互补性。

Transformer的自注意力机制可视为全连接图,能学习全局依赖关系,自然适用于跨模态信号的交互融合。现有研究中,Transformer在图像分割、视频检索等领域已展现多模态处理优势,但尚未有工作将其应用于多光谱目标检测。

因此,CFT模块旨在利用Transformer的自注意力机制, 解决跨模态融合中的长距离依赖和全局信息整合问题

2.2 核心原理

CFT模块的核心是通过Transformer的自注意力机制,同时实现 模态内(Intra-modality) 模态间(Inter-modality) 的特征融合。具体原理如下:

  1. 特征预处理 :将RGB和热成像的特征图 F R F_R F R F T F_T F T 展平为序列 I R I_R I R I T I_T I T ,拼接为 I ∈ R 2 H W × C I \in \mathbb{R}^{2HW \times C} I R 2 H W × C ,并添加可学习的位置嵌入以保留空间信息。
  2. 自注意力机制 :通过查询(Q)、键(K)、值(V)的投影和缩放点积计算注意力权重,形成相关矩阵 α \alpha α 。该矩阵自然分解为四个块:RGB模态内、热成像模态内的相关性,以及两者之间的跨模态相关性,从而捕捉不同模态内部的局部特征和跨模态的全局交互。
  3. 多头注意力与残差连接 :采用多头注意力机制捕捉多子空间关系,输出经多层感知机(MLP)处理后,通过残差连接回原模态分支,避免梯度消失并增强特征表示。

2.3 模块结构

CFT模块的结构设计围绕高效融合和计算优化展开,具体包括以下部分:

  1. 双流 backbone :原文以YOLOv5为基础,构建RGB和热成像双分支特征提取网络,在中间层嵌入CFT模块,形成跨模态融合骨干网络(CFB)。
  2. 下采样与上采样 :为降低计算复杂度,在进入Transformer前通过全局平均池化将特征图下采样至固定低分辨率(8×8),处理后通过双线性插值上采样回原分辨率,平衡性能与效率。
  3. Transformer块 :包含8个重复的Transformer单元,每个单元由层归一化、多头注意力和MLP组成,逐层整合跨模态信息。

在这里插入图片描述

2.4 优势

  1. 全局上下文融合 :通过自注意力机制捕捉长距离依赖,整合RGB与热成像的全局互补信息,例如RGB的纹理细节与热成像的低光照轮廓,提升复杂场景下的检测鲁棒性。
  2. 端到端自动融合 :无需手动设计复杂的融合规则,Transformer自动学习模态内和模态间的交互权重,简化跨模态融合流程。
  3. 性能提升显著 :在FLIR、LLVIP、VEDAI等数据集上,CFT相比基线方法(如双流CNN)显著提升检测精度。例如,在VEDAI数据集上,mAP75提升18.2%,mAP提升9.2%,证明其对小目标和复杂场景的有效性。
  4. 泛化能力强 :可灵活集成到YOLOv3、Faster R-CNN等单阶段/两阶段检测器,均实现性能增益,表明其跨框架的通用性。
  5. 计算效率优化 :通过下采样策略控制计算量,在保持高精度的同时,避免传统Transformer的高内存占用问题,适合实际部署。

CFT模块通过Transformer的自注意力机制,将跨模态融合从局部卷积升级为全局上下文建模,有效解决了多光谱检测中模态互补性挖掘的难题。

论文: https://arxiv.org/abs/2111.00273
源码: https://github.com/DocF/multispectral-object-detection

三、CFT的实现代码

CFT 的实现代码如下:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

class SelfAttention(nn.Module):
    """
     Multi-head masked self-attention layer
    """

    def __init__(self, d_model, d_k, d_v, h, attn_pdrop=.1, resid_pdrop=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(SelfAttention, self).__init__()
        assert d_k % h == 0
        self.d_model = d_model
        self.d_k = d_model // h
        self.d_v = d_model // h
        self.h = h

        # key, query, value projections for all heads
        self.que_proj = nn.Linear(d_model, h * self.d_k)  # query projection
        self.key_proj = nn.Linear(d_model, h * self.d_k)  # key projection
        self.val_proj = nn.Linear(d_model, h * self.d_v)  # value projection
        self.out_proj = nn.Linear(h * self.d_v, d_model)  # output projection

        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x, attention_mask=None, attention_weights=None):
        '''
        Computes Self-Attention
        Args:
            x (tensor): input (token) dim:(b_s, nx, c),
                b_s means batch size
                nx means length, for CNN, equals H*W, i.e. the length of feature maps
                c means channel, i.e. the channel of feature maps
            attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
            attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        Return:
            output (tensor): dim:(b_s, nx, c)
        '''

        b_s, nq = x.shape[:2]
        nk = x.shape[1]
        q = self.que_proj(x).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.key_proj(x).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk) K^T
        v = self.val_proj(x).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        # Self-Attention
        #  :math:`(\text(Attention(Q,K,V) = Softmax((Q*K^T)/\sqrt(d_k))`
        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)

        # weight and mask
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)

        # get attention matrix
        att = torch.softmax(att, -1)
        att = self.attn_drop(att)

        # output
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.resid_drop(self.out_proj(out))  # (b_s, nq, d_model)

        return out

class myTransformerBlock(nn.Module):
    """ Transformer block """

    def __init__(self, d_model, d_k, d_v, h, block_exp, attn_pdrop, resid_pdrop):
        """
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        :param block_exp: Expansion factor for MLP (feed foreword network)

        """
        super().__init__()
        self.ln_input = nn.LayerNorm(d_model)
        self.ln_output = nn.LayerNorm(d_model)
        self.sa = SelfAttention(d_model, d_k, d_v, h, attn_pdrop, resid_pdrop)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, block_exp * d_model),
            # nn.SiLU(),  # changed from GELU
            nn.GELU(),  # changed from GELU
            nn.Linear(block_exp * d_model, d_model),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x):
        bs, nx, c = x.size()

        x = x + self.sa(self.ln_input(x))
        x = x + self.mlp(self.ln_output(x))

        return x

class CFT(nn.Module):
    """  the full GPT language model, with a context size of block_size """

    def __init__(self, d_model, h=8, block_exp=4,
                 n_layer=8, vert_anchors=8, horz_anchors=8,
                 embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1):
        super().__init__()

        self.n_embd = d_model
        self.vert_anchors = vert_anchors
        self.horz_anchors = horz_anchors

        d_k = d_model
        d_v = d_model

        # positional embedding parameter (learnable), rgb_fea + ir_fea
        self.pos_emb = nn.Parameter(torch.zeros(1, 2 * vert_anchors * horz_anchors, self.n_embd))

        # transformer
        self.trans_blocks = nn.Sequential(*[myTransformerBlock(d_model, d_k, d_v, h, block_exp, attn_pdrop, resid_pdrop)
                                            for layer in range(n_layer)])

        # decoder head
        self.ln_f = nn.LayerNorm(self.n_embd)

        # regularization
        self.drop = nn.Dropout(embd_pdrop)

        # avgpool
        self.avgpool = nn.AdaptiveAvgPool2d((self.vert_anchors, self.horz_anchors))

        # init weights
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    def forward(self, x):
        """
        Args:
            x (tuple?)

        """
        rgb_fea = x[0]  # rgb_fea (tensor): dim:(B, C, H, W)
        ir_fea = x[1]   # ir_fea (tensor): dim:(B, C, H, W)
        assert rgb_fea.shape[0] == ir_fea.shape[0]
        bs, c, h, w = rgb_fea.shape

        # -------------------------------------------------------------------------
        # AvgPooling
        # -------------------------------------------------------------------------
        # AvgPooling for reduce the dimension due to expensive computation
        rgb_fea = self.avgpool(rgb_fea)
        ir_fea = self.avgpool(ir_fea)

        # -------------------------------------------------------------------------
        # Transformer
        # -------------------------------------------------------------------------
        # pad token embeddings along number of tokens dimension
        rgb_fea_flat = rgb_fea.view(bs, c, -1)  # flatten the feature
        ir_fea_flat = ir_fea.view(bs, c, -1)  # flatten the feature
        token_embeddings = torch.cat([rgb_fea_flat, ir_fea_flat], dim=2)  # concat
        token_embeddings = token_embeddings.permute(0, 2, 1).contiguous()  # dim:(B, 2*H*W, C)

        # transformer
        x = self.drop(self.pos_emb + token_embeddings)  # sum positional embedding and token    dim:(B, 2*H*W, C)
        x = self.trans_blocks(x)  # dim:(B, 2*H*W, C)

        # decoder head
        x = self.ln_f(x)  # dim:(B, 2*H*W, C)
        x = x.view(bs, 2, self.vert_anchors, self.horz_anchors, self.n_embd)
        x = x.permute(0, 1, 4, 2, 3)  # dim:(B, 2, C, H, W)

        # 这样截取的方式, 是否采用映射的方式更加合理
        rgb_fea_out = x[:, 0, :, :, :].contiguous().view(bs, self.n_embd, self.vert_anchors, self.horz_anchors)
        ir_fea_out = x[:, 1, :, :, :].contiguous().view(bs, self.n_embd, self.vert_anchors, self.horz_anchors)

        # -------------------------------------------------------------------------
        # Interpolate (or Upsample)
        # -------------------------------------------------------------------------
        rgb_fea_out = F.interpolate(rgb_fea_out, size=([h, w]), mode='bilinear')
        ir_fea_out = F.interpolate(ir_fea_out, size=([h, w]), mode='bilinear')

        return rgb_fea_out, ir_fea_out

class Add(nn.Module):
        def __init__(self, arg):
            super().__init__()
            self.arg = arg

        def forward(self, x):
            assert len(x) == 2, "输入必须包含两个待相加的张量"
            tensor_a, tensor_b = x[0], x[1]

            if tensor_a.shape[2:] != tensor_b.shape[2:]:
                target_size = tensor_a.shape[2:] if tensor_a.shape[2] >= tensor_b.shape[2] else tensor_b.shape[2:]
                tensor_a = F.interpolate(tensor_a, size=target_size, mode='bilinear', align_corners=False)
                tensor_b = F.interpolate(tensor_b, size=target_size, mode='bilinear', align_corners=False)

            return torch.add(tensor_a, tensor_b)

class Add2(nn.Module):
        def __init__(self, c1, index):
            super().__init__()
            self.index = index

        def forward(self, x):
            assert len(x) == 2, "输入必须包含两个张量"
            src, trans = x[0], x[1]

            trans_part = trans[0] if self.index == 0 else trans[1]

            if src.shape[2:] != trans_part.shape[2:]:
                trans_part = F.interpolate(
                    trans_part,
                    size=src.shape[2:],
                    mode='bilinear',
                    align_corners=False
                )

            return torch.add(src, trans_part)

四、融合步骤

5.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 CFT.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

5.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .CFT import *

在这里插入图片描述

5.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

其次:在 parse_model函数 中注册 Add Add2 CFT 模块

在这里插入图片描述

        elif m is Add:
            c2 = ch[f[0]]
            args = [c2]
        elif m is Add2:
            c2 = ch[f[0]]
            args = [c2, args[1]]
        elif m is CFT:
            c2 = ch[f[0]]
            args = [c2]

在这里插入图片描述


五、yaml模型文件

5.1 模型改进版本1 ⭐

此处以 ultralytics/cfg/models/v12/yolov12.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 yolov12-CFT-p234.yaml

将下方内容复制到 yolov12-CFT-p234.yaml 文件下。

📌 此模型的修方法是将骨干网络中的,不同模态之间的P2, P3, P4进行跨模态融合。

需要注意的是在骨干中,我将每个conv后面的C3k2或者A2C2f都设置了相同的通道数,也可调回原本的。

# YOLOv12 🚀, AGPL-3.0 license
# YOLOv12 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# CFG file for YOLOv12-turbo

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov12n.yaml' will call yolov12.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 497 layers, 2,553,904 parameters, 2,553,888 gradients, 6.2 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 497 layers, 9,127,424 parameters, 9,127,408 gradients, 19.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 533 layers, 19,670,784 parameters, 19,670,768 gradients, 60.4 GFLOPs
  l: [1.00, 1.00, 512] # summary: 895 layers, 26,506,496 parameters, 26,506,480 gradients, 83.3 GFLOPs
  x: [1.00, 1.50, 512] # summary: 895 layers, 59,414,176 parameters, 59,414,160 gradients, 185.9 GFLOPs

# YOLO12 backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

    # Visible
  - [1, 1, Conv,  [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv,  [128, 3, 2, 1, 2]] # 4-P2/4
  - [-1, 2, C3k2,  [128, False, 0.25]] # 5
    # infrared
  - [2, 1, Conv,  [64, 3, 2]] # 6-P1/2
  - [-1, 1, Conv,  [128, 3, 2, 1, 2]] # 7-P2/4
  - [-1, 2, C3k2,  [128, False, 0.25]] # 8

  # transformer fusion
  - [ [ 5,8 ], 1, CFT, [ 128 ] ] # 9-P2/4
  - [ [ 5,9 ], 1, Add2, [ 128,0 ] ]  # 10-P2/4 stream one:x+trans[0]
  - [ [ 8,9 ], 1, Add2, [ 128,1 ] ]  # 11-P2/4 stream two:x+trans[1]

    # Visible
  - [10, 1, Conv,  [256, 3, 2, 1, 4]] # 12-P3/8
  - [-1, 2, C3k2,  [256, False, 0.25]] # 13
    # infrared
  - [11, 1, Conv,  [256, 3, 2, 1, 4]] # 14-P3/8
  - [-1, 2, C3k2,  [256, False, 0.25]] # 15

  # transformer fusion
  - [ [ 13,15 ], 1, CFT, [ 256 ] ]   # 16-P3/8
  - [ [ 13,16 ], 1, Add2, [ 256,0 ] ]    # 17-P3/8 stream one x+trans[0]
  - [ [ 15,16 ], 1, Add2, [ 256,1 ] ]    # 18-P3/8 stream two x+trans[1]

    # Visible
  - [17, 1, Conv,  [512, 3, 2]] # 19-P4/16
  - [-1, 4, A2C2f, [512, True, 4]] # 20
    # infrared
  - [18, 1, Conv,  [512, 3, 2]] # 21-P4/16
  - [-1, 4, A2C2f, [512, True, 4]] # 22

  # transformer fusion
  - [ [ 20,22 ], 1, CFT, [ 512 ] ]   # 23-P3/8
  - [ [ 20,23 ], 1, Add2, [ 512,0 ] ]    # 24-P3/8 stream one x+trans[0]
  - [ [ 22,23 ], 1, Add2, [ 512,1 ] ]    # 25-P3/8 stream two x+trans[1]

    # Visible
  - [24, 1, Conv,  [1024, 3, 2]] # 26-P5/32
  - [-1, 4, A2C2f, [1024, True, 1]] # 27
    # infrared
  - [25, 1, Conv,  [1024, 3, 2]] # 28-P5/32
  - [-1, 4, A2C2f, [1024, True, 1]] # 29

  - [ [ 17,18 ], 1, Add, [ 1 ] ]   # 30-P3/8 fusion backbone P3
  - [ [ 24,25 ], 1, Add, [ 1 ] ]   # 31-P4/16 fusion backbone P4
  - [ [ 27,29 ], 1, Add, [ 1 ] ]   # 32-P5/32 fusion backbone P5

# YOLO12 head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 31], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, A2C2f, [512, False, -1]] # 35

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 30], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, A2C2f, [256, False, -1]] # 38

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 35], 1, Concat, [1]] # cat head P4
  - [-1, 2, A2C2f, [512, False, -1]] # 41

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 32], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 44 (P5/32-large)

  - [[35, 38, 44], 1, Detect, [nc]] # Detect(P3, P4, P5)

5.2 模型改进版本2⭐

此处以 ultralytics/cfg/models/v12/yolov12.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 yolov12-CFT-p2345.yaml

将下方内容复制到 yolov12-CFT-p2345.yaml 文件下。

📌 此模型的修方法是将骨干网络中的,不同模态之间的P2, P3, P4,P5进行跨模态融合。

需要注意的是在骨干中,我将每个conv后面的C3k2或者A2C2f都设置了相同的通道数,也可调回原本的。

# YOLOv12 🚀, AGPL-3.0 license
# YOLOv12 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# CFG file for YOLOv12-turbo

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov12n.yaml' will call yolov12.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 497 layers, 2,553,904 parameters, 2,553,888 gradients, 6.2 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 497 layers, 9,127,424 parameters, 9,127,408 gradients, 19.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 533 layers, 19,670,784 parameters, 19,670,768 gradients, 60.4 GFLOPs
  l: [1.00, 1.00, 512] # summary: 895 layers, 26,506,496 parameters, 26,506,480 gradients, 83.3 GFLOPs
  x: [1.00, 1.50, 512] # summary: 895 layers, 59,414,176 parameters, 59,414,160 gradients, 185.9 GFLOPs

# YOLO12 backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

    # Visible
  - [1, 1, Conv,  [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv,  [128, 3, 2, 1, 2]] # 4-P2/4
  - [-1, 2, C3k2,  [128, False, 0.25]] # 5
    # infrared
  - [2, 1, Conv,  [64, 3, 2]] # 6-P1/2
  - [-1, 1, Conv,  [128, 3, 2, 1, 2]] # 7-P2/4
  - [-1, 2, C3k2,  [128, False, 0.25]] # 8

  # transformer fusion
  - [ [ 5,8 ], 1, CFT, [ 128 ] ] # 9-P2/4
  - [ [ 5,9 ], 1, Add2, [ 128,0 ] ]  # 10-P2/4 stream one:x+trans[0]
  - [ [ 8,9 ], 1, Add2, [ 128,1 ] ]  # 11-P2/4 stream two:x+trans[1]

    # Visible
  - [10, 1, Conv,  [256, 3, 2, 1, 4]] # 12-P3/8
  - [-1, 2, C3k2,  [256, False, 0.25]] # 13
    # infrared
  - [11, 1, Conv,  [256, 3, 2, 1, 4]] # 14-P3/8
  - [-1, 2, C3k2,  [256, False, 0.25]] # 15

  # transformer fusion
  - [ [ 13,15 ], 1, CFT, [ 256 ] ]   # 16-P3/8
  - [ [ 13,16 ], 1, Add2, [ 256,0 ] ]    # 17-P3/8 stream one x+trans[0]
  - [ [ 15,16 ], 1, Add2, [ 256,1 ] ]    # 18-P3/8 stream two x+trans[1]

    # Visible
  - [17, 1, Conv,  [512, 3, 2]] # 19-P4/16
  - [-1, 4, A2C2f, [512, True, 4]] # 20
    # infrared
  - [18, 1, Conv,  [512, 3, 2]] # 21-P4/16
  - [-1, 4, A2C2f, [512, True, 4]] # 22

  # transformer fusion
  - [ [ 20,22 ], 1, CFT, [ 512 ] ]   # 23-P3/8
  - [ [ 20,23 ], 1, Add2, [ 512,0 ] ]    # 24-P3/8 stream one x+trans[0]
  - [ [ 22,23 ], 1, Add2, [ 512,1 ] ]    # 25-P3/8 stream two x+trans[1]

    # Visible
  - [24, 1, Conv,  [1024, 3, 2]] # 26-P5/32
  - [-1, 4, A2C2f, [1024, True, 1]] # 27
    # infrared
  - [25, 1, Conv,  [1024, 3, 2]] # 28-P5/32
  - [-1, 4, A2C2f, [1024, True, 1]] # 29

  # transformer fusion
  - [ [ 27,29 ], 1, CFT, [ 1024 ] ]    # 30-P5/32
  - [ [ 27,30 ], 1, Add2, [ 1024,0 ] ]    # 31-P5/32 stream one x+trans[0]
  - [ [ 29,30 ], 1, Add2, [ 1024,1 ] ]    # 32-P5/32 stream two x+trans[1]

  - [ [ 17,18 ], 1, Add, [ 1 ] ]   # 33-P3/8 fusion backbone P3
  - [ [ 24,25 ], 1, Add, [ 1 ] ]   # 34-P4/16 fusion backbone P4
  - [ [ 31,32 ], 1, Add, [ 1 ] ]   # 35-P5/32 fusion backbone P5

# YOLO12 head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 34], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, A2C2f, [512, False, -1]] # 38

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 33], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, A2C2f, [256, False, -1]] # 41

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 38], 1, Concat, [1]] # cat head P4
  - [-1, 2, A2C2f, [512, False, -1]] # 44

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 35], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 47 (P5/32-large)

  - [[38, 44, 47], 1, Detect, [nc]] # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。

YOLOv12-CFT-p234

YOLOv12-CFT-p234 summary: 1,094 layers, 5,877,123 parameters, 5,877,107 gradients, 220.6 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      2368  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2, 1, 2]
  5                  -1  1      1976  ultralytics.nn.modules.block.C3k2            [32, 32, 1, False, 0.25]
  6                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  7                  -1  1      2368  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2, 1, 2]
  8                  -1  1      1976  ultralytics.nn.modules.block.C3k2            [32, 32, 1, False, 0.25]
  9              [5, 8]  1    105792  ultralytics.nn.AddModules.CFT.CFT            [32]
 10              [5, 9]  1         0  ultralytics.nn.AddModules.CFT.Add2           [32, 0]
 11              [8, 9]  1         0  ultralytics.nn.AddModules.CFT.Add2           [32, 1]
 12                  10  1      4736  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2, 1, 4]
 13                  -1  1      7664  ultralytics.nn.modules.block.C3k2            [64, 64, 1, False, 0.25]
 14                  11  1      4736  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2, 1, 4]
 15                  -1  1      7664  ultralytics.nn.modules.block.C3k2            [64, 64, 1, False, 0.25]
 16            [13, 15]  1    408192  ultralytics.nn.AddModules.CFT.CFT            [64]
 17            [13, 16]  1         0  ultralytics.nn.AddModules.CFT.Add2           [64, 0]
 18            [15, 16]  1         0  ultralytics.nn.AddModules.CFT.Add2           [64, 1]
 19                  17  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
 20                  -1  2    180864  ultralytics.nn.AddModules.A2C2f.A2C2f        [128, 128, 2, True, 4]
 21                  18  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
 22                  -1  2    180864  ultralytics.nn.AddModules.A2C2f.A2C2f        [128, 128, 2, True, 4]
 23            [20, 22]  1   1602816  ultralytics.nn.AddModules.CFT.CFT            [128]
 24            [20, 23]  1         0  ultralytics.nn.AddModules.CFT.Add2           [128, 0]
 25            [22, 23]  1         0  ultralytics.nn.AddModules.CFT.Add2           [128, 1]
 26                  24  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 27                  -1  2    689408  ultralytics.nn.AddModules.A2C2f.A2C2f        [256, 256, 2, True, 1]
 28                  25  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 29                  -1  2    689408  ultralytics.nn.AddModules.A2C2f.A2C2f        [256, 256, 2, True, 1]
 30            [17, 18]  1         0  ultralytics.nn.AddModules.CFT.Add            [64]
 31            [24, 25]  1         0  ultralytics.nn.AddModules.CFT.Add            [128]
 32            [27, 29]  1         0  ultralytics.nn.AddModules.CFT.Add            [256]
 33                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 34            [-1, 31]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 35                  -1  1     86912  ultralytics.nn.AddModules.A2C2f.A2C2f        [384, 128, 1, False, -1]
 36                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 37            [-1, 30]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 38                  -1  1     21952  ultralytics.nn.AddModules.A2C2f.A2C2f        [192, 64, 1, False, -1]
 39                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 40            [-1, 35]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 41                  -1  1     74624  ultralytics.nn.AddModules.A2C2f.A2C2f        [192, 128, 1, False, -1]
 42                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 43            [-1, 32]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 44                  -1  1    378880  ultralytics.nn.modules.block.C3k2            [384, 256, 1, True]
 45        [35, 38, 44]  1    499475  ultralytics.nn.modules.head.Detect           [1, [128, 64, 256]]
YOLOv12-CFT-p234 summary: 1,094 layers, 5,877,123 parameters, 5,877,107 gradients, 220.6 GFLOPs

YOLOv12-CFT-p2345

YOLOv12-CFT-p2345 summary: 1,221 layers, 12,274,243 parameters, 12,274,227 gradients, 865.8 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      2368  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2, 1, 2]
  5                  -1  1      1976  ultralytics.nn.modules.block.C3k2            [32, 32, 1, False, 0.25]
  6                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  7                  -1  1      2368  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2, 1, 2]
  8                  -1  1      1976  ultralytics.nn.modules.block.C3k2            [32, 32, 1, False, 0.25]
  9              [5, 8]  1    105792  ultralytics.nn.AddModules.CFT.CFT            [32]
 10              [5, 9]  1         0  ultralytics.nn.AddModules.CFT.Add2           [32, 0]
 11              [8, 9]  1         0  ultralytics.nn.AddModules.CFT.Add2           [32, 1]
 12                  10  1      4736  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2, 1, 4]
 13                  -1  1      7664  ultralytics.nn.modules.block.C3k2            [64, 64, 1, False, 0.25]
 14                  11  1      4736  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2, 1, 4]
 15                  -1  1      7664  ultralytics.nn.modules.block.C3k2            [64, 64, 1, False, 0.25]
 16            [13, 15]  1    408192  ultralytics.nn.AddModules.CFT.CFT            [64]
 17            [13, 16]  1         0  ultralytics.nn.AddModules.CFT.Add2           [64, 0]
 18            [15, 16]  1         0  ultralytics.nn.AddModules.CFT.Add2           [64, 1]
 19                  17  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
 20                  -1  2    180864  ultralytics.nn.AddModules.A2C2f.A2C2f        [128, 128, 2, True, 4]
 21                  18  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
 22                  -1  2    180864  ultralytics.nn.AddModules.A2C2f.A2C2f        [128, 128, 2, True, 4]
 23            [20, 22]  1   1602816  ultralytics.nn.AddModules.CFT.CFT            [128]
 24            [20, 23]  1         0  ultralytics.nn.AddModules.CFT.Add2           [128, 0]
 25            [22, 23]  1         0  ultralytics.nn.AddModules.CFT.Add2           [128, 1]
 26                  24  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 27                  -1  2    689408  ultralytics.nn.AddModules.A2C2f.A2C2f        [256, 256, 2, True, 1]
 28                  25  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
 29                  -1  2    689408  ultralytics.nn.AddModules.A2C2f.A2C2f        [256, 256, 2, True, 1]
 30            [27, 29]  1   6351360  ultralytics.nn.AddModules.CFT.CFT            [256]
 31            [27, 30]  1         0  ultralytics.nn.AddModules.CFT.Add2           [256, 0]
 32            [29, 30]  1         0  ultralytics.nn.AddModules.CFT.Add2           [256, 1]
 33            [17, 18]  1         0  ultralytics.nn.AddModules.CFT.Add            [64]
 34            [24, 25]  1         0  ultralytics.nn.AddModules.CFT.Add            [128]
 35            [31, 32]  1         0  ultralytics.nn.AddModules.CFT.Add            [256]
 36                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 37            [-1, 34]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 38                  -1  1     86912  ultralytics.nn.AddModules.A2C2f.A2C2f        [384, 128, 1, False, -1]
 39                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 40            [-1, 33]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 41                  -1  1     21952  ultralytics.nn.AddModules.A2C2f.A2C2f        [192, 64, 1, False, -1]
 42                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 43            [-1, 38]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 44                  -1  1     74624  ultralytics.nn.AddModules.A2C2f.A2C2f        [192, 128, 1, False, -1]
 45                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 46            [-1, 35]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 47                  -1  1    378880  ultralytics.nn.modules.block.C3k2            [384, 256, 1, True]
 48        [38, 44, 47]  1    545235  ultralytics.nn.modules.head.Detect           [1, [128, 128, 256]]
YOLOv12-CFT-p2345 summary: 1,221 layers, 12,274,243 parameters, 12,274,227 gradients, 865.8 GFLOPs