学习资源站

YOLOv8独家原创改进:大核卷积涨点系列 _ Shift-ConvNets,具有大核效应的小卷积核 _ 2024年最新论文

💡💡💡本文独家改进:大的卷积核设计成为使卷积神经网络(CNNs)再次强大的理想解决方案,Shift-ConvNets稀疏/移位操作让小卷积核也能达到大卷积核效果,创新十足实现涨点,助力YOLOv8

💡💡💡在多个私有数据集和公开数据集VisDrone2019、PASCAL VOC实现涨点

收录

YOLOv8原创自研

💡💡💡全网独家首发创新(原创),适合paper !!!

💡💡💡 2024年计算机视觉顶会创新点适用于Yolov5、Yolov7、Yolov8等各个Yolo系列,专栏文章提供每一步步骤和源码,轻松带你上手魔改网络 !!!

💡💡💡重点:通过本专栏的阅读,后续你也可以设计魔改网络,在网络不同位置(Backbone、head、detect、loss等)进行魔改,实现创新!!!

1.Shift-ConvNets原理介绍

摘要:近年来的研究表明,视觉变压器(ViTs)的卓越性能得益于大的接受野。因此,大卷积核设计成为卷积神经网络(cnn)再次伟大的理想解决方案。然而,典型的大卷积核是对硬件不友好的运算符,导致各种硬件平台的兼容性降低。因此,简单地扩大卷积核的大小是不明智的。在本文中,我们揭示了小卷积核和卷积操作可以实现大核大小的关闭效果。然后,我们提出了一种移位算子,确保cnn在稀疏机制的帮助下捕获远程依赖关系,同时保持硬件友好。实验结果表明,我们的移位算子显著提高了常规CNN的准确率,同时显著降低了计算需求。在ImageNet-1k上,我们的移位增强CNN模型优于最先进的模型。 

论文:https://arxiv.org/pdf/2401.12736.pdf

 

大卷积核是提高cnn性能的核心要素。同时,其固有的问题也不容忽视:

1. 大卷积核的核大小过大,超出了常规算子的优化范围。各种软件和硬件平台的有限优化阻碍了它的广泛应用。

2. 增加卷积核的大小似乎具有递减的边际收益。

3. 转换器具有稀疏关注的能力,而不必关注所有输入令牌。CNN有可能引入这个功能吗?

4. 大卷积核的计算成本是其应用的一个障碍。即使是参数较少的异构卷积,如可学习的扩展卷积和DCN,也不是硬件友好的算子。

基于大卷积核的持续演化,我们提出了移位算子。改进的模块结构如图4(a)所示。具体来说,我们将大卷积核转换成一组正常的小卷积核,然后对每个卷积结果应用移位操作。

图4所示。(a)将一个M × N卷积核分解为k N × N卷积核,并使用移位操作来完成等效大卷积核运算(对于SLaK部分阶段,M=51, N=5);(b)整体模块结构。 

图5所示。大卷积核的框架。(a) Shift操作,向多个方向移动特征,实现局部特征变化;(b)在xvolution中提出特征关注方法,通过转移特征建立局部特征依赖关系来近似全局特征,在cnn中廉价地使用非局部特征关注方法;(c)我们提出的移位方法,沿着一维移动特征以与小卷积核大小的网格对齐,以等效地表示大卷积核,并通过去除一些移位来建立稀疏依赖关系。

图6所示。大卷积核的框架。(a) SLaK在大卷积核中的应用。它使用两个51 × 5卷积核进行水平和垂直卷积。最后加上5×5卷积的结果。(SLaK使用细粒度稀疏性来减少参数计数);(b)我们提出的方法将一个大的51 × 5卷积核分成11个标准的5 × 5卷积核。然后我们使用特征移位和加法来达到与51 × 5大卷积核相同的效果。同时,我们利用粗粒度稀疏性实现大卷积,探索局部特征之间的稀疏相关性。

 

 

2.如何将Shift-Conv加入到YOLOv8

2.1 新建ultralytics/nn/conv/shift_wiseConv.py

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from ultralytics.nn.modules import C3, Bottleneck , C2f

__all__ = ['ReparamLargeKernelConv']

def get_conv2d(
        in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias
):
    # return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
    try:
        paddings = (kernel_size[0] // 2, kernel_size[1] // 2)
    except Exception as e:
        paddings = padding
    return nn.Conv2d(
        in_channels, out_channels, kernel_size, stride, paddings, dilation, groups, bias
    )

def get_bn(channels):
    return nn.BatchNorm2d(channels)

class Mask(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.weight = torch.nn.Parameter(data=torch.Tensor(*size), requires_grad=True)
        self.weight.data.uniform_(-1, 1)

    def forward(self, x):
        w = torch.sigmoid(self.weight)
        masked_wt = w.mul(x)
        return masked_wt

def conv_bn_ori(
        in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1, bn=True
):
    if padding is None:
        padding = kernel_size // 2
    result = nn.Sequential()
    result.add_module(
        "conv",
        get_conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=False,
        ),
    )

    if bn:
        result.add_module("bn", get_bn(out_channels))
    return result

class LoRAConvsByWeight(nn.Module):
    '''
    merge LoRA1 LoRA2
    shuffle channel by weights rather index
    '''

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 big_kernel, small_kernel,
                 stride=1, group=1,
                 bn=True, use_small_conv=True):
        super().__init__()
        self.kernels = (small_kernel, big_kernel)
        self.stride = stride
        self.small_conv = use_small_conv
        # add same padding for vertical and horizon axis. should delete it accordingly
        padding, after_padding_index, index = self.shift(self.kernels)
        self.pad = padding, after_padding_index, index
        self.nk = math.ceil(big_kernel / small_kernel)
        out_n = out_channels * self.nk
        self.split_convs = nn.Conv2d(in_channels, out_n,
                                     kernel_size=small_kernel, stride=stride,
                                     padding=padding, groups=group,
                                     bias=False)

        self.lora1 = Mask((1, out_n, 1, 1))
        self.lora2 = Mask((1, out_n, 1, 1))
        self.use_bn = bn

        if bn:
            self.bn_lora1 = get_bn(out_channels)
            self.bn_lora2 = get_bn(out_channels)
        else:
            self.bn_lora1 = None
            self.bn_lora2 = None

    def forward(self, inputs):
        out = self.split_convs(inputs)
        # split output
        *_, ori_h, ori_w = inputs.shape
        lora1_x = self.forward_lora(self.lora1(out), ori_h, ori_w, VH='H', bn=self.bn_lora1)
        lora2_x = self.forward_lora(self.lora2(out), ori_h, ori_w, VH='W', bn=self.bn_lora2)
        x = lora1_x + lora2_x
        return x

    def forward_lora(self, out, ori_h, ori_w, VH='H', bn=None):
        # shift along the index of every group
        b, c, h, w = out.shape
        out = torch.split(out.reshape(b, -1, self.nk, h, w), 1, 2)  # ※※※※※※※※※※※
        x = 0
        for i in range(self.nk):
            outi = self.rearrange_data(out[i], i, ori_h, ori_w, VH)
            x = x + outi
        if self.use_bn:
            x = bn(x)
        return x

    def rearrange_data(self, x, idx, ori_h, ori_w, VH):
        padding, _, index = self.pad
        x = x.squeeze(2)  # ※※※※※※※
        *_, h, w = x.shape
        k = min(self.kernels)
        ori_k = max(self.kernels)
        ori_p = ori_k // 2
        stride = self.stride
        # need to calculate start point after conv
        # how many windows shift from real start window index
        if (idx + 1) >= index:
            pad_l = 0
            s = (idx + 1 - index) * (k // stride)
        else:
            pad_l = (index - 1 - idx) * (k // stride)
            s = 0
        if VH == 'H':
            # assume add sufficient padding for origin conv
            suppose_len = (ori_w + 2 * ori_p - ori_k) // stride + 1
            pad_r = 0 if (s + suppose_len) <= (w + pad_l) else s + suppose_len - w - pad_l
            new_pad = (pad_l, pad_r, 0, 0)
            dim = 3
            # e = w + pad_l + pad_r - s - suppose_len
        else:
            # assume add sufficient padding for origin conv
            suppose_len = (ori_h + 2 * ori_p - ori_k) // stride + 1
            pad_r = 0 if (s + suppose_len) <= (h + pad_l) else s + suppose_len - h - pad_l
            new_pad = (0, 0, pad_l, pad_r)
            dim = 2
            # e = h + pad_l + pad_r - s - suppose_len
        # print('new_pad', new_pad)
        if len(set(new_pad)) > 1:
            x = F.pad(x, new_pad)
        # split_list = [s, suppose_len, e]
        # padding on v direction
        if padding * 2 + 1 != k:
            pad = padding - k // 2
            if VH == 'H':  # horizonal
                x = torch.narrow(x, 2, pad, h - 2 * pad)
            else:  # vertical
                x = torch.narrow(x, 3, pad, w - 2 * pad)

        xs = torch.narrow(x, dim, s, suppose_len)
        return xs

    def shift(self, kernels):
        '''
        We assume the conv does not change the feature map size, so padding = bigger_kernel_size//2. Otherwise,
        you may configure padding as you wish, and change the padding of small_conv accordingly.
        '''
        mink, maxk = min(kernels), max(kernels)
        mid_p = maxk // 2
        # 1. new window size is mink. middle point index in the window
        offset_idx_left = mid_p % mink
        offset_idx_right = (math.ceil(maxk / mink) * mink - mid_p - 1) % mink
        # 2. padding
        padding = offset_idx_left % mink
        while padding < offset_idx_right:
            padding += mink
        # 3. make sure last pixel can be scan by min window
        while padding < (mink - 1):
            padding += mink
        # 4. index of windows start point of middle point
        after_padding_index = padding - offset_idx_left
        index = math.ceil((mid_p + 1) / mink)
        real_start_idx = index - after_padding_index // mink
        # 5. output:padding how to padding input in v&h direction;
        # after_padding_index: middle point of original kernel will located in which window
        # real_start_idx: start window index after padding in original kernel along long side
        return padding, after_padding_index, real_start_idx

def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1, bn=True, use_small_conv=True):
    if isinstance(kernel_size, int) or len(set(kernel_size)) == 1:
        return conv_bn_ori(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            groups,
            dilation,
            bn)
    else:
        big_kernel, small_kernel = kernel_size
        return LoRAConvsByWeight(in_channels, out_channels, bn=bn,
                                 big_kernel=big_kernel, small_kernel=small_kernel,
                                 group=groups, stride=stride,
                                 use_small_conv=use_small_conv)

def fuse_bn(conv, bn):
    kernel = conv.weight
    running_mean = bn.running_mean
    running_var = bn.running_var
    gamma = bn.weight
    beta = bn.bias
    eps = bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)
    return kernel * t, beta - running_mean * gamma / std

class ReparamLargeKernelConv(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            small_kernel=5,
            stride=1,
            groups=1,
            small_kernel_merged=False,
            Decom=True,
            bn=True,
    ):
        super(ReparamLargeKernelConv, self).__init__()
        self.kernel_size = kernel_size
        self.small_kernel = small_kernel
        self.Decom = Decom
        # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
        padding = kernel_size // 2
        if small_kernel_merged:  # cpp版本的conv,加快速度
            self.lkb_reparam = get_conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=1,
                groups=groups,
                bias=True,
            )
        else:
            if self.Decom:
                self.LoRA = conv_bn(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=(kernel_size, small_kernel),
                    stride=stride,
                    padding=padding,
                    dilation=1,
                    groups=groups,
                    bn=bn
                )
            else:
                self.lkb_origin = conv_bn(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    dilation=1,
                    groups=groups,
                    bn=bn,
                )

            if (small_kernel is not None) and small_kernel < kernel_size:
                self.small_conv = conv_bn(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=small_kernel,
                    stride=stride,
                    padding=small_kernel // 2,
                    groups=groups,
                    dilation=1,
                    bn=bn,
                )

        self.bn = get_bn(out_channels)
        self.act = nn.SiLU()

    def forward(self, inputs):
        if hasattr(self, "lkb_reparam"):
            out = self.lkb_reparam(inputs)
        elif self.Decom:
            # out = self.LoRA1(inputs) + self.LoRA2(inputs)
            out = self.LoRA(inputs)
            if hasattr(self, "small_conv"):
                out += self.small_conv(inputs)
        else:
            out = self.lkb_origin(inputs)
            if hasattr(self, "small_conv"):
                out += self.small_conv(inputs)
        return self.act(self.bn(out))

    def get_equivalent_kernel_bias(self):
        eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
        if hasattr(self, "small_conv"):
            small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
            eq_b += small_b
            #   add to the central part
            eq_k += nn.functional.pad(
                small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
            )
        return eq_k, eq_b

    def switch_to_deploy(self):
        if hasattr(self, 'lkb_origin'):
            eq_k, eq_b = self.get_equivalent_kernel_bias()
            self.lkb_reparam = get_conv2d(
                in_channels=self.lkb_origin.conv.in_channels,
                out_channels=self.lkb_origin.conv.out_channels,
                kernel_size=self.lkb_origin.conv.kernel_size,
                stride=self.lkb_origin.conv.stride,
                padding=self.lkb_origin.conv.padding,
                dilation=self.lkb_origin.conv.dilation,
                groups=self.lkb_origin.conv.groups,
                bias=True,
            )
            self.lkb_reparam.weight.data = eq_k
            self.lkb_reparam.bias.data = eq_b
            self.__delattr__("lkb_origin")
            if hasattr(self, "small_conv"):
                self.__delattr__("small_conv")

class Bottleneck_SWC(Bottleneck):
    """Standard bottleneck with DilatedReparamBlock."""

    def __init__(self, c1, c2, kernel_size, shortcut=True, g=1, k=(3, 3), e=0.5):  # ch_in, ch_out, shortcut, groups, kernels, expand
        super().__init__(c1, c2, shortcut, g, k, e)
        c_ = int(c2 * e)  # hidden channels
        self.cv2 = ReparamLargeKernelConv(c2, c2, kernel_size, groups=(c2 // 16))

class C3_SWC(C3):
    def __init__(self, c1, c2, n=1, kernel_size=13, shortcut=False, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)  # hidden channels
        self.m = nn.Sequential(*(Bottleneck_SWC(c_, c_, kernel_size, shortcut, g, k=(1, 3), e=1.0) for _ in range(n)))

class C2f_SWC(C2f):
    def __init__(self, c1, c2, n=1, kernel_size=13, shortcut=False, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        self.m = nn.ModuleList(Bottleneck_SWC(self.c, self.c, kernel_size, shortcut, g, k=(3, 3), e=1.0) for _ in range(n))

2.2  注册ultralytics/nn/tasks.py

1)C2f_SWC进行定义

from ultralytics.nn.conv.shift_wiseConv import C2f_SWC

2)修改 def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)

        if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
                 BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3,C2f_SWC):
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)

            args = [c1, c2, *args[1:]]
            if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3,C2f_SWC):
                args.insert(2, n)  # number of repeats
                n = 1

2.3 yolov8_C2f_shift_wiseConv.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f_SWC, [128, 11, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f_SWC, [256, 9, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f_SWC, [512, 7, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f_SWC, [1024, 7, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)