学习资源站

【YOLOv10多模态融合改进】_利用DeformableAttentionTransformer可变形注意力二次改进CGAFusion动态关注不同模态间的目标区域-

【YOLOv10多模态融合改进】| 利用 Deformable Attention Transformer 可变形注意力 二次改进CGA Fusion 动态关注不同模态间的目标区域

一、本文介绍

本文记录的是利用 DAT 模块改进 YOLOv10 的多模态融合部分 。主要讲解如何利用一些现有的模块二次改进多模态的融合部分。

DAT 全称为 Deformable Attention Transformer ,其作用在于通过 可变形注意力机制 ,同时包含了数据依赖的注意力模式, 克服了常见注意力方法存在的内存计算成本高、受无关区域影响以及数据不可知等问题 。相比一些只提供固定注意力模式的方法, 能更好地聚焦于不同模态间的相关区域并捕捉更有信息的特征。

本文将其用于 CGA Fusion 模块中并进行二次创新,更好地突出不同模态的重要特征,提升模型性能。



二、Deformable Attention Transformer介绍

Vision Transformer with Deformable Attention

2.1 出发点

  • 解决现有注意力机制的问题
    • 现有的 Vision Transformers 存在使用密集注意力导致内存和计算成本过高,特征可能受无关区域影响的问题。
    • Swin Transformer 采用的稀疏注意力是数据不可知的,可能限制对长距离关系建模的能力。
  • 借鉴可变形卷积网络(DCN)的思想
    • DCN 在CNN中通过学习可变形感受野,能在数据依赖的基础上 选择性地关注更有信息的区域 ,取得了很好的效果,启发了在Vision Transformers中探索可变形注意力模式。

在这里插入图片描述

2.2 原理

  • 数据依赖的注意力模式
    • 通过一个 偏移网络(offset network) 根据输入的 查询特征(query features) 学习到 参考点(reference points) 偏移量(offsets) ,从而确定在特征图中需要关注的重要区域。
    • 这种方式使得注意力模块能够以数据依赖的方式聚焦于相关区域, 避免了对无关区域的关注,同时也克服了手工设计的稀疏注意力模式可能丢失相关信息的问题。

2.3 结构

2.3.1 参考点生成

  • 首先在特征图上生成均匀网格的参考点 p ∈ R H G × W G × 2 p \in \mathbb{R}^{H_{G} ×W_{G} ×2} p R H G × W G × 2 ,网格大小是从输入特征图大小按因子 r r r 下采样得到的,即 H G = H / r H_{G}=H / r H G = H / r W G = W / r W_{G}=W / r W G = W / r 。参考点的值是线性间隔的2D坐标,并归一化到 [ − 1 , + 1 ] [-1, +1] [ 1 , + 1 ] 范围。

2.3.2 偏移量计算

  • 将特征图线性投影得到查询令牌 q = x W q q=x W_{q} q = x W q ,然后送入一个轻量级的子网 θ o f f s e t ( ⋅ ) \theta_{offset }(\cdot) θ o ff se t ( ) 生成偏移量 Δ p = θ o f f s e t ( q ) \Delta p=\theta_{offset }(q) Δ p = θ o ff se t ( q ) 。为了稳定训练过程,会对 Δ p \Delta p Δ p 的幅度进行缩放。

2.3.3 特征采样与投影

  • 根据偏移量在变形点的位置对特征进行采样作为键(keys)和值(values),即 k ~ = x ~ W k \tilde{k}=\tilde{x} W_{k} k ~ = x ~ W k v ~ = x ~ W v \tilde{v}=\tilde{x} W_{v} v ~ = x ~ W v ,其中 x ~ = ϕ ( x ; p + Δ p ) \tilde{x}=\phi(x ; p+\Delta p) x ~ = ϕ ( x ; p + Δ p ) ,采样函数 ϕ ( ⋅ ; ⋅ ) \phi(\cdot ; \cdot) ϕ ( ; ) 采用双线性插值。

2.3.4 注意力计算

  • 对查询 q q q 和变形后的键 k ~ \tilde{k} k ~ 进行多头注意力计算,注意力头的输出公式为 z ( m ) = σ ( q ( m ) k ~ ( m ) ⊤ / d + ϕ ( B ^ ; R ) ) v ~ ( m ) z^{(m)}=\sigma\left(q^{(m)} \tilde{k}^{(m) \top} / \sqrt{d}+\phi(\hat{B} ; R)\right) \tilde{v}^{(m)} z ( m ) = σ ( q ( m ) k ~ ( m ) / d + ϕ ( B ^ ; R ) ) v ~ ( m ) ,其中还考虑了相对位置偏移 R R R 和变形点提供的更强大的相对位置偏差 ϕ ( B ^ ; R ) \phi(\hat{B} ; R) ϕ ( B ^ ; R )
    在这里插入图片描述

2.4 优势

  • 灵活性和效率
    • 能够根据输入数据动态地确定关注区域,聚焦于相关信息,避免了对无关区域的计算和关注,提高了模型的效率。
    • 通过学习共享的偏移量,在保持线性空间复杂度的同时,实现了可变形的注意力模式,相比于直接应用DCN机制到注意力模块,大大降低了计算复杂度。
  • 性能优势
    • 在多个基准数据集上的实验表明,基于 可变形注意力模块 构建的 Deformable Attention Transformer 模型在图像分类、目标检测和语义分割等任务上取得了优于竞争基准模型的结果,如在ImageNet分类任务上,相比Swin Transformer在Top - 1准确率上有显著提升。

论文: https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Vision_Transformer_With_Deformable_Attention_CVPR_2022_paper.pdf
源码: https://github.com/LeapLabTHU/DAT

三、DFAFusion的实现代码

DFAFusion 的实现代码如下:

import einops
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import trunc_normal_
import torch.nn.functional as F

class LayerNormProxy(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return einops.rearrange(x, 'b h w c -> b c h w')

class DAttentionBaseline(nn.Module):

    def __init__(
            self, q_size=(224,224), kv_size=(224,224), n_heads=8, n_head_channels=32, n_groups=1,
            attn_drop=0.0, proj_drop=0.0, stride=1,
            offset_range_factor=-1, use_pe=True, dwc_pe=True,
            no_off=False, fixed_pe=False, ksize=9, log_cpb=False
    ):

        super().__init__()
        n_head_channels = int(q_size / 8)
        q_size = (q_size, q_size)

        self.dwc_pe = dwc_pe
        self.n_head_channels = n_head_channels
        self.scale = self.n_head_channels ** -0.5
        self.n_heads = n_heads
        self.q_h, self.q_w = q_size
        # self.kv_h, self.kv_w = kv_size
        self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
        self.nc = n_head_channels * n_heads
        self.n_groups = n_groups
        self.n_group_channels = self.nc // self.n_groups
        self.n_group_heads = self.n_heads // self.n_groups
        self.use_pe = use_pe
        self.fixed_pe = fixed_pe
        self.no_off = no_off
        self.offset_range_factor = offset_range_factor
        self.ksize = ksize
        self.log_cpb = log_cpb
        self.stride = stride
        kk = self.ksize
        pad_size = kk // 2 if kk != stride else 0

        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )

        if self.no_off:
            for m in self.conv_offset.parameters():
                m.requires_grad_(False)

        self.proj_q = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_k = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0)

        self.proj_v = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )
        self.proj_out = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_drop = nn.Dropout(proj_drop, inplace=True)
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)

        if self.use_pe and not self.no_off:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(
                    self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
                )
                trunc_normal_(self.rpe_table, std=0.01)
            elif self.log_cpb:
                # Borrowed from Swin-V2
                self.rpe_table = nn.Sequential(
                    nn.Linear(2, 32, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Linear(32, self.n_group_heads, bias=False)
                )
            else:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
                trunc_normal_(self.rpe_table, std=0.01)
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    def forward(self, x):
        x = x
        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        offset = self.conv_offset(q_off).contiguous()  # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor >= 0 and not self.no_off:
            offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill_(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).clamp(-1., +1.)

        if self.no_off:
            x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
            assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
        else:
            x_sampled = F.grid_sample(
                input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
                grid=pos[..., (1, 0)],  # y, x -> x, y
                mode='bilinear', align_corners=True)  # B * g, Cg, Hg, Wg

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)
        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)

        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        attn = torch.einsum('b c m, b c n -> b m n', q, k)  # B * h, HW, Ns
        attn = attn.mul(self.scale)

        if self.use_pe and (not self.no_off):

            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels,
                                                                              H * W)
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
            elif self.log_cpb:
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                            q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                                   n_sample,
                                                                                                   2).unsqueeze(1)).mul(
                    4.0)  # d_y, d_x [-8, +8]
                displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
                attn_bias = self.rpe_table(displacement)  # B * g, H * W, n_sample, h_g
                attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                            q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                                   n_sample,
                                                                                                   2).unsqueeze(1)).mul(
                    0.5)
                attn_bias = F.grid_sample(
                    input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads,
                                           g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True)  # B * g, h_g, HW, Ns

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        out = torch.einsum('b m n, b c n -> b c m', attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        y = self.proj_drop(self.proj_out(out))
        h, w = pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

        return y

class PixelAttention_CGA(nn.Module):
    def __init__(self, dim):
        super(PixelAttention_CGA, self).__init__()
        self.pa2 = nn.Conv2d(2 * dim, dim, 7, padding=3, padding_mode='reflect' ,groups=dim, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, pattn1):
        B, C, H, W = x.shape
        x = x.unsqueeze(dim=2) # B, C, 1, H, W
        pattn1 = pattn1.unsqueeze(dim=2) # B, C, 1, H, W
        x2 = torch.cat([x, pattn1], dim=2) # B, C, 2, H, W
        x2 = rearrange(x2, 'b c t h w -> b (c t) h w')
        pattn2 = self.pa2(x2)
        pattn2 = self.sigmoid(pattn2)
        return pattn2

class DFAFusion(nn.Module):
    def __init__(self, dim):
        super(DFAFusion, self).__init__()
        self.cfam = DAttentionBaseline(dim)
        self.pa = PixelAttention_CGA(dim)
        self.conv = nn.Conv2d(dim, dim, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, data):
        x, y = data
        initial = x + y
        pattn1 = self.cfam(initial)
        pattn2 = self.sigmoid(self.pa(initial, pattn1))
        result = initial + pattn2 * x + (1 - pattn2) * y
        result = self.conv(result)
        return result

四、融合步骤

5.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 DFAFusion.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

5.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .DFAFusion import *

在这里插入图片描述

5.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

其次:在 parse_model函数 中注册 DFAFusion 模块

在这里插入图片描述

        elif m in {DFAFusion}:
            c2 = ch[f[0]]
            args = [c2]

在这里插入图片描述

最后将 ultralytics/utils/torch_utils.py 中的 get_flops 函数中的 stride 指定为 640

在这里插入图片描述


五、yaml模型文件

5.1 中期融合⭐

📌 此模型的修方法是将原本的中期融合中的Concat融合部分换成DFAFusion,融合骨干部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, Conv, [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 8-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 10-P5/32
  - [-1, 3, C2f, [1024, True]]

  - [2, 1, Conv, [64, 3, 2]] # 12-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 13-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 15-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 17-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 19-P5/32
  - [-1, 3, C2f, [1024, True]]

  - [[7, 16], 1, DFAFusion, []]  # 21 cat backbone P3
  - [[9, 18], 1, DFAFusion, []]  # 22 cat backbone P4
  - [[11, 20], 1, DFAFusion, []]  # 23 cat backbone P5

  - [-1, 1, SPPF, [1024, 5]] # 24
  - [-1, 1, PSA, [1024]] # 25

# YOLOv10.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 22], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 28

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 21], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 31 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 28], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 34 (P4/16-medium)

  - [-1, 1, SCDown, [512, 3, 2]]
  - [[-1, 25], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2fCIB, [1024, True, True]] # 37 (P5/32-large)

  - [[31, 34, 37], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

5.2 中-后期融合⭐

📌 此模型的修方法是将原本的中-后期融合中的Concat融合部分换成DFAFusion,融合FPN部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, Conv, [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 8-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 10-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 12
  - [-1, 1, PSA, [1024]] # 13

  - [2, 1, Conv, [64, 3, 2]] # 14-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 15-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 17-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 19-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 21-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 23
  - [-1, 1, PSA, [1024]] # 24

# YOLOv10.0n head
head:
  - [13, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 9], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 27

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 7], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 30 (P3/8-small)

  - [24, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 20], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 33

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 18], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 36 (P3/8-small)

  - [[13, 24], 1, DFAFusion, []]  # 37 cat backbone P3
  - [[27, 33], 1, DFAFusion, []]  # 38 cat backbone P4
  - [[30, 36], 1, DFAFusion, []]  # 39 cat backbone P5

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 38], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 42 (P4/16-medium)

  - [-1, 1, SCDown, [512, 3, 2]]
  - [[-1, 37], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2fCIB, [1024, True, True]] # 45 (P5/32-large)

  - [[39, 42, 45], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

5.3 后期融合⭐

📌 此模型的修方法是将原本的后期融合中的Concat融合部分换成DFAFusion,融合颈部部分的多模态信息。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
ch: 6
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, IN, []]  # 0
  - [-1, 1, Multiin, [1]]  # 1
  - [-2, 1, Multiin, [2]]  # 2

  - [1, 1, Conv, [64, 3, 2]] # 3-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 4-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 6-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 8-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 10-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 12
  - [-1, 1, PSA, [1024]] # 13

  - [2, 1, Conv, [64, 3, 2]] # 14-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 15-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 17-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 19-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 21-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 23
  - [-1, 1, PSA, [1024]] # 24

# YOLOv10.0n head
head:
  - [13, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 9], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 27

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 7], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 30 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 27], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 33 (P4/16-medium)

  - [-1, 1, SCDown, [512, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2fCIB, [1024, True, True]] # 36 (P5/32-large)

  - [24, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 20], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 39

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 18], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 42 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 39], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 45 (P4/16-medium)

  - [-1, 1, SCDown, [512, 3, 2]]
  - [[-1, 24], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2fCIB, [1024, True, True]] # 48 (P5/32-large)

  - [[30, 42], 1, DFAFusion, []]  # 49 cat backbone P3
  - [[33, 45], 1, DFAFusion, []]  # 50 cat backbone P4
  - [[36, 48], 1, DFAFusion, []]  # 51 cat backbone P5

  - [[49, 50, 51], 1, v10Detect, [nc]] # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到不同的融合层已经加入到模型中,并可以进行训练了。

YOLOv10n-mid-DFAFusion

YOLOv10n-mid-DFAFusion summary: 547 layers, 4,011,702 parameters, 4,011,686 gradients, 12.3 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  5                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  6                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  7                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  8                  -1  1      9856  ultralytics.nn.modules.block.SCDown          [64, 128, 3, 2]
  9                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 10                  -1  1     36096  ultralytics.nn.modules.block.SCDown          [128, 256, 3, 2]
 11                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 12                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
 13                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
 14                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
 15                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
 16                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
 17                  -1  1      9856  ultralytics.nn.modules.block.SCDown          [64, 128, 3, 2]
 18                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 19                  -1  1     36096  ultralytics.nn.modules.block.SCDown          [128, 256, 3, 2]
 20                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 21             [7, 16]  1     33280  ultralytics.nn.AddModules.DFAFusion.DFAFusion[64]
 22             [9, 18]  1    107520  ultralytics.nn.AddModules.DFAFusion.DFAFusion[128]
 23            [11, 20]  1    378880  ultralytics.nn.AddModules.DFAFusion.DFAFusion[256]
 24                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 25                  -1  1    249728  ultralytics.nn.modules.block.PSA             [256, 256]
 26                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 27            [-1, 22]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 28                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 29                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 30            [-1, 21]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 31                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 32                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 33            [-1, 28]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 34                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 35                  -1  1     18048  ultralytics.nn.modules.block.SCDown          [128, 128, 3, 2]
 36            [-1, 25]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 37                  -1  1    282624  ultralytics.nn.modules.block.C2fCIB          [384, 256, 1, True, True]
 38        [31, 34, 37]  1    861718  ultralytics.nn.modules.head.v10Detect        [1, [64, 128, 256]]
YOLOv10n-mid-DFAFusion summary: 547 layers, 4,011,702 parameters, 4,011,686 gradients, 12.3 GFLOPs

YOLOv10n-mid-to-late-DFAFusion

YOLOv10n-mid-to-late-DFAFusion summary: 617 layers, 4,611,510 parameters, 4,611,494 gradients, 13.6 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  5                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  6                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  7                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  8                  -1  1      9856  ultralytics.nn.modules.block.SCDown          [64, 128, 3, 2]
  9                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 10                  -1  1     36096  ultralytics.nn.modules.block.SCDown          [128, 256, 3, 2]
 11                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 12                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 13                  -1  1    249728  ultralytics.nn.modules.block.PSA             [256, 256]
 14                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
 15                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
 16                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
 17                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
 18                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
 19                  -1  1      9856  ultralytics.nn.modules.block.SCDown          [64, 128, 3, 2]
 20                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 21                  -1  1     36096  ultralytics.nn.modules.block.SCDown          [128, 256, 3, 2]
 22                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 23                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 24                  -1  1    249728  ultralytics.nn.modules.block.PSA             [256, 256]
 25                  13  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 26             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 27                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 28                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 29             [-1, 7]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 30                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 31                  24  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 32            [-1, 20]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 33                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 34                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 35            [-1, 18]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 36                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 37            [13, 24]  1    378880  ultralytics.nn.AddModules.DFAFusion.DFAFusion[256]
 38            [27, 33]  1    107520  ultralytics.nn.AddModules.DFAFusion.DFAFusion[128]
 39            [30, 36]  1     33280  ultralytics.nn.AddModules.DFAFusion.DFAFusion[64]
 40                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 41            [-1, 38]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 42                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 43                  -1  1     18048  ultralytics.nn.modules.block.SCDown          [128, 128, 3, 2]
 44            [-1, 37]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 45                  -1  1    282624  ultralytics.nn.modules.block.C2fCIB          [384, 256, 1, True, True]
 46        [39, 42, 45]  1    861718  ultralytics.nn.modules.head.v10Detect        [1, [64, 128, 256]]
YOLOv10n-mid-to-late-DFAFusion summary: 617 layers, 4,611,510 parameters, 4,611,494 gradients, 13.6 GFLOPs

YOLOv10n-late-DFAFusion

YOLOv10n-late-DFAFusion summary: 677 layers, 5,072,822 parameters, 5,072,806 gradients, 14.4 GFLOPs

                   from  n    params  module                                       arguments
  0                  -1  1         0  ultralytics.nn.AddModules.multimodal.IN      []
  1                  -1  1         0  ultralytics.nn.AddModules.multimodal.Multiin [1]
  2                  -2  1         0  ultralytics.nn.AddModules.multimodal.Multiin [2]
  3                   1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  4                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  5                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  6                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  7                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  8                  -1  1      9856  ultralytics.nn.modules.block.SCDown          [64, 128, 3, 2]
  9                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 10                  -1  1     36096  ultralytics.nn.modules.block.SCDown          [128, 256, 3, 2]
 11                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 12                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 13                  -1  1    249728  ultralytics.nn.modules.block.PSA             [256, 256]
 14                   2  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
 15                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
 16                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
 17                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
 18                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
 19                  -1  1      9856  ultralytics.nn.modules.block.SCDown          [64, 128, 3, 2]
 20                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
 21                  -1  1     36096  ultralytics.nn.modules.block.SCDown          [128, 256, 3, 2]
 22                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
 23                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 24                  -1  1    249728  ultralytics.nn.modules.block.PSA             [256, 256]
 25                  13  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 26             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 27                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 28                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 29             [-1, 7]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 30                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 31                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 32            [-1, 27]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 33                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 34                  -1  1     18048  ultralytics.nn.modules.block.SCDown          [128, 128, 3, 2]
 35            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 36                  -1  1    282624  ultralytics.nn.modules.block.C2fCIB          [384, 256, 1, True, True]
 37                  24  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 38            [-1, 20]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 39                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 40                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 41            [-1, 18]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 42                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 43                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 44            [-1, 39]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 45                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 46                  -1  1     18048  ultralytics.nn.modules.block.SCDown          [128, 128, 3, 2]
 47            [-1, 24]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 48                  -1  1    282624  ultralytics.nn.modules.block.C2fCIB          [384, 256, 1, True, True]
 49            [30, 42]  1     33280  ultralytics.nn.AddModules.DFAFusion.DFAFusion[64]
 50            [33, 45]  1    107520  ultralytics.nn.AddModules.DFAFusion.DFAFusion[128]
 51            [36, 48]  1    378880  ultralytics.nn.AddModules.DFAFusion.DFAFusion[256]
 52        [49, 50, 51]  1    861718  ultralytics.nn.modules.head.v10Detect        [1, [64, 128, 256]]
YOLOv10n-late-DFAFusion summary: 677 layers, 5,072,822 parameters, 5,072,806 gradients, 14.4 GFLOPs