学习资源站

RT-DETR改进策略【Neck】 ACMMM 2024 WFU:小波特征上采样 通过小波变换的频率分解与跨尺度融合机制,解决传统上采样过程中的混叠和细节丢失问题-

RT-DETR改进策略【Neck】| ACMMM 2024 WFU:小波特征上采样 | 通过小波变换的频率分解与跨尺度融合机制,解决传统上采样过程中的混叠和细节丢失问题

一、本文介绍

本文记录的是 利用WFU对RT-DETR的颈部网络进行改进的方法研究

YOLOv10 采用传统的 最近邻插值 的方法进行上采样无法有效地分离和融合不同频率的特征分量,导致高频细节模糊或低频结构失真,从而影响模型在多尺度目标检测中的精度。 WFU 通过 小波变换的多尺度分解与动态频率融合 的方式进行上采样, 先将输入特征分解为低频结构和高频细节分量,再分别通过残差块增强高频信息、跨尺度串联优化低频结构,最终通过逆小波变换实现特征重构,能够更精准地保留边缘纹理细节并强化语义结构的连贯性。



二、WFU介绍

Efficient Face Super-Resolution via Wavelet-based Feature Enhancement Network

2.1 设计出发点

在传统的编码器-解码器结构中,解码器通常需要通过上采样将不同尺度的特征图对齐后进行融合。然而,直接融合可能导致高频和低频特征的混叠,影响面部细节的重建质量。现有的方法(如残差 concatenation)虽然能传递信息,但未充分考虑不同频率特征的特性,导致细节恢复不清晰。

为解决这一问题,WFU模块利用小波变换的多尺度分析能力,将不同尺度的特征分解为高频和低频分量,分别进行处理和融合,以避免混叠并增强细节。

2.2 结构原理

在这里插入图片描述

2.2.1 特征分解与对齐

对于来自编码器的较大尺度特征 F s F_s F s (如 R H 4 × W 4 × 4 C \mathbb{R}^{\frac{H}{4} \times \frac{W}{4} \times 4C} R 4 H × 4 W × 4 C )和来自解码器的较小尺度特征 F s + 1 F_{s+1} F s + 1 (如 R H 8 × W 8 × 4 C \mathbb{R}^{\frac{H}{8} \times \frac{W}{8} \times 4C} R 8 H × 8 W × 4 C ),首先对 $ F_s $ 应用小波变换(WT),分解为四个子带:

  • 低频分量 A L L s A_{LL}^s A LL s (捕获整体结构)
  • 高频分量 H L R s H_{LR}^s H L R s V R L s V_{RL}^s V R L s D R R s D_{RR}^s D RR s (捕获边缘、纹理等细节)
    分解后,所有子带的尺度与 F s + 1 F_{s+1} F s + 1 一致( R H 8 × W 8 × 4 C \mathbb{R}^{\frac{H}{8} \times \frac{W}{8} \times 4C} R 8 H × 8 W × 4 C ),便于跨尺度融合。

2.2.2 频率分量处理

  • 低频融合 :假设 F s + 1 F_{s+1} F s + 1 主要包含低频信息,将其与 A L L s A_{LL}^s A LL s 串联,作为增强后的低频子带,强化整体结构的连贯性。
  • 高频增强 :对三个高频分量 H L R s H_{LR}^s H L R s V R L s V_{RL}^s V R L s D R R s D_{RR}^s D RR s ,通过残差块进一步提取细节特征,抑制噪声并增强边缘响应。

2.2.3 逆小波变换与输出

将处理后的低频和高频分量通过逆小波变换(IWT)重构,生成上采样后的特征 F s ′ F_s' F s
F s ′ = IWT ( Concat ( A L L s , F s + 1 ) , R ( H L R s , V R L s , D R R s ) ) F_s' = \text{IWT}\left( \text{Concat}(A_{LL}^s, F_{s+1}), \mathcal{R}(H_{LR}^s, V_{RL}^s, D_{RR}^s) \right) F s = IWT ( Concat ( A LL s , F s + 1 ) , R ( H L R s , V R L s , D RR s ) )
其中, R \mathcal{R} R 表示残差块操作。通过这种方式,WFU模块实现了跨尺度的频率特征分离与融合,避免了直接上采样导致的混叠问题。

2.3 优势

  1. 高效的跨尺度特征融合 :传统方法通过插值或卷积直接上采样,容易丢失高频细节或引入伪影。WFU利用小波变换的多分辨率特性,将不同尺度的特征分解为频率分量,分别处理后再融合,确保低频结构和高频细节的准确传递。

  2. 抑制频率混叠与细节增强 :小波变换的无损分解特性避免了下采样和上采样过程中的信息丢失。高频分量的独立处理(如残差块增强)有效保留了面部边缘和纹理(如眼睛睫毛、皮肤纹路),提升了重建图像的真实性。

总结

WFU模块通过小波变换的频率分解与跨尺度融合机制,解决了传统上采样过程中的混叠和细节丢失问题,实现了高效、高保真的面部细节重建。其轻量化设计和强泛化能力使其成为提升人脸超分辨率模型性能的关键组件,尤其在平衡计算效率与重建质量方面表现突出。

论文: https://arxiv.org/pdf/2407.19768
源码: https://github.com/PRIS-CV/WFEN

三、WFU的实现代码

WFU模块 的实现代码如下:

import torch
from torch import nn
import torch.nn.functional as F

def autopad(k, p=None, d=1):
    """
    Pads kernel to 'same' output shape, adjusting for optional dilation; returns padding size.
    `k`: kernel, `p`: padding, `d`: dilation.
    """
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
    default_act = nn.SiLU()  # default activation
 
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initializes a standard convolution layer with optional batch normalization and activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
 
    def forward(self, x):
        """Applies a convolution followed by batch normalization and an activation function to the input tensor `x`."""
        return self.act(self.bn(self.conv(x)))
 
    def forward_fuse(self, x):
        """Applies a fused convolution and activation function to the input tensor `x`."""
        return self.act(self.conv(x))

class HaarWavelet(nn.Module):
    def __init__(self, in_channels, grad=False):
        super(HaarWavelet, self).__init__()
        self.in_channels = in_channels

        self.haar_weights = torch.ones(4, 1, 2, 2)
        #h
        self.haar_weights[1, 0, 0, 1] = -1
        self.haar_weights[1, 0, 1, 1] = -1
        #v
        self.haar_weights[2, 0, 1, 0] = -1
        self.haar_weights[2, 0, 1, 1] = -1
        #d
        self.haar_weights[3, 0, 1, 0] = -1
        self.haar_weights[3, 0, 0, 1] = -1

        self.haar_weights = torch.cat([self.haar_weights] * self.in_channels, 0)
        self.haar_weights = nn.Parameter(self.haar_weights)
        self.haar_weights.requires_grad = grad

    def forward(self, x, rev=False):
        if not rev:
            out = F.conv2d(x, self.haar_weights, bias=None, stride=2, groups=self.in_channels) / 4.0
            out = out.reshape([x.shape[0], self.in_channels, 4, x.shape[2] // 2, x.shape[3] // 2])
            out = torch.transpose(out, 1, 2)
            out = out.reshape([x.shape[0], self.in_channels * 4, x.shape[2] // 2, x.shape[3] // 2])
            return out
        else:
            out = x.reshape([x.shape[0], 4, self.in_channels, x.shape[2], x.shape[3]])
            out = torch.transpose(out, 1, 2)
            out = out.reshape([x.shape[0], self.in_channels * 4, x.shape[2], x.shape[3]])
            return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=2, groups = self.in_channels)

class WFU(nn.Module):
    def __init__(self, chn):
        super(WFU, self).__init__()
        dim_big, dim_small = chn
        self.dim = dim_big
        self.HaarWavelet = HaarWavelet(dim_big, grad=False)
        self.InverseHaarWavelet = HaarWavelet(dim_big, grad=False)
        self.RB = nn.Sequential(
            # nn.Conv2d(dim_big, dim_big, kernel_size=3, padding=1),
            # nn.ReLU(),
            Conv(dim_big, dim_big, 3),
            nn.Conv2d(dim_big, dim_big, kernel_size=3, padding=1),
        )

        self.channel_tranformation = nn.Sequential(
            # nn.Conv2d(dim_big+dim_small, dim_big+dim_small // 1, kernel_size=1, padding=0),
            # nn.ReLU(),
            Conv(dim_big+dim_small, dim_big+dim_small // 1, 1),
            nn.Conv2d(dim_big+dim_small // 1, dim_big*3, kernel_size=1, padding=0),
        )

    def forward(self, x):
        x_big, x_small = x
        haar = self.HaarWavelet(x_big, rev=False)
        a = haar.narrow(1, 0, self.dim)
        h = haar.narrow(1, self.dim, self.dim)
        v = haar.narrow(1, self.dim*2, self.dim) 
        d = haar.narrow(1, self.dim*3, self.dim)

        hvd = self.RB(h + v + d)
        a_ = self.channel_tranformation(torch.cat([x_small, a], dim=1))
        out = self.InverseHaarWavelet(torch.cat([hvd, a_], dim=1), rev=True)
        return out

四、添加步骤

4.1 修改一

① 在 ultralytics/nn/ 目录下新建 AddModules 文件夹用于存放模块代码

② 在 AddModules 文件夹下新建 WFU.py ,将 第三节 中的代码粘贴到此处

在这里插入图片描述

4.2 修改二

AddModules 文件夹下新建 __init__.py (已有则不用新建),在文件内导入模块: from .WFU import *

在这里插入图片描述

4.3 修改三

ultralytics/nn/modules/tasks.py 文件中,需要在两处位置添加各模块类名称。

首先:导入模块

在这里插入图片描述

然后,在 parse_model函数 中添加如下代码:

        elif m in {WFU}:
            c1 = [ch[x] for x in f]
            c2 = c1[0]
            args = [c1]

在这里插入图片描述


五、yaml模型文件

5.1 模型改进版本⭐

此处以 ultralytics/cfg/models/rt-detr/rtdetr-l.yaml 为例,在同目录下创建一个用于自己数据集训练的模型文件 rtdetr-l-WFU.yaml

rtdetr-l.yaml 中的内容复制到 rtdetr-l-WFU.yaml 文件下,修改 nc 数量等于自己数据中目标的数量。

📌 模型的修改方法是将 颈部网络 中的 上采样 替换成 WFU模块

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [128, 256, 3]] # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P4/16
  - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P5/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]] # 11
  - [-1, 1, Conv, [256, 1, 1]]  # 12, Y5, lateral_convs.0

  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 13 input_proj.1
  - [[-1, -2], 1, WFU, []] # 14
  - [-1, 3, RepC3, [256, 0.5]]  # 15, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]   # 16, Y4, lateral_convs.1

  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 17 input_proj.0
  - [[-1, -2], 1, WFU, []]  # 18 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]    # X19 (17), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]   # 20, downsample_convs.0
  - [[-1, 16], 1, Concat, [1]]  # 21 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]    # F4 (22), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]   # 23, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]]  # 24 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]    # F5 (25), pan_blocks.1

  - [[19, 22, 25], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)


六、成功运行结果

打印网络模型可以看到 WFU 已经加入到模型中,并可以进行训练了。

rtdetr-l-WFU

rtdetr-l-WFU summary: 637 layers, 26,814,420 parameters, 26,798,036 gradients, 79.4 GFLOPs

                   from  n    params  module                                       arguments                     
  0                  -1  1     25248  ultralytics.nn.modules.block.HGStem          [3, 32, 48]                   
  1                  -1  6    155072  ultralytics.nn.modules.block.HGBlock         [48, 48, 128, 3, 6]           
  2                  -1  1      1408  ultralytics.nn.modules.conv.DWConv           [128, 128, 3, 2, 1, False]    
  3                  -1  6   1034496  ultralytics.nn.modules.block.HGBlock         [128, 128, 256, 3, 6]         
  4                  -1  1      5632  ultralytics.nn.modules.conv.DWConv           [256, 512, 3, 2, 1, False]    
  5                  -1  6   1695360  ultralytics.nn.modules.block.HGBlock         [512, 192, 1024, 5, 6, True, False]
  6                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  7                  -1  6   2055808  ultralytics.nn.modules.block.HGBlock         [1024, 192, 1024, 5, 6, True, True]
  8                  -1  1     11264  ultralytics.nn.modules.conv.DWConv           [1024, 1024, 3, 2, 1, False]  
  9                  -1  6   6708480  ultralytics.nn.modules.block.HGBlock         [1024, 384, 2048, 5, 6, True, False]
 10                  -1  1    524800  ultralytics.nn.modules.conv.Conv             [2048, 256, 1, 1, None, 1, 1, False]
 11                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]                
 12                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 13                   7  1    262656  ultralytics.nn.modules.conv.Conv             [1024, 256, 1, 1, None, 1, 1, False]
 14            [-1, -2]  1   1845760  ultralytics.nn.AddModules.WFU.WFU            [[256, 256]]                  
 15                  -1  3    592384  ultralytics.nn.modules.block.RepC3           [256, 256, 3, 0.5]            
 16                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 17                   3  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1, None, 1, 1, False]
 18            [-1, -2]  1   1845760  ultralytics.nn.AddModules.WFU.WFU            [[256, 256]]                  
 19                  -1  3    592384  ultralytics.nn.modules.block.RepC3           [256, 256, 3, 0.5]            
 20                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 21            [-1, 16]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 22                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]            
 23                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 24            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 25                  -1  3    657920  ultralytics.nn.modules.block.RepC3           [512, 256, 3, 0.5]            
 26        [19, 22, 25]  1   3917684  ultralytics.nn.modules.head.RTDETRDecoder    [1, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-l-WFU summary: 637 layers, 26,814,420 parameters, 26,798,036 gradients, 79.4 GFLOPs