RT-DETR改进策略【RT-DETR和Mamba】| 替换骨干 Mamba-RT-DETR-T !!! 最新的发文热点
一、本文介绍
本文记录的是
利用
Mamba-YOLO
优化
RT-DETR
的目标检测网络模型
。
Mamba-YOLO
模型是一种基于状态空间模型(SSM)的目标检测模型,
旨在解决传统目标检测模型在处理复杂场景和长距离依赖关系时的局限性
,是
目前最新的发文热点
。本文分成三个章节分别介绍
Mamba-YOLO
模型结构中各个模块的设计结构和优势,本章讲解
Simple Stem模块
,并在文末配置
Mamba-RT-DETR-T
网络结构。
Mamba YOLO:SSMs-Based YOLO For Object Detection
二、Simple Stem 模块介绍
Simple Stem
模块是
Mamba - YOLO
模型中的一个重要组成部分,其
主要作用是在模型的初始阶段对输入图像进行处理
,方便后续的特征提取和目标检测。以下是对
Simple Stem
模块的详细介绍:
2.1 设计背景
现代
Vision Transformers(ViTs)
通常采用
分段补丁
作为初始模块,通过
卷积操作将图像分割成非重叠的片段
。然而,这种方法会限制
ViTs
的优化能力,进而影响整体性能。为了在性能和效率之间找到平衡,
Mamba - YOLO
提出了
Simple Stem
模块。
2.2 设计结构
Simple Stem模块
摒弃了传统
ViTs
的分段补丁方式,采用了
两个步长为2、核大小为3的卷积操作
。
这种设计相对较为简洁,避免了复杂的图像分割过程,同时能够有效地对输入图像进行初步的特征提取和下采样操作。通过这两个卷积层的连续处理, 图像的分辨率在保持一定特征信息的同时得到了降低 ,为后续的网络层提供了合适的输入尺度。
2.3 优势
-
平衡性能与效率
:相较于传统
ViTs的初始模块,Simple Stem模块在不损失过多信息的前提下,减少了计算复杂度,提高了模型的整体效率。它能够在模型的起始阶段 快速处理图像数据 ,使得后续网络层能够更高效地进行特征学习和目标检测任务,从而在性能和效率之间实现了较好的平衡。 - 优化特征表示 :两个步长为2、核大小为3的卷积操作能够有效地捕捉图像的局部特征信息,同时在 一定程度上保留了图像的空间信息 。这种特征表示方式有助于后续网络层更好地理解图像内容,为准确检测目标物体提供了有力的支持。
-
增强模型适应性
:
Simple Stem模块的设计使得Mamba - YOLO模型能够更好地适应不同场景下的目标检测任务。其简洁而有效的结构能够快速处理各种输入图像,无论是简单场景还是复杂场景,都能够为模型 提供稳定且有效的初始特征信息 ,从而提高了模型在实际应用中的泛化能力。
论文: https://arxiv.org/pdf/2406.05835
源码: https://github.com/HZAI-ZJNU/Mamba-YOLO
三、Mamba-YOLO相关模块的实现代码
Mamba-RT-DETR-T、Mamba-RT-DETR-B、Mamba-RT-DETR-L这三篇文章中的第三节和第四节的内容和步骤是完全一致的,只需参考一篇,进行配置即可
实现代码如下:
import torch
import math
from functools import partial
from typing import Callable, Any
import torch.nn as nn
from einops import rearrange, repeat
from timm.layers import DropPath
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
try:
import selective_scan_cuda_core
import selective_scan_cuda_oflex
import selective_scan_cuda_ndstate
# import selective_scan_cuda_nrow
import selective_scan_cuda
except:
pass
__all__ = ("VSSBlock_YOLO", "SimpleStem", "VisionClueMerge", "XSSBlock")
class LayerNorm2d(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
super().__init__()
self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c').contiguous()
x = self.norm(x)
x = rearrange(x, 'b h w c -> b c h w').contiguous()
return x
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
# Cross Scan
class CrossScan(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
xs = x.new_empty((B, 4, C, H * W))
xs[:, 0] = x.flatten(2, 3)
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
return y.view(B, -1, H, W)
class CrossMerge(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor):
B, K, D, H, W = ys.shape
ctx.shape = (H, W)
ys = ys.view(B, K, D, -1)
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
# B, D, L = x.shape
# out: (b, k, d, l)
H, W = ctx.shape
B, C, L = x.shape
xs = x.new_empty((B, 4, C, L))
xs[:, 0] = x
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
xs = xs.view(B, 4, C, H, W)
return xs, None, None
class SelectiveScanCore(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1,
oflex=True):
# all in float
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if B.dim() == 3:
B = B.unsqueeze(dim=1)
ctx.squeeze_B = True
if C.dim() == 3:
C = C.unsqueeze(dim=1)
ctx.squeeze_C = True
ctx.delta_softplus = delta_softplus
ctx.backnrows = backnrows
out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
)
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
def cross_selective_scan(
x: torch.Tensor = None,
x_proj_weight: torch.Tensor = None,
x_proj_bias: torch.Tensor = None,
dt_projs_weight: torch.Tensor = None,
dt_projs_bias: torch.Tensor = None,
A_logs: torch.Tensor = None,
Ds: torch.Tensor = None,
out_norm: torch.nn.Module = None,
out_norm_shape="v0",
nrows=-1,
backnrows=-1,
delta_softplus=True,
to_dtype=True,
force_fp32=False,
ssoflex=True,
SelectiveScan=None,
scan_mode_type='default'
):
B, D, H, W = x.shape
D, N = A_logs.shape
K, D, R = dt_projs_weight.shape
L = H * W
def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex)
xs = CrossScan.apply(x)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
if x_proj_bias is not None:
x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)
xs = xs.view(B, -1, L)
dts = dts.contiguous().view(B, -1, L)
As = -torch.exp(A_logs.to(torch.float))
Bs = Bs.contiguous()
Cs = Cs.contiguous()
Ds = Ds.to(torch.float)
delta_bias = dt_projs_bias.view(-1).to(torch.float)
if force_fp32:
xs = xs.to(torch.float)
dts = dts.to(torch.float)
Bs = Bs.to(torch.float)
Cs = Cs.to(torch.float)
ys: torch.Tensor = selective_scan(
xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
).view(B, K, -1, H, W)
y: torch.Tensor = CrossMerge.apply(ys)
if out_norm_shape in ["v1"]:
y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1)
else:
y = y.transpose(dim0=1, dim1=2).contiguous()
y = out_norm(y).view(B, H, W, -1)
return (y.to(x.dtype) if to_dtype else y)
class SS2D(nn.Module):
def __init__(
self,
d_model=96,
d_state=16,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
dt_rank="auto",
act_layer=nn.SiLU,
d_conv=3,
conv_bias=True,
dropout=0.0,
bias=False,
forward_type="v2",
**kwargs,
):
"""
ssm_rank_ratio would be used in the future...
"""
factory_kwargs = {"device": None, "dtype": None}
super().__init__()
d_expand = int(ssm_ratio * d_model)
d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expand
self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state
self.d_conv = d_conv
self.K = 4
def checkpostfix(tag, value):
ret = value[-len(tag):] == tag
if ret:
value = value[:-len(tag)]
return ret, value
self.disable_force32, forward_type = checkpostfix("no32", forward_type)
self.disable_z, forward_type = checkpostfix("noz", forward_type)
self.disable_z_act, forward_type = checkpostfix("nozact", forward_type)
self.out_norm = nn.LayerNorm(d_inner)
FORWARD_TYPES = dict(
v2=partial(self.forward_corev2, force_fp32=None, SelectiveScan=SelectiveScanCore),
)
self.forward_core = FORWARD_TYPES.get(forward_type, FORWARD_TYPES.get("v2", None))
d_proj = d_expand if self.disable_z else (d_expand * 2)
self.in_proj = nn.Conv2d(d_model, d_proj, kernel_size=1, stride=1, groups=1, bias=bias, **factory_kwargs)
self.act: nn.Module = nn.GELU()
if self.d_conv > 1:
self.conv2d = nn.Conv2d(
in_channels=d_expand,
out_channels=d_expand,
groups=d_expand,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.ssm_low_rank = False
if d_inner < d_expand:
self.ssm_low_rank = True
self.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs)
self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs)
self.x_proj = [
nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False,
**factory_kwargs)
for _ in range(self.K)
]
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
del self.x_proj
self.out_proj = nn.Conv2d(d_expand, d_model, kernel_size=1, stride=1, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
self.Ds = nn.Parameter(torch.ones((self.K * d_inner)))
self.A_logs = nn.Parameter(
torch.zeros((self.K * d_inner, self.d_state)))
self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank)))
self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner)))
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
**factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
dt_init_std = dt_rank ** -0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A)
if copies > 0:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=-1, device=None, merge=True):
D = torch.ones(d_inner, device=device)
if copies > 0:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D)
D._no_weight_decay = True
return D
def forward_corev2(self, x: torch.Tensor, channel_first=False, SelectiveScan=SelectiveScanCore,
cross_selective_scan=cross_selective_scan, force_fp32=None):
force_fp32 = (self.training and (not self.disable_force32)) if force_fp32 is None else force_fp32
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
if self.ssm_low_rank:
x = self.in_rank(x)
x = cross_selective_scan(
x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,
self.A_logs, self.Ds,
out_norm=getattr(self, "out_norm", None),
out_norm_shape=getattr(self, "out_norm_shape", "v0"),
delta_softplus=True, force_fp32=force_fp32,
SelectiveScan=SelectiveScan, ssoflex=self.training, # output fp32
)
if self.ssm_low_rank:
x = self.out_rank(x)
return x
def forward(self, x: torch.Tensor, **kwargs):
x = self.in_proj(x)
if not self.disable_z:
x, z = x.chunk(2, dim=1)
if not self.disable_z_act:
z1 = self.act(z)
if self.d_conv > 0:
x = self.conv2d(x)
x = self.act(x)
y = self.forward_core(x, channel_first=(self.d_conv > 1))
y = y.permute(0, 3, 1, 2).contiguous()
if not self.disable_z:
y = y * z1
out = self.dropout(self.out_proj(y))
return out
class RGBlock(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
channels_first=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = int(2 * hidden_features / 3)
self.fc1 = nn.Conv2d(in_features, hidden_features * 2, kernel_size=1)
self.dwconv = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, bias=True,
groups=hidden_features)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x, v = self.fc1(x).chunk(2, dim=1)
x = self.act(self.dwconv(x) + x) * v
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LSBlock(nn.Module):
def __init__(self, in_features, hidden_features=None, act_layer=nn.GELU, drop=0):
super().__init__()
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=3, padding=3 // 2, groups=hidden_features)
self.norm = nn.BatchNorm2d(hidden_features)
self.fc2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=1, padding=0)
self.act = act_layer()
self.fc3 = nn.Conv2d(hidden_features, in_features, kernel_size=1, padding=0)
self.drop = nn.Dropout(drop)
def forward(self, x):
input = x
x = self.fc1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.act(x)
x = self.fc3(x)
x = input + self.drop(x)
return x
class XSSBlock(nn.Module):
def __init__(
self,
in_channels: int = 0,
hidden_dim: int = 0,
n: int = 1,
mlp_ratio=4.0,
drop_path: float = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(LayerNorm2d, eps=1e-6),
ssm_d_state: int = 16,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
ssm_dt_rank: Any = "auto",
ssm_act_layer=nn.SiLU,
ssm_conv: int = 3,
ssm_conv_bias=True,
ssm_drop_rate: float = 0,
ssm_init="v0",
forward_type="v2",
mlp_act_layer=nn.GELU,
mlp_drop_rate: float = 0.0,
use_checkpoint: bool = False,
post_norm: bool = False,
**kwargs,
):
super().__init__()
self.in_proj = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU()
) if in_channels != hidden_dim else nn.Identity()
self.hidden_dim = hidden_dim
self.norm = norm_layer(hidden_dim)
self.ss2d = nn.Sequential(*(SS2D(d_model=self.hidden_dim,
d_state=ssm_d_state,
ssm_ratio=ssm_ratio,
ssm_rank_ratio=ssm_rank_ratio,
dt_rank=ssm_dt_rank,
act_layer=ssm_act_layer,
d_conv=ssm_conv,
conv_bias=ssm_conv_bias,
dropout=ssm_drop_rate, ) for _ in range(n)))
self.drop_path = DropPath(drop_path)
self.lsblock = LSBlock(hidden_dim, hidden_dim)
self.mlp_branch = mlp_ratio > 0
if self.mlp_branch:
self.norm2 = norm_layer(hidden_dim)
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
self.mlp = RGBlock(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer,
drop=mlp_drop_rate)
def forward(self, input):
input = self.in_proj(input)
X1 = self.lsblock(input)
input = input + self.drop_path(self.ss2d(self.norm(X1)))
if self.mlp_branch:
input = input + self.drop_path(self.mlp(self.norm2(input)))
return input
class VSSBlock_YOLO(nn.Module):
def __init__(
self,
in_channels: int = 0,
hidden_dim: int = 0,
drop_path: float = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(LayerNorm2d, eps=1e-6),
ssm_d_state: int = 16,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
ssm_dt_rank: Any = "auto",
ssm_act_layer=nn.SiLU,
ssm_conv: int = 3,
ssm_conv_bias=True,
ssm_drop_rate: float = 0,
ssm_init="v0",
forward_type="v2",
mlp_ratio=4.0,
mlp_act_layer=nn.GELU,
mlp_drop_rate: float = 0.0,
use_checkpoint: bool = False,
post_norm: bool = False,
**kwargs,
):
super().__init__()
self.ssm_branch = ssm_ratio > 0
self.mlp_branch = mlp_ratio > 0
self.use_checkpoint = use_checkpoint
self.post_norm = post_norm
# proj
self.proj_conv = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(hidden_dim),
nn.SiLU()
)
if self.ssm_branch:
self.norm = norm_layer(hidden_dim)
self.op = SS2D(
d_model=hidden_dim,
d_state=ssm_d_state,
ssm_ratio=ssm_ratio,
ssm_rank_ratio=ssm_rank_ratio,
dt_rank=ssm_dt_rank,
act_layer=ssm_act_layer,
d_conv=ssm_conv,
conv_bias=ssm_conv_bias,
dropout=ssm_drop_rate,
# bias=False,
# dt_min=0.001,
# dt_max=0.1,
# dt_init="random",
# dt_scale="random",
# dt_init_floor=1e-4,
initialize=ssm_init,
forward_type=forward_type,
)
self.drop_path = DropPath(drop_path)
self.lsblock = LSBlock(hidden_dim, hidden_dim)
if self.mlp_branch:
self.norm2 = norm_layer(hidden_dim)
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
self.mlp = RGBlock(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer,
drop=mlp_drop_rate, channels_first=False)
def forward(self, input: torch.Tensor):
input = self.proj_conv(input)
X1 = self.lsblock(input)
x = input + self.drop_path(self.op(self.norm(X1)))
if self.mlp_branch:
x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
return x
class SimpleStem(nn.Module):
def __init__(self, inp, embed_dim, ks=3):
super().__init__()
self.hidden_dims = embed_dim // 2
self.conv = nn.Sequential(
nn.Conv2d(inp, self.hidden_dims, kernel_size=ks, stride=2, padding=autopad(ks, d=1), bias=False),
nn.BatchNorm2d(self.hidden_dims),
nn.GELU(),
nn.Conv2d(self.hidden_dims, embed_dim, kernel_size=ks, stride=2, padding=autopad(ks, d=1), bias=False),
nn.BatchNorm2d(embed_dim),
nn.SiLU(),
)
def forward(self, x):
return self.conv(x)
class VisionClueMerge(nn.Module):
def __init__(self, dim, out_dim):
super().__init__()
self.hidden = int(dim * 4)
self.pw_linear = nn.Sequential(
nn.Conv2d(self.hidden, out_dim, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_dim),
nn.SiLU()
)
def forward(self, x):
y = torch.cat([
x[..., ::2, ::2],
x[..., 1::2, ::2],
x[..., ::2, 1::2],
x[..., 1::2, 1::2]
], dim=1)
return self.pw_linear(y)
四、添加步骤
4.1 基础环境
环境要求:
- Linux
- NVIDIA GPU
- PyTorch 1.12+
- CUDA 11.6+
个人环境:
- Linux
- NVIDIA GPU
- PyTorch 2.0.0
- CUDA 11.8
此处的
PyTorch和CUDA版本必须对应
,
后续安装过程中出现的问题多半是版本不匹配或是网络问题造成的。
此外,官方是在
Linux
上进行实现的,
Windows
的可以尝试一下看看。
4.2 安装并编译
以下是必须安装的模块:
1️⃣
mmcv
pip install mmcv
若报错,使用
mim
安装
首先安装openmim : pip install -U openmim
然后安装mmcv : mim install mmcv
2️⃣
causal-conv1d
pip install causal-conv1d
Building wheels for collected packages: causal-conv1d
Building wheel for causal-conv1d (setup.py) … -
若卡在这里,就是网络问题,可尝试本地安装…
3️⃣
mamba-ssm
pip install mamba-ssm
Building wheels for collected packages: mamba-ssm
Building wheel for mamba-ssm (setup.py) … -
同上
3️⃣
编译
Mamba
在群内已上传项目包
mamba
,下载到本地后解压,放在
ultralytics/nn/AddModules/mamba
路径下;
cd
进入到
ultralytics/nn/AddModules/mamba
路径下执行:
python setup.py install
若出现类似报错,则说明版本不匹配。
其它报错也是类似,重新编译前需在相同目录下执行:
python setup.py clean --all
然后再次重新编译,完成。
5️⃣
编译
Mamba-YOLO
在群内已上传相关项目包
selective_scan
,下载到本地后解压,放在
ultralytics/nn/AddModules/selective_scan
路径下;
cd
进入到
ultralytics/nn/AddModules/selective_scan
路径下执行:
python setup.py install
其余步骤,与上一步类似。
全部顺利安装完成后,相关的配置就算完成了。
4.3 代码配置
1️⃣ 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
2️⃣ 在
AddModules
文件夹下新建
mamba_yolo.py
,将
第三节
中的代码粘贴到此处
3️⃣ 在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .mamba_yolo import *
3️⃣ 在
ultralytics/nn/modules/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
其次:在
parse_model函数
中注册模块:
SimpleStem
,
VisionClueMerge
,
VSSBlock_YOLO
,
XSSBlock
:
在
DetectionModel
类下,添加如下代码
try:
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward on CPU
except RuntimeError:
try:
self.model.to(torch.device('cuda'))
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(
torch.zeros(1, ch, s, s).to(torch.device('cuda')))]) # forward on CUDA
except RuntimeError as error:
raise error
并注释这一行
# m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
五、yaml模型文件
5.1 模型改进版本
📌 新建模型文件
rtdetr-mamba-T.yaml
,并配置如下结构:
nc: 1 # number of classes
nc: 1 # number of classes
scales: # [depth, width, max_channels]
T: [0.33, 0.25, 1024] #Mamba-YOLOv8-T summary: 6.1M parameters, 14.3GFLOPs
# Mamba-YOLO backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, SimpleStem, [128, 3]] # 0-P2/4
- [-1, 2, VSSBlock_YOLO, [128]] # 1
- [-1, 1, VisionClueMerge, [256]] # 2 p3/8
- [-1, 2, VSSBlock_YOLO, [256]] # 3
- [-1, 1, VisionClueMerge, [512]] # 4 p4/16
- [-1, 2, VSSBlock_YOLO, [512]] # 5
- [-1, 1, VisionClueMerge, [1024]] # 6 p5/32
- [-1, 2, VSSBlock_YOLO, [1024]] # 7
- [-1, 1, SPPF, [1024, 5]] # 8
- [-1, 2, C2PSA, [1024]] # 9
# Mamba-YOLO PAFPN
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 5], 1, Concat, [1]] # cat backbone P4
- [-1, 2, XSSBlock, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 3], 1, Concat, [1]] # cat backbone P3
- [-1, 2, XSSBlock, [256]] # 15 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 2, XSSBlock, [512]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 2, XSSBlock, [1024]] # 21 (P5/32-large)
- [[15, 18, 21], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]] # Detect(P3, P4, P5)
六、成功运行结果
打印网络模型可以看到
Mamba
相关模块已经加入到模型中,并可以进行训练了。
rtdetr-mamba-T :
rtdetr-mamba-T summary: 426 layers, 8,607,604 parameters, 8,607,604 gradients
from n params module arguments
0 -1 1 5136 ultralytics.nn.AddModules.mamba_yolo.SimpleStem[3, 32, 3]
1 -1 1 33692 ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[32, 32]
2 -1 1 8384 ultralytics.nn.AddModules.mamba_yolo.VisionClueMerge[32, 64]
3 -1 1 104184 ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[64, 64]
4 -1 1 33152 ultralytics.nn.AddModules.mamba_yolo.VisionClueMerge[64, 128]
5 -1 1 355964 ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[128, 128]
6 -1 1 131840 ultralytics.nn.AddModules.mamba_yolo.VisionClueMerge[128, 256]
7 -1 1 1301496 ultralytics.nn.AddModules.mamba_yolo.VSSBlock_YOLO[256, 256]
8 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
9 -1 1 249728 ultralytics.nn.modules.block.C2PSA [256, 256, 1]
10 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
11 [-1, 5] 1 0 ultralytics.nn.modules.conv.Concat [1]
12 -1 1 388604 ultralytics.nn.AddModules.mamba_yolo.XSSBlock[384, 128]
13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
14 [-1, 3] 1 0 ultralytics.nn.modules.conv.Concat [1]
15 -1 1 112312 ultralytics.nn.AddModules.mamba_yolo.XSSBlock[192, 64]
16 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
17 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1]
18 -1 1 364028 ultralytics.nn.AddModules.mamba_yolo.XSSBlock[192, 128]
19 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
20 [-1, 9] 1 0 ultralytics.nn.modules.conv.Concat [1]
21 -1 1 1334008 ultralytics.nn.AddModules.mamba_yolo.XSSBlock[384, 256]
22 [15, 18, 21] 1 3835764 ultralytics.nn.modules.head.RTDETRDecoder [1, [64, 128, 256], 256, 300, 4, 8, 3]
rtdetr-mamba-T summary: 426 layers, 8,607,604 parameters, 8,607,604 gradients