一、本文介绍
本文给大家带来的最新改进是独家创新利用 Frequency-Adaptive Dilated Convolution 改进YOLOv11的检测头, 频率自适应膨胀卷积(FADC), FADC的核心思想是 根据图像的局部频率成分动态调整膨胀率 。这种方法使得网络能够根据图像内容的局部变化来调整 感受野 ,从而在细节丰富或高频信息密集的区域提高 性能 , 本文内容为博主全网独创新 ,下图为精度对比表现。
二、原理介绍
官方论文地址:
官方论文地址点击此处即可跳转
官方代码地址: 官方代码地址点击此处即可跳转
频率自适应膨胀卷积(FADC) ,其目的是提高 膨胀卷积 在语义分割中的性能。以下是主要思想的总结:
-
膨胀卷积概述 :膨胀卷积通过在卷积核的连续元素之间插入间隔,来扩展感受野。这种方法可以在不增加参数数量的情况下,捕获更广泛的上下文信息。
-
固定膨胀率的问题 :传统的膨胀卷积方法将膨胀率作为全局超参数固定,而这种固定的膨胀率可能在处理图像的不同区域时不够理想,因为图像的不同部分包含不同的频率成分。
-
频率自适应膨胀卷积(FADC) :FADC的核心思想是 根据图像的局部频率成分动态调整膨胀率 。这种方法使得网络能够根据图像内容的局部变化来调整感受野,从而在细节丰富或高频信息密集的区域提高性能。
-
两个插件模块 :
- 自适应卷积核(AdaKern) :该模块通过将卷积权重分解为低频和高频部分,来增强带宽和感受野大小的调节能力。
- 第二个模块(在文中未完全描述)进一步增强了模型对不同频率成分的适应能力。
这些方法的目标是在 语义分割 任务中,通过让卷积网络能够更好地捕捉图像特征的局部变化,来提升性能。
三、核心代码
YOLOv11的检测头使用了DWConv,这个论文提出了一种频率选择深度卷积整好可以替换,使用方式看章节四!
- import copy
- import math
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.fft
- from scipy.spatial import distance
- from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d, modulated_deform_conv2d
- import torch_dct as dct
- from ultralytics.utils.tal import dist2bbox, make_anchors
- class OmniAttention(nn.Module):
- def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
- super(OmniAttention, self).__init__()
- attention_channel = max(int(in_planes * reduction), min_channel)
- self.kernel_size = kernel_size
- self.kernel_num = kernel_num
- self.temperature = 1.0
- self.avgpool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)
- self.bn = nn.BatchNorm2d(attention_channel)
- self.relu = nn.ReLU(inplace=True)
- self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
- self.func_channel = self.get_channel_attention
- if in_planes == groups and in_planes == out_planes: # depth-wise convolution
- self.func_filter = self.skip
- else:
- self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
- self.func_filter = self.get_filter_attention
- if kernel_size == 1: # point-wise convolution
- self.func_spatial = self.skip
- else:
- self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
- self.func_spatial = self.get_spatial_attention
- if kernel_num == 1:
- self.func_kernel = self.skip
- else:
- self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
- self.func_kernel = self.get_kernel_attention
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def update_temperature(self, temperature):
- self.temperature = temperature
- @staticmethod
- def skip(_):
- return 1.0
- def get_channel_attention(self, x):
- channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
- return channel_attention
- def get_filter_attention(self, x):
- filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
- return filter_attention
- def get_spatial_attention(self, x):
- spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
- spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
- return spatial_attention
- def get_kernel_attention(self, x):
- kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
- kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
- return kernel_attention
- def forward(self, x):
- x = self.avgpool(x)
- x = self.fc(x)
- if x.shape[3] == 1:
- pass
- else:
- x = self.bn(x)
- x = self.relu(x)
- return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)
- import torch.nn.functional as F
- def generate_laplacian_pyramid(input_tensor, num_levels, size_align=True, mode='bilinear'):
- pyramid = []
- current_tensor = input_tensor
- _, _, H, W = current_tensor.shape
- for _ in range(num_levels):
- b, _, h, w = current_tensor.shape
- downsampled_tensor = F.interpolate(current_tensor, (h // 2 + h % 2, w // 2 + w % 2), mode=mode,
- align_corners=(H % 2) == 1) # antialias=True
- if size_align:
- # upsampled_tensor = F.interpolate(downsampled_tensor, (h, w), mode='bilinear', align_corners=(H%2) == 1)
- # laplacian = current_tensor - upsampled_tensor
- # laplacian = F.interpolate(laplacian, (H, W), mode='bilinear', align_corners=(H%2) == 1)
- upsampled_tensor = F.interpolate(downsampled_tensor, (H, W), mode=mode, align_corners=(H % 2) == 1)
- laplacian = F.interpolate(current_tensor, (H, W), mode=mode, align_corners=(H % 2) == 1) - upsampled_tensor
- # print(laplacian.shape)
- else:
- upsampled_tensor = F.interpolate(downsampled_tensor, (h, w), mode=mode, align_corners=(H % 2) == 1)
- laplacian = current_tensor - upsampled_tensor
- pyramid.append(laplacian)
- current_tensor = downsampled_tensor
- if size_align: current_tensor = F.interpolate(current_tensor, (H, W), mode=mode, align_corners=(H % 2) == 1)
- pyramid.append(current_tensor)
- return pyramid
- class FrequencySelection(nn.Module):
- def __init__(self,
- in_channels,
- k_list=[2],
- # freq_list=[2, 3, 5, 7, 9, 11],
- lowfreq_att=True,
- fs_feat='feat',
- lp_type='freq',
- act='sigmoid',
- spatial='conv',
- spatial_group=1,
- spatial_kernel=3,
- init='zero',
- global_selection=False,
- ):
- super().__init__()
- # k_list.sort()
- # print()
- self.k_list = k_list
- # self.freq_list = freq_list
- self.lp_list = nn.ModuleList()
- self.freq_weight_conv_list = nn.ModuleList()
- self.fs_feat = fs_feat
- self.lp_type = lp_type
- self.in_channels = in_channels
- # self.residual = residual
- if spatial_group > 64: spatial_group = in_channels
- self.spatial_group = spatial_group
- self.lowfreq_att = lowfreq_att
- if spatial == 'conv':
- self.freq_weight_conv_list = nn.ModuleList()
- _n = len(k_list)
- if lowfreq_att: _n += 1
- for i in range(_n):
- freq_weight_conv = nn.Conv2d(in_channels=in_channels,
- out_channels=self.spatial_group,
- stride=1,
- kernel_size=spatial_kernel,
- groups=self.spatial_group,
- padding=spatial_kernel // 2,
- bias=True)
- if init == 'zero':
- freq_weight_conv.weight.data.zero_()
- freq_weight_conv.bias.data.zero_()
- else:
- # raise NotImplementedError
- pass
- self.freq_weight_conv_list.append(freq_weight_conv)
- else:
- raise NotImplementedError
- if self.lp_type == 'avgpool':
- for k in k_list:
- self.lp_list.append(nn.Sequential(
- nn.ReplicationPad2d(padding=k // 2),
- # nn.ZeroPad2d(padding= k // 2),
- nn.AvgPool2d(kernel_size=k, padding=0, stride=1)
- ))
- elif self.lp_type == 'laplacian':
- pass
- elif self.lp_type == 'freq':
- pass
- else:
- raise NotImplementedError
- self.act = act
- # self.freq_weight_conv_list.append(nn.Conv2d(self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 1, kernel_size=1, padding=0, bias=True))
- self.global_selection = global_selection
- if self.global_selection:
- self.global_selection_conv_real = nn.Conv2d(in_channels=in_channels,
- out_channels=self.spatial_group,
- stride=1,
- kernel_size=1,
- groups=self.spatial_group,
- padding=0,
- bias=True)
- self.global_selection_conv_imag = nn.Conv2d(in_channels=in_channels,
- out_channels=self.spatial_group,
- stride=1,
- kernel_size=1,
- groups=self.spatial_group,
- padding=0,
- bias=True)
- if init == 'zero':
- self.global_selection_conv_real.weight.data.zero_()
- self.global_selection_conv_real.bias.data.zero_()
- self.global_selection_conv_imag.weight.data.zero_()
- self.global_selection_conv_imag.bias.data.zero_()
- def sp_act(self, freq_weight):
- if self.act == 'sigmoid':
- freq_weight = freq_weight.sigmoid() * 2
- elif self.act == 'softmax':
- freq_weight = freq_weight.softmax(dim=1) * freq_weight.shape[1]
- else:
- raise NotImplementedError
- return freq_weight
- def forward(self, x, att_feat=None):
- """
- att_feat:feat for gen att
- """
- # freq_weight = self.freq_weight_conv(x)
- # self.sp_act(freq_weight)
- # if self.residual: x_residual = x.clone()
- if att_feat is None: att_feat = x
- x_list = []
- if self.lp_type == 'avgpool':
- # for avg, freq_weight in zip(self.avg_list, self.freq_weight_conv_list):
- pre_x = x
- b, _, h, w = x.shape
- for idx, avg in enumerate(self.lp_list):
- low_part = avg(x)
- high_part = pre_x - low_part
- pre_x = low_part
- # x_list.append(freq_weight[:, idx:idx+1] * high_part)
- freq_weight = self.freq_weight_conv_list[idx](att_feat)
- freq_weight = self.sp_act(freq_weight)
- # tmp = freq_weight[:, :, idx:idx+1] * high_part.reshape(b, self.spatial_group, -1, h, w)
- tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * high_part.reshape(b, self.spatial_group,
- -1, h, w)
- x_list.append(tmp.reshape(b, -1, h, w))
- if self.lowfreq_att:
- freq_weight = self.freq_weight_conv_list[len(x_list)](att_feat)
- # tmp = freq_weight[:, :, len(x_list):len(x_list)+1] * pre_x.reshape(b, self.spatial_group, -1, h, w)
- tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * pre_x.reshape(b, self.spatial_group, -1, h,
- w)
- x_list.append(tmp.reshape(b, -1, h, w))
- else:
- x_list.append(pre_x)
- elif self.lp_type == 'laplacian':
- # for avg, freq_weight in zip(self.avg_list, self.freq_weight_conv_list):
- # pre_x = x
- b, _, h, w = x.shape
- pyramids = generate_laplacian_pyramid(x, len(self.k_list), size_align=True)
- # print('pyramids', len(pyramids))
- for idx, avg in enumerate(self.k_list):
- # print(idx)
- high_part = pyramids[idx]
- freq_weight = self.freq_weight_conv_list[idx](att_feat)
- freq_weight = self.sp_act(freq_weight)
- # tmp = freq_weight[:, :, idx:idx+1] * high_part.reshape(b, self.spatial_group, -1, h, w)
- tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * high_part.reshape(b, self.spatial_group,
- -1, h, w)
- x_list.append(tmp.reshape(b, -1, h, w))
- if self.lowfreq_att:
- freq_weight = self.freq_weight_conv_list[len(x_list)](att_feat)
- # tmp = freq_weight[:, :, len(x_list):len(x_list)+1] * pre_x.reshape(b, self.spatial_group, -1, h, w)
- tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * pyramids[-1].reshape(b, self.spatial_group,
- -1, h, w)
- x_list.append(tmp.reshape(b, -1, h, w))
- else:
- x_list.append(pyramids[-1])
- elif self.lp_type == 'freq':
- pre_x = x.clone()
- b, _, h, w = x.shape
- # b, _c, h, w = freq_weight.shape
- # freq_weight = freq_weight.reshape(b, self.spatial_group, -1, h, w)
- x_fft = torch.fft.fftshift(torch.fft.fft2(x, norm='ortho'))
- if self.global_selection:
- # global_att_real = self.global_selection_conv_real(x_fft.real)
- # global_att_real = self.sp_act(global_att_real).reshape(b, self.spatial_group, -1, h, w)
- # global_att_imag = self.global_selection_conv_imag(x_fft.imag)
- # global_att_imag = self.sp_act(global_att_imag).reshape(b, self.spatial_group, -1, h, w)
- # x_fft = x_fft.reshape(b, self.spatial_group, -1, h, w)
- # x_fft.real *= global_att_real
- # x_fft.imag *= global_att_imag
- # x_fft = x_fft.reshape(b, -1, h, w)
- # 将x_fft复数拆分成实部和虚部
- x_real = x_fft.real
- x_imag = x_fft.imag
- # 计算实部的全局注意力
- global_att_real = self.global_selection_conv_real(x_real)
- global_att_real = self.sp_act(global_att_real).reshape(b, self.spatial_group, -1, h, w)
- # 计算虚部的全局注意力
- global_att_imag = self.global_selection_conv_imag(x_imag)
- global_att_imag = self.sp_act(global_att_imag).reshape(b, self.spatial_group, -1, h, w)
- # 重塑x_fft为形状为(b, self.spatial_group, -1, h, w)的张量
- x_real = x_real.reshape(b, self.spatial_group, -1, h, w)
- x_imag = x_imag.reshape(b, self.spatial_group, -1, h, w)
- # 分别应用实部和虚部的全局注意力
- x_fft_real_updated = x_real * global_att_real
- x_fft_imag_updated = x_imag * global_att_imag
- # 合并为复数
- x_fft_updated = torch.complex(x_fft_real_updated, x_fft_imag_updated)
- # 重塑x_fft为形状为(b, -1, h, w)的张量
- x_fft = x_fft_updated.reshape(b, -1, h, w)
- for idx, freq in enumerate(self.k_list):
- mask = torch.zeros_like(x[:, 0:1, :, :], device=x.device)
- mask[:, :, round(h / 2 - h / (2 * freq)):round(h / 2 + h / (2 * freq)),
- round(w / 2 - w / (2 * freq)):round(w / 2 + w / (2 * freq))] = 1.0
- low_part = torch.fft.ifft2(torch.fft.ifftshift(x_fft * mask), norm='ortho').real
- high_part = pre_x - low_part
- pre_x = low_part
- freq_weight = self.freq_weight_conv_list[idx](att_feat)
- freq_weight = self.sp_act(freq_weight)
- # tmp = freq_weight[:, :, idx:idx+1] * high_part.reshape(b, self.spatial_group, -1, h, w)
- tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * high_part.reshape(b, self.spatial_group,
- -1, h, w)
- x_list.append(tmp.reshape(b, -1, h, w))
- if self.lowfreq_att:
- freq_weight = self.freq_weight_conv_list[len(x_list)](att_feat)
- # tmp = freq_weight[:, :, len(x_list):len(x_list)+1] * pre_x.reshape(b, self.spatial_group, -1, h, w)
- tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * pre_x.reshape(b, self.spatial_group, -1, h,
- w)
- x_list.append(tmp.reshape(b, -1, h, w))
- else:
- x_list.append(pre_x)
- x = sum(x_list)
- return x
- class AdaptiveDilatedConv(ModulatedDeformConv2d):
- """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
- layers.
- Args:
- in_channels (int): Same as nn.Conv2d.
- out_channels (int): Same as nn.Conv2d.
- kernel_size (int or tuple[int]): Same as nn.Conv2d.
- stride (int): Same as nn.Conv2d, while tuple is not supported.
- padding (int): Same as nn.Conv2d, while tuple is not supported.
- dilation (int): Same as nn.Conv2d, while tuple is not supported.
- groups (int): Same as nn.Conv2d.
- bias (bool or str): If specified as `auto`, it will be decided by the
- norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
- False.
- """
- _version = 2
- def __init__(self, *args,
- offset_freq=None, # deprecated
- padding_mode='repeat',
- kernel_decompose='both',
- conv_type='conv',
- sp_att=False,
- pre_fs=True, # False, use dilation
- epsilon=1e-4,
- use_zero_dilation=False,
- use_dct=False,
- fs_cfg={
- 'k_list': [2, 4, 8],
- 'fs_feat': 'feat',
- 'lowfreq_att': False,
- 'lp_type': 'freq',
- # 'lp_type':'laplacian',
- 'act': 'sigmoid',
- 'spatial': 'conv',
- 'spatial_group': 1,
- },
- **kwargs):
- super().__init__(*args, **kwargs)
- if padding_mode == 'zero':
- self.PAD = nn.ZeroPad2d(self.kernel_size[0] // 2)
- elif padding_mode == 'repeat':
- self.PAD = nn.ReplicationPad2d(self.kernel_size[0] // 2)
- else:
- self.PAD = nn.Identity()
- self.kernel_decompose = kernel_decompose
- self.use_dct = use_dct
- if kernel_decompose == 'both':
- self.OMNI_ATT1 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
- groups=1, reduction=0.0625, kernel_num=1, min_channel=16)
- self.OMNI_ATT2 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels,
- kernel_size=self.kernel_size[0] if self.use_dct else 1, groups=1,
- reduction=0.0625, kernel_num=1, min_channel=16)
- elif kernel_decompose == 'high':
- self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
- groups=1, reduction=0.0625, kernel_num=1, min_channel=16)
- elif kernel_decompose == 'low':
- self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
- groups=1, reduction=0.0625, kernel_num=1, min_channel=16)
- self.conv_type = conv_type
- if conv_type == 'conv':
- self.conv_offset = nn.Conv2d(
- self.in_channels,
- self.deform_groups * 1,
- kernel_size=self.kernel_size,
- stride=self.stride,
- padding=self.kernel_size[0] // 2 if isinstance(self.PAD, nn.Identity) else 0,
- dilation=1,
- bias=True)
- else:
- raise NotImplementedError
- pass
- # self.conv_offset_low = nn.Sequential(
- # nn.AvgPool2d(
- # kernel_size=self.kernel_size,
- # stride=self.stride,
- # padding=1,
- # ),
- # nn.Conv2d(
- # self.in_channels,
- # self.deform_groups * 1,
- # kernel_size=1,
- # stride=1,
- # padding=0,
- # dilation=1,
- # bias=False),
- # )
- # self.conv_offset_high = nn.Sequential(
- # LHPFConv3(channels=self.in_channels, stride=1, padding=1, residual=False),
- # nn.Conv2d(
- # self.in_channels,
- # self.deform_groups * 1,
- # kernel_size=1,
- # stride=1,
- # padding=0,
- # dilation=1,
- # bias=True),
- # )
- self.conv_mask = nn.Conv2d(
- self.in_channels,
- self.deform_groups * 1 * self.kernel_size[0] * self.kernel_size[1],
- kernel_size=self.kernel_size,
- stride=self.stride,
- padding=self.kernel_size[0] // 2 if isinstance(self.PAD, nn.Identity) else 0,
- dilation=1,
- bias=True)
- if sp_att:
- self.conv_mask_mean_level = nn.Conv2d(
- self.in_channels,
- self.deform_groups * 1,
- kernel_size=self.kernel_size,
- stride=self.stride,
- padding=self.kernel_size[0] // 2 if isinstance(self.PAD, nn.Identity) else 0,
- dilation=1,
- bias=True)
- self.offset_freq = offset_freq
- # An offset is like [y0, x0, y1, x1, y2, x2, ⋯, y8, x8]
- offset = [-1, -1, -1, 0, -1, 1,
- 0, -1, 0, 0, 0, 1,
- 1, -1, 1, 0, 1, 1]
- offset = torch.Tensor(offset)
- # offset[0::2] *= self.dilation[0]
- # offset[1::2] *= self.dilation[1]
- # a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
- self.register_buffer('dilated_offset', torch.Tensor(offset[None, None, ..., None, None])) # B, G, 18, 1, 1
- if fs_cfg is not None:
- if pre_fs:
- self.FS = FrequencySelection(self.in_channels, **fs_cfg)
- else:
- self.FS = FrequencySelection(1, **fs_cfg) # use dilation
- self.pre_fs = pre_fs
- self.epsilon = epsilon
- self.use_zero_dilation = use_zero_dilation
- self.init_weights()
- def freq_select(self, x):
- if self.offset_freq is None:
- res = x
- elif self.offset_freq in ('FLC_high', 'SLP_high'):
- res = x - self.LP(x)
- elif self.offset_freq in ('FLC_res', 'SLP_res'):
- res = 2 * x - self.LP(x)
- else:
- raise NotImplementedError
- return res
- def init_weights(self):
- super().init_weights()
- if hasattr(self, 'conv_offset'):
- # if isinstanace(self.conv_offset, nn.Conv2d):
- if self.conv_type == 'conv':
- self.conv_offset.weight.data.zero_()
- # self.conv_offset.bias.data.fill_((self.dilation[0] - 1) / self.dilation[0] + 1e-4)
- self.conv_offset.bias.data.fill_((self.dilation[0] - 1) / self.dilation[0] + self.epsilon)
- # self.conv_offset.bias.data.zero_()
- if hasattr(self, 'conv_mask'):
- self.conv_mask.weight.data.zero_()
- self.conv_mask.bias.data.zero_()
- if hasattr(self, 'conv_mask_mean_level'):
- self.conv_mask.weight.data.zero_()
- self.conv_mask.bias.data.zero_()
- # @force_fp32(apply_to=('x',))
- # @force_fp32
- def forward(self, x):
- # offset = self.conv_offset(self.freq_select(x)) + self.conv_offset_low(self.freq_select(x))
- if hasattr(self, 'FS') and self.pre_fs: x = self.FS(x)
- if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
- c_att1, f_att1, _, _, = self.OMNI_ATT1(x)
- c_att2, f_att2, spatial_att2, _, = self.OMNI_ATT2(x)
- elif hasattr(self, 'OMNI_ATT'):
- c_att, f_att, _, _, = self.OMNI_ATT(x)
- if self.conv_type == 'conv':
- offset = self.conv_offset(self.PAD(self.freq_select(x)))
- elif self.conv_type == 'multifreqband':
- offset = self.conv_offset(self.freq_select(x))
- # high_gate = self.conv_offset_high(x)
- # high_gate = torch.exp(-0.5 * high_gate ** 2)
- # offset = F.relu(offset, inplace=True) * self.dilation[0] - 1 # ensure > 0
- if self.use_zero_dilation:
- offset = (F.relu(offset + 1, inplace=True) - 1) * self.dilation[0] # ensure > 0
- else:
- # offset = F.relu(offset, inplace=True) * self.dilation[0] # ensure > 0
- offset = offset.abs() * self.dilation[0] # ensure > 0
- # offset[offset<0] = offset[offset<0].exp() - 1
- # print(offset.mean(), offset.std(), offset.max(), offset.min())
- if hasattr(self, 'FS') and (self.pre_fs == False): x = self.FS(x, F.interpolate(offset, x.shape[-2:],
- mode='bilinear', align_corners=(
- x.shape[
- -1] % 2) == 1))
- # print(offset.max(), offset.abs().min(), offset.abs().mean())
- # offset *= high_gate # ensure > 0
- b, _, h, w = offset.shape
- offset = offset.reshape(b, self.deform_groups, -1, h, w) * self.dilated_offset
- # offset = offset.reshape(b, self.deform_groups, -1, h, w).repeat(1, 1, 9, 1, 1)
- # offset[:, :, 0::2, ] *= self.dilated_offset[:, :, 0::2, ]
- # offset[:, :, 1::2, ] *= self.dilated_offset[:, :, 1::2, ]
- offset = offset.reshape(b, -1, h, w)
- x = self.PAD(x)
- mask = self.conv_mask(x)
- mask = mask.sigmoid()
- # print(mask.shape)
- # mask = mask.reshape(b, self.deform_groups, -1, h, w).softmax(dim=2)
- if hasattr(self, 'conv_mask_mean_level'):
- mask_mean_level = torch.sigmoid(self.conv_mask_mean_level(x)).reshape(b, self.deform_groups, -1, h, w)
- mask = mask * mask_mean_level
- mask = mask.reshape(b, -1, h, w)
- if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
- offset = offset.reshape(1, -1, h, w)
- mask = mask.reshape(1, -1, h, w)
- x = x.reshape(1, -1, x.size(-2), x.size(-1))
- adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, c_out, c_in, k, k
- adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
- adaptive_weight_res = adaptive_weight - adaptive_weight_mean
- b, c_out, c_in, k, k = adaptive_weight.shape
- if self.use_dct:
- dct_coefficients = dct.dct_2d(adaptive_weight_res)
- # print(adaptive_weight_res.shape, dct_coefficients.shape)
- spatial_att2 = spatial_att2.reshape(b, 1, 1, k, k)
- dct_coefficients = dct_coefficients * (spatial_att2 * 2)
- # print(dct_coefficients.shape)
- adaptive_weight_res = dct.idct_2d(dct_coefficients)
- # adaptive_weight_res = adaptive_weight_res.reshape(b, c_out, c_in, k, k)
- # print(adaptive_weight_res.shape, dct_coefficients.shape)
- # adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(1)) * (2 * f_att.unsqueeze(2)) + adaptive_weight - adaptive_weight_mean
- # adaptive_weight = adaptive_weight_mean * (c_att1.unsqueeze(1) * 2) * (f_att1.unsqueeze(2) * 2) + (adaptive_weight - adaptive_weight_mean) * (c_att2.unsqueeze(1) * 2) * (f_att2.unsqueeze(2) * 2)
- adaptive_weight = adaptive_weight_mean * (c_att1.unsqueeze(1) * 2) * (
- f_att1.unsqueeze(2) * 2) + adaptive_weight_res * (c_att2.unsqueeze(1) * 2) * (
- f_att2.unsqueeze(2) * 2)
- adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
- if self.bias is not None:
- bias = self.bias.repeat(b)
- else:
- bias = self.bias
- x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
- self.stride,
- (self.kernel_size[0] // 2, self.kernel_size[1] // 2) if isinstance(self.PAD,
- nn.Identity) else (
- 0, 0), # padding
- (1, 1), # dilation
- self.groups * b, self.deform_groups * b)
- elif hasattr(self, 'OMNI_ATT'):
- offset = offset.reshape(1, -1, h, w)
- mask = mask.reshape(1, -1, h, w)
- x = x.reshape(1, -1, x.size(-2), x.size(-1))
- adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, c_out, c_in, k, k
- adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
- # adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(1)) * (2 * f_att.unsqueeze(2)) + adaptive_weight - adaptive_weight_mean
- if self.kernel_decompose == 'high':
- adaptive_weight = adaptive_weight_mean + (adaptive_weight - adaptive_weight_mean) * (
- c_att.unsqueeze(1) * 2) * (f_att.unsqueeze(2) * 2)
- elif self.kernel_decompose == 'low':
- adaptive_weight = adaptive_weight_mean * (c_att.unsqueeze(1) * 2) * (f_att.unsqueeze(2) * 2) + (
- adaptive_weight - adaptive_weight_mean)
- adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
- if self.bias is not None:
- bias = self.bias.repeat(b)
- else:
- bias = self.bias
- # adaptive_bias = self.unsqueeze(0).repeat(b, 1, 1, 1, 1)
- # print(adaptive_weight.shape)
- # print(offset.shape)
- # print(mask.shape)
- # print(x.shape)
- x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
- self.stride,
- (self.kernel_size[0] // 2, self.kernel_size[1] // 2) if isinstance(self.PAD,
- nn.Identity) else (
- 0, 0), # padding
- (1, 1), # dilation
- self.groups * b, self.deform_groups * b)
- else:
- x = modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
- self.stride,
- (self.kernel_size[0] // 2, self.kernel_size[1] // 2) if isinstance(self.PAD,
- nn.Identity) else (
- 0, 0), # padding
- (1, 1), # dilation
- self.groups, self.deform_groups)
- # x = modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
- # self.stride, self.padding,
- # self.dilation, self.groups,
- # self.deform_groups)
- # if hasattr(self, 'OMNI_ATT'): x = x * f_att
- return x.reshape(b, -1, h, w)
- class AdaptiveDilatedDWConv(ModulatedDeformConv2d):
- """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
- layers.
- Args:
- in_channels (int): Same as nn.Conv2d.
- out_channels (int): Same as nn.Conv2d.
- kernel_size (int or tuple[int]): Same as nn.Conv2d.
- stride (int): Same as nn.Conv2d, while tuple is not supported.
- padding (int): Same as nn.Conv2d, while tuple is not supported.
- dilation (int): Same as nn.Conv2d, while tuple is not supported.
- groups (int): Same as nn.Conv2d.
- bias (bool or str): If specified as `auto`, it will be decided by the
- norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
- False.
- """
- _version = 2
- def __init__(self, *args,
- offset_freq=None,
- use_BFM=False,
- kernel_decompose='both',
- padding_mode='repeat',
- # padding_mode='zero',
- normal_conv_dim=0,
- pre_fs=True, # False, use dilation
- fs_cfg={
- # 'k_list':[3,5,7,9],
- 'k_list': [2, 4, 8],
- 'fs_feat': 'feat',
- 'lowfreq_att': False,
- # 'lp_type':'freq_eca',
- # 'lp_type':'freq_channel_att',
- # 'lp_type':'freq',
- # 'lp_type':'avgpool',
- 'lp_type': 'freq',
- 'act': 'sigmoid',
- 'spatial': 'conv',
- 'spatial_group': 1,
- },
- **kwargs):
- super().__init__(*args, **kwargs)
- assert self.kernel_size[0] in (3, 7)
- assert self.groups == self.in_channels
- if kernel_decompose == 'both':
- self.OMNI_ATT1 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
- groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
- self.OMNI_ATT2 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
- groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
- elif kernel_decompose == 'high':
- self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
- groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
- elif kernel_decompose == 'low':
- self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
- groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
- self.kernel_decompose = kernel_decompose
- self.normal_conv_dim = normal_conv_dim
- if padding_mode == 'zero':
- self.PAD = nn.ZeroPad2d(self.kernel_size[0] // 2)
- elif padding_mode == 'repeat':
- self.PAD = nn.ReplicationPad2d(self.kernel_size[0] // 2)
- else:
- self.PAD = nn.Identity()
- self.conv_offset = nn.Conv2d(
- self.in_channels - self.normal_conv_dim,
- self.deform_groups * 1,
- # self.groups * 1,
- kernel_size=self.kernel_size,
- stride=self.stride,
- padding=self.padding if isinstance(self.PAD, nn.Identity) else 0,
- dilation=1,
- bias=True)
- # self.conv_offset_low = nn.Sequential(
- # nn.AvgPool2d(
- # kernel_size=self.kernel_size,
- # stride=self.stride,
- # padding=1,
- # ),
- # nn.Conv2d(
- # self.in_channels,
- # self.deform_groups * 1,
- # kernel_size=1,
- # stride=1,
- # padding=0,
- # dilation=1,
- # bias=False),
- # )
- self.conv_mask = nn.Sequential(
- nn.Conv2d(
- self.in_channels - self.normal_conv_dim,
- self.in_channels - self.normal_conv_dim,
- kernel_size=self.kernel_size,
- stride=self.stride,
- padding=self.padding if isinstance(self.PAD, nn.Identity) else 0,
- groups=self.in_channels - self.normal_conv_dim,
- dilation=1,
- bias=False),
- nn.Conv2d(
- self.in_channels - self.normal_conv_dim,
- self.deform_groups * 1 * self.kernel_size[0] * self.kernel_size[1],
- kernel_size=1,
- stride=1,
- padding=0,
- groups=1,
- dilation=1,
- bias=True)
- )
- self.offset_freq = offset_freq
- # An offset is like [y0, x0, y1, x1, y2, x2, ⋯, y8, x8]
- if self.kernel_size[0] == 3:
- offset = [-1, -1, -1, 0, -1, 1,
- 0, -1, 0, 0, 0, 1,
- 1, -1, 1, 0, 1, 1]
- elif self.kernel_size[0] == 7:
- offset = [
- -3, -3, -3, -2, -3, -1, -3, 0, -3, 1, -3, 2, -3, 3,
- -2, -3, -2, -2, -2, -1, -2, 0, -2, 1, -2, 2, -2, 3,
- -1, -3, -1, -2, -1, -1, -1, 0, -1, 1, -1, 2, -1, 3,
- 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3,
- 1, -3, 1, -2, 1, -1, 1, 0, 1, 1, 1, 2, 1, 3,
- 2, -3, 2, -2, 2, -1, 2, 0, 2, 1, 2, 2, 2, 3,
- 3, -3, 3, -2, 3, -1, 3, 0, 3, 1, 3, 2, 3, 3,
- ]
- else:
- raise NotImplementedError
- offset = torch.Tensor(offset)
- # offset[0::2] *= self.dilation[0]
- # offset[1::2] *= self.dilation[1]
- # a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
- self.register_buffer('dilated_offset', torch.Tensor(offset[None, None, ..., None, None])) # B, G, 49, 1, 1
- self.init_weights()
- self.use_BFM = use_BFM
- if use_BFM:
- alpha = 8
- BFM = np.zeros((self.in_channels, 1, self.kernel_size[0], self.kernel_size[0]))
- for i in range(self.kernel_size[0]):
- for j in range(self.kernel_size[0]):
- point_1 = (i, j)
- point_2 = (self.kernel_size[0] // 2, self.kernel_size[0] // 2)
- dist = distance.euclidean(point_1, point_2)
- BFM[:, :, i, j] = alpha / (dist + alpha)
- self.register_buffer('BFM', torch.Tensor(BFM))
- if fs_cfg is not None:
- if pre_fs:
- self.FS = FrequencySelection(self.in_channels - self.normal_conv_dim, **fs_cfg)
- else:
- self.FS = FrequencySelection(1, **fs_cfg) # use dilation
- self.pre_fs = pre_fs
- def freq_select(self, x):
- if self.offset_freq is None:
- pass
- elif self.offset_freq in ('FLC_high', 'SLP_high'):
- x - self.LP(x)
- elif self.offset_freq in ('FLC_res', 'SLP_res'):
- 2 * x - self.LP(x)
- else:
- raise NotImplementedError
- return x
- def init_weights(self):
- super().init_weights()
- if hasattr(self, 'conv_offset'):
- self.conv_offset.weight.data.zero_()
- self.conv_offset.bias.data.fill_((self.dilation[0] - 1) / self.dilation[0] + 1e-4)
- # self.conv_offset.bias.data.zero_()
- # if hasattr(self, 'conv_offset_low'):
- # self.conv_offset_low[1].weight.data.zero_()
- if hasattr(self, 'conv_mask'):
- self.conv_mask[1].weight.data.zero_()
- self.conv_mask[1].bias.data.zero_()
- def forward(self, x):
- if self.normal_conv_dim > 0:
- return self.mix_forward(x)
- else:
- return self.ad_forward(x)
- def ad_forward(self, x):
- if hasattr(self, 'FS') and self.pre_fs: x = self.FS(x)
- if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
- c_att1, _, _, _, = self.OMNI_ATT1(x)
- c_att2, _, _, _, = self.OMNI_ATT2(x)
- elif hasattr(self, 'OMNI_ATT'):
- c_att, _, _, _, = self.OMNI_ATT(x)
- x = self.PAD(x)
- offset = self.conv_offset(x)
- offset = F.relu(offset, inplace=True) * self.dilation[0] # ensure > 0
- if hasattr(self, 'FS') and (self.pre_fs == False): x = self.FS(x, offset)
- b, _, h, w = offset.shape
- offset = offset.reshape(b, self.deform_groups, -1, h, w) * self.dilated_offset
- offset = offset.reshape(b, -1, h, w)
- mask = self.conv_mask(x)
- mask = torch.sigmoid(mask)
- if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
- offset = offset.reshape(1, -1, h, w)
- # print(offset.max(), offset.min(), offset.mean())
- mask = mask.reshape(1, -1, h, w)
- x = x.reshape(1, -1, x.size(-2), x.size(-1))
- adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
- adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
- adaptive_weight = adaptive_weight_mean * (2 * c_att1.unsqueeze(2)) + (
- adaptive_weight - adaptive_weight_mean) * (2 * c_att2.unsqueeze(2))
- adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
- if self.bias is not None:
- bias = self.bias.repeat(b)
- else:
- bias = self.bias
- x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
- self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
- # padding
- (1, 1), # dilation
- self.groups * b, self.deform_groups * b)
- return x.reshape(b, -1, h, w)
- elif hasattr(self, 'OMNI_ATT'):
- offset = offset.reshape(1, -1, h, w)
- mask = mask.reshape(1, -1, h, w)
- x = x.reshape(1, -1, x.size(-2), x.size(-1))
- adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
- adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
- if self.kernel_decompose == 'high':
- adaptive_weight = adaptive_weight_mean + (adaptive_weight - adaptive_weight_mean) * (
- 2 * c_att.unsqueeze(2))
- elif self.kernel_decompose == 'low':
- adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(2)) + (
- adaptive_weight - adaptive_weight_mean)
- adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
- if self.bias is not None:
- bias = self.bias.repeat(b)
- else:
- bias = self.bias
- x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
- self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
- # padding
- (1, 1), # dilation
- self.groups * b, self.deform_groups * b)
- return x.reshape(b, -1, h, w)
- else:
- return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
- self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
- # padding
- self.dilation, self.groups,
- self.deform_groups)
- def mix_forward(self, x):
- if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
- c_att1, _, _, _, = self.OMNI_ATT1(x)
- c_att2, _, _, _, = self.OMNI_ATT2(x)
- elif hasattr(self, 'OMNI_ATT'):
- c_att, _, _, _, = self.OMNI_ATT(x)
- ori_x = x
- normal_conv_x = ori_x[:, -self.normal_conv_dim:] # ad:normal
- x = ori_x[:, :-self.normal_conv_dim]
- if hasattr(self, 'FS') and self.pre_fs: x = self.FS(x)
- x = self.PAD(x)
- offset = self.conv_offset(x)
- if hasattr(self, 'FS') and (self.pre_fs == False): x = self.FS(x, F.interpolate(offset, x.shape[-2:],
- mode='bilinear', align_corners=(
- x.shape[
- -1] % 2) == 1))
- # if hasattr(self, 'FS') and (self.pre_fs==False): x = self.FS(x, offset)
- # offset = F.relu(offset, inplace=True) * self.dilation[0] # ensure > 0
- offset[offset < 0] = offset[offset < 0].exp() - 1
- b, _, h, w = offset.shape
- offset = offset.reshape(b, self.deform_groups, -1, h, w) * self.dilated_offset
- offset = offset.reshape(b, -1, h, w)
- mask = self.conv_mask(x)
- mask = torch.sigmoid(mask)
- if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
- offset = offset.reshape(1, -1, h, w)
- # print(offset.max(), offset.min(), offset.mean())
- mask = mask.reshape(1, -1, h, w)
- x = x.reshape(1, -1, x.size(-2), x.size(-1))
- adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
- adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
- adaptive_weight = adaptive_weight_mean * (2 * c_att1.unsqueeze(2)) + (
- adaptive_weight - adaptive_weight_mean) * (2 * c_att2.unsqueeze(2))
- if self.bias is not None:
- bias = self.bias.repeat(b)
- else:
- bias = self.bias
- # adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
- x = modulated_deform_conv2d(x, offset, mask, adaptive_weight[:, :-self.normal_conv_dim].reshape(-1,
- self.in_channels // self.groups,
- self.kernel_size[
- 0],
- self.kernel_size[
- 1]),
- bias,
- self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
- # padding
- (1, 1), # dilation
- (self.in_channels - self.normal_conv_dim) * b, self.deform_groups * b)
- x = x.reshape(b, -1, h, w)
- normal_conv_x = F.conv2d(normal_conv_x.reshape(1, -1, h, w),
- adaptive_weight[:, -self.normal_conv_dim:].reshape(-1,
- self.in_channels // self.groups,
- self.kernel_size[0],
- self.kernel_size[1]),
- bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation,
- groups=self.normal_conv_dim * b)
- normal_conv_x = normal_conv_x.reshape(b, -1, h, w)
- # return torch.cat([normal_conv_x, x], dim=1)
- return torch.cat([x, normal_conv_x], dim=1)
- elif hasattr(self, 'OMNI_ATT'):
- offset = offset.reshape(1, -1, h, w)
- mask = mask.reshape(1, -1, h, w)
- x = x.reshape(1, -1, x.size(-2), x.size(-1))
- adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
- adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
- if self.bias is not None:
- bias = self.bias.repeat(b)
- else:
- bias = self.bias
- if self.kernel_decompose == 'high':
- adaptive_weight = adaptive_weight_mean + (adaptive_weight - adaptive_weight_mean) * (
- 2 * c_att.unsqueeze(2))
- elif self.kernel_decompose == 'low':
- adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(2)) + (
- adaptive_weight - adaptive_weight_mean)
- x = modulated_deform_conv2d(x, offset, mask, adaptive_weight[:, :-self.normal_conv_dim].reshape(-1,
- self.in_channels // self.groups,
- self.kernel_size[
- 0],
- self.kernel_size[
- 1]),
- bias,
- self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
- # padding
- (1, 1), # dilation
- (self.in_channels - self.normal_conv_dim) * b, self.deform_groups * b)
- x = x.reshape(b, -1, h, w)
- normal_conv_x = F.conv2d(normal_conv_x.reshape(1, -1, h, w),
- adaptive_weight[:, -self.normal_conv_dim:].reshape(-1,
- self.in_channels // self.groups,
- self.kernel_size[0],
- self.kernel_size[1]),
- bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation,
- groups=self.normal_conv_dim * b)
- normal_conv_x = normal_conv_x.reshape(b, -1, h, w)
- # return torch.cat([normal_conv_x, x], dim=1)
- return torch.cat([x, normal_conv_x], dim=1)
- else:
- return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
- self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
- # padding
- self.dilation, self.groups,
- self.deform_groups)
- def autopad(k, p=None, d=1): # kernel, padding, dilation
- """Pad to 'same' shape outputs."""
- if d > 1:
- k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
- if p is None:
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
- return p
- class Conv(nn.Module):
- """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
- default_act = nn.SiLU() # default activation
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
- """Initialize Conv layer with given arguments including activation."""
- super().__init__()
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
- self.bn = nn.BatchNorm2d(c2)
- self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
- def forward(self, x):
- """Apply convolution, batch normalization and activation to input tensor."""
- return self.act(self.bn(self.conv(x)))
- def forward_fuse(self, x):
- """Perform transposed convolution of 2D data."""
- return self.act(self.conv(x))
- class DFL(nn.Module):
- """
- Integral module of Distribution Focal Loss (DFL).
- Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
- """
- def __init__(self, c1=16):
- """Initialize a convolutional layer with a given number of input channels."""
- super().__init__()
- self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
- x = torch.arange(c1, dtype=torch.float)
- self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
- self.c1 = c1
- def forward(self, x):
- """Applies a transformer layer on input tensor 'x' and returns a tensor."""
- b, _, a = x.shape # batch, channels, anchors
- return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
- # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
- class ADDWConvHead(nn.Module):
- """YOLOv8 Detect head for detection models."""
- dynamic = False # force grid reconstruction
- export = False # export mode
- end2end = False # end2end
- max_det = 300 # max_det
- shape = None
- anchors = torch.empty(0) # init
- strides = torch.empty(0) # init
- def __init__(self, nc=80, ch=()):
- """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
- super().__init__()
- self.nc = nc # number of classes
- self.nl = len(ch) # number of detection layers
- self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
- self.no = nc + self.reg_max * 4 # number of outputs per anchor
- self.stride = torch.zeros(self.nl) # strides computed during build
- c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
- self.cv2 = nn.ModuleList(
- nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
- )
- self.cv3 = nn.ModuleList(
- nn.Sequential(
- nn.Sequential(AdaptiveDilatedDWConv(x, x, groups=math.gcd(x, x), kernel_size=3, stride=1, dilation=1), Conv(x, c3, 1)),
- nn.Sequential(AdaptiveDilatedDWConv(c3, c3, groups=math.gcd(c3, c3), kernel_size=3, stride=1, dilation=1), Conv(c3, c3, 1)),
- nn.Conv2d(c3, self.nc, 1),
- )
- for x in ch
- )
- self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
- if self.end2end:
- self.one2one_cv2 = copy.deepcopy(self.cv2)
- self.one2one_cv3 = copy.deepcopy(self.cv3)
- def forward(self, x):
- """Concatenates and returns predicted bounding boxes and class probabilities."""
- if self.end2end:
- return self.forward_end2end(x)
- for i in range(self.nl):
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
- if self.training: # Training path
- return x
- y = self._inference(x)
- return y if self.export else (y, x)
- def forward_end2end(self, x):
- """
- Performs forward pass of the v10Detect module.
- Args:
- x (tensor): Input tensor.
- Returns:
- (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
- If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
- """
- x_detach = [xi.detach() for xi in x]
- one2one = [
- torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
- ]
- for i in range(self.nl):
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
- if self.training: # Training path
- return {"one2many": x, "one2one": one2one}
- y = self._inference(one2one)
- y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
- return y if self.export else (y, {"one2many": x, "one2one": one2one})
- def _inference(self, x):
- """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
- # Inference path
- shape = x[0].shape # BCHW
- x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
- if self.dynamic or self.shape != shape:
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
- self.shape = shape
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
- box = x_cat[:, : self.reg_max * 4]
- cls = x_cat[:, self.reg_max * 4 :]
- else:
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
- if self.export and self.format in {"tflite", "edgetpu"}:
- # Precompute normalization factor to increase numerical stability
- # See https://github.com/ultralytics/ultralytics/issues/7371
- grid_h = shape[2]
- grid_w = shape[3]
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
- norm = self.strides / (self.stride[0] * grid_size)
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
- else:
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
- return torch.cat((dbox, cls.sigmoid()), 1)
- def bias_init(self):
- """Initialize Detect() biases, WARNING: requires stride availability."""
- m = self # self.model[-1] # Detect() module
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
- for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
- a[-1].bias.data[:] = 1.0 # box
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
- if self.end2end:
- for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
- a[-1].bias.data[:] = 1.0 # box
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
- def decode_bboxes(self, bboxes, anchors):
- """Decode bounding boxes."""
- return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
- @staticmethod
- def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
- """
- Post-processes YOLO model predictions.
- Args:
- preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
- format [x, y, w, h, class_probs].
- max_det (int): Maximum detections per image.
- nc (int, optional): Number of classes. Default: 80.
- Returns:
- (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
- dimension format [x, y, w, h, max_class_prob, class_index].
- """
- batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
- boxes, scores = preds.split([4, nc], dim=-1)
- index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
- boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
- scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
- scores, index = scores.flatten(1).topk(min(max_det, anchors))
- i = torch.arange(batch_size)[..., None] # batch indices
- return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
- if __name__ == "__main__":
- # Generating Sample image
- image1 = (1, 64, 160, 160)
- image2 = (1, 128, 80, 80)
- image3 = (1, 256, 40, 40)
- image1 = torch.rand(image1)
- image2 = torch.rand(image2)
- image3 = torch.rand(image3)
- image = [image1, image2, image3]
- channel = (64, 128, 256)
- # Model
- mobilenet_v1 = ADDWConvHead(nc=80, ch=channel)
- out = mobilenet_v1(image)
- print(out)
四、添加教程
4.1 修改一
首先我们将上面的代码复制粘贴到'ultralytics/nn' 目录下新建一个py文件复制粘贴进去,具体名字自己来定.
4.2 修改二
第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。
4.3 修改三
第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可) !
4.4 修改四
第四步我门找到如下文件'ultralytics/nn/tasks.py,找到如下的代码进行将检测头添加进去,这里给大家推荐个快速搜索的方法用ctrl+f然后搜索Detect然后就能快速查找了。
4.5 修改五
同理
4.6 修改六
同理
4.7 修改七
这里有一些不一样,我们需要加一行代码
- else:
- return 'detect'
为啥呢不一样,因为这里的m在代码执行过程中会将你的代码自动转换为小写,所以直接else方便一点,以后出现一些其它分割或者其它的教程的时候在提供其它的修改教程。
4.8 修改八
同理.
到此就修改完成了,大家可以复制下面的yaml文件运行。
五、FADDWConvHead检测头的yaml文件
此版本训练信息:YOLO11-FADDWConvHead summary: 446 layers, 2,667,632 parameters, 2,667,616 gradients, 6.6 GFLOPs
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
- # Parameters
- nc: 80 # number of classes
- scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
- # [depth, width, max_channels]
- n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
- s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
- m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
- l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
- x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
- # YOLO11n backbone
- backbone:
- # [from, repeats, module, args]
- - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- - [-1, 2, C3k2, [256, False, 0.25]]
- - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- - [-1, 2, C3k2, [512, False, 0.25]]
- - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- - [-1, 2, C3k2, [512, True]]
- - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- - [-1, 2, C3k2, [1024, True]]
- - [-1, 1, SPPF, [1024, 5]] # 9
- - [-1, 2, C2PSA, [1024]] # 10
- # YOLO11n head
- head:
- - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- - [[-1, 6], 1, Concat, [1]] # cat backbone P4
- - [-1, 2, C3k2, [512, False]] # 13
- - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- - [[-1, 4], 1, Concat, [1]] # cat backbone P3
- - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- - [-1, 1, Conv, [256, 3, 2]]
- - [[-1, 13], 1, Concat, [1]] # cat head P4
- - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- - [-1, 1, Conv, [512, 3, 2]]
- - [[-1, 10], 1, Concat, [1]] # cat head P5
- - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- - [[16, 19, 22], 1, FADDWConvHead, [nc]] # Detect(P3, P4, P5)
六、完美运行记录
最后提供一下完美运行的图片。
七、本文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~