学习资源站

RT-DETR改进策略【RT-DETR和Mamba】替换骨干Mamba-RT-DETR-B!!!最新的发文热点-

RT-DETR改进策略【RT-DETR和Mamba】| 替换骨干 Mamba-RT-DETR-B !!! 最新的发文热点

一、本文介绍

本文记录的是 利用 Mamba-YOLO 优化 RT-DETR 的目标检测网络模型 Mamba-YOLO 模型是一种基于状态空间模型(SSM)的目标检测模型, 旨在解决传统目标检测模型在处理复杂场景和长距离依赖关系时的局限性 ,是 目前最新的发文热点 。本文分成三个章节分别介绍 Mamba-YOLO 模型结构中各个模块的设计结构和优势,本章讲解 ODSSBlock模块 ,并在文末配置 Mamba-RT-DETR-B 网络结构。


Mamba YOLO:SSMs-Based YOLO For Object Detection

在这里插入图片描述


二、ODSSBlock 模块介绍

ODSSBlock (Object Detection State Space Block)是 Mamba - YOLO模型 中的核心模块,对于提升模型的目标检测能力起着关键作用。 它主要负责对输入特征进行深度处理,以学习更丰富和有效的特征表示,从而提高模型对目标物体的检测精度。

2.1 设计结构

  1. 输入预处理 :输入特征首先经过 1 × 1 1×1 1 × 1 卷积 批归一化 激活函数 进行预处理。这一步骤 有助于调整特征的维度、分布,并引入非线性激活,使网络能够学习到更复杂的特征关系 。公式表示为: Z l − 2 = Φ ^ ( B N ( C o n v 1 × 1 ( Z l − 3 ) ) ) Z^{l - 2}=\hat{\Phi}(BN(Conv_{1×1}(Z^{l - 3}))) Z l 2 = Φ ^ ( BN ( C o n v 1 × 1 ( Z l 3 ))) 其中 Φ ^ \hat{\Phi} Φ ^ 表示 SiLU激活函数

  2. Layer Normalization与Residual Linking :借鉴Transformer Blocks的风格架构,采用 Layer Normalization 对特征进行 归一化处理 ,以 加速模型的训练和收敛 。同时,引入 残差连接(Residual Linking) ,使得模型在 深度堆叠时能够保持信息的有效流动 ,避免 梯度消失或爆炸 问题,确保网络能够高效地 学习深层次的特征表示 。计算过程为: Z l − 1 = S S 2 D ( L N ( L S ( Z l − 2 ) ) ) + Z l − 2 Z^{l - 1}=SS2D(LN(LS(Z^{l - 2})))+Z^{l - 2} Z l 1 = SS 2 D ( L N ( L S ( Z l 2 ))) + Z l 2

  3. SS2D操作 SS2D (2D - Selective - Scan)是 ODSSBlock 中的重要操作,它主要包括 扫描扩展 (Scan Expansion)、 S6块处理 (S6 Block)和 扫描合并 (Scan Merge)三个步骤(如图所示)。

    • 扫描扩展 :将输入图像扩展为一系列子图像,从对角线视角看,沿着 四个对称方向 (自上而下、自下而上、从左到右、从右到左)进行扫描,每个方向的扫描结果形成一个序列。这样的布局全面覆盖输入图像的所有区域, 为后续特征提取提供丰富的多维信息基础,增强了图像特征多维捕获的效率和全面性。
    • S6块处理 :对扫描扩展得到的子图像进行 特征提取 操作,是 SS2D算法 中对图像子块进行特征提取的关键步骤。
    • 扫描合并 :将来自不同方向的经过 S6块 处理后的序列作为输入,将这些子图像序列合并在一起,形成与输入图像大小相同的输出图像,从而将不同方向提取的 特征融合 起来, 实现从局部特征到全局特征的提取。

在这里插入图片描述

  1. LocalSpatial Block(LS Block) LS Block 主要用于 增强对局部特征的捕获能力 。它对输入特征先进行 深度可分离卷积 ,以提取 局部空间信息 ,降低计算成本和参数数量。接着通过 1 × 1 1×1 1 × 1 卷积 混合通道信息,使用 非线性GELU激活函数 改变特征的通道数量而不改变空间维度, 增强特征表示 。最后将原始输入与处理后的特征通过 残差连接 融合。公式为: F l = C o n v 1 × 1 ( Φ ( C o n v 1 × 1 ( F l − 1 ) ) ) ⊕ F l − 2 F^{l}=Conv_{1×1}(\Phi(Conv_{1×1}(F^{l - 1})))\oplus F^{l - 2} F l = C o n v 1 × 1 ( Φ ( C o n v 1 × 1 ( F l 1 ))) F l 2 其中 F l F^{l} F l 是输出特征, Φ \Phi Φ 表示激活函数(GELU)。

在这里插入图片描述

  1. ResGated Block(RG Block) RG Block 旨在以较低的计算成本提高模型性能,通过引入门控机制和深度卷积残差连接,能够 有效地捕捉像素级别的局部依赖关系 ,同时将全局依赖关系和特征传递到每个像素,使模型 对图像中的细粒度特征更加敏感 ,增强模型的表达能力。

    它从输入创建两个分支,在每个分支上以 1 × 1 1×1 1 × 1 卷积实现全连接层操作。一个分支使用 深度可分离卷积 (DW - Conv)作为位置编码模块,并通过 残差连接 回流梯度 。采用 非线性GeLU 作为激活函数,两个分支通过元素乘法合并,然后通过 1 × 1 1×1 1 × 1 卷积融合通道信息,最后与原始输入通过残差连接相加。输出特征 X l X^{l} X l 的计算公式为: X l = C o n v 1 × 1 ( X 1 l − 1 ⊙ Φ ( D W C o n v 3 × 3 ( X 2 l − 1 ) ⊕ X 2 l − 1 ) ) ⊕ X l − 2 X^{l}=Conv_{1×1}(X_{1}^{l - 1}\odot\Phi(DWConv_{3×3}(X_{2}^{l - 1})\oplus X_{2}^{l - 1}))\oplus X^{l - 2} X l = C o n v 1 × 1 ( X 1 l 1 Φ ( D W C o n v 3 × 3 ( X 2 l 1 ) X 2 l 1 )) X l 2 其中 ⊙ \odot 表示元素乘法, Φ \Phi Φ 表示激活函数(GeLU)。

在这里插入图片描述

  1. 最终输出 :经过上述一系列处理后, ODSSBlock 通过 残差连接 将处理后的特征与原始输入相加,得到最终的输出特征,公式为: Z l = R G ( L N ( Z l − 1 ) ) + Z l − 1 Z^{l}=RG(LN(Z^{l - 1})) + Z^{l - 1} Z l = RG ( L N ( Z l 1 )) + Z l 1

在这里插入图片描述

2.2 优势

  1. 高效的特征处理 ODSSBlock 通过一系列精心设计的操作,包括 输入预处理 Layer Normalization 残差连接 SS2D操作 以及 LS Block RG Block 的协同工作,能够高效地处理输入特征。在保持深度堆叠的情况下,确保网络能够稳定训练, 同时有效提取图像的局部和全局特征 ,为目标检测提供丰富的特征信息。
  2. 增强的局部特征捕获 LS Block 通过深度可分离卷积和残差连接,在降低计算成本的同时增强了对局部空间信息的提取能力,提高了模型对不同尺度目标的检测能力, 增强了模型对尺度变化的鲁棒性
  3. 全局特征整合与表达 RG Block 在捕获像素级别的局部依赖关系的同时,有效地整合了全局特征信息,增强了模型的表达能力, 有助于提高模型对目标物体的定位和识别准确性
  4. 多方向特征融合 SS2D操作 这种多方向的特征融合方式使得模型能够 更全面地理解图像中的目标物体及其上下文关系 ,提高了模型在复杂场景下的检测性能。

论文: https://arxiv.org/pdf/2406.05835
源码: https://github.com/HZAI-ZJNU/Mamba-YOLO

三、Mamba-YOLO相关模块的实现代码

Mamba-RT-DETR-T、Mamba-RT-DETR-B、Mamba-RT-DETR-L这三篇文章中的第三节和第四节的内容和步骤是完全一致的,只需参考一篇,进行配置即可

实现代码如下:

import torch
import math
from functools import partial
from typing import Callable, Any

import torch.nn as nn
from einops import rearrange, repeat
from timm.layers import DropPath

DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
try:
    import selective_scan_cuda_core
    import selective_scan_cuda_oflex
    import selective_scan_cuda_ndstate
    # import selective_scan_cuda_nrow
    import selective_scan_cuda
except:
    pass

__all__ = ("VSSBlock_YOLO", "SimpleStem", "VisionClueMerge", "XSSBlock")

class LayerNorm2d(nn.Module):

    def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
        super().__init__()
        self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine)

    def forward(self, x):
        x = rearrange(x, 'b c h w -> b h w c').contiguous()
        x = self.norm(x)
        x = rearrange(x, 'b h w c -> b c h w').contiguous()
        return x

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

# Cross Scan
class CrossScan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        xs = x.new_empty((B, 4, C, H * W))
        xs[:, 0] = x.flatten(2, 3)
        xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        return xs

    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        return y.view(B, -1, H, W)

class CrossMerge(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1)
        ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        return y

    @staticmethod
    def backward(ctx, x: torch.Tensor):
        # B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        xs = x.new_empty((B, 4, C, L))
        xs[:, 0] = x
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        xs = xs.view(B, 4, C, H, W)
        return xs, None, None

class SelectiveScanCore(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd
    def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1,
                oflex=True):
        # all in float
        if u.stride(-1) != 1:
            u = u.contiguous()
        if delta.stride(-1) != 1:
            delta = delta.contiguous()
        if D is not None and D.stride(-1) != 1:
            D = D.contiguous()
        if B.stride(-1) != 1:
            B = B.contiguous()
        if C.stride(-1) != 1:
            C = C.contiguous()
        if B.dim() == 3:
            B = B.unsqueeze(dim=1)
            ctx.squeeze_B = True
        if C.dim() == 3:
            C = C.unsqueeze(dim=1)
            ctx.squeeze_C = True
        ctx.delta_softplus = delta_softplus
        ctx.backnrows = backnrows
        out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
        ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
        return out

    @staticmethod
    @torch.cuda.amp.custom_bwd
    def backward(ctx, dout, *args):
        u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
            u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
        )
        return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)

def cross_selective_scan(
        x: torch.Tensor = None,
        x_proj_weight: torch.Tensor = None,
        x_proj_bias: torch.Tensor = None,
        dt_projs_weight: torch.Tensor = None,
        dt_projs_bias: torch.Tensor = None,
        A_logs: torch.Tensor = None,
        Ds: torch.Tensor = None,
        out_norm: torch.nn.Module = None,
        out_norm_shape="v0",
        nrows=-1,  
        backnrows=-1, 
        delta_softplus=True,
        to_dtype=True,
        force_fp32=False,  
        ssoflex=True,
        SelectiveScan=None,
        scan_mode_type='default'
):
   
    B, D, H, W = x.shape
    D, N = A_logs.shape
    K, D, R = dt_projs_weight.shape
    L = H * W

    def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
        return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex)

    xs = CrossScan.apply(x)

    x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
    if x_proj_bias is not None:
        x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
    dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
    dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)
    xs = xs.view(B, -1, L)
    dts = dts.contiguous().view(B, -1, L)
    As = -torch.exp(A_logs.to(torch.float))
    Bs = Bs.contiguous()
    Cs = Cs.contiguous()
    Ds = Ds.to(torch.float)  
    delta_bias = dt_projs_bias.view(-1).to(torch.float)

    if force_fp32:
        xs = xs.to(torch.float)
        dts = dts.to(torch.float)
        Bs = Bs.to(torch.float)
        Cs = Cs.to(torch.float)

    ys: torch.Tensor = selective_scan(
        xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
    ).view(B, K, -1, H, W)

    y: torch.Tensor = CrossMerge.apply(ys)

    if out_norm_shape in ["v1"]:  
        y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) 
    else: 
        y = y.transpose(dim0=1, dim1=2).contiguous() 
        y = out_norm(y).view(B, H, W, -1)

    return (y.to(x.dtype) if to_dtype else y)

class SS2D(nn.Module):
    def __init__(
            self,
            d_model=96,
            d_state=16,
            ssm_ratio=2.0,
            ssm_rank_ratio=2.0,
            dt_rank="auto",
            act_layer=nn.SiLU,
            d_conv=3, 
            conv_bias=True,
            dropout=0.0,
            bias=False,
            forward_type="v2",
            **kwargs,
    ):
        """
        ssm_rank_ratio would be used in the future...
        """
        factory_kwargs = {"device": None, "dtype": None}
        super().__init__()
        d_expand = int(ssm_ratio * d_model)
        d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expand
        self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
        self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state  
        self.d_conv = d_conv
        self.K = 4

        def checkpostfix(tag, value):
            ret = value[-len(tag):] == tag
            if ret:
                value = value[:-len(tag)]
            return ret, value

        self.disable_force32, forward_type = checkpostfix("no32", forward_type)
        self.disable_z, forward_type = checkpostfix("noz", forward_type)
        self.disable_z_act, forward_type = checkpostfix("nozact", forward_type)

        self.out_norm = nn.LayerNorm(d_inner)
        FORWARD_TYPES = dict(
            v2=partial(self.forward_corev2, force_fp32=None, SelectiveScan=SelectiveScanCore),
        )
        self.forward_core = FORWARD_TYPES.get(forward_type, FORWARD_TYPES.get("v2", None))
        d_proj = d_expand if self.disable_z else (d_expand * 2)
        self.in_proj = nn.Conv2d(d_model, d_proj, kernel_size=1, stride=1, groups=1, bias=bias, **factory_kwargs)
        self.act: nn.Module = nn.GELU()

        if self.d_conv > 1:
            self.conv2d = nn.Conv2d(
                in_channels=d_expand,
                out_channels=d_expand,
                groups=d_expand,
                bias=conv_bias,
                kernel_size=d_conv,
                padding=(d_conv - 1) // 2,
                **factory_kwargs,
            )
        self.ssm_low_rank = False
        if d_inner < d_expand:
            self.ssm_low_rank = True
            self.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs)
            self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs)

        self.x_proj = [
            nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False,
                      **factory_kwargs)
            for _ in range(self.K)
        ]
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0))  # (K, N, inner)
        del self.x_proj

        self.out_proj = nn.Conv2d(d_expand, d_model, kernel_size=1, stride=1, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()

        self.Ds = nn.Parameter(torch.ones((self.K * d_inner)))
        self.A_logs = nn.Parameter(
            torch.zeros((self.K * d_inner, self.d_state)))  
        self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank)))
        self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner)))

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
                **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        dt_init_std = dt_rank ** -0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)

        return dt_proj

    @staticmethod
    def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A) 
        if copies > 0:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def D_init(d_inner, copies=-1, device=None, merge=True):
        D = torch.ones(d_inner, device=device)
        if copies > 0:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D) 
        D._no_weight_decay = True
        return D

    def forward_corev2(self, x: torch.Tensor, channel_first=False, SelectiveScan=SelectiveScanCore,
                       cross_selective_scan=cross_selective_scan, force_fp32=None):
        force_fp32 = (self.training and (not self.disable_force32)) if force_fp32 is None else force_fp32
        if not channel_first:
            x = x.permute(0, 3, 1, 2).contiguous()
        if self.ssm_low_rank:
            x = self.in_rank(x)
        x = cross_selective_scan(
            x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,
            self.A_logs, self.Ds,
            out_norm=getattr(self, "out_norm", None),
            out_norm_shape=getattr(self, "out_norm_shape", "v0"),
            delta_softplus=True, force_fp32=force_fp32,
            SelectiveScan=SelectiveScan, ssoflex=self.training,  # output fp32
        )
        if self.ssm_low_rank:
            x = self.out_rank(x)
        return x

    def forward(self, x: torch.Tensor, **kwargs):
        x = self.in_proj(x)
        if not self.disable_z:
            x, z = x.chunk(2, dim=1)
            if not self.disable_z_act:
                z1 = self.act(z)
        if self.d_conv > 0:
            x = self.conv2d(x) 
        x = self.act(x)
        y = self.forward_core(x, channel_first=(self.d_conv > 1))
        y = y.permute(0, 3, 1, 2).contiguous()
        if not self.disable_z:
            y = y * z1
        out = self.dropout(self.out_proj(y))
        return out

class RGBlock(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
                 channels_first=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        hidden_features = int(2 * hidden_features / 3)
        self.fc1 = nn.Conv2d(in_features, hidden_features * 2, kernel_size=1)
        self.dwconv = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, bias=True,
                                groups=hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x, v = self.fc1(x).chunk(2, dim=1)
        x = self.act(self.dwconv(x) + x) * v
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class LSBlock(nn.Module):
    def __init__(self, in_features, hidden_features=None, act_layer=nn.GELU, drop=0):
        super().__init__()
        self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=3, padding=3 // 2, groups=hidden_features)
        self.norm = nn.BatchNorm2d(hidden_features)
        self.fc2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=1, padding=0)
        self.act = act_layer()
        self.fc3 = nn.Conv2d(hidden_features, in_features, kernel_size=1, padding=0)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        input = x
        x = self.fc1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.act(x)
        x = self.fc3(x)
        x = input + self.drop(x)
        return x

class XSSBlock(nn.Module):
    def __init__(
            self,
            in_channels: int = 0,
            hidden_dim: int = 0,
            n: int = 1,
            mlp_ratio=4.0,
            drop_path: float = 0,
            norm_layer: Callable[..., torch.nn.Module] = partial(LayerNorm2d, eps=1e-6),
            ssm_d_state: int = 16,
            ssm_ratio=2.0,
            ssm_rank_ratio=2.0,
            ssm_dt_rank: Any = "auto",
            ssm_act_layer=nn.SiLU,
            ssm_conv: int = 3,
            ssm_conv_bias=True,
            ssm_drop_rate: float = 0,
            ssm_init="v0",
            forward_type="v2",
            mlp_act_layer=nn.GELU,
            mlp_drop_rate: float = 0.0,
            use_checkpoint: bool = False,
            post_norm: bool = False,
            **kwargs,
    ):
        super().__init__()

        self.in_proj = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU()
        ) if in_channels != hidden_dim else nn.Identity()
        self.hidden_dim = hidden_dim
        self.norm = norm_layer(hidden_dim)
        self.ss2d = nn.Sequential(*(SS2D(d_model=self.hidden_dim,
                                         d_state=ssm_d_state,
                                         ssm_ratio=ssm_ratio,
                                         ssm_rank_ratio=ssm_rank_ratio,
                                         dt_rank=ssm_dt_rank,
                                         act_layer=ssm_act_layer,
                                         d_conv=ssm_conv,
                                         conv_bias=ssm_conv_bias,
                                         dropout=ssm_drop_rate, ) for _ in range(n)))
        self.drop_path = DropPath(drop_path)
        self.lsblock = LSBlock(hidden_dim, hidden_dim)
        self.mlp_branch = mlp_ratio > 0
        if self.mlp_branch:
            self.norm2 = norm_layer(hidden_dim)
            mlp_hidden_dim = int(hidden_dim * mlp_ratio)
            self.mlp = RGBlock(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer,
                               drop=mlp_drop_rate)

    def forward(self, input):
        input = self.in_proj(input)
        X1 = self.lsblock(input)
        input = input + self.drop_path(self.ss2d(self.norm(X1)))
        if self.mlp_branch:
            input = input + self.drop_path(self.mlp(self.norm2(input)))
        return input

class VSSBlock_YOLO(nn.Module):
    def __init__(
            self,
            in_channels: int = 0,
            hidden_dim: int = 0,
            drop_path: float = 0,
            norm_layer: Callable[..., torch.nn.Module] = partial(LayerNorm2d, eps=1e-6),
            ssm_d_state: int = 16,
            ssm_ratio=2.0,
            ssm_rank_ratio=2.0,
            ssm_dt_rank: Any = "auto",
            ssm_act_layer=nn.SiLU,
            ssm_conv: int = 3,
            ssm_conv_bias=True,
            ssm_drop_rate: float = 0,
            ssm_init="v0",
            forward_type="v2",
            mlp_ratio=4.0,
            mlp_act_layer=nn.GELU,
            mlp_drop_rate: float = 0.0,
            use_checkpoint: bool = False,
            post_norm: bool = False,
            **kwargs,
    ):
        super().__init__()
        self.ssm_branch = ssm_ratio > 0
        self.mlp_branch = mlp_ratio > 0
        self.use_checkpoint = use_checkpoint
        self.post_norm = post_norm

        # proj
        self.proj_conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU()
        )

        if self.ssm_branch:
            self.norm = norm_layer(hidden_dim)
            self.op = SS2D(
                d_model=hidden_dim,
                d_state=ssm_d_state,
                ssm_ratio=ssm_ratio,
                ssm_rank_ratio=ssm_rank_ratio,
                dt_rank=ssm_dt_rank,
                act_layer=ssm_act_layer,
                d_conv=ssm_conv,
                conv_bias=ssm_conv_bias,
                dropout=ssm_drop_rate,
                # bias=False,
                # dt_min=0.001,
                # dt_max=0.1,
                # dt_init="random",
                # dt_scale="random",
                # dt_init_floor=1e-4,
                initialize=ssm_init,
                forward_type=forward_type,
            )

        self.drop_path = DropPath(drop_path)
        self.lsblock = LSBlock(hidden_dim, hidden_dim)
        if self.mlp_branch:
            self.norm2 = norm_layer(hidden_dim)
            mlp_hidden_dim = int(hidden_dim * mlp_ratio)
            self.mlp = RGBlock(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer,
                               drop=mlp_drop_rate, channels_first=False)

    def forward(self, input: torch.Tensor):
        input = self.proj_conv(input)
        X1 = self.lsblock(input)
        x = input + self.drop_path(self.op(self.norm(X1)))
        if self.mlp_branch:
            x = x + self.drop_path(self.mlp(self.norm2(x)))  # FFN
        return x

class SimpleStem(nn.Module):
    def __init__(self, inp, embed_dim, ks=3):
        super().__init__()
        self.hidden_dims = embed_dim // 2
        self.conv = nn.Sequential(
            nn.Conv2d(inp, self.hidden_dims, kernel_size=ks, stride=2, padding=autopad(ks, d=1), bias=False),
            nn.BatchNorm2d(self.hidden_dims),
            nn.GELU(),
            nn.Conv2d(self.hidden_dims, embed_dim, kernel_size=ks, stride=2, padding=autopad(ks, d=1), bias=False),
            nn.BatchNorm2d(embed_dim),
            nn.SiLU(),
        )

    def forward(self, x):
        return self.conv(x)

class VisionClueMerge(nn.Module):
    def __init__(self, dim, out_dim):
        super().__init__()
        self.hidden = int(dim * 4)

        self.pw_linear = nn.Sequential(
            nn.Conv2d(self.hidden, out_dim, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(out_dim),
            nn.SiLU()
        )

    def forward(self, x):
        y = torch.cat([
            x[..., ::2, ::2],
            x[..., 1::2, ::2],
            x[..., ::2, 1::2],
            x[..., 1::2, 1::2]
        ], dim=1)
        return self.pw_linear(y)


四、添加步骤

4.1 基础环境

环境要求:

  • Linux
  • NVIDIA GPU
  • PyTorch 1.12+
  • CUDA 11.6+

个人环境:

  • Linux
  • NVIDIA GPU
  • PyTorch 2.0.0
  • CUDA 11.8

在这里插入图片描述
此处的 PyTorch和CUDA版本必须对应 后续安装过程中出现的问题多半是版本不匹配或是网络问题造成的。 此外,官方是在 Linux 上进行实现的, Windows 的可以尝试一下看看。

4.2 安装并编译

以下是必须安装的模块:

1️⃣ mmcv

pip install mmcv

若报错,使用 mim 安装

首先安装openmim : pip install -U openmim
然后安装mmcv : mim install mmcv

2️⃣ causal-conv1d

pip install causal-conv1d

Building wheels for collected packages: causal-conv1d
Building wheel for causal-conv1d (setup.py) … -

若卡在这里,就是网络问题,可尝试本地安装…

3️⃣ mamba-ssm

pip install mamba-ssm

Building wheels for collected packages: mamba-ssm
Building wheel for mamba-ssm (setup.py) … -

同上

3️⃣ 编译 Mamba

在群内已上传项目包 mamba ,下载到本地后解压,放在 ultralytics/nn/AddModules/mamba 路径下;

在这里插入图片描述

cd 进入到 ultralytics/nn/AddModules/mamba 路径下执行:

python setup.py install

在这里插入图片描述
若出现类似报错,则说明版本不匹配。

在这里插入图片描述
其它报错也是类似,重新编译前需在相同目录下执行:

python setup.py clean --all

然后再次重新编译,完成。

5️⃣ 编译 Mamba-YOLO
在群内已上传相关项目包 selective_scan ,下载到本地后解压,放在 ultralytics/nn/AddModules/selective_scan 路径下;

在这里插入图片描述

cd 进入到 ultralytics/nn/AddModules/selective_scan 路径下执行:

python setup.py install

在这里插入图片描述

其余步骤,与上一步类似。

全部顺利安装完成后,相关的配置就算完成了。

4.3 代码配置

1️⃣ 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

2️⃣ 在 AddModules 文件夹下新建 mamba_yolo.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

3️⃣ 在 AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .mamba_yolo import *

在这里插入图片描述

3️⃣ 在 ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

其次:在 parse_model函数 中注册模块: SimpleStem , VisionClueMerge , VSSBlock_YOLO , XSSBlock

在这里插入图片描述

在这里插入图片描述

DetectionModel 类下,添加如下代码

try:
   m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))])  # forward on CPU
except RuntimeError:
   try:
       self.model.to(torch.device('cuda'))
       m.stride = torch.tensor([s / x.shape[-2] for x in _forward(
           torch.zeros(1, ch, s, s).to(torch.device('cuda')))])  # forward on CUDA
   except RuntimeError as error:
       raise error

并注释这一行

# m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))])  # forward

在这里插入图片描述


五、yaml模型文件

5.1 模型改进版本

📌 新建模型文件 rtdetr-mamba-B.yaml ,并配置如下结构:

nc: 1  # number of classes
scales:   # [depth, width, max_channels]
  B: [0.33, 0.50, 1024]  #Mamba-YOLOv8-T summary: 6.1M parameters,   14.3GFLOPs

# Mamba-YOLO backbone 
backbone:
  # [from, repeats, module, args]
  - [-1, 1, SimpleStem, [128, 3]]   # 0-P2/4
  - [-1, 2, VSSBlock_YOLO, [128]]               # 1
  - [-1, 1, VisionClueMerge, [256]]      # 2 p3/8
  - [-1, 2, VSSBlock_YOLO, [256]]              # 3
  - [-1, 1, VisionClueMerge, [512]]      # 4 p4/16
  - [-1, 2, VSSBlock_YOLO, [512]]              # 5
  - [-1, 1, VisionClueMerge, [1024]]      # 6 p5/32
  - [-1, 2, VSSBlock_YOLO, [1024]]              # 7
  - [-1, 1, SPPF, [1024, 5]]               # 8
  - [-1, 2, C2PSA, [1024]]               # 9

# Mamba-YOLO PAFPN
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P4
  - [-1, 2, XSSBlock, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 3], 1, Concat, [1]]  # cat backbone P3
  - [-1, 2, XSSBlock, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 2, XSSBlock, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 2, XSSBlock, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到 Mamba-YOLO 已经加入到模型中,并可以进行训练了。

rtdetr-mamba-B

rtdetr-mamba-B summary: 426 layers, 22,019,428 parameters, 22,019,428 gradients

                   from  n    params  module                                       arguments                     
  0                  -1  1     19488  ultralytics.nn.AddModules.mamba_yolo.SimpleStem[3, 64, 3]                    
  1                  -1  1    104184  ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[64, 64]                      
  2                  -1  1     33152  ultralytics.nn.AddModules.mamba_yolo.VisionClueMerge[64, 128]                     
  3                  -1  1    355964  ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[128, 128]                    
  4                  -1  1    131840  ultralytics.nn.AddModules.mamba_yolo.VisionClueMerge[128, 256]                    
  5                  -1  1   1301496  ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[256, 256]                    
  6                  -1  1    525824  ultralytics.nn.AddModules.mamba_yolo.VisionClueMerge[256, 512]                    
  7                  -1  1   4962812  ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[512, 512]                    
  8                  -1  1    656896  ultralytics.nn.modules.block.SPPF            [512, 512, 5]                 
  9                  -1  1    990976  ultralytics.nn.modules.block.C2PSA           [512, 512, 1]                 
 10                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 11             [-1, 5]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 12                  -1  1   1432312  ultralytics.nn.AddModules.mamba_yolo.XSSBlock[768, 256]                    
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 14             [-1, 3]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 15                  -1  1    388604  ultralytics.nn.AddModules.mamba_yolo.XSSBlock[384, 128]                    
 16                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]              
 17            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 18                  -1  1   1334008  ultralytics.nn.AddModules.mamba_yolo.XSSBlock[384, 256]                    
 19                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 20             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 21                  -1  1   5093372  ultralytics.nn.AddModules.mamba_yolo.XSSBlock[768, 512]                    
 22        [15, 18, 21]  1   3950452  ultralytics.nn.modules.head.RTDETRDecoder    [1, [128, 256, 512], 256, 300, 4, 8, 3]
rtdetr-mamba-B summary: 426 layers, 22,019,428 parameters, 22,019,428 gradients