学习资源站

RT-DETR改进策略【Backbone主干网络】ICLR-2023替换骨干网络为:RevCol一种新型神经网络设计范式-

RT-DETR改进策略【Backbone/主干网络】| ICLR-2023 替换骨干网络为:RevCol 一种新型神经网络设计范式

一、本文介绍

本文记录的是 基于RevCol的RT-DETR目标检测改进方法研究 RevCol 是一种新型神经网络设计范式,它由 多个子网(列)及多级可逆连接构成 ,正向传播时特征逐渐解缠结且 保持信息 。可逆变换借鉴 可逆神经网络 思想,设计多级可逆单元用于解决 模型对特征图形状的限制 以及 与信息瓶颈原则的冲突 。本文将其应用到 RT-DETR 中,并配置了原论文中的 revcol_tiny revcol_small revcol_base revcol_large revcol_xlarge 五种不同大小的模型,以适应不同的需求。



二、RevCol模型设计

2.1出发点

  • 信息瓶颈原则的局限 :传统监督学习网络遵循 信息瓶颈 原则(IB),如图所示,靠近 输入的层包含更多低级信息 靠近输出的层富含语义信息 ,即 与目标无关的信息在逐层传播中逐渐被压缩 。但这种方式可能导致下游任务性能不佳,尤其当学习到的特征过度压缩或语义信息与目标任务无关,且源任务和目标任务存在领域差距时。

在这里插入图片描述

  • 解缠结特征学习的需求 :提出构建网络 学习解缠结表示 ,不同于 IB学习 ,解缠结特征学习旨在 将任务相关概念或语义分别嵌入到几个解耦维度,同时保持整个特征向量大致与输入有相同信息量 ,类似于生物细胞机制。

在计算机视觉任务中,学习解缠结特征是合理的,例如在ImageNet预训练时,高级语义表示被调整,同时低级信息(如边缘位置)也应在其他特征维度中保留,以满足下游任务(如对象检测)的需求。

2.2 原理

2.2.1 可逆变换的核心作用

  • 基于可逆神经网络 :可逆变换在特征解缠结中起关键作用,灵感源于 可逆神经网络 。以 RevNet 为例,如图( a )所示,它将输入分区,通过可逆映射进行计算,但存在对特征维度约束过强及网络不完全可逆的问题。

在这里插入图片描述

  • 提出广义可逆公式 :将 RevNet 的公式推广为更通用的形式,如图( b )所示,通过增加递归阶数 m ,放松了对特征图尺寸的约束,使其能更好地与现有网络架构合作,且网络仍 保持可逆性
  • 多级可逆单元 :将公式重构成多列形式,如图( c )所示,每列由一组 m 个特征图及其母网络组成,称为 多级可逆单元 ,作为 RevCol 的基本组件。

2.2.2 中间监督机制

  • 解决信息丢失问题 :尽管多级可逆单元能在列迭代中保持信息,但 下采样块仍可能在列内丢弃信息 。为缓解此问题,提出 中间监督方法
  • 监督方式 :在前面列的最后一级特征(Level 4)添加 两个辅助头 ,一个是 解码器用于重建输入图像 ,另一个是 线性分类器 。通过 最小化二进制交叉熵(BCE) 重建损失和以 交叉熵(CE)损失 训练线性分类器,对不同列设置不同权重的复合损失,以最大化特征与预测之间的互信息下限。

2.3 结构

2.3.1 宏观设计

  • 多子网与可逆连接 :如图所示, RevCol 网络由 N 个结构相同(权重不一定相同)的子网(列)组成,每个子网接收输入副本并生成预测。列之间采用 可逆变换传播多级特征 (从低级到高级语义表示),最后一列预测输入的最终解缠结表示。

在这里插入图片描述

  • 特征提取与传播 :输入图像先由补丁嵌入模块分割成非重叠补丁,再输入各子网。从每个列提取四级特征图用于列间信息传播。对于分类任务,使用最后一列的Level 4特征图;对于下游任务,使用最后一列的所有四级特征图。列间可逆连接采用简化的多级可逆单元实现,即取当前列一个低级特征和前一列一个高级特征作为输入,保持可逆性同时减少GPU资源消耗。

2.3.2 微观设计

  • 基于ConvNeXt的修改 :默认采用 ConvNeXt 块实现各列,并进行修改以适配宏观架构。
    • 融合模块 :在原始 ConvNeXt 的各级中,修改补丁合并块,将 LayerNorm 放在补丁合并卷积之后,通道数在补丁合并卷积中翻倍,并引入上采样块。上采样块由线性通道映射层、 LayerNorm 和特征图插值层组成,线性通道映射层通道数减半,两个块的输出相加后传入后续的残差块。
    • 卷积核大小 :将原始 ConvNeXt 中的7×7卷积默认修改为3×3,以加快训练速度,虽增大卷积核可提高精度,但 RevCol 的多列设计已扩大有效感受野,限制了大卷积核带来的精度提升。
    • 可逆操作γ :采用可学习的可逆通道缩放作为可逆操作 γ ,每次特征求和时,为 抑制特征幅度使训练稳定 ,同时在训练时截断 γ 的绝对值, 避免反向计算时数值误差过大

2.4 优势

  • 特征解缠结优势 :在 RevCol 中,各列最低级保持低级特征,最后一列最高级具有高度语义,信息在列间无损传播时逐渐解缠结,一些特征图语义性增强,一些保持低级。这使模型对依赖高低级特征的下游任务更灵活,可逆连接对解缠结机制起关键作用,对比无可逆连接的HRNet等模型,在实验中有性能优势。
  • 内存节省优势 :传统网络训练需大量内存存储前向传播的激活以用于梯度计算,而 RevCol 由于列间连接可逆,在反向传播时可从最后一列到第一列重建激活,训练时只需在内存中维护一列的激活。实验表明,随着列数增加, RevCol 大致保持 O(1) 的额外内存消耗,而非可逆架构的内存消耗随列数线性增加。
  • 新的缩放因子优势 RevCol 架构中,列数成为除深度(块数)和宽度(每个块的通道数)之外的新维度。在一定范围内,增加列数与同时增加宽度和深度有相似效果,有利于模型扩展到大模型和大数据集上。

论文: https://arxiv.org/pdf/2212.11696.pdf
源码: https://github.com/megvii-research/RevCol

三、RevCol的实现代码

RevCol模型 的实现代码如下:

# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------
from typing import Tuple, Any, List
from timm.models.layers import trunc_normal_
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath

__all__ = ['revcol_tiny', 'revcol_small', 'revcol_base', 'revcol_large', 'revcol_xlarge']

class UpSampleConvnext(nn.Module):
    def __init__(self, ratio, inchannel, outchannel):
        super().__init__()
        self.ratio = ratio
        self.channel_reschedule = nn.Sequential(
            # LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
            nn.Linear(inchannel, outchannel),
            LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
        self.upsample = nn.Upsample(scale_factor=2 ** ratio, mode='nearest')

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.channel_reschedule(x)
        x = x = x.permute(0, 3, 1, 2)

        return self.upsample(x)

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first", elementwise_affine=True):
        super().__init__()
        self.elementwise_affine = elementwise_affine
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(normalized_shape))
            self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            if self.elementwise_affine:
                x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class ConvNextBlock(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """

    def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path=0.0):
        super().__init__()
        self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                                groups=in_channel)  # depthwise conv
        self.norm = nn.LayerNorm(in_channel, eps=1e-6)
        self.pwconv1 = nn.Linear(in_channel, hidden_dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(hidden_dim, out_channel)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        # print(f"x min: {x.min()}, x max: {x.max()}, input min: {input.min()}, input max: {input.max()}, x mean: {x.mean()}, x var: {x.var()}, ratio: {torch.sum(x>8)/x.numel()}")
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

class Decoder(nn.Module):
    def __init__(self, depth=[2, 2, 2, 2], dim=[112, 72, 40, 24], block_type=None, kernel_size=3) -> None:
        super().__init__()
        self.depth = depth
        self.dim = dim
        self.block_type = block_type
        self._build_decode_layer(dim, depth, kernel_size)
        self.projback = nn.Sequential(
            nn.Conv2d(
                in_channels=dim[-1],
                out_channels=4 ** 2 * 3, kernel_size=1),
            nn.PixelShuffle(4),
        )

    def _build_decode_layer(self, dim, depth, kernel_size):
        normal_layers = nn.ModuleList()
        upsample_layers = nn.ModuleList()
        proj_layers = nn.ModuleList()

        norm_layer = LayerNorm

        for i in range(1, len(dim)):
            module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
            normal_layers.append(nn.Sequential(*module))
            upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
            proj_layers.append(nn.Sequential(
                nn.Conv2d(dim[i - 1], dim[i], 1, 1),
                norm_layer(dim[i]),
                nn.GELU()
            ))
        self.normal_layers = normal_layers
        self.upsample_layers = upsample_layers
        self.proj_layers = proj_layers

    def _forward_stage(self, stage, x):
        x = self.proj_layers[stage](x)
        x = self.upsample_layers[stage](x)
        return self.normal_layers[stage](x)

    def forward(self, c3):
        x = self._forward_stage(0, c3)  # 14
        x = self._forward_stage(1, x)  # 28
        x = self._forward_stage(2, x)  # 56
        x = self.projback(x)
        return x

class SimDecoder(nn.Module):
    def __init__(self, in_channel, encoder_stride) -> None:
        super().__init__()
        self.projback = nn.Sequential(
            LayerNorm(in_channel),
            nn.Conv2d(
                in_channels=in_channel,
                out_channels=encoder_stride ** 2 * 3, kernel_size=1),
            nn.PixelShuffle(encoder_stride),
        )

    def forward(self, c3):
        return self.projback(c3)

def get_gpu_states(fwd_gpu_devices) -> Tuple[List[int], List[torch.Tensor]]:
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())

    return fwd_gpu_states

def get_gpu_device(*args):
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))
    return fwd_gpu_devices

def set_device_states(fwd_cpu_state, devices, states) -> None:
    torch.set_rng_state(fwd_cpu_state)
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)

def detach_and_grad(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
    if isinstance(inputs, tuple):
        out = []
        for inp in inputs:
            if not isinstance(inp, torch.Tensor):
                out.append(inp)
                continue

            x = inp.detach()
            x.requires_grad = True
            out.append(x)
        return tuple(out)
    else:
        raise RuntimeError(
            "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)

def get_cpu_and_gpu_states(gpu_devices):
    return torch.get_rng_state(), get_gpu_states(gpu_devices)

class ReverseFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, run_functions, alpha, *args):
        l0, l1, l2, l3 = run_functions
        alpha0, alpha1, alpha2, alpha3 = alpha
        ctx.run_functions = run_functions
        ctx.alpha = alpha
        ctx.preserve_rng_state = True

        ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
                                   "dtype": torch.get_autocast_gpu_dtype(),
                                   "cache_enabled": torch.is_autocast_cache_enabled()}
        ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
                                   "dtype": torch.get_autocast_cpu_dtype(),
                                   "cache_enabled": torch.is_autocast_cache_enabled()}

        assert len(args) == 5
        [x, c0, c1, c2, c3] = args
        if type(c0) == int:
            ctx.first_col = True
        else:
            ctx.first_col = False
        with torch.no_grad():
            gpu_devices = get_gpu_device(*args)
            ctx.gpu_devices = gpu_devices
            ctx.cpu_states_0, ctx.gpu_states_0 = get_cpu_and_gpu_states(gpu_devices)
            c0 = l0(x, c1) + c0 * alpha0
            ctx.cpu_states_1, ctx.gpu_states_1 = get_cpu_and_gpu_states(gpu_devices)
            c1 = l1(c0, c2) + c1 * alpha1
            ctx.cpu_states_2, ctx.gpu_states_2 = get_cpu_and_gpu_states(gpu_devices)
            c2 = l2(c1, c3) + c2 * alpha2
            ctx.cpu_states_3, ctx.gpu_states_3 = get_cpu_and_gpu_states(gpu_devices)
            c3 = l3(c2, None) + c3 * alpha3
        ctx.save_for_backward(x, c0, c1, c2, c3)
        return x, c0, c1, c2, c3

    @staticmethod
    def backward(ctx, *grad_outputs):
        x, c0, c1, c2, c3 = ctx.saved_tensors
        l0, l1, l2, l3 = ctx.run_functions
        alpha0, alpha1, alpha2, alpha3 = ctx.alpha
        gx_right, g0_right, g1_right, g2_right, g3_right = grad_outputs
        (x, c0, c1, c2, c3) = detach_and_grad((x, c0, c1, c2, c3))

        with torch.enable_grad(), \
                torch.random.fork_rng(devices=ctx.gpu_devices, enabled=ctx.preserve_rng_state), \
                torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
                torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):

            g3_up = g3_right
            g3_left = g3_up * alpha3  ##shortcut
            set_device_states(ctx.cpu_states_3, ctx.gpu_devices, ctx.gpu_states_3)
            oup3 = l3(c2, None)
            torch.autograd.backward(oup3, g3_up, retain_graph=True)
            with torch.no_grad():
                c3_left = (1 / alpha3) * (c3 - oup3)  ## feature reverse
            g2_up = g2_right + c2.grad
            g2_left = g2_up * alpha2  ##shortcut

            (c3_left,) = detach_and_grad((c3_left,))
            set_device_states(ctx.cpu_states_2, ctx.gpu_devices, ctx.gpu_states_2)
            oup2 = l2(c1, c3_left)
            torch.autograd.backward(oup2, g2_up, retain_graph=True)
            c3_left.requires_grad = False
            cout3 = c3_left * alpha3  ##alpha3 update
            torch.autograd.backward(cout3, g3_up)

            with torch.no_grad():
                c2_left = (1 / alpha2) * (c2 - oup2)  ## feature reverse
            g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
            g1_up = g1_right + c1.grad
            g1_left = g1_up * alpha1  ##shortcut

            (c2_left,) = detach_and_grad((c2_left,))
            set_device_states(ctx.cpu_states_1, ctx.gpu_devices, ctx.gpu_states_1)
            oup1 = l1(c0, c2_left)
            torch.autograd.backward(oup1, g1_up, retain_graph=True)
            c2_left.requires_grad = False
            cout2 = c2_left * alpha2  ##alpha2 update
            torch.autograd.backward(cout2, g2_up)

            with torch.no_grad():
                c1_left = (1 / alpha1) * (c1 - oup1)  ## feature reverse
            g0_up = g0_right + c0.grad
            g0_left = g0_up * alpha0  ##shortcut
            g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left  ## Fusion

            (c1_left,) = detach_and_grad((c1_left,))
            set_device_states(ctx.cpu_states_0, ctx.gpu_devices, ctx.gpu_states_0)
            oup0 = l0(x, c1_left)
            torch.autograd.backward(oup0, g0_up, retain_graph=True)
            c1_left.requires_grad = False
            cout1 = c1_left * alpha1  ##alpha1 update
            torch.autograd.backward(cout1, g1_up)

            with torch.no_grad():
                c0_left = (1 / alpha0) * (c0 - oup0)  ## feature reverse
            gx_up = x.grad  ## Fusion
            g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left  ## Fusion
            c0_left.requires_grad = False
            cout0 = c0_left * alpha0  ##alpha0 update
            torch.autograd.backward(cout0, g0_up)

        if ctx.first_col:
            return None, None, gx_up, None, None, None, None
        else:
            return None, None, gx_up, g0_left, g1_left, g2_left, g3_left

class Fusion(nn.Module):
    def __init__(self, level, channels, first_col) -> None:
        super().__init__()

        self.level = level
        self.first_col = first_col
        self.down = nn.Sequential(
            nn.Conv2d(channels[level - 1], channels[level], kernel_size=2, stride=2),
            LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
        ) if level in [1, 2, 3] else nn.Identity()
        if not first_col:
            self.up = UpSampleConvnext(1, channels[level + 1], channels[level]) if level in [0, 1, 2] else nn.Identity()

    def forward(self, *args):

        c_down, c_up = args

        if self.first_col:
            x = self.down(c_down)
            return x

        if self.level == 3:
            x = self.down(c_down)
        else:
            x = self.up(c_up) + self.down(c_down)
        return x

class Level(nn.Module):
    def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None:
        super().__init__()
        countlayer = sum(layers[:level])
        expansion = 4
        self.fusion = Fusion(level, channels, first_col)
        modules = [ConvNextBlock(channels[level], expansion * channels[level], channels[level], kernel_size=kernel_size,
                                 layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer + i]) for i in
                   range(layers[level])]
        self.blocks = nn.Sequential(*modules)

    def forward(self, *args):
        x = self.fusion(*args)
        x = self.blocks(x)
        return x

class SubNet(nn.Module):
    def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None:
        super().__init__()
        shortcut_scale_init_value = 0.5
        self.save_memory = save_memory
        self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
                                   requires_grad=True) if shortcut_scale_init_value > 0 else None
        self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
                                   requires_grad=True) if shortcut_scale_init_value > 0 else None
        self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
                                   requires_grad=True) if shortcut_scale_init_value > 0 else None
        self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
                                   requires_grad=True) if shortcut_scale_init_value > 0 else None

        self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates)

        self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates)

        self.level2 = Level(2, channels, layers, kernel_size, first_col, dp_rates)

        self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates)

    def _forward_nonreverse(self, *args):
        x, c0, c1, c2, c3 = args

        c0 = (self.alpha0) * c0 + self.level0(x, c1)
        c1 = (self.alpha1) * c1 + self.level1(c0, c2)
        c2 = (self.alpha2) * c2 + self.level2(c1, c3)
        c3 = (self.alpha3) * c3 + self.level3(c2, None)
        return c0, c1, c2, c3

    def _forward_reverse(self, *args):

        local_funs = [self.level0, self.level1, self.level2, self.level3]
        alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
        _, c0, c1, c2, c3 = ReverseFunction.apply(
            local_funs, alpha, *args)

        return c0, c1, c2, c3

    def forward(self, *args):

        self._clamp_abs(self.alpha0.data, 1e-3)
        self._clamp_abs(self.alpha1.data, 1e-3)
        self._clamp_abs(self.alpha2.data, 1e-3)
        self._clamp_abs(self.alpha3.data, 1e-3)

        if self.save_memory:
            return self._forward_reverse(*args)
        else:
            return self._forward_nonreverse(*args)

    def _clamp_abs(self, data, value):
        with torch.no_grad():
            sign = data.sign()
            data.abs_().clamp_(value)
            data *= sign

class Classifier(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.LayerNorm(in_channels, eps=1e-6),  # final norm layer
            nn.Linear(in_channels, num_classes),
        )

    def forward(self, x):
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class FullNet(nn.Module):
    def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size=3, drop_path=0.0,
                 save_memory=True, inter_supv=True) -> None:
        super().__init__()
        self.num_subnet = num_subnet
        self.inter_supv = inter_supv
        self.channels = channels
        self.layers = layers

        self.stem = nn.Sequential(
            nn.Conv2d(3, channels[0], kernel_size=4, stride=4),
            LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
        )

        dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
        for i in range(num_subnet):
            first_col = True if i == 0 else False
            self.add_module(f'subnet{str(i)}', SubNet(
                channels, layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory))

        self.apply(self._init_weights)
        self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]

    def forward(self, x):

        c0, c1, c2, c3 = 0, 0, 0, 0
        x = self.stem(x)
        for i in range(self.num_subnet):
            c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
        return [c0, c1, c2, c3]

    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            trunc_normal_(module.weight, std=.02)
            nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Linear):
            trunc_normal_(module.weight, std=.02)
            nn.init.constant_(module.bias, 0)

##-------------------------------------- Tiny -----------------------------------------

def revcol_tiny(save_memory=True, inter_supv=True, drop_path=0.1, kernel_size=3):
    channels = [64, 128, 256, 512]
    layers = [2, 2, 4, 2]
    num_subnet = 4
    return FullNet(channels, layers, num_subnet, drop_path=drop_path, save_memory=save_memory, inter_supv=inter_supv,
                   kernel_size=kernel_size)

##-------------------------------------- Small -----------------------------------------

def revcol_small(save_memory=True, inter_supv=True, drop_path=0.3, kernel_size=3):
    channels = [64, 128, 256, 512]
    layers = [2, 2, 4, 2]
    num_subnet = 8
    return FullNet(channels, layers, num_subnet, drop_path=drop_path, save_memory=save_memory, inter_supv=inter_supv,
                   kernel_size=kernel_size)

##-------------------------------------- Base -----------------------------------------

def revcol_base(save_memory=True, inter_supv=True, drop_path=0.4, kernel_size=3, head_init_scale=None):
    channels = [72, 144, 288, 576]
    layers = [1, 1, 3, 2]
    num_subnet = 16
    return FullNet(channels, layers, num_subnet, drop_path=drop_path, save_memory=save_memory, inter_supv=inter_supv,
                   kernel_size=kernel_size)

##-------------------------------------- Large -----------------------------------------

def revcol_large(save_memory=True, inter_supv=True, drop_path=0.5, kernel_size=3, head_init_scale=None):
    channels = [128, 256, 512, 1024]
    layers = [1, 2, 6, 2]
    num_subnet = 8
    return FullNet(channels, layers, num_subnet, drop_path=drop_path, save_memory=save_memory, inter_supv=inter_supv,
                   kernel_size=kernel_size)

##--------------------------------------Extra-Large -----------------------------------------
def revcol_xlarge(save_memory=True, inter_supv=True, drop_path=0.5, kernel_size=3, head_init_scale=None):
    channels = [224, 448, 896, 1792]
    layers = [1, 2, 6, 2]
    num_subnet = 8
    return FullNet(channels, layers, num_subnet, drop_path=drop_path, save_memory=save_memory, inter_supv=inter_supv,
                   kernel_size=kernel_size)

# model = revcol_xlarge(True)
# # 示例输入
# input = torch.randn(64, 3, 224, 224)
# output = model(input)
#
# print(len(output))#torch.Size([3, 64, 224, 224])
 

四、修改步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 RevCol.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .RevCol import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要添加各模块类。

① 首先:导入模块

在这里插入图片描述

② 在BaseModel类的predict函数中,在如下两处位置中去掉 embed 参数:

在这里插入图片描述

③ 在BaseModel类的_predict_once函数,替换如下代码:

    def _predict_once(self, x, profile=False, visualize=False):
        """
        Perform a forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor to the model.
            profile (bool):  Print the computation time of each layer if True, defaults to False.
            visualize (bool): Save the feature maps of the model if True, defaults to False.

        Returns:
            (torch.Tensor): The last output of the model.
        """
        y, dt = [], []  # outputs
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        return x

在这里插入图片描述

④ 将 RTDETRDetectionModel类 中的 predict函数 完整替换:

    def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
        """
        Perform a forward pass through the model.

        Args:
            x (torch.Tensor): The input tensor.
            profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
            visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
            batch (dict, optional): Ground truth data for evaluation. Defaults to None.
            augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.

        Returns:
            (torch.Tensor): Model's output tensor.
        """
        y, dt = [], []  # outputs
        for m in self.model[:-1]:  # except the head part
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if hasattr(m, 'backbone'):
                x = m(x)
                for _ in range(5 - len(x)):
                    x.insert(0, None)
                for i_idx, i in enumerate(x):
                    if i_idx in self.save:
                        y.append(i)
                    else:
                        y.append(None)
                # for i in x:
                #     if i is not None:
                #         print(i.size())
                x = x[-1]
            else:
                x = m(x)  # run
                y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        head = self.model[-1]
        x = head([y[j] for j in head.f], batch)  # head inference
        return x

在这里插入图片描述

⑤ 在 parse_model函数 如下位置替换如下代码:

    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    is_backbone = False
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        try:
            if m == 'node_mode':
                m = d[m]
                if len(args) > 0:
                    if args[0] == 'head_channel':
                        args[0] = int(d[args[0]])
            t = m
            m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m]  # get module
        except:
            pass
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    try:
                        args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
                    except:
                        args[j] = a

替换后如下:

在这里插入图片描述

⑥ 在 parse_model 函数,添加如下代码。

elif m in {
           revcol_tiny, revcol_small, revcol_base, revcol_large, revcol_xlarge
           }:
    m = m(*args)
    c2 = m.width_list  

在这里插入图片描述

⑦ 在 parse_model函数 如下位置替换如下代码:

    	if isinstance(c2, list):
            is_backbone = True
            m_ = m
            m_.backbone = True
        else:
            m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
            t = str(m)[8:-2].replace('__main__.', '')  # module type
        
        m_.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i + 4 if is_backbone else i, f, t  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f}  {t:<45}{str(args):<30}')  # print
        save.extend(x % (i + 4 if is_backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        if isinstance(c2, list):
            ch.extend(c2)
            for _ in range(5 - len(ch)):
                ch.insert(0, 0)
        else:
            ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

在这里插入图片描述

⑧ 在 ultralytics\nn\autobackend.py 文件的 AutoBackend类 中的 forward函数 ,完整替换如下代码:

    def forward(self, im, augment=False, visualize=False):
        """
        Runs inference on the YOLOv8 MultiBackend model.

        Args:
            im (torch.Tensor): The image tensor to perform inference on.
            augment (bool): whether to perform data augmentation during inference, defaults to False
            visualize (bool): whether to visualize the output predictions, defaults to False

        Returns:
            (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
        """
        b, ch, h, w = im.shape  # batch, channel, height, width
        if self.fp16 and im.dtype != torch.float16:
            im = im.half()  # to FP16
        if self.nhwc:
            im = im.permute(0, 2, 3, 1)  # torch BCHW to numpy BHWC shape(1,320,192,3)

        if self.pt or self.nn_module:  # PyTorch
            y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
        elif self.jit:  # TorchScript
            y = self.model(im)
        elif self.dnn:  # ONNX OpenCV DNN
            im = im.cpu().numpy()  # torch to numpy
            self.net.setInput(im)
            y = self.net.forward()
        elif self.onnx:  # ONNX Runtime
            im = im.cpu().numpy()  # torch to numpy
            y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
        elif self.xml:  # OpenVINO
            im = im.cpu().numpy()  # FP32
            y = list(self.ov_compiled_model(im).values())
        elif self.engine:  # TensorRT
            if self.dynamic and im.shape != self.bindings['images'].shape:
                i = self.model.get_binding_index('images')
                self.context.set_binding_shape(i, im.shape)  # reshape if dynamic
                self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
                for name in self.output_names:
                    i = self.model.get_binding_index(name)
                    self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
            s = self.bindings['images'].shape
            assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
            self.binding_addrs['images'] = int(im.data_ptr())
            self.context.execute_v2(list(self.binding_addrs.values()))
            y = [self.bindings[x].data for x in sorted(self.output_names)]
        elif self.coreml:  # CoreML
            im = im[0].cpu().numpy()
            im_pil = Image.fromarray((im * 255).astype('uint8'))
            # im = im.resize((192, 320), Image.BILINEAR)
            y = self.model.predict({'image': im_pil})  # coordinates are xywh normalized
            if 'confidence' in y:
                raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with '
                                f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")
                # TODO: CoreML NMS inference handling
                # from ultralytics.utils.ops import xywh2xyxy
                # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]])  # xyxy pixels
                # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
                # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
            elif len(y) == 1:  # classification model
                y = list(y.values())
            elif len(y) == 2:  # segmentation model
                y = list(reversed(y.values()))  # reversed for segmentation models (pred, proto)
        elif self.paddle:  # PaddlePaddle
            im = im.cpu().numpy().astype(np.float32)
            self.input_handle.copy_from_cpu(im)
            self.predictor.run()
            y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
        elif self.ncnn:  # ncnn
            mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
            ex = self.net.create_extractor()
            input_names, output_names = self.net.input_names(), self.net.output_names()
            ex.input(input_names[0], mat_in)
            y = []
            for output_name in output_names:
                mat_out = self.pyncnn.Mat()
                ex.extract(output_name, mat_out)
                y.append(np.array(mat_out)[None])
        elif self.triton:  # NVIDIA Triton Inference Server
            im = im.cpu().numpy()  # torch to numpy
            y = self.model(im)
        else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
            im = im.cpu().numpy()
            if self.saved_model:  # SavedModel
                y = self.model(im, training=False) if self.keras else self.model(im)
                if not isinstance(y, list):
                    y = [y]
            elif self.pb:  # GraphDef
                y = self.frozen_func(x=self.tf.constant(im))
                if len(y) == 2 and len(self.names) == 999:  # segments and names not defined
                    ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0)  # index of protos, boxes
                    nc = y[ib].shape[1] - y[ip].shape[3] - 4  # y = (1, 160, 160, 32), (1, 116, 8400)
                    self.names = {i: f'class{i}' for i in range(nc)}
            else:  # Lite or Edge TPU
                details = self.input_details[0]
                integer = details['dtype'] in (np.int8, np.int16)  # is TFLite quantized int8 or int16 model
                if integer:
                    scale, zero_point = details['quantization']
                    im = (im / scale + zero_point).astype(details['dtype'])  # de-scale
                self.interpreter.set_tensor(details['index'], im)
                self.interpreter.invoke()
                y = []
                for output in self.output_details:
                    x = self.interpreter.get_tensor(output['index'])
                    if integer:
                        scale, zero_point = output['quantization']
                        x = (x.astype(np.float32) - zero_point) * scale  # re-scale
                    if x.ndim > 2:  # if task is not classification
                        # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
                        # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
                        x[:, [0, 2]] *= w
                        x[:, [1, 3]] *= h
                    y.append(x)
            # TF segment fixes: export is reversed vs ONNX export and protos are transposed
            if len(y) == 2:  # segment with (det, proto) output order reversed
                if len(y[1].shape) != 4:
                    y = list(reversed(y))  # should be y = (1, 116, 8400), (1, 160, 160, 32)
                y[1] = np.transpose(y[1], (0, 3, 1, 2))  # should be y = (1, 116, 8400), (1, 32, 160, 160)
            y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]

        # for x in y:
        #     print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape)  # debug shapes
        if isinstance(y, (list, tuple)):
            return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
        else:
            return self.from_numpy(y)

在这里插入图片描述

至此就修改完成了,可以配置模型开始训练了


五、yaml模型文件

5.1 模型改进⭐

在代码配置完成后,配置模型的YAML文件。

此处以 ultralytics/cfg/models/rt-detr/rtdetr-l.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 rtdetr-RevCol.yaml

rtdetr-l.yaml 中的内容复制到 rtdetr-RevCol.yaml 文件下,修改 nc 数量等于自己数据中目标的数量。

📌 模型的修改方法是将 骨干网络 替换成 revcol_tiny

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, revcol_tiny, []]  # 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 5 input_proj.2
  - [-1, 1, AIFI, [1024, 8]] # 6
  - [-1, 1, Conv, [256, 1, 1]]  # 7, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 8
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 9 input_proj.1
  - [[-2, -1], 1, Concat, [1]] # 10
  - [-1, 3, RepC3, [256]]  # 11, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]   # 12, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13
  - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 14 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 15 cat backbone P4
  - [-1, 3, RepC3, [256]]    # X3 (16), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]   # 17, downsample_convs.0
  - [[-1, 12], 1, Concat, [1]]  # 18 cat Y4
  - [-1, 3, RepC3, [256]]    # F4 (19), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]   # 20, downsample_convs.1
  - [[-1, 7], 1, Concat, [1]]  # 21 cat Y5
  - [-1, 3, RepC3, [256]]    # F5 (22), pan_blocks.1

  - [[16, 19, 22], 1, RTDETRDecoder, [nc]]  # Detect(P3, P4, P5)


六、成功运行结果

分别打印网络模型可以看到 RevCol模块 已经加入到模型中,并可以进行训练了。

rtdetr-RevCol

rtdetr-RevColV1 summary: 794 layers, 48,508,771 parameters, 48,508,771 gradients, 134.8 GFLOPs


                   from  n    params  module                                       arguments                     
  0                  -1  1  29942144  revcol_tiny                                  []                            
  1                  -1  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
  2                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]                
  3                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
  4                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
  5                   3  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
  6            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
  7                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
  8                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
  9                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 10                   2  1     33280  ultralytics.nn.modules.conv.Conv             [128, 256, 1, 1, None, 1, 1, False]
 11            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 12                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 13                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 14            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 15                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 16                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 17             [-1, 7]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 18                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 19        [16, 19, 22]  1   7303907  ultralytics.nn.modules.head.RTDETRDecoder    [1, [256, 256, 256]]          
rtdetr-RevColV1 summary: 794 layers, 48,508,771 parameters, 48,508,771 gradients, 134.8 GFLOPs