【YOLOv8多模态融合改进】| 利用 Deformable Attention Transformer 可变形注意力 二次改进CGA Fusion 动态关注不同模态间的目标区域
一、本文介绍
本文记录的是利用 DAT 模块改进 YOLOv8 的多模态融合部分 。主要讲解如何利用一些现有的模块二次改进多模态的融合部分。
DAT
全称为
Deformable Attention Transformer
,其作用在于通过
可变形注意力机制
,同时包含了数据依赖的注意力模式,
克服了常见注意力方法存在的内存计算成本高、受无关区域影响以及数据不可知等问题
。相比一些只提供固定注意力模式的方法,
能更好地聚焦于不同模态间的相关区域并捕捉更有信息的特征。
本文将其用于
CGA Fusion
模块中并进行二次创新,更好地突出不同模态的重要特征,提升模型性能。
二、Deformable Attention Transformer介绍
Vision Transformer with Deformable Attention
2.1 出发点
-
解决现有注意力机制的问题
-
现有的
Vision Transformers存在使用密集注意力导致内存和计算成本过高,特征可能受无关区域影响的问题。 -
Swin Transformer采用的稀疏注意力是数据不可知的,可能限制对长距离关系建模的能力。
-
现有的
-
借鉴可变形卷积网络(DCN)的思想
-
DCN在CNN中通过学习可变形感受野,能在数据依赖的基础上 选择性地关注更有信息的区域 ,取得了很好的效果,启发了在Vision Transformers中探索可变形注意力模式。
-
2.2 原理
-
数据依赖的注意力模式
-
通过一个
偏移网络(offset network)根据输入的查询特征(query features)学习到参考点(reference points)的偏移量(offsets),从而确定在特征图中需要关注的重要区域。 - 这种方式使得注意力模块能够以数据依赖的方式聚焦于相关区域, 避免了对无关区域的关注,同时也克服了手工设计的稀疏注意力模式可能丢失相关信息的问题。
-
通过一个
2.3 结构
2.3.1 参考点生成
- 首先在特征图上生成均匀网格的参考点 p ∈ R H G × W G × 2 p \in \mathbb{R}^{H_{G} ×W_{G} ×2} p ∈ R H G × W G × 2 ,网格大小是从输入特征图大小按因子 r r r 下采样得到的,即 H G = H / r H_{G}=H / r H G = H / r , W G = W / r W_{G}=W / r W G = W / r 。参考点的值是线性间隔的2D坐标,并归一化到 [ − 1 , + 1 ] [-1, +1] [ − 1 , + 1 ] 范围。
2.3.2 偏移量计算
- 将特征图线性投影得到查询令牌 q = x W q q=x W_{q} q = x W q ,然后送入一个轻量级的子网 θ o f f s e t ( ⋅ ) \theta_{offset }(\cdot) θ o ff se t ( ⋅ ) 生成偏移量 Δ p = θ o f f s e t ( q ) \Delta p=\theta_{offset }(q) Δ p = θ o ff se t ( q ) 。为了稳定训练过程,会对 Δ p \Delta p Δ p 的幅度进行缩放。
2.3.3 特征采样与投影
- 根据偏移量在变形点的位置对特征进行采样作为键(keys)和值(values),即 k ~ = x ~ W k \tilde{k}=\tilde{x} W_{k} k ~ = x ~ W k , v ~ = x ~ W v \tilde{v}=\tilde{x} W_{v} v ~ = x ~ W v ,其中 x ~ = ϕ ( x ; p + Δ p ) \tilde{x}=\phi(x ; p+\Delta p) x ~ = ϕ ( x ; p + Δ p ) ,采样函数 ϕ ( ⋅ ; ⋅ ) \phi(\cdot ; \cdot) ϕ ( ⋅ ; ⋅ ) 采用双线性插值。
2.3.4 注意力计算
-
对查询
q
q
q
和变形后的键
k
~
\tilde{k}
k
~
进行多头注意力计算,注意力头的输出公式为
z
(
m
)
=
σ
(
q
(
m
)
k
~
(
m
)
⊤
/
d
+
ϕ
(
B
^
;
R
)
)
v
~
(
m
)
z^{(m)}=\sigma\left(q^{(m)} \tilde{k}^{(m) \top} / \sqrt{d}+\phi(\hat{B} ; R)\right) \tilde{v}^{(m)}
z
(
m
)
=
σ
(
q
(
m
)
k
~
(
m
)
⊤
/
d
+
ϕ
(
B
^
;
R
)
)
v
~
(
m
)
,其中还考虑了相对位置偏移
R
R
R
和变形点提供的更强大的相对位置偏差
ϕ
(
B
^
;
R
)
\phi(\hat{B} ; R)
ϕ
(
B
^
;
R
)
。
2.4 优势
-
灵活性和效率
- 能够根据输入数据动态地确定关注区域,聚焦于相关信息,避免了对无关区域的计算和关注,提高了模型的效率。
- 通过学习共享的偏移量,在保持线性空间复杂度的同时,实现了可变形的注意力模式,相比于直接应用DCN机制到注意力模块,大大降低了计算复杂度。
-
性能优势
-
在多个基准数据集上的实验表明,基于
可变形注意力模块构建的Deformable Attention Transformer模型在图像分类、目标检测和语义分割等任务上取得了优于竞争基准模型的结果,如在ImageNet分类任务上,相比Swin Transformer在Top - 1准确率上有显著提升。
-
在多个基准数据集上的实验表明,基于
论文: https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Vision_Transformer_With_Deformable_Attention_CVPR_2022_paper.pdf
源码: https://github.com/LeapLabTHU/DAT
三、DFAFusion的实现代码
DFAFusion
的实现代码如下:
import einops
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import trunc_normal_
import torch.nn.functional as F
class LayerNormProxy(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = einops.rearrange(x, 'b c h w -> b h w c')
x = self.norm(x)
return einops.rearrange(x, 'b h w c -> b c h w')
class DAttentionBaseline(nn.Module):
def __init__(
self, q_size=(224,224), kv_size=(224,224), n_heads=8, n_head_channels=32, n_groups=1,
attn_drop=0.0, proj_drop=0.0, stride=1,
offset_range_factor=-1, use_pe=True, dwc_pe=True,
no_off=False, fixed_pe=False, ksize=9, log_cpb=False
):
super().__init__()
n_head_channels = int(q_size / 8)
q_size = (q_size, q_size)
self.dwc_pe = dwc_pe
self.n_head_channels = n_head_channels
self.scale = self.n_head_channels ** -0.5
self.n_heads = n_heads
self.q_h, self.q_w = q_size
# self.kv_h, self.kv_w = kv_size
self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
self.nc = n_head_channels * n_heads
self.n_groups = n_groups
self.n_group_channels = self.nc // self.n_groups
self.n_group_heads = self.n_heads // self.n_groups
self.use_pe = use_pe
self.fixed_pe = fixed_pe
self.no_off = no_off
self.offset_range_factor = offset_range_factor
self.ksize = ksize
self.log_cpb = log_cpb
self.stride = stride
kk = self.ksize
pad_size = kk // 2 if kk != stride else 0
self.conv_offset = nn.Sequential(
nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
LayerNormProxy(self.n_group_channels),
nn.GELU(),
nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
)
if self.no_off:
for m in self.conv_offset.parameters():
m.requires_grad_(False)
self.proj_q = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_k = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0)
self.proj_v = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_out = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_drop = nn.Dropout(proj_drop, inplace=True)
self.attn_drop = nn.Dropout(attn_drop, inplace=True)
if self.use_pe and not self.no_off:
if self.dwc_pe:
self.rpe_table = nn.Conv2d(
self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
elif self.fixed_pe:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
)
trunc_normal_(self.rpe_table, std=0.01)
elif self.log_cpb:
# Borrowed from Swin-V2
self.rpe_table = nn.Sequential(
nn.Linear(2, 32, bias=True),
nn.ReLU(inplace=True),
nn.Linear(32, self.n_group_heads, bias=False)
)
else:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
)
trunc_normal_(self.rpe_table, std=0.01)
else:
self.rpe_table = None
@torch.no_grad()
def _get_ref_points(self, H_key, W_key, B, dtype, device):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
indexing='ij'
)
ref = torch.stack((ref_y, ref_x), -1)
ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2
return ref
@torch.no_grad()
def _get_q_grid(self, H, W, B, dtype, device):
ref_y, ref_x = torch.meshgrid(
torch.arange(0, H, dtype=dtype, device=device),
torch.arange(0, W, dtype=dtype, device=device),
indexing='ij'
)
ref = torch.stack((ref_y, ref_x), -1)
ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2
return ref
def forward(self, x):
x = x
B, C, H, W = x.size()
dtype, device = x.dtype, x.device
q = self.proj_q(x)
q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
offset = self.conv_offset(q_off).contiguous() # B * g 2 Hg Wg
Hk, Wk = offset.size(2), offset.size(3)
n_sample = Hk * Wk
if self.offset_range_factor >= 0 and not self.no_off:
offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)
offset = einops.rearrange(offset, 'b p h w -> b h w p')
reference = self._get_ref_points(Hk, Wk, B, dtype, device)
if self.no_off:
offset = offset.fill_(0.0)
if self.offset_range_factor >= 0:
pos = offset + reference
else:
pos = (offset + reference).clamp(-1., +1.)
if self.no_off:
x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
else:
x_sampled = F.grid_sample(
input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
grid=pos[..., (1, 0)], # y, x -> x, y
mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
x_sampled = x_sampled.reshape(B, C, 1, n_sample)
q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
attn = attn.mul(self.scale)
if self.use_pe and (not self.no_off):
if self.dwc_pe:
residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels,
H * W)
elif self.fixed_pe:
rpe_table = self.rpe_table
attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
elif self.log_cpb:
q_grid = self._get_q_grid(H, W, B, dtype, device)
displacement = (
q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
n_sample,
2).unsqueeze(1)).mul(
4.0) # d_y, d_x [-8, +8]
displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
attn_bias = self.rpe_table(displacement) # B * g, H * W, n_sample, h_g
attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
else:
rpe_table = self.rpe_table
rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
q_grid = self._get_q_grid(H, W, B, dtype, device)
displacement = (
q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
n_sample,
2).unsqueeze(1)).mul(
0.5)
attn_bias = F.grid_sample(
input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads,
g=self.n_groups),
grid=displacement[..., (1, 0)],
mode='bilinear', align_corners=True) # B * g, h_g, HW, Ns
attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
attn = attn + attn_bias
attn = F.softmax(attn, dim=2)
attn = self.attn_drop(attn)
out = torch.einsum('b m n, b c n -> b c m', attn, v)
if self.use_pe and self.dwc_pe:
out = out + residual_lepe
out = out.reshape(B, C, H, W)
y = self.proj_drop(self.proj_out(out))
h, w = pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)
return y
class PixelAttention_CGA(nn.Module):
def __init__(self, dim):
super(PixelAttention_CGA, self).__init__()
self.pa2 = nn.Conv2d(2 * dim, dim, 7, padding=3, padding_mode='reflect' ,groups=dim, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x, pattn1):
B, C, H, W = x.shape
x = x.unsqueeze(dim=2) # B, C, 1, H, W
pattn1 = pattn1.unsqueeze(dim=2) # B, C, 1, H, W
x2 = torch.cat([x, pattn1], dim=2) # B, C, 2, H, W
x2 = rearrange(x2, 'b c t h w -> b (c t) h w')
pattn2 = self.pa2(x2)
pattn2 = self.sigmoid(pattn2)
return pattn2
class DFAFusion(nn.Module):
def __init__(self, dim):
super(DFAFusion, self).__init__()
self.cfam = DAttentionBaseline(dim)
self.pa = PixelAttention_CGA(dim)
self.conv = nn.Conv2d(dim, dim, 1, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, data):
x, y = data
initial = x + y
pattn1 = self.cfam(initial)
pattn2 = self.sigmoid(self.pa(initial, pattn1))
result = initial + pattn2 * x + (1 - pattn2) * y
result = self.conv(result)
return result
四、融合步骤
5.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
DFAFusion.py
,将
第三节
中的代码粘贴到此处
5.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .DFAFusion import *
5.3 修改三
在
ultralytics/nn/modules/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
其次:在
parse_model函数
中注册
DFAFusion
模块
elif m in {DFAFusion}:
c2 = ch[f[0]]
args = [c2]
最后将
ultralytics/utils/torch_utils.py
中的
get_flops
函数中的
stride
指定为
640
。
五、yaml模型文件
5.1 中期融合⭐
📌 此模型的修方法是将原本的中期融合中的Concat融合部分换成DFAFusion,融合骨干部分的多模态信息。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
ch: 6
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
- [1, 1, Conv, [64, 3, 2]] # 3-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 8-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32
- [-1, 3, C2f, [1024, True]]
- [2, 1, Conv, [64, 3, 2]] # 12-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 13-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 15-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 17-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 19-P5/32
- [-1, 3, C2f, [1024, True]]
- [[7, 16], 1, DFAFusion, []] # 21 cat backbone P3
- [[9, 18], 1, DFAFusion, []] # 22 cat backbone P4
- [[11, 20], 1, DFAFusion, []] # 23 cat backbone P5
- [-1, 1, SPPF, [1024, 5]] # 24
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 22], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 27
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 21], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 30 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 27], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 33 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 24], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 36 (P5/32-large)
- [[30, 33, 36], 1, Detect, [nc]] # Detect(P3, P4, P5)
5.2 中-后期融合⭐
📌 此模型的修方法是将原本的中-后期融合中的Concat融合部分换成DFAFusion,融合FPN部分的多模态信息。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
ch: 6
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
- [1, 1, Conv, [64, 3, 2]] # 3-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 8-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 12
- [2, 1, Conv, [64, 3, 2]] # 13-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 14-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 16-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 18-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 20-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 22
# YOLOv8.0n head
head:
- [12, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 9], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 25
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 7], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 28 (P3/8-small)
- [22, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 19], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 31
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 17], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 34 (P3/8-small)
- [ [ 12, 22 ], 1, DFAFusion, [] ] # cat head P3 35
- [ [ 25, 31 ], 1, DFAFusion, [] ] # cat head P4 36
- [ [ 28, 34 ], 1, DFAFusion, [] ] # cat head P5 37
- [37, 1, Conv, [256, 3, 2]]
- [[-1, 36], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 40 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 35], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 43 (P5/32-large)
- [[37, 40, 43], 1, Detect, [nc]] # Detect(P3, P4, P5)
5.3 后期融合⭐
📌 此模型的修方法是将原本的后期融合中的Concat融合部分换成DFAFusion,融合颈部部分的多模态信息。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
ch: 6
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, IN, []] # 0
- [-1, 1, Multiin, [1]] # 1
- [-2, 1, Multiin, [2]] # 2
- [1, 1, Conv, [64, 3, 2]] # 3-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 8-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 12
- [2, 1, Conv, [64, 3, 2]] # 13-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 14-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 16-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 18-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 20-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 22
# YOLOv8.0n head
head:
- [12, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 9], 1, Concat, [1] ] # cat backbone P4
- [-1, 3, C2f, [512]] # 25
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[ -1, 7], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 28 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 25], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 31 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 34 (P5/32-large)
- [22, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 19], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 37
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[ -1, 17 ], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 40 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 37], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 43 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 22], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 46 (P5/32-large)
- [[28, 40], 1, DFAFusion, []] # cat head P5 47
- [[31, 43], 1, DFAFusion, []] # cat head P5 48
- [[34, 46], 1, DFAFusion, []] # cat head P5 49
- [[47, 48, 49], 1, Detect, [nc]] # Detect(P3, P4, P5)
六、成功运行结果
打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。
YOLOv8-mid-DFAFusion :
YOLOv8-mid-DFAFusion summary: 401 layers, 4,318,131 parameters, 4,318,115 gradients, 11.1 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
4 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
5 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True]
6 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2]
7 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True]
8 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
9 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True]
10 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
11 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True]
12 2 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
13 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
14 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True]
15 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2]
16 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True]
17 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
18 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True]
19 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
20 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True]
21 [7, 16] 1 33280 ultralytics.nn.AddModules.DFAFusion.DFAFusion[64]
22 [9, 18] 1 107520 ultralytics.nn.AddModules.DFAFusion.DFAFusion[128]
23 [11, 20] 1 378880 ultralytics.nn.AddModules.DFAFusion.DFAFusion[256]
24 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
25 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
26 [-1, 22] 1 0 ultralytics.nn.modules.conv.Concat [1]
27 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1]
28 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
29 [-1, 21] 1 0 ultralytics.nn.modules.conv.Concat [1]
30 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1]
31 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
32 [-1, 27] 1 0 ultralytics.nn.modules.conv.Concat [1]
33 -1 1 123648 ultralytics.nn.modules.block.C2f [192, 128, 1]
34 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
35 [-1, 24] 1 0 ultralytics.nn.modules.conv.Concat [1]
36 -1 1 493056 ultralytics.nn.modules.block.C2f [384, 256, 1]
37 [30, 33, 36] 1 430867 ultralytics.nn.modules.head.Detect [1, [64, 128, 256]]
YOLOv8-mid-DFAFusion summary: 401 layers, 4,318,131 parameters, 4,318,115 gradients, 11.1 GFLOPs
YOLOv8-mid-to-late-DFAFusion :
YOLOv8-mid-to-late-DFAFusion summary: 443 layers, 4,668,211 parameters, 4,668,195 gradients, 12.2 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
4 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
5 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True]
6 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2]
7 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True]
8 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
9 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True]
10 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
11 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True]
12 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
13 2 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
14 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
15 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True]
16 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2]
17 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True]
18 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
19 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True]
20 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
21 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True]
22 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
23 12 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
24 [-1, 9] 1 0 ultralytics.nn.modules.conv.Concat [1]
25 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1]
26 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
27 [-1, 7] 1 0 ultralytics.nn.modules.conv.Concat [1]
28 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1]
29 22 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
30 [-1, 19] 1 0 ultralytics.nn.modules.conv.Concat [1]
31 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1]
32 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
33 [-1, 17] 1 0 ultralytics.nn.modules.conv.Concat [1]
34 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1]
35 [12, 22] 1 378880 ultralytics.nn.AddModules.DFAFusion.DFAFusion[256]
36 [25, 31] 1 107520 ultralytics.nn.AddModules.DFAFusion.DFAFusion[128]
37 [28, 34] 1 33280 ultralytics.nn.AddModules.DFAFusion.DFAFusion[64]
38 37 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
39 [-1, 36] 1 0 ultralytics.nn.modules.conv.Concat [1]
40 -1 1 123648 ultralytics.nn.modules.block.C2f [192, 128, 1]
41 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
42 [-1, 35] 1 0 ultralytics.nn.modules.conv.Concat [1]
43 -1 1 493056 ultralytics.nn.modules.block.C2f [384, 256, 1]
44 [37, 40, 43] 1 430867 ultralytics.nn.modules.head.Detect [1, [64, 128, 256]]
YOLOv8-mid-to-late-DFAFusion summary: 443 layers, 4,668,211 parameters, 4,668,195 gradients, 12.2 GFLOPs
YOLOv8-late-DFAFusion :
YOLOv8-late-DFAFusion summary: 481 layers, 5,469,619 parameters, 5,469,603 gradients, 13.2 GFLOPs
from n params module arguments
0 -1 1 0 ultralytics.nn.AddModules.multimodal.IN []
1 -1 1 0 ultralytics.nn.AddModules.multimodal.Multiin [1]
2 -2 1 0 ultralytics.nn.AddModules.multimodal.Multiin [2]
3 1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
4 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
5 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True]
6 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2]
7 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True]
8 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
9 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True]
10 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
11 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True]
12 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
13 2 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
14 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
15 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True]
16 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2]
17 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True]
18 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
19 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True]
20 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
21 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True]
22 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
23 12 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
24 [-1, 9] 1 0 ultralytics.nn.modules.conv.Concat [1]
25 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1]
26 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
27 [-1, 7] 1 0 ultralytics.nn.modules.conv.Concat [1]
28 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1]
29 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
30 [-1, 25] 1 0 ultralytics.nn.modules.conv.Concat [1]
31 -1 1 123648 ultralytics.nn.modules.block.C2f [192, 128, 1]
32 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
33 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1]
34 -1 1 493056 ultralytics.nn.modules.block.C2f [384, 256, 1]
35 22 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
36 [-1, 19] 1 0 ultralytics.nn.modules.conv.Concat [1]
37 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1]
38 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
39 [-1, 17] 1 0 ultralytics.nn.modules.conv.Concat [1]
40 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1]
41 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
42 [-1, 37] 1 0 ultralytics.nn.modules.conv.Concat [1]
43 -1 1 123648 ultralytics.nn.modules.block.C2f [192, 128, 1]
44 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
45 [-1, 22] 1 0 ultralytics.nn.modules.conv.Concat [1]
46 -1 1 493056 ultralytics.nn.modules.block.C2f [384, 256, 1]
47 [28, 40] 1 33280 ultralytics.nn.AddModules.DFAFusion.DFAFusion[64]
48 [31, 43] 1 107520 ultralytics.nn.AddModules.DFAFusion.DFAFusion[128]
49 [34, 46] 1 378880 ultralytics.nn.AddModules.DFAFusion.DFAFusion[256]
50 [47, 48, 49] 1 430867 ultralytics.nn.modules.head.Detect [1, [64, 128, 256]]
YOLOv8-late-DFAFusion summary: 481 layers, 5,469,619 parameters, 5,469,603 gradients, 13.2 GFLOPs