RT-DETR改进策略【Neck】| ACMMM 2024 WFU:小波特征上采样 | 通过小波变换的频率分解与跨尺度融合机制,解决传统上采样过程中的混叠和细节丢失问题
一、本文介绍
本文记录的是 利用WFU对RT-DETR的颈部网络进行改进的方法研究 。
YOLOv10
采用传统的
最近邻插值
的方法进行上采样无法有效地分离和融合不同频率的特征分量,导致高频细节模糊或低频结构失真,从而影响模型在多尺度目标检测中的精度。
WFU
通过
小波变换的多尺度分解与动态频率融合
的方式进行上采样,
先将输入特征分解为低频结构和高频细节分量,再分别通过残差块增强高频信息、跨尺度串联优化低频结构,最终通过逆小波变换实现特征重构,能够更精准地保留边缘纹理细节并强化语义结构的连贯性。
二、WFU介绍
Efficient Face Super-Resolution via Wavelet-based Feature Enhancement Network
2.1 设计出发点
在传统的编码器-解码器结构中,解码器通常需要通过上采样将不同尺度的特征图对齐后进行融合。然而,直接融合可能导致高频和低频特征的混叠,影响面部细节的重建质量。现有的方法(如残差 concatenation)虽然能传递信息,但未充分考虑不同频率特征的特性,导致细节恢复不清晰。
为解决这一问题,WFU模块利用小波变换的多尺度分析能力,将不同尺度的特征分解为高频和低频分量,分别进行处理和融合,以避免混叠并增强细节。
2.2 结构原理
2.2.1 特征分解与对齐
对于来自编码器的较大尺度特征 F s F_s F s (如 R H 4 × W 4 × 4 C \mathbb{R}^{\frac{H}{4} \times \frac{W}{4} \times 4C} R 4 H × 4 W × 4 C )和来自解码器的较小尺度特征 F s + 1 F_{s+1} F s + 1 (如 R H 8 × W 8 × 4 C \mathbb{R}^{\frac{H}{8} \times \frac{W}{8} \times 4C} R 8 H × 8 W × 4 C ),首先对 $ F_s $ 应用小波变换(WT),分解为四个子带:
- 低频分量 A L L s A_{LL}^s A LL s (捕获整体结构)
-
高频分量
H
L
R
s
H_{LR}^s
H
L
R
s
、
V
R
L
s
V_{RL}^s
V
R
L
s
、
D
R
R
s
D_{RR}^s
D
RR
s
(捕获边缘、纹理等细节)
分解后,所有子带的尺度与 F s + 1 F_{s+1} F s + 1 一致( R H 8 × W 8 × 4 C \mathbb{R}^{\frac{H}{8} \times \frac{W}{8} \times 4C} R 8 H × 8 W × 4 C ),便于跨尺度融合。
2.2.2 频率分量处理
- 低频融合 :假设 F s + 1 F_{s+1} F s + 1 主要包含低频信息,将其与 A L L s A_{LL}^s A LL s 串联,作为增强后的低频子带,强化整体结构的连贯性。
- 高频增强 :对三个高频分量 H L R s H_{LR}^s H L R s 、 V R L s V_{RL}^s V R L s 、 D R R s D_{RR}^s D RR s ,通过残差块进一步提取细节特征,抑制噪声并增强边缘响应。
2.2.3 逆小波变换与输出
将处理后的低频和高频分量通过逆小波变换(IWT)重构,生成上采样后的特征
F
s
′
F_s'
F
s
′
:
F
s
′
=
IWT
(
Concat
(
A
L
L
s
,
F
s
+
1
)
,
R
(
H
L
R
s
,
V
R
L
s
,
D
R
R
s
)
)
F_s' = \text{IWT}\left( \text{Concat}(A_{LL}^s, F_{s+1}), \mathcal{R}(H_{LR}^s, V_{RL}^s, D_{RR}^s) \right)
F
s
′
=
IWT
(
Concat
(
A
LL
s
,
F
s
+
1
)
,
R
(
H
L
R
s
,
V
R
L
s
,
D
RR
s
)
)
其中,
R
\mathcal{R}
R
表示残差块操作。通过这种方式,WFU模块实现了跨尺度的频率特征分离与融合,避免了直接上采样导致的混叠问题。
2.3 优势
-
高效的跨尺度特征融合 :传统方法通过插值或卷积直接上采样,容易丢失高频细节或引入伪影。WFU利用小波变换的多分辨率特性,将不同尺度的特征分解为频率分量,分别处理后再融合,确保低频结构和高频细节的准确传递。
-
抑制频率混叠与细节增强 :小波变换的无损分解特性避免了下采样和上采样过程中的信息丢失。高频分量的独立处理(如残差块增强)有效保留了面部边缘和纹理(如眼睛睫毛、皮肤纹路),提升了重建图像的真实性。
总结
WFU模块通过小波变换的频率分解与跨尺度融合机制,解决了传统上采样过程中的混叠和细节丢失问题,实现了高效、高保真的面部细节重建。其轻量化设计和强泛化能力使其成为提升人脸超分辨率模型性能的关键组件,尤其在平衡计算效率与重建质量方面表现突出。
论文: https://arxiv.org/pdf/2407.19768
源码: https://github.com/PRIS-CV/WFEN
三、WFU的实现代码
WFU模块
的实现代码如下:
import torch
from torch import nn
import torch.nn.functional as F
def autopad(k, p=None, d=1):
"""
Pads kernel to 'same' output shape, adjusting for optional dilation; returns padding size.
`k`: kernel, `p`: padding, `d`: dilation.
"""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initializes a standard convolution layer with optional batch normalization and activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Applies a convolution followed by batch normalization and an activation function to the input tensor `x`."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Applies a fused convolution and activation function to the input tensor `x`."""
return self.act(self.conv(x))
class HaarWavelet(nn.Module):
def __init__(self, in_channels, grad=False):
super(HaarWavelet, self).__init__()
self.in_channels = in_channels
self.haar_weights = torch.ones(4, 1, 2, 2)
#h
self.haar_weights[1, 0, 0, 1] = -1
self.haar_weights[1, 0, 1, 1] = -1
#v
self.haar_weights[2, 0, 1, 0] = -1
self.haar_weights[2, 0, 1, 1] = -1
#d
self.haar_weights[3, 0, 1, 0] = -1
self.haar_weights[3, 0, 0, 1] = -1
self.haar_weights = torch.cat([self.haar_weights] * self.in_channels, 0)
self.haar_weights = nn.Parameter(self.haar_weights)
self.haar_weights.requires_grad = grad
def forward(self, x, rev=False):
if not rev:
out = F.conv2d(x, self.haar_weights, bias=None, stride=2, groups=self.in_channels) / 4.0
out = out.reshape([x.shape[0], self.in_channels, 4, x.shape[2] // 2, x.shape[3] // 2])
out = torch.transpose(out, 1, 2)
out = out.reshape([x.shape[0], self.in_channels * 4, x.shape[2] // 2, x.shape[3] // 2])
return out
else:
out = x.reshape([x.shape[0], 4, self.in_channels, x.shape[2], x.shape[3]])
out = torch.transpose(out, 1, 2)
out = out.reshape([x.shape[0], self.in_channels * 4, x.shape[2], x.shape[3]])
return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=2, groups = self.in_channels)
class WFU(nn.Module):
def __init__(self, chn):
super(WFU, self).__init__()
dim_big, dim_small = chn
self.dim = dim_big
self.HaarWavelet = HaarWavelet(dim_big, grad=False)
self.InverseHaarWavelet = HaarWavelet(dim_big, grad=False)
self.RB = nn.Sequential(
# nn.Conv2d(dim_big, dim_big, kernel_size=3, padding=1),
# nn.ReLU(),
Conv(dim_big, dim_big, 3),
nn.Conv2d(dim_big, dim_big, kernel_size=3, padding=1),
)
self.channel_tranformation = nn.Sequential(
# nn.Conv2d(dim_big+dim_small, dim_big+dim_small // 1, kernel_size=1, padding=0),
# nn.ReLU(),
Conv(dim_big+dim_small, dim_big+dim_small // 1, 1),
nn.Conv2d(dim_big+dim_small // 1, dim_big*3, kernel_size=1, padding=0),
)
def forward(self, x):
x_big, x_small = x
haar = self.HaarWavelet(x_big, rev=False)
a = haar.narrow(1, 0, self.dim)
h = haar.narrow(1, self.dim, self.dim)
v = haar.narrow(1, self.dim*2, self.dim)
d = haar.narrow(1, self.dim*3, self.dim)
hvd = self.RB(h + v + d)
a_ = self.channel_tranformation(torch.cat([x_small, a], dim=1))
out = self.InverseHaarWavelet(torch.cat([hvd, a_], dim=1), rev=True)
return out
四、添加步骤
4.1 修改一
① 在
ultralytics/nn/
目录下新建
AddModules
文件夹用于存放模块代码
② 在
AddModules
文件夹下新建
WFU.py
,将
第三节
中的代码粘贴到此处
4.2 修改二
在
AddModules
文件夹下新建
__init__.py
(已有则不用新建),在文件内导入模块:
from .WFU import *
4.3 修改三
在
ultralytics/nn/modules/tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:导入模块
然后,在
parse_model函数
中添加如下代码:
elif m in {WFU}:
c1 = [ch[x] for x in f]
c2 = c1[0]
args = [c1]
五、yaml模型文件
5.1 模型改进版本⭐
此处以
ultralytics/cfg/models/rt-detr/rtdetr-l.yaml
为例,在同目录下创建一个用于自己数据集训练的模型文件
rtdetr-l-WFU.yaml
。
将
rtdetr-l.yaml
中的内容复制到
rtdetr-l-WFU.yaml
文件下,修改
nc
数量等于自己数据中目标的数量。
📌 模型的修改方法是将
颈部网络
中的
上采样
替换成
WFU模块
。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, HGStem, [32, 48]] # 0-P2/4
- [-1, 6, HGBlock, [48, 128, 3]] # stage 1
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
- [-1, 6, HGBlock, [128, 256, 3]] # stage 2
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P4/16
- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
- [-1, 6, HGBlock, [192, 1024, 5, True, True]]
- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P5/32
- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
head:
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
- [-1, 1, AIFI, [1024, 8]] # 11
- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 13 input_proj.1
- [[-1, -2], 1, WFU, []] # 14
- [-1, 3, RepC3, [256, 0.5]] # 15, fpn_blocks.0
- [-1, 1, Conv, [256, 1, 1]] # 16, Y4, lateral_convs.1
- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 17 input_proj.0
- [[-1, -2], 1, WFU, []] # 18 cat backbone P4
- [-1, 3, RepC3, [256, 0.5]] # X19 (17), fpn_blocks.1
- [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.0
- [[-1, 16], 1, Concat, [1]] # 21 cat Y4
- [-1, 3, RepC3, [256, 0.5]] # F4 (22), pan_blocks.0
- [-1, 1, Conv, [256, 3, 2]] # 23, downsample_convs.1
- [[-1, 12], 1, Concat, [1]] # 24 cat Y5
- [-1, 3, RepC3, [256, 0.5]] # F5 (25), pan_blocks.1
- [[19, 22, 25], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]] # Detect(P3, P4, P5)
六、成功运行结果
打印网络模型可以看到
WFU
已经加入到模型中,并可以进行训练了。
rtdetr-l-WFU :
rtdetr-l-WFU summary: 637 layers, 26,814,420 parameters, 26,798,036 gradients, 79.4 GFLOPs
from n params module arguments
0 -1 1 25248 ultralytics.nn.modules.block.HGStem [3, 32, 48]
1 -1 6 155072 ultralytics.nn.modules.block.HGBlock [48, 48, 128, 3, 6]
2 -1 1 1408 ultralytics.nn.modules.conv.DWConv [128, 128, 3, 2, 1, False]
3 -1 6 1034496 ultralytics.nn.modules.block.HGBlock [128, 128, 256, 3, 6]
4 -1 1 5632 ultralytics.nn.modules.conv.DWConv [256, 512, 3, 2, 1, False]
5 -1 6 1695360 ultralytics.nn.modules.block.HGBlock [512, 192, 1024, 5, 6, True, False]
6 -1 6 2055808 ultralytics.nn.modules.block.HGBlock [1024, 192, 1024, 5, 6, True, True]
7 -1 6 2055808 ultralytics.nn.modules.block.HGBlock [1024, 192, 1024, 5, 6, True, True]
8 -1 1 11264 ultralytics.nn.modules.conv.DWConv [1024, 1024, 3, 2, 1, False]
9 -1 6 6708480 ultralytics.nn.modules.block.HGBlock [1024, 384, 2048, 5, 6, True, False]
10 -1 1 524800 ultralytics.nn.modules.conv.Conv [2048, 256, 1, 1, None, 1, 1, False]
11 -1 1 789760 ultralytics.nn.modules.transformer.AIFI [256, 1024, 8]
12 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
13 7 1 262656 ultralytics.nn.modules.conv.Conv [1024, 256, 1, 1, None, 1, 1, False]
14 [-1, -2] 1 1845760 ultralytics.nn.AddModules.WFU.WFU [[256, 256]]
15 -1 3 592384 ultralytics.nn.modules.block.RepC3 [256, 256, 3, 0.5]
16 -1 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1]
17 3 1 66048 ultralytics.nn.modules.conv.Conv [256, 256, 1, 1, None, 1, 1, False]
18 [-1, -2] 1 1845760 ultralytics.nn.AddModules.WFU.WFU [[256, 256]]
19 -1 3 592384 ultralytics.nn.modules.block.RepC3 [256, 256, 3, 0.5]
20 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
21 [-1, 16] 1 0 ultralytics.nn.modules.conv.Concat [1]
22 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
23 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
24 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1]
25 -1 3 657920 ultralytics.nn.modules.block.RepC3 [512, 256, 3, 0.5]
26 [19, 22, 25] 1 3917684 ultralytics.nn.modules.head.RTDETRDecoder [1, [256, 256, 256], 256, 300, 4, 8, 3]
rtdetr-l-WFU summary: 637 layers, 26,814,420 parameters, 26,798,036 gradients, 79.4 GFLOPs