【RT-DETR多模态融合改进】| CFT:跨模态融合Transformer | 利用Transformer的自注意力机制,解决跨模态融合中的长距离依赖和全局信息整合问题
一、本文介绍
本文记录的是 利用 CFT 模块改进 RT-DETR 的多模态目标检测网络模型 。
CFT(Cross-Modality Fusion Transformer)
的设计出发点在于
解决传统多模态检测中跨模态特征融合不充分的问题
,即当
不同模态
数据需协同检测时,基于CNN的方法因局部卷积的局限性,
难以捕捉长距离依赖和全局模态间的互补信息
,导致复杂光照、遮挡等场景下检测精度不足。
本文利用
CFT
模块,将多模态特征序列拼接后自动学习模态内与模态间的交互权重,
在特征提取阶段整合全局上下文信息,增强对不同模态互补特征的利用能力
,从而提升模型在多模态场景下的检测鲁棒性与准确性。
二、CFT模块介绍
Cross-Modality Fusion Transformer for Multispectral Object Detection
2.1 设计出发点
传统基于CNN的多光谱目标检测方法主要通过卷积操作进行特征融合,但卷积的局部感受野限制了其捕捉长距离依赖和全局上下文信息的能力,难以充分挖掘RGB与热成像(Thermal)模态间的互补性。
Transformer的自注意力机制可视为全连接图,能学习全局依赖关系,自然适用于跨模态信号的交互融合。现有研究中,Transformer在图像分割、视频检索等领域已展现多模态处理优势,但尚未有工作将其应用于多光谱目标检测。
因此,CFT模块旨在利用Transformer的自注意力机制, 解决跨模态融合中的长距离依赖和全局信息整合问题 。
2.2 核心原理
CFT模块的核心是通过Transformer的自注意力机制,同时实现 模态内(Intra-modality) 和 模态间(Inter-modality) 的特征融合。具体原理如下:
- 特征预处理 :将RGB和热成像的特征图 F R F_R F R 和 F T F_T F T 展平为序列 I R I_R I R 和 I T I_T I T ,拼接为 I ∈ R 2 H W × C I \in \mathbb{R}^{2HW \times C} I ∈ R 2 H W × C ,并添加可学习的位置嵌入以保留空间信息。
- 自注意力机制 :通过查询(Q)、键(K)、值(V)的投影和缩放点积计算注意力权重,形成相关矩阵 α \alpha α 。该矩阵自然分解为四个块:RGB模态内、热成像模态内的相关性,以及两者之间的跨模态相关性,从而捕捉不同模态内部的局部特征和跨模态的全局交互。
- 多头注意力与残差连接 :采用多头注意力机制捕捉多子空间关系,输出经多层感知机(MLP)处理后,通过残差连接回原模态分支,避免梯度消失并增强特征表示。
2.3 模块结构
CFT模块的结构设计围绕高效融合和计算优化展开,具体包括以下部分:
- 双流 backbone :原文以YOLOv5为基础,构建RGB和热成像双分支特征提取网络,在中间层嵌入CFT模块,形成跨模态融合骨干网络(CFB)。
- 下采样与上采样 :为降低计算复杂度,在进入Transformer前通过全局平均池化将特征图下采样至固定低分辨率(8×8),处理后通过双线性插值上采样回原分辨率,平衡性能与效率。
- Transformer块 :包含8个重复的Transformer单元,每个单元由层归一化、多头注意力和MLP组成,逐层整合跨模态信息。
2.4 优势
- 全局上下文融合 :通过自注意力机制捕捉长距离依赖,整合RGB与热成像的全局互补信息,例如RGB的纹理细节与热成像的低光照轮廓,提升复杂场景下的检测鲁棒性。
- 端到端自动融合 :无需手动设计复杂的融合规则,Transformer自动学习模态内和模态间的交互权重,简化跨模态融合流程。
- 性能提升显著 :在FLIR、LLVIP、VEDAI等数据集上,CFT相比基线方法(如双流CNN)显著提升检测精度。例如,在VEDAI数据集上,mAP75提升18.2%,mAP提升9.2%,证明其对小目标和复杂场景的有效性。
- 泛化能力强 :可灵活集成到YOLOv3、Faster R-CNN等单阶段/两阶段检测器,均实现性能增益,表明其跨框架的通用性。
- 计算效率优化 :通过下采样策略控制计算量,在保持高精度的同时,避免传统Transformer的高内存占用问题,适合实际部署。
CFT模块通过Transformer的自注意力机制,将跨模态融合从局部卷积升级为全局上下文建模,有效解决了多光谱检测中模态互补性挖掘的难题。
论文: https://arxiv.org/abs/2111.00273
源码: https://github.com/DocF/multispectral-object-detection
三、CFT的实现代码
CFT
的实现代码如下:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
class SelfAttention(nn.Module):
"""
Multi-head masked self-attention layer
"""
def __init__(self, d_model, d_k, d_v, h, attn_pdrop=.1, resid_pdrop=.1):
'''
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
'''
super(SelfAttention, self).__init__()
assert d_k % h == 0
self.d_model = d_model
self.d_k = d_model // h
self.d_v = d_model // h
self.h = h
# key, query, value projections for all heads
self.que_proj = nn.Linear(d_model, h * self.d_k) # query projection
self.key_proj = nn.Linear(d_model, h * self.d_k) # key projection
self.val_proj = nn.Linear(d_model, h * self.d_v) # value projection
self.out_proj = nn.Linear(h * self.d_v, d_model) # output projection
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x, attention_mask=None, attention_weights=None):
'''
Computes Self-Attention
Args:
x (tensor): input (token) dim:(b_s, nx, c),
b_s means batch size
nx means length, for CNN, equals H*W, i.e. the length of feature maps
c means channel, i.e. the channel of feature maps
attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
Return:
output (tensor): dim:(b_s, nx, c)
'''
b_s, nq = x.shape[:2]
nk = x.shape[1]
q = self.que_proj(x).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
k = self.key_proj(x).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) K^T
v = self.val_proj(x).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
# Self-Attention
# :math:`(\text(Attention(Q,K,V) = Softmax((Q*K^T)/\sqrt(d_k))`
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
# weight and mask
if attention_weights is not None:
att = att * attention_weights
if attention_mask is not None:
att = att.masked_fill(attention_mask, -np.inf)
# get attention matrix
att = torch.softmax(att, -1)
att = self.attn_drop(att)
# output
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
out = self.resid_drop(self.out_proj(out)) # (b_s, nq, d_model)
return out
class myTransformerBlock(nn.Module):
""" Transformer block """
def __init__(self, d_model, d_k, d_v, h, block_exp, attn_pdrop, resid_pdrop):
"""
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
:param block_exp: Expansion factor for MLP (feed foreword network)
"""
super().__init__()
self.ln_input = nn.LayerNorm(d_model)
self.ln_output = nn.LayerNorm(d_model)
self.sa = SelfAttention(d_model, d_k, d_v, h, attn_pdrop, resid_pdrop)
self.mlp = nn.Sequential(
nn.Linear(d_model, block_exp * d_model),
# nn.SiLU(), # changed from GELU
nn.GELU(), # changed from GELU
nn.Linear(block_exp * d_model, d_model),
nn.Dropout(resid_pdrop),
)
def forward(self, x):
bs, nx, c = x.size()
x = x + self.sa(self.ln_input(x))
x = x + self.mlp(self.ln_output(x))
return x
class CFT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, d_model, h=8, block_exp=4,
n_layer=8, vert_anchors=8, horz_anchors=8,
embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1):
super().__init__()
self.n_embd = d_model
self.vert_anchors = vert_anchors
self.horz_anchors = horz_anchors
d_k = d_model
d_v = d_model
# positional embedding parameter (learnable), rgb_fea + ir_fea
self.pos_emb = nn.Parameter(torch.zeros(1, 2 * vert_anchors * horz_anchors, self.n_embd))
# transformer
self.trans_blocks = nn.Sequential(*[myTransformerBlock(d_model, d_k, d_v, h, block_exp, attn_pdrop, resid_pdrop)
for layer in range(n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(self.n_embd)
# regularization
self.drop = nn.Dropout(embd_pdrop)
# avgpool
self.avgpool = nn.AdaptiveAvgPool2d((self.vert_anchors, self.horz_anchors))
# init weights
self.apply(self._init_weights)
@staticmethod
def _init_weights(module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x):
"""
Args:
x (tuple?)
"""
rgb_fea = x[0] # rgb_fea (tensor): dim:(B, C, H, W)
ir_fea = x[1] # ir_fea (tensor): dim:(B, C, H, W)
assert rgb_fea.shape[0] == ir_fea.shape[0]
bs, c, h, w = rgb_fea.shape
# -------------------------------------------------------------------------
# AvgPooling
# -------------------------------------------------------------------------
# AvgPooling for reduce the dimension due to expensive computation
rgb_fea = self.avgpool(rgb_fea)
ir_fea = self.avgpool(ir_fea)
# -------------------------------------------------------------------------
# Transformer
# -------------------------------------------------------------------------
# pad token embeddings along number of tokens dimension
rgb_fea_flat = rgb_fea.view(bs, c, -1) # flatten the feature
ir_fea_flat = ir_fea.view(bs, c, -1) # flatten the feature
token_embeddings = torch.cat([rgb_fea_flat, ir_fea_flat], dim=2) # concat
token_embeddings = token_embeddings.permute(0, 2, 1).contiguous() # dim:(B, 2*H*W, C)
# transformer
x = self.drop(self.pos_emb + token_embeddings) # sum positional embedding and token dim:(B, 2*H*W, C)
x = self.trans_blocks(x) # dim:(B, 2*H*W, C)
# decoder head
x = self.ln_f(x) # dim:(B, 2*H*W, C)
x = x.view(bs, 2, self.vert_anchors, self.horz_anchors, self.n_embd)
x = x.permute(0, 1, 4, 2, 3) # dim:(B, 2, C, H, W)
# 这样截取的方式, 是否采用映射的方式更加合理
rgb_fea_out = x[:, 0, :, :, :].contiguous().view(bs, self.n_embd, self.vert_anchors, self.horz_anchors)
ir_fea_out = x[:, 1, :, :, :].contiguous().view(bs, self.n_embd, self.vert_anchors, self.horz_anchors)
# -------------------------------------------------------------------------
# Interpolate (or Upsample)
# -------------------------------------------------------------------------
rgb_fea_out = F.interpolate(rgb_fea_out, size=([h, w]), mode='bilinear')
ir_fea_out = F.interpolate(ir_fea_out, size=([h, w]), mode='bilinear')
return rgb_fea_out, ir_fea_out
class Add(nn.Module):
def __init__(self, arg):
super().__init__()
self.arg = arg
def forward(self, x):
assert len(x) == 2, "输入必须包含两个待相加的张量"
tensor_a, tensor_b = x[0], x[1]
if tensor_a.shape[2:] != tensor_b.shape[2:]:
target_size = tensor_a.shape[2:] if tensor_a.shape[2] >= tensor_b.shape[2] else tensor_b.shape[2:]
tensor_a = F.interpolate(tensor_a, size=target_size, mode='bilinear', align_corners=False)
tensor_b = F.interpolate(tensor_b, size=target_size, mode='bilinear', align_corners=False)
return torch.add(tensor_a, tensor_b)
class Add2(nn.Module):
def __init__(self, c1, index):
super().__init__()
self.index = index
def forward(self, x):
assert len(x) == 2, "输入必须包含两个张量"
src, trans = x[0], x[1]
trans_part = trans[0] if self.index == 0 else trans[1]
if src.shape[2:] != trans_part.shape[2:]:
trans_part = F.interpolate(
trans_part,
size=src.shape[2:],
mode='bilinear',
align_corners=False
)
return torch.add(src, trans_part)
四、融合步骤
5.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
CFT.py
,将
第三节
中的代码粘贴到此处
5.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .CFT import *
5.3 修改三
在
ultralytics/nn/modules/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
其次:在
parse_model函数
中注册
Add
、
Add2
和
CFT
模块
elif m is Add:
c2 = ch[f[0]]
args = [c2]
elif m is Add2:
c2 = ch[f[0]]
args = [c2, args[1]]
elif m is CFT:
c2 = ch[f[0]]
args = [c2]
五、yaml模型文件
5.1 模型改进版本1 ⭐
创建一个用于自己数据集训练的模型文件
rtdetr-resnet18-CFT-p234.yaml
。
将下方内容复制到
rtdetr-resnet18-CFT-p234.yaml
文件下。
📌 此模型的修方法是将骨干网络中的,不同模态之间的P2, P3, P4进行跨模态融合。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-ResNet50 object detection model with P3-P5 outputs.
# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
# Visible
- [1, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 3-P1
- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 4
- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 5
- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 6-P2
- [-1, 2, Blocks, [64, BasicBlock, 2, False]] # 7
# infrared
- [2, 1, ConvNormLayer, [32, 3, 2, 1, 'relu']] # 8-P1
- [-1, 1, ConvNormLayer, [32, 3, 1, 1, 'relu']] # 9
- [-1, 1, ConvNormLayer, [64, 3, 1, 1, 'relu']] # 10
- [-1, 1, nn.MaxPool2d, [3, 2, 1]] # 11-P2
- [-1, 2, Blocks, [64, BasicBlock, 2, False]] # 12
# transformer fusion
- [ [ 7,12 ], 1, CFT, [ 64 ] ] # 13-P2/4
- [ [ 7,13 ], 1, Add2, [ 64,0 ] ] # 14-P2/4 stream one:x+trans[0]
- [ [ 12,13 ], 1, Add2, [ 64,1 ] ] # 15-P2/4 stream two:x+trans[1]
# Visible
- [14, 2, Blocks, [128, BasicBlock, 3, False]] # 16-P3
# infrared
- [15, 2, Blocks, [128, BasicBlock, 3, False]] # 17-P3
# transformer fusion
- [ [ 16,17 ], 1, CFT, [ 128 ] ] # 18-P3/8
- [ [ 16,18 ], 1, Add2, [ 128,0 ] ] # 19-P3/8 stream one x+trans[0]
- [ [ 17,18 ], 1, Add2, [ 128,1 ] ] # 20-P3/8 stream two x+trans[1]
# Visible
- [19, 2, Blocks, [256, BasicBlock, 4, False]] # 21-P4
# infrared
- [20, 2, Blocks, [256, BasicBlock, 4, False]] # 22-P4
# transformer fusion
- [ [ 21,22 ], 1, CFT, [ 256 ] ] # 23-P3/8
- [ [ 21,23 ], 1, Add2, [ 256,0 ] ] # 24-P3/8 stream one x+trans[0]
- [ [ 22,23 ], 1, Add2, [ 256,1 ] ] # 25-P3/8 stream two x+trans[1]
- [24, 2, Blocks, [512, BasicBlock, 5, False]] # 26-P5
- [25, 2, Blocks, [512, BasicBlock, 5, False]] # 27-P5
- [[19, 20], 1, Concat, [1]] # 28 cat backbone P3
- [[24, 25], 1, Concat, [1]] # 29 cat backbone P4
- [[26, 27], 1, Concat, [1]] # 30 cat backbone P5
head:
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 31 input_proj.2
- [-1, 1, AIFI, [1024, 8]]
- [-1, 1, Conv, [256, 1, 1]] # 33, Y5, lateral_convs.0
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 34
- [29, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 35 input_proj.1
- [[-2, -1], 1, Concat, [1]]
- [-1, 3, RepC3, [256, 0.5]] # 37, fpn_blocks.0
- [-1, 1, Conv, [256, 1, 1]] # 38, Y4, lateral_convs.1
- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 39
- [28, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 40 input_proj.0
- [[-2, -1], 1, Concat, [1]] # 41 cat backbone P4
- [-1, 3, RepC3, [256, 0.5]] # X3 (42), fpn_blocks.1
- [-1, 1, Conv, [256, 3, 2]] # 43, downsample_convs.0
- [[-1, 38], 1, Concat, [1]] # 44 cat Y4
- [-1, 3, RepC3, [256, 0.5]] # F4 (45), pan_blocks.0
- [-1, 1, Conv, [256, 3, 2]] # 46, downsample_convs.1
- [[-1, 33], 1, Concat, [1]] # 47 cat Y5
- [-1, 3, RepC3, [256, 0.5]] # F5 (48), pan_blocks.1
- [[42, 45, 48], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]] # Detect(P3, P4, P5)
5.2 模型改进版本2⭐
创建一个用于自己数据集训练的模型文件
rtdetr-resnet18-CFT-p2345.yaml
。
将下方内容复制到
rtdetr-resnet18-CFT-p2345.yaml
文件下。
📌 此模型的修方法是将骨干网络中的,不同模态之间的P2, P3, P4,P5进行跨模态融合。
# YOLOv12 🚀, AGPL-3.0 license
# YOLOv12 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# CFG file for YOLOv12-turbo
# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov12n.yaml' will call yolov12.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 497 layers, 2,553,904 parameters, 2,553,888 gradients, 6.2 GFLOPs
s: [0.50, 0.50, 1024] # summary: 497 layers, 9,127,424 parameters, 9,127,408 gradients, 19.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 533 layers, 19,670,784 parameters, 19,670,768 gradients, 60.4 GFLOPs
l: [1.00, 1.00, 512] # summary: 895 layers, 26,506,496 parameters, 26,506,480 gradients, 83.3 GFLOPs
x: [1.00, 1.50, 512] # summary: 895 layers, 59,414,176 parameters, 59,414,160 gradients, 185.9 GFLOPs
# YOLO12 backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
# Visible
- [1, 1, Conv, [64, 3, 2]] # 3-P1/2
- [-1, 1, Conv, [128, 3, 2, 1, 2]] # 4-P2/4
- [-1, 2, C3k2, [128, False, 0.25]] # 5
# infrared
- [2, 1, Conv, [64, 3, 2]] # 6-P1/2
- [-1, 1, Conv, [128, 3, 2, 1, 2]] # 7-P2/4
- [-1, 2, C3k2, [128, False, 0.25]] # 8
# transformer fusion
- [ [ 5,8 ], 1, CFT, [ 128 ] ] # 9-P2/4
- [ [ 5,9 ], 1, Add2, [ 128,0 ] ] # 10-P2/4 stream one:x+trans[0]
- [ [ 8,9 ], 1, Add2, [ 128,1 ] ] # 11-P2/4 stream two:x+trans[1]
# Visible
- [10, 1, Conv, [256, 3, 2, 1, 4]] # 12-P3/8
- [-1, 2, C3k2, [256, False, 0.25]] # 13
# infrared
- [11, 1, Conv, [256, 3, 2, 1, 4]] # 14-P3/8
- [-1, 2, C3k2, [256, False, 0.25]] # 15
# transformer fusion
- [ [ 13,15 ], 1, CFT, [ 256 ] ] # 16-P3/8
- [ [ 13,16 ], 1, Add2, [ 256,0 ] ] # 17-P3/8 stream one x+trans[0]
- [ [ 15,16 ], 1, Add2, [ 256,1 ] ] # 18-P3/8 stream two x+trans[1]
# Visible
- [17, 1, Conv, [512, 3, 2]] # 19-P4/16
- [-1, 4, A2C2f, [512, True, 4]] # 20
# infrared
- [18, 1, Conv, [512, 3, 2]] # 21-P4/16
- [-1, 4, A2C2f, [512, True, 4]] # 22
# transformer fusion
- [ [ 20,22 ], 1, CFT, [ 512 ] ] # 23-P3/8
- [ [ 20,23 ], 1, Add2, [ 512,0 ] ] # 24-P3/8 stream one x+trans[0]
- [ [ 22,23 ], 1, Add2, [ 512,1 ] ] # 25-P3/8 stream two x+trans[1]
# Visible
- [24, 1, Conv, [1024, 3, 2]] # 26-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 27
# infrared
- [25, 1, Conv, [1024, 3, 2]] # 28-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 29
# transformer fusion
- [ [ 27,29 ], 1, CFT, [ 1024 ] ] # 30-P5/32
- [ [ 27,30 ], 1, Add2, [ 1024,0 ] ] # 31-P5/32 stream one x+trans[0]
- [ [ 29,30 ], 1, Add2, [ 1024,1 ] ] # 32-P5/32 stream two x+trans[1]
- [ [ 17,18 ], 1, Add, [ 1 ] ] # 33-P3/8 fusion backbone P3
- [ [ 24,25 ], 1, Add, [ 1 ] ] # 34-P4/16 fusion backbone P4
- [ [ 31,32 ], 1, Add, [ 1 ] ] # 35-P5/32 fusion backbone P5
# YOLO12 head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 34], 1, Concat, [1]] # cat backbone P4
- [-1, 2, A2C2f, [512, False, -1]] # 38
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 33], 1, Concat, [1]] # cat backbone P3
- [-1, 2, A2C2f, [256, False, -1]] # 41
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 38], 1, Concat, [1]] # cat head P4
- [-1, 2, A2C2f, [512, False, -1]] # 44
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 35], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 47 (P5/32-large)
- [[38, 44, 47], 1, Detect, [nc]] # Detect(P3, P4, P5)
六、成功运行结果
打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。
rtdetr-resnet18-CFT-p234 :
rtdetr-resnet18-CFT-p234 summary: 862 layers, 39,894,612 parameters, 39,894,612 gradients, 95.1 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 960 ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
4 -1 1 9312 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
5 -1 1 18624 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
6 -1 1 0 torch.nn.modules.pooling.MaxPool2d [3, 2, 1]
7 -1 2 152512 ultralytics.nn.AddModules.ResNet.Blocks [64, 64, 2, 'BasicBlock', 2, False]
8 2 1 960 ultralytics.nn.AddModules.ResNet.ConvNormLayer[3, 32, 3, 2, 1, 'relu']
9 -1 1 9312 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 32, 3, 1, 1, 'relu']
10 -1 1 18624 ultralytics.nn.AddModules.ResNet.ConvNormLayer[32, 64, 3, 1, 1, 'relu']
11 -1 1 0 torch.nn.modules.pooling.MaxPool2d [3, 2, 1]
12 -1 2 152512 ultralytics.nn.AddModules.ResNet.Blocks [64, 64, 2, 'BasicBlock', 2, False]
13 [7, 12] 1 408192 ultralytics.nn.AddModules.CFT.CFT [64]
14 [7, 13] 1 0 ultralytics.nn.AddModules.CFT.Add2 [64, 0]
15 [12, 13] 1 0 ultralytics.nn.AddModules.CFT.Add2 [64, 1]
16 14 2 526208 ultralytics.nn.AddModules.ResNet.Blocks [64, 128, 2, 'BasicBlock', 3, False]
17 15 2 526208 ultralytics.nn.AddModules.ResNet.Blocks [64, 128, 2, 'BasicBlock', 3, False]
18 [16, 17] 1 1602816 ultralytics.nn.AddModules.CFT.CFT [128]
19 [16, 18] 1 0 ultralytics.nn.AddModules.CFT.Add2 [128, 0]
20 [17, 18] 1 0 ultralytics.nn.AddModules.CFT.Add2 [128, 1]
21 19 2 2100992 ultralytics.nn.AddModules.ResNet.Blocks [128, 256, 2, 'BasicBlock', 4, False]
22 20 2 2100992 ultralytics.nn.AddModules.ResNet.Blocks [128, 256, 2, 'BasicBlock', 4, False]
23 [21, 22] 1 6351360 ultralytics.nn.AddModules.CFT.CFT [256]
24 [21, 23] 1 0 ultralytics.nn.AddModules.CFT.Add2 [256, 0]
25 [22, 23] 1 0 ultralytics.nn.AddModules.CFT.Add2 [256, 1]
26 24 2 8396288 ultralytics.nn.AddModules.ResNet.Blocks [256, 512, 2, 'BasicBlock', 5, False]
27 25 2 8396288 ultralytics.nn.AddModules.ResNet.Blocks [256, 512, 2, 'BasicBlock', 5, False]
28 [19, 20] 1 0 ultralytics.nn.modules.conv.Concat [1]
29 [24, 25] 1 0 ultralytics.nn.modules.conv.Concat [1]
30 [26, 27] 1 0 ultralytics.nn.modules.conv.Concat [1]
31 -1 1 262656 ultralytics.nn.modules.conv.Conv [1024, 256, 1, 1, None, 1, 1, False]
32 -1 1 789760 ultralytics.nn.modules.transformer.AIFI [256, 1024, 8]
33 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
34 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
35 29 1 131584 ultralytics.nn.modules.conv.Conv [512, 256, 1, 1, None, 1, 1, False]
36 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
37 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
38 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
39 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
40 28 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1, None, 1, 1, False]
41 [-2, -1] 1 0 ultralytics.nn.modules.conv.Concat [1]
42 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
43 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
44 [-1, 38] 1 0 ultralytics.nn.modules.conv.Concat [1]
45 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
46 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
47 [-1, 33] 1 0 ultralytics.nn.modules.conv.Concat [1]
48 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
49 [42, 45, 48] 1 3927956 ultralytics.nn.modules.head.RTDETRDecoder [9, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-resnet18-CFT-p234 summary: 862 layers, 39,894,612 parameters, 39,894,612 gradients, 95.1 GFLOPs
YOLOv12-CFT-p2345 :
YOLOv12-CFT-p2345 summary: 1,221 layers, 12,274,243 parameters, 12,274,227 gradients, 865.8 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
4 -1 1 2368 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2, 1, 2]
5 -1 1 1976 ultralytics.nn.modules.block.C3k2 [32, 32, 1, False, 0.25]
6 2 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
7 -1 1 2368 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2, 1, 2]
8 -1 1 1976 ultralytics.nn.modules.block.C3k2 [32, 32, 1, False, 0.25]
9 [5, 8] 1 105792 ultralytics.nn.AddModules.CFT.CFT [32]
10 [5, 9] 1 0 ultralytics.nn.AddModules.CFT.Add2 [32, 0]
11 [8, 9] 1 0 ultralytics.nn.AddModules.CFT.Add2 [32, 1]
12 10 1 4736 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2, 1, 4]
13 -1 1 7664 ultralytics.nn.modules.block.C3k2 [64, 64, 1, False, 0.25]
14 11 1 4736 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2, 1, 4]
15 -1 1 7664 ultralytics.nn.modules.block.C3k2 [64, 64, 1, False, 0.25]
16 [13, 15] 1 408192 ultralytics.nn.AddModules.CFT.CFT [64]
17 [13, 16] 1 0 ultralytics.nn.AddModules.CFT.Add2 [64, 0]
18 [15, 16] 1 0 ultralytics.nn.AddModules.CFT.Add2 [64, 1]
19 17 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
20 -1 2 180864 ultralytics.nn.AddModules.A2C2f.A2C2f [128, 128, 2, True, 4]
21 18 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
22 -1 2 180864 ultralytics.nn.AddModules.A2C2f.A2C2f [128, 128, 2, True, 4]
23 [20, 22] 1 1602816 ultralytics.nn.AddModules.CFT.CFT [128]
24 [20, 23] 1 0 ultralytics.nn.AddModules.CFT.Add2 [128, 0]
25 [22, 23] 1 0 ultralytics.nn.AddModules.CFT.Add2 [128, 1]
26 24 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
27 -1 2 689408 ultralytics.nn.AddModules.A2C2f.A2C2f [256, 256, 2, True, 1]
28 25 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
29 -1 2 689408 ultralytics.nn.AddModules.A2C2f.A2C2f [256, 256, 2, True, 1]
30 [27, 29] 1 6351360 ultralytics.nn.AddModules.CFT.CFT [256]
31 [27, 30] 1 0 ultralytics.nn.AddModules.CFT.Add2 [256, 0]
32 [29, 30] 1 0 ultralytics.nn.AddModules.CFT.Add2 [256, 1]
33 [17, 18] 1 0 ultralytics.nn.AddModules.CFT.Add [64]
34 [24, 25] 1 0 ultralytics.nn.AddModules.CFT.Add [128]
35 [31, 32] 1 0 ultralytics.nn.AddModules.CFT.Add [256]
36 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
37 [-1, 34] 1 0 ultralytics.nn.modules.conv.Concat [1]
38 -1 1 86912 ultralytics.nn.AddModules.A2C2f.A2C2f [384, 128, 1, False, -1]
39 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
40 [-1, 33] 1 0 ultralytics.nn.modules.conv.Concat [1]
41 -1 1 21952 ultralytics.nn.AddModules.A2C2f.A2C2f [192, 64, 1, False, -1]
42 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
43 [-1, 38] 1 0 ultralytics.nn.modules.conv.Concat [1]
44 -1 1 74624 ultralytics.nn.AddModules.A2C2f.A2C2f [192, 128, 1, False, -1]
45 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
46 [-1, 35] 1 0 ultralytics.nn.modules.conv.Concat [1]
47 -1 1 378880 ultralytics.nn.modules.block.C3k2 [384, 256, 1, True]
48 [38, 44, 47] 1 545235 ultralytics.nn.modules.head.Detect [1, [128, 128, 256]]
YOLOv12-CFT-p2345 summary: 1,221 layers, 12,274,243 parameters, 12,274,227 gradients, 865.8 GFLOPs