学习资源站

RT-DETR改进策略【模型轻量化】替换骨干网络CVPR-2024RepViT轻量级的VisionTransformers架构-

RT-DETR改进策略【模型轻量化】| 替换骨干网络 CVPR-2024 RepViT 轻量级的Vision Transformers架构

一、本文介绍

本文记录的是 基于RepVit的RT-DETR轻量化改进方法研究 RepVit 的网络结构借鉴 ViT 的设计理念,通过分离的 token mixe channel mixer 减少推理时的计算和内存成本,同时减少扩展比率并增加宽度,降低延迟 ,并通过 加倍通道 来弥补参数大幅减少的问题,提高了准确性。本文在替换骨干网络中配置了原论文中的 repvit_m0_9 repvit_m1_0 repvit_m1_1 repvit_m1_5 repvit_m2_3 五种模型,以满足不同的需求。

模型 参数量 计算量 推理速度
rtdetr-l 32.8M 108.0GFLOPs 11.6ms
Improved 23.2M 74.4GFLOPs 11.5ms


二、RepVit结构详解

2.1 出发点

在计算机视觉领域,设计轻量化模型对于在资源受限的移动设备上实现视觉模型的部署至关重要。近年来, 轻量级Vision Transformers (ViTs)在移动设备上表现出优越性能和较低延迟,但ViTs和轻量级Convolutional Neural Networks(CNNs)在块结构、宏观和微观设计上存在显著差异未被充分研究。本研究从ViT视角重新审视轻量级CNNs的高效设计,旨在为移动设备探索更优的模型架构,因此提出了 RepViT模型

2.2 原理

2.2.1 借鉴ViT的设计理念

  • 块设计(Block design)
    • 分离token mixer和channel mixer :轻量级ViTs的块结构包含分离的 token mixer channel mixer 这一重要设计特征。研究发现ViTs的有效性主要源于其通用的token mixer和channel mixer架构(MetaFormer架构)。在MobileNetV3 - L中,原始块设计使token mixer和channel mixer耦合,通过 移动DW卷积 可选的挤压 - 激励(SE)层 ,成功分离两者,并采用结构重新参数化技术增强模型学习,减少了推理时的计算和内存成本,降低了延迟,命名为 RepViT块
    • 减少扩展比率并增加宽度 :在ViTs中,通道混合器的扩展比率通常较大,消耗大量计算资源。而在 RepViT块 中,将所有阶段的通道混合器扩展比率设置为2,降低了延迟,并通过在每个阶段加倍通道来弥补参数大幅减少的问题,提高了准确性。

在这里插入图片描述

2.2.2 宏观设计(Macro design)

  • 早期卷积用于stem :ViTs通常使用patchify操作作为stem,容易导致优化性欠佳和对训练配方敏感。而MobileNetV3 - L采用复杂的stem,存在延迟瓶颈且限制了表示能力。研究采用早期卷积方式,即堆叠两个步长为2的3×3卷积作为stem,减少了延迟,提高了准确性。
  • 更深的下采样层 :ViTs通过单独的补丁合并层实现空间下采样,有利于增加网络深度和减少信息损失。而MobileNetV3 - L仅通过倒置瓶颈块实现下采样,可能缺乏足够网络深度。研究采用DW卷积和1×1卷积进行空间下采样并调制通道维度,还前置一个 RepViT块 进一步加深下采样层,并放置一个FFN模块记忆更多潜在信息,提高了准确性,同时降低了延迟。
  • 简单分类器 :轻量级ViTs的分类器通常由全局平均池化层和线性层组成,对延迟友好。而MobileNetV3 - L采用复杂分类器,增加了延迟负担。考虑到 RepViT块 设计后最后阶段有更多通道,研究采用简单分类器替代,虽有一定精度下降,但降低了延迟。
  • 整体阶段比率 :调整不同阶段的块数量比例,采用1:1:7:1的阶段比率并增加网络深度,提高了准确性,同时降低了延迟。

2.2.3 微观设计(Micro design)

  • 内核大小选择 :CNNs的性能和延迟受卷积核大小影响。虽然大内核卷积可展示性能增益,但对移动设备不友好。MobileNetV3 - L主要使用3×3卷积,研究在所有模块中优先使用3×3卷积,维持了准确性,同时降低了延迟。
  • 挤压 - 激励(SE)层放置 SE层 可弥补卷积的局限性,但在MobileNetV3 - L中某些块使用SE层存在问题。研究设计了一种跨块使用SE层的策略,即每个阶段的第1、3、5等块使用SE层,以最小的延迟增加获得最大的精度提升。

在这里插入图片描述

2.3 结构

RepViT模型 是一个全新的纯轻量级CNN家族,其结构基于ViT - like MetaFormer结构,完全由重新参数化卷积组成。它具有多个变体,如RepViT - M0.9/M1.0/M1.1/M1.5/M2.3等,不同变体通过每个阶段的通道数量和块数量来区分。

2.4 优势

  1. 性能优越
    • 在ImageNet - 1K上进行图像分类实验时, RepViT 在不同模型大小下均达到了最先进的性能。

例如RepViT - M1.0`在iPhone 12上以1.0 ms的延迟实现了超过80%的top - 1准确率,这是轻量级模型首次达到该水平。在不使用知识蒸馏的情况下,也能显著优于其他竞争模型。

  1. 延迟较低
    • RepViT 在各种视觉任务中展现出良好的延迟性能。

例如在对象检测和实例分割任务中,在相似模型大小下,RepViT - M1.1相比EfficientFormer - L1 backbone具有更小的延迟;在语义分割任务中,RepViT - M1.5相比EfficientFormerV2 - S2具有近50%的延迟降低,同时具有更好的性能。

  1. 适用于移动设备
    • RepViT 的设计充分考虑了移动设备的资源受限特性,通过借鉴ViT的高效架构设计,对轻量级CNN进行优化,使其在移动设备上具有良好的性能和延迟表现,为移动设备上的视觉任务提供了更优的模型选择。

论文: https://arxiv.org/pdf/2307.09283
源码: https://github.com/THU-MIG/RepViT

三、RepVit模块的实现代码

RepVit模块 的实现代码如下:

 
import torch.nn as nn
import numpy as np
from timm.models.layers import SqueezeExcite
import torch

__all__ = ['repvit_m0_9', 'repvit_m1_0', 'repvit_m1_1', 'repvit_m1_5', 'repvit_m2_3']

def replace_batchnorm(net):
    for child_name, child in net.named_children():
        if hasattr(child, 'fuse_self'):
            fused = child.fuse_self()
            setattr(net, child_name, fused)
            replace_batchnorm(fused)
        elif isinstance(child, torch.nn.BatchNorm2d):
            setattr(net, child_name, torch.nn.Identity())
        else:
            replace_batchnorm(child)

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class Conv2d_BN(torch.nn.Sequential):
    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
                 groups=1, bn_weight_init=1, resolution=-10000):
        super().__init__()
        self.add_module('c', torch.nn.Conv2d(
            a, b, ks, stride, pad, dilation, groups, bias=False))
        self.add_module('bn', torch.nn.BatchNorm2d(b))
        torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        torch.nn.init.constant_(self.bn.bias, 0)

    @torch.no_grad()
    def fuse_self(self):
        c, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = c.weight * w[:, None, None, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
            device=c.weight.device)
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

class Residual(torch.nn.Module):
    def __init__(self, m, drop=0.):
        super().__init__()
        self.m = m
        self.drop = drop

    def forward(self, x):
        if self.training and self.drop > 0:
            return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
                                              device=x.device).ge_(self.drop).div(1 - self.drop).detach()
        else:
            return x + self.m(x)
    
    @torch.no_grad()
    def fuse_self(self):
        if isinstance(self.m, Conv2d_BN):
            m = self.m.fuse_self()
            assert(m.groups == m.in_channels)
            identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
            identity = torch.nn.functional.pad(identity, [1,1,1,1])
            m.weight += identity.to(m.weight.device)
            return m
        elif isinstance(self.m, torch.nn.Conv2d):
            m = self.m
            assert(m.groups != m.in_channels)
            identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
            identity = torch.nn.functional.pad(identity, [1,1,1,1])
            m.weight += identity.to(m.weight.device)
            return m
        else:
            return self

class RepVGGDW(torch.nn.Module):
    def __init__(self, ed) -> None:
        super().__init__()
        self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
        self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
        self.dim = ed
        self.bn = torch.nn.BatchNorm2d(ed)
    
    def forward(self, x):
        return self.bn((self.conv(x) + self.conv1(x)) + x)
    
    @torch.no_grad()
    def fuse_self(self):
        conv = self.conv.fuse_self()
        conv1 = self.conv1
        
        conv_w = conv.weight
        conv_b = conv.bias
        conv1_w = conv1.weight
        conv1_b = conv1.bias
        
        conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])

        identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])

        final_conv_w = conv_w + conv1_w + identity
        final_conv_b = conv_b + conv1_b

        conv.weight.data.copy_(final_conv_w)
        conv.bias.data.copy_(final_conv_b)

        bn = self.bn
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = conv.weight * w[:, None, None, None]
        b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        conv.weight.data.copy_(w)
        conv.bias.data.copy_(b)
        return conv

class RepViTBlock(nn.Module):
    def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
        super(RepViTBlock, self).__init__()
        assert stride in [1, 2]

        self.identity = stride == 1 and inp == oup
        assert(hidden_dim == 2 * inp)

        if stride == 2:
            self.token_mixer = nn.Sequential(
                Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
                Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(oup, 2 * oup, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
                ))
        else:
            assert(self.identity)
            self.token_mixer = nn.Sequential(
                RepVGGDW(inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(inp, hidden_dim, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
                ))

    def forward(self, x):
        return self.channel_mixer(self.token_mixer(x))

class RepViT(nn.Module):
    def __init__(self, cfgs):
        super(RepViT, self).__init__()
        # setting of inverted residual blocks
        self.cfgs = cfgs

        # building first layer
        input_channel = self.cfgs[0][2]
        patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
                           Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
        layers = [patch_embed]
        # building inverted residual blocks
        block = RepViTBlock
        for k, t, c, use_se, use_hs, s in self.cfgs:
            output_channel = _make_divisible(c, 8)
            exp_size = _make_divisible(input_channel * t, 8)
            layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
            input_channel = output_channel
        self.features = nn.ModuleList(layers)
        self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
        
    def forward(self, x):
        input_size = x.size(2)
        scale = [4, 8, 16, 32]
        features = [None, None, None, None]
        for f in self.features:
            x = f(x)
            if input_size // x.size(2) in scale:
                features[scale.index(input_size // x.size(2))] = x
        return features
    
    def switch_to_deploy(self):
        replace_batchnorm(self)

def update_weight(model_dict, weight_dict):
    idx, temp_dict = 0, {}
    for k, v in weight_dict.items():
        # k = k[9:]
        if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
            temp_dict[k] = v
            idx += 1
    model_dict.update(temp_dict)
    print(f'loading weights... {idx}/{len(model_dict)} items')
    return model_dict

def repvit_m0_9(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  48, 1, 0, 1],
        [3,   2,  48, 0, 0, 1],
        [3,   2,  48, 0, 0, 1],
        [3,   2,  96, 0, 0, 2],
        [3,   2,  96, 1, 0, 1],
        [3,   2,  96, 0, 0, 1],
        [3,   2,  96, 0, 0, 1],
        [3,   2,  192, 0, 1, 2],
        [3,   2,  192, 1, 1, 1],
        [3,   2,  192, 0, 1, 1],
        [3,   2,  192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 1, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 192, 0, 1, 1],
        [3,   2, 384, 0, 1, 2],
        [3,   2, 384, 1, 1, 1],
        [3,   2, 384, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m1_0(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  56, 1, 0, 1],
        [3,   2,  56, 0, 0, 1],
        [3,   2,  56, 0, 0, 1],
        [3,   2,  112, 0, 0, 2],
        [3,   2,  112, 1, 0, 1],
        [3,   2,  112, 0, 0, 1],
        [3,   2,  112, 0, 0, 1],
        [3,   2,  224, 0, 1, 2],
        [3,   2,  224, 1, 1, 1],
        [3,   2,  224, 0, 1, 1],
        [3,   2,  224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 1, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 224, 0, 1, 1],
        [3,   2, 448, 0, 1, 2],
        [3,   2, 448, 1, 1, 1],
        [3,   2, 448, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m1_1(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  64, 1, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  128, 0, 0, 2],
        [3,   2,  128, 1, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  256, 0, 1, 2],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 512, 0, 1, 2],
        [3,   2, 512, 1, 1, 1],
        [3,   2, 512, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m1_5(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  64, 1, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  64, 1, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  64, 0, 0, 1],
        [3,   2,  128, 0, 0, 2],
        [3,   2,  128, 1, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  128, 1, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  128, 0, 0, 1],
        [3,   2,  256, 0, 1, 2],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2,  256, 0, 1, 1],
        [3,   2,  256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 1, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 256, 0, 1, 1],
        [3,   2, 512, 0, 1, 2],
        [3,   2, 512, 1, 1, 1],
        [3,   2, 512, 0, 1, 1],
        [3,   2, 512, 1, 1, 1],
        [3,   2, 512, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

def repvit_m2_3(weights=''):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t, c, SE, HS, s 
        [3,   2,  80, 1, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  80, 1, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  80, 1, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  80, 0, 0, 1],
        [3,   2,  160, 0, 0, 2],
        [3,   2,  160, 1, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  160, 1, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  160, 1, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  160, 0, 0, 1],
        [3,   2,  320, 0, 1, 2],
        [3,   2,  320, 1, 1, 1],
        [3,   2,  320, 0, 1, 1],
        [3,   2,  320, 1, 1, 1],
        [3,   2,  320, 0, 1, 1],
        [3,   2,  320, 1, 1, 1],
        [3,   2,  320, 0, 1, 1],
        [3,   2,  320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 1, 1, 1],
        [3,   2, 320, 0, 1, 1],
        # [3,   2, 320, 1, 1, 1],
        # [3,   2, 320, 0, 1, 1],
        [3,   2, 320, 0, 1, 1],
        [3,   2, 640, 0, 1, 2],
        [3,   2, 640, 1, 1, 1],
        [3,   2, 640, 0, 1, 1],
        # [3,   2, 640, 1, 1, 1],
        # [3,   2, 640, 0, 1, 1]
    ]
    model = RepViT(cfgs)
    if weights:
        model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['model']))
    return model

if __name__ == '__main__':
    model = repvit_m2_3('repvit_m2_3_distill_450e.pth')
    inputs = torch.randn((1, 3, 640, 640))
    res = model(inputs)
    for i in res:
        print(i.size())

四、修改步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 RepVit.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .RepVit import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要添加各模块类。

① 首先:导入模块

在这里插入图片描述

② 在BaseModel类的predict函数中,在如下两处位置中去掉 embed 参数:

在这里插入图片描述

③ 在BaseModel类的_predict_once函数,替换如下代码:

    def _predict_once(self, x, profile=False, visualize=False):
        """
        Perform a forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor to the model.
            profile (bool):  Print the computation time of each layer if True, defaults to False.
            visualize (bool): Save the feature maps of the model if True, defaults to False.

        Returns:
            (torch.Tensor): The last output of the model.
        """
        y, dt = [], []  # outputs
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        return x

在这里插入图片描述

④ 将 RTDETRDetectionModel类 中的 predict函数 完整替换:

    def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
        """
        Perform a forward pass through the model.

        Args:
            x (torch.Tensor): The input tensor.
            profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
            visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
            batch (dict, optional): Ground truth data for evaluation. Defaults to None.
            augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.

        Returns:
            (torch.Tensor): Model's output tensor.
        """
        y, dt = [], []  # outputs
        for m in self.model[:-1]:  # except the head part
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if hasattr(m, 'backbone'):
                x = m(x)
                for _ in range(5 - len(x)):
                    x.insert(0, None)
                for i_idx, i in enumerate(x):
                    if i_idx in self.save:
                        y.append(i)
                    else:
                        y.append(None)
                # for i in x:
                #     if i is not None:
                #         print(i.size())
                x = x[-1]
            else:
                x = m(x)  # run
                y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        head = self.model[-1]
        x = head([y[j] for j in head.f], batch)  # head inference
        return x

在这里插入图片描述

⑤ 在 parse_model函数 如下位置替换如下代码:

    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    is_backbone = False
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        try:
            if m == 'node_mode':
                m = d[m]
                if len(args) > 0:
                    if args[0] == 'head_channel':
                        args[0] = int(d[args[0]])
            t = m
            m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m]  # get module
        except:
            pass
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    try:
                        args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
                    except:
                        args[j] = a

替换后如下:

在这里插入图片描述

⑥ 在 parse_model 函数,添加如下代码。

elif m in {repvit_m0_9, repvit_m1_0, repvit_m1_1, repvit_m1_5, repvit_m2_3,}:
    m = m(*args)
    c2 = m.channel

在这里插入图片描述

⑦ 在 parse_model函数 如下位置替换如下代码:

    	if isinstance(c2, list):
            is_backbone = True
            m_ = m
            m_.backbone = True
        else:
            m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
            t = str(m)[8:-2].replace('__main__.', '')  # module type
        
        m_.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i + 4 if is_backbone else i, f, t  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f}  {t:<45}{str(args):<30}')  # print
        save.extend(x % (i + 4 if is_backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        if isinstance(c2, list):
            ch.extend(c2)
            for _ in range(5 - len(ch)):
                ch.insert(0, 0)
        else:
            ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

在这里插入图片描述

⑧ 在 ultralytics\nn\autobackend.py 文件的 AutoBackend类 中的 forward函数 ,完整替换如下代码:

    def forward(self, im, augment=False, visualize=False):
        """
        Runs inference on the YOLOv8 MultiBackend model.

        Args:
            im (torch.Tensor): The image tensor to perform inference on.
            augment (bool): whether to perform data augmentation during inference, defaults to False
            visualize (bool): whether to visualize the output predictions, defaults to False

        Returns:
            (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
        """
        b, ch, h, w = im.shape  # batch, channel, height, width
        if self.fp16 and im.dtype != torch.float16:
            im = im.half()  # to FP16
        if self.nhwc:
            im = im.permute(0, 2, 3, 1)  # torch BCHW to numpy BHWC shape(1,320,192,3)

        if self.pt or self.nn_module:  # PyTorch
            y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
        elif self.jit:  # TorchScript
            y = self.model(im)
        elif self.dnn:  # ONNX OpenCV DNN
            im = im.cpu().numpy()  # torch to numpy
            self.net.setInput(im)
            y = self.net.forward()
        elif self.onnx:  # ONNX Runtime
            im = im.cpu().numpy()  # torch to numpy
            y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
        elif self.xml:  # OpenVINO
            im = im.cpu().numpy()  # FP32
            y = list(self.ov_compiled_model(im).values())
        elif self.engine:  # TensorRT
            if self.dynamic and im.shape != self.bindings['images'].shape:
                i = self.model.get_binding_index('images')
                self.context.set_binding_shape(i, im.shape)  # reshape if dynamic
                self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
                for name in self.output_names:
                    i = self.model.get_binding_index(name)
                    self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
            s = self.bindings['images'].shape
            assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
            self.binding_addrs['images'] = int(im.data_ptr())
            self.context.execute_v2(list(self.binding_addrs.values()))
            y = [self.bindings[x].data for x in sorted(self.output_names)]
        elif self.coreml:  # CoreML
            im = im[0].cpu().numpy()
            im_pil = Image.fromarray((im * 255).astype('uint8'))
            # im = im.resize((192, 320), Image.BILINEAR)
            y = self.model.predict({'image': im_pil})  # coordinates are xywh normalized
            if 'confidence' in y:
                raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with '
                                f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")
                # TODO: CoreML NMS inference handling
                # from ultralytics.utils.ops import xywh2xyxy
                # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]])  # xyxy pixels
                # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
                # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
            elif len(y) == 1:  # classification model
                y = list(y.values())
            elif len(y) == 2:  # segmentation model
                y = list(reversed(y.values()))  # reversed for segmentation models (pred, proto)
        elif self.paddle:  # PaddlePaddle
            im = im.cpu().numpy().astype(np.float32)
            self.input_handle.copy_from_cpu(im)
            self.predictor.run()
            y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
        elif self.ncnn:  # ncnn
            mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
            ex = self.net.create_extractor()
            input_names, output_names = self.net.input_names(), self.net.output_names()
            ex.input(input_names[0], mat_in)
            y = []
            for output_name in output_names:
                mat_out = self.pyncnn.Mat()
                ex.extract(output_name, mat_out)
                y.append(np.array(mat_out)[None])
        elif self.triton:  # NVIDIA Triton Inference Server
            im = im.cpu().numpy()  # torch to numpy
            y = self.model(im)
        else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
            im = im.cpu().numpy()
            if self.saved_model:  # SavedModel
                y = self.model(im, training=False) if self.keras else self.model(im)
                if not isinstance(y, list):
                    y = [y]
            elif self.pb:  # GraphDef
                y = self.frozen_func(x=self.tf.constant(im))
                if len(y) == 2 and len(self.names) == 999:  # segments and names not defined
                    ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0)  # index of protos, boxes
                    nc = y[ib].shape[1] - y[ip].shape[3] - 4  # y = (1, 160, 160, 32), (1, 116, 8400)
                    self.names = {i: f'class{i}' for i in range(nc)}
            else:  # Lite or Edge TPU
                details = self.input_details[0]
                integer = details['dtype'] in (np.int8, np.int16)  # is TFLite quantized int8 or int16 model
                if integer:
                    scale, zero_point = details['quantization']
                    im = (im / scale + zero_point).astype(details['dtype'])  # de-scale
                self.interpreter.set_tensor(details['index'], im)
                self.interpreter.invoke()
                y = []
                for output in self.output_details:
                    x = self.interpreter.get_tensor(output['index'])
                    if integer:
                        scale, zero_point = output['quantization']
                        x = (x.astype(np.float32) - zero_point) * scale  # re-scale
                    if x.ndim > 2:  # if task is not classification
                        # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
                        # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
                        x[:, [0, 2]] *= w
                        x[:, [1, 3]] *= h
                    y.append(x)
            # TF segment fixes: export is reversed vs ONNX export and protos are transposed
            if len(y) == 2:  # segment with (det, proto) output order reversed
                if len(y[1].shape) != 4:
                    y = list(reversed(y))  # should be y = (1, 116, 8400), (1, 160, 160, 32)
                y[1] = np.transpose(y[1], (0, 3, 1, 2))  # should be y = (1, 116, 8400), (1, 32, 160, 160)
            y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]

        # for x in y:
        #     print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape)  # debug shapes
        if isinstance(y, (list, tuple)):
            return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
        else:
            return self.from_numpy(y)

在这里插入图片描述

至此就修改完成了,可以配置模型开始训练了


五、yaml模型文件

5.1 模型改进⭐

在代码配置完成后,配置模型的YAML文件。

此处以 ultralytics/cfg/models/rt-detr/rtdetr-l.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 rtdetr-RepVit.yaml

rtdetr-l.yaml 中的内容复制到 rtdetr-RepVit.yaml 文件下,修改 nc 数量等于自己数据中目标的数量。

📌 模型的修改方法是将 骨干网络 替换成 starnet_s050

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, repvit_m0_9, []]  # 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 5 input_proj.2
  - [-1, 1, AIFI, [1024, 8]] # 6
  - [-1, 1, Conv, [256, 1, 1]]  # 7, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 8
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 9 input_proj.1
  - [[-2, -1], 1, Concat, [1]] # 10
  - [-1, 3, RepC3, [256]]  # 11, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]   # 12, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13
  - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 14 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 15 cat backbone P4
  - [-1, 3, RepC3, [256]]    # X3 (16), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]   # 17, downsample_convs.0
  - [[-1, 12], 1, Concat, [1]]  # 18 cat Y4
  - [-1, 3, RepC3, [256]]    # F4 (19), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]   # 20, downsample_convs.1
  - [[-1, 7], 1, Concat, [1]]  # 21 cat Y5
  - [-1, 3, RepC3, [256]]    # F5 (22), pan_blocks.1

  - [[16, 19, 22], 1, RTDETRDecoder, [nc]]  # Detect(P3, P4, P5)


六、成功运行结果

分别打印网络模型可以看到 RepVit模块 已经加入到模型中,并可以进行训练了。

rtdetr-RepVit

rtdetr-RepVit summary: 898 layers, 23,227,075 parameters, 23,227,075 gradients, 74.4 GFLOPs

                   from  n    params  module                                       arguments                     
  0                  -1  1   4717792  repvit_m0_9                                  []                            
  1                  -1  1     98816  ultralytics.nn.modules.conv.Conv             [384, 256, 1, 1, None, 1, 1, False]
  2                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]                
  3                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
  4                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
  5                   3  1     49664  ultralytics.nn.modules.conv.Conv             [192, 256, 1, 1, None, 1, 1, False]
  6            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
  7                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
  8                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
  9                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 10                   2  1     25088  ultralytics.nn.modules.conv.Conv             [96, 256, 1, 1, None, 1, 1, False]
 11            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 12                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 13                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 14            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 15                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 16                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 17             [-1, 7]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 18                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 19        [16, 19, 22]  1   7303907  ultralytics.nn.modules.head.RTDETRDecoder    [1, [256, 256, 256]]          
rtdetr-RepVit summary: 898 layers, 23,227,075 parameters, 23,227,075 gradients, 74.4 GFLOPs