学习资源站

YOLOv11改进-检测头篇-2024独家创新自适应性DWConv改进v11检测头独创FADWCHead(全网独家首发创新)

一、本文介绍

本文给大家带来的最新改进是独家创新利用 Frequency-Adaptive Dilated Convolution 改进YOLOv11的检测头, 频率自适应膨胀卷积(FADC), FADC的核心思想是 根据图像的局部频率成分动态调整膨胀率 。这种方法使得网络能够根据图像内容的局部变化来调整 感受野 ,从而在细节丰富或高频信息密集的区域提高 性能 本文内容为博主全网独创新 ,下图为精度对比表现。



二、原理介绍

官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转


频率自适应膨胀卷积(FADC) ,其目的是提高 膨胀卷积 在语义分割中的性能。以下是主要思想的总结:

  1. 膨胀卷积概述 :膨胀卷积通过在卷积核的连续元素之间插入间隔,来扩展感受野。这种方法可以在不增加参数数量的情况下,捕获更广泛的上下文信息。

  2. 固定膨胀率的问题 :传统的膨胀卷积方法将膨胀率作为全局超参数固定,而这种固定的膨胀率可能在处理图像的不同区域时不够理想,因为图像的不同部分包含不同的频率成分。

  3. 频率自适应膨胀卷积(FADC) :FADC的核心思想是 根据图像的局部频率成分动态调整膨胀率 。这种方法使得网络能够根据图像内容的局部变化来调整感受野,从而在细节丰富或高频信息密集的区域提高性能。

  4. 两个插件模块

    • 自适应卷积核(AdaKern) :该模块通过将卷积权重分解为低频和高频部分,来增强带宽和感受野大小的调节能力。
    • 第二个模块(在文中未完全描述)进一步增强了模型对不同频率成分的适应能力。

这些方法的目标是在 语义分割 任务中,通过让卷积网络能够更好地捕捉图像特征的局部变化,来提升性能。


三、核心代码

YOLOv11的检测头使用了DWConv,这个论文提出了一种频率选择深度卷积整好可以替换,使用方式看章节四!

  1. import copy
  2. import math
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.fft
  7. from scipy.spatial import distance
  8. from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d, modulated_deform_conv2d
  9. import torch_dct as dct
  10. from ultralytics.utils.tal import dist2bbox, make_anchors
  11. class OmniAttention(nn.Module):
  12. def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
  13. super(OmniAttention, self).__init__()
  14. attention_channel = max(int(in_planes * reduction), min_channel)
  15. self.kernel_size = kernel_size
  16. self.kernel_num = kernel_num
  17. self.temperature = 1.0
  18. self.avgpool = nn.AdaptiveAvgPool2d(1)
  19. self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)
  20. self.bn = nn.BatchNorm2d(attention_channel)
  21. self.relu = nn.ReLU(inplace=True)
  22. self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
  23. self.func_channel = self.get_channel_attention
  24. if in_planes == groups and in_planes == out_planes: # depth-wise convolution
  25. self.func_filter = self.skip
  26. else:
  27. self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
  28. self.func_filter = self.get_filter_attention
  29. if kernel_size == 1: # point-wise convolution
  30. self.func_spatial = self.skip
  31. else:
  32. self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
  33. self.func_spatial = self.get_spatial_attention
  34. if kernel_num == 1:
  35. self.func_kernel = self.skip
  36. else:
  37. self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
  38. self.func_kernel = self.get_kernel_attention
  39. self._initialize_weights()
  40. def _initialize_weights(self):
  41. for m in self.modules():
  42. if isinstance(m, nn.Conv2d):
  43. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  44. if m.bias is not None:
  45. nn.init.constant_(m.bias, 0)
  46. if isinstance(m, nn.BatchNorm2d):
  47. nn.init.constant_(m.weight, 1)
  48. nn.init.constant_(m.bias, 0)
  49. def update_temperature(self, temperature):
  50. self.temperature = temperature
  51. @staticmethod
  52. def skip(_):
  53. return 1.0
  54. def get_channel_attention(self, x):
  55. channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
  56. return channel_attention
  57. def get_filter_attention(self, x):
  58. filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
  59. return filter_attention
  60. def get_spatial_attention(self, x):
  61. spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
  62. spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
  63. return spatial_attention
  64. def get_kernel_attention(self, x):
  65. kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
  66. kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
  67. return kernel_attention
  68. def forward(self, x):
  69. x = self.avgpool(x)
  70. x = self.fc(x)
  71. if x.shape[3] == 1:
  72. pass
  73. else:
  74. x = self.bn(x)
  75. x = self.relu(x)
  76. return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)
  77. import torch.nn.functional as F
  78. def generate_laplacian_pyramid(input_tensor, num_levels, size_align=True, mode='bilinear'):
  79. pyramid = []
  80. current_tensor = input_tensor
  81. _, _, H, W = current_tensor.shape
  82. for _ in range(num_levels):
  83. b, _, h, w = current_tensor.shape
  84. downsampled_tensor = F.interpolate(current_tensor, (h // 2 + h % 2, w // 2 + w % 2), mode=mode,
  85. align_corners=(H % 2) == 1) # antialias=True
  86. if size_align:
  87. # upsampled_tensor = F.interpolate(downsampled_tensor, (h, w), mode='bilinear', align_corners=(H%2) == 1)
  88. # laplacian = current_tensor - upsampled_tensor
  89. # laplacian = F.interpolate(laplacian, (H, W), mode='bilinear', align_corners=(H%2) == 1)
  90. upsampled_tensor = F.interpolate(downsampled_tensor, (H, W), mode=mode, align_corners=(H % 2) == 1)
  91. laplacian = F.interpolate(current_tensor, (H, W), mode=mode, align_corners=(H % 2) == 1) - upsampled_tensor
  92. # print(laplacian.shape)
  93. else:
  94. upsampled_tensor = F.interpolate(downsampled_tensor, (h, w), mode=mode, align_corners=(H % 2) == 1)
  95. laplacian = current_tensor - upsampled_tensor
  96. pyramid.append(laplacian)
  97. current_tensor = downsampled_tensor
  98. if size_align: current_tensor = F.interpolate(current_tensor, (H, W), mode=mode, align_corners=(H % 2) == 1)
  99. pyramid.append(current_tensor)
  100. return pyramid
  101. class FrequencySelection(nn.Module):
  102. def __init__(self,
  103. in_channels,
  104. k_list=[2],
  105. # freq_list=[2, 3, 5, 7, 9, 11],
  106. lowfreq_att=True,
  107. fs_feat='feat',
  108. lp_type='freq',
  109. act='sigmoid',
  110. spatial='conv',
  111. spatial_group=1,
  112. spatial_kernel=3,
  113. init='zero',
  114. global_selection=False,
  115. ):
  116. super().__init__()
  117. # k_list.sort()
  118. # print()
  119. self.k_list = k_list
  120. # self.freq_list = freq_list
  121. self.lp_list = nn.ModuleList()
  122. self.freq_weight_conv_list = nn.ModuleList()
  123. self.fs_feat = fs_feat
  124. self.lp_type = lp_type
  125. self.in_channels = in_channels
  126. # self.residual = residual
  127. if spatial_group > 64: spatial_group = in_channels
  128. self.spatial_group = spatial_group
  129. self.lowfreq_att = lowfreq_att
  130. if spatial == 'conv':
  131. self.freq_weight_conv_list = nn.ModuleList()
  132. _n = len(k_list)
  133. if lowfreq_att: _n += 1
  134. for i in range(_n):
  135. freq_weight_conv = nn.Conv2d(in_channels=in_channels,
  136. out_channels=self.spatial_group,
  137. stride=1,
  138. kernel_size=spatial_kernel,
  139. groups=self.spatial_group,
  140. padding=spatial_kernel // 2,
  141. bias=True)
  142. if init == 'zero':
  143. freq_weight_conv.weight.data.zero_()
  144. freq_weight_conv.bias.data.zero_()
  145. else:
  146. # raise NotImplementedError
  147. pass
  148. self.freq_weight_conv_list.append(freq_weight_conv)
  149. else:
  150. raise NotImplementedError
  151. if self.lp_type == 'avgpool':
  152. for k in k_list:
  153. self.lp_list.append(nn.Sequential(
  154. nn.ReplicationPad2d(padding=k // 2),
  155. # nn.ZeroPad2d(padding= k // 2),
  156. nn.AvgPool2d(kernel_size=k, padding=0, stride=1)
  157. ))
  158. elif self.lp_type == 'laplacian':
  159. pass
  160. elif self.lp_type == 'freq':
  161. pass
  162. else:
  163. raise NotImplementedError
  164. self.act = act
  165. # self.freq_weight_conv_list.append(nn.Conv2d(self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 1, kernel_size=1, padding=0, bias=True))
  166. self.global_selection = global_selection
  167. if self.global_selection:
  168. self.global_selection_conv_real = nn.Conv2d(in_channels=in_channels,
  169. out_channels=self.spatial_group,
  170. stride=1,
  171. kernel_size=1,
  172. groups=self.spatial_group,
  173. padding=0,
  174. bias=True)
  175. self.global_selection_conv_imag = nn.Conv2d(in_channels=in_channels,
  176. out_channels=self.spatial_group,
  177. stride=1,
  178. kernel_size=1,
  179. groups=self.spatial_group,
  180. padding=0,
  181. bias=True)
  182. if init == 'zero':
  183. self.global_selection_conv_real.weight.data.zero_()
  184. self.global_selection_conv_real.bias.data.zero_()
  185. self.global_selection_conv_imag.weight.data.zero_()
  186. self.global_selection_conv_imag.bias.data.zero_()
  187. def sp_act(self, freq_weight):
  188. if self.act == 'sigmoid':
  189. freq_weight = freq_weight.sigmoid() * 2
  190. elif self.act == 'softmax':
  191. freq_weight = freq_weight.softmax(dim=1) * freq_weight.shape[1]
  192. else:
  193. raise NotImplementedError
  194. return freq_weight
  195. def forward(self, x, att_feat=None):
  196. """
  197. att_feat:feat for gen att
  198. """
  199. # freq_weight = self.freq_weight_conv(x)
  200. # self.sp_act(freq_weight)
  201. # if self.residual: x_residual = x.clone()
  202. if att_feat is None: att_feat = x
  203. x_list = []
  204. if self.lp_type == 'avgpool':
  205. # for avg, freq_weight in zip(self.avg_list, self.freq_weight_conv_list):
  206. pre_x = x
  207. b, _, h, w = x.shape
  208. for idx, avg in enumerate(self.lp_list):
  209. low_part = avg(x)
  210. high_part = pre_x - low_part
  211. pre_x = low_part
  212. # x_list.append(freq_weight[:, idx:idx+1] * high_part)
  213. freq_weight = self.freq_weight_conv_list[idx](att_feat)
  214. freq_weight = self.sp_act(freq_weight)
  215. # tmp = freq_weight[:, :, idx:idx+1] * high_part.reshape(b, self.spatial_group, -1, h, w)
  216. tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * high_part.reshape(b, self.spatial_group,
  217. -1, h, w)
  218. x_list.append(tmp.reshape(b, -1, h, w))
  219. if self.lowfreq_att:
  220. freq_weight = self.freq_weight_conv_list[len(x_list)](att_feat)
  221. # tmp = freq_weight[:, :, len(x_list):len(x_list)+1] * pre_x.reshape(b, self.spatial_group, -1, h, w)
  222. tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * pre_x.reshape(b, self.spatial_group, -1, h,
  223. w)
  224. x_list.append(tmp.reshape(b, -1, h, w))
  225. else:
  226. x_list.append(pre_x)
  227. elif self.lp_type == 'laplacian':
  228. # for avg, freq_weight in zip(self.avg_list, self.freq_weight_conv_list):
  229. # pre_x = x
  230. b, _, h, w = x.shape
  231. pyramids = generate_laplacian_pyramid(x, len(self.k_list), size_align=True)
  232. # print('pyramids', len(pyramids))
  233. for idx, avg in enumerate(self.k_list):
  234. # print(idx)
  235. high_part = pyramids[idx]
  236. freq_weight = self.freq_weight_conv_list[idx](att_feat)
  237. freq_weight = self.sp_act(freq_weight)
  238. # tmp = freq_weight[:, :, idx:idx+1] * high_part.reshape(b, self.spatial_group, -1, h, w)
  239. tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * high_part.reshape(b, self.spatial_group,
  240. -1, h, w)
  241. x_list.append(tmp.reshape(b, -1, h, w))
  242. if self.lowfreq_att:
  243. freq_weight = self.freq_weight_conv_list[len(x_list)](att_feat)
  244. # tmp = freq_weight[:, :, len(x_list):len(x_list)+1] * pre_x.reshape(b, self.spatial_group, -1, h, w)
  245. tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * pyramids[-1].reshape(b, self.spatial_group,
  246. -1, h, w)
  247. x_list.append(tmp.reshape(b, -1, h, w))
  248. else:
  249. x_list.append(pyramids[-1])
  250. elif self.lp_type == 'freq':
  251. pre_x = x.clone()
  252. b, _, h, w = x.shape
  253. # b, _c, h, w = freq_weight.shape
  254. # freq_weight = freq_weight.reshape(b, self.spatial_group, -1, h, w)
  255. x_fft = torch.fft.fftshift(torch.fft.fft2(x, norm='ortho'))
  256. if self.global_selection:
  257. # global_att_real = self.global_selection_conv_real(x_fft.real)
  258. # global_att_real = self.sp_act(global_att_real).reshape(b, self.spatial_group, -1, h, w)
  259. # global_att_imag = self.global_selection_conv_imag(x_fft.imag)
  260. # global_att_imag = self.sp_act(global_att_imag).reshape(b, self.spatial_group, -1, h, w)
  261. # x_fft = x_fft.reshape(b, self.spatial_group, -1, h, w)
  262. # x_fft.real *= global_att_real
  263. # x_fft.imag *= global_att_imag
  264. # x_fft = x_fft.reshape(b, -1, h, w)
  265. # 将x_fft复数拆分成实部和虚部
  266. x_real = x_fft.real
  267. x_imag = x_fft.imag
  268. # 计算实部的全局注意力
  269. global_att_real = self.global_selection_conv_real(x_real)
  270. global_att_real = self.sp_act(global_att_real).reshape(b, self.spatial_group, -1, h, w)
  271. # 计算虚部的全局注意力
  272. global_att_imag = self.global_selection_conv_imag(x_imag)
  273. global_att_imag = self.sp_act(global_att_imag).reshape(b, self.spatial_group, -1, h, w)
  274. # 重塑x_fft为形状为(b, self.spatial_group, -1, h, w)的张量
  275. x_real = x_real.reshape(b, self.spatial_group, -1, h, w)
  276. x_imag = x_imag.reshape(b, self.spatial_group, -1, h, w)
  277. # 分别应用实部和虚部的全局注意力
  278. x_fft_real_updated = x_real * global_att_real
  279. x_fft_imag_updated = x_imag * global_att_imag
  280. # 合并为复数
  281. x_fft_updated = torch.complex(x_fft_real_updated, x_fft_imag_updated)
  282. # 重塑x_fft为形状为(b, -1, h, w)的张量
  283. x_fft = x_fft_updated.reshape(b, -1, h, w)
  284. for idx, freq in enumerate(self.k_list):
  285. mask = torch.zeros_like(x[:, 0:1, :, :], device=x.device)
  286. mask[:, :, round(h / 2 - h / (2 * freq)):round(h / 2 + h / (2 * freq)),
  287. round(w / 2 - w / (2 * freq)):round(w / 2 + w / (2 * freq))] = 1.0
  288. low_part = torch.fft.ifft2(torch.fft.ifftshift(x_fft * mask), norm='ortho').real
  289. high_part = pre_x - low_part
  290. pre_x = low_part
  291. freq_weight = self.freq_weight_conv_list[idx](att_feat)
  292. freq_weight = self.sp_act(freq_weight)
  293. # tmp = freq_weight[:, :, idx:idx+1] * high_part.reshape(b, self.spatial_group, -1, h, w)
  294. tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * high_part.reshape(b, self.spatial_group,
  295. -1, h, w)
  296. x_list.append(tmp.reshape(b, -1, h, w))
  297. if self.lowfreq_att:
  298. freq_weight = self.freq_weight_conv_list[len(x_list)](att_feat)
  299. # tmp = freq_weight[:, :, len(x_list):len(x_list)+1] * pre_x.reshape(b, self.spatial_group, -1, h, w)
  300. tmp = freq_weight.reshape(b, self.spatial_group, -1, h, w) * pre_x.reshape(b, self.spatial_group, -1, h,
  301. w)
  302. x_list.append(tmp.reshape(b, -1, h, w))
  303. else:
  304. x_list.append(pre_x)
  305. x = sum(x_list)
  306. return x
  307. class AdaptiveDilatedConv(ModulatedDeformConv2d):
  308. """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
  309. layers.
  310. Args:
  311. in_channels (int): Same as nn.Conv2d.
  312. out_channels (int): Same as nn.Conv2d.
  313. kernel_size (int or tuple[int]): Same as nn.Conv2d.
  314. stride (int): Same as nn.Conv2d, while tuple is not supported.
  315. padding (int): Same as nn.Conv2d, while tuple is not supported.
  316. dilation (int): Same as nn.Conv2d, while tuple is not supported.
  317. groups (int): Same as nn.Conv2d.
  318. bias (bool or str): If specified as `auto`, it will be decided by the
  319. norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
  320. False.
  321. """
  322. _version = 2
  323. def __init__(self, *args,
  324. offset_freq=None, # deprecated
  325. padding_mode='repeat',
  326. kernel_decompose='both',
  327. conv_type='conv',
  328. sp_att=False,
  329. pre_fs=True, # False, use dilation
  330. epsilon=1e-4,
  331. use_zero_dilation=False,
  332. use_dct=False,
  333. fs_cfg={
  334. 'k_list': [2, 4, 8],
  335. 'fs_feat': 'feat',
  336. 'lowfreq_att': False,
  337. 'lp_type': 'freq',
  338. # 'lp_type':'laplacian',
  339. 'act': 'sigmoid',
  340. 'spatial': 'conv',
  341. 'spatial_group': 1,
  342. },
  343. **kwargs):
  344. super().__init__(*args, **kwargs)
  345. if padding_mode == 'zero':
  346. self.PAD = nn.ZeroPad2d(self.kernel_size[0] // 2)
  347. elif padding_mode == 'repeat':
  348. self.PAD = nn.ReplicationPad2d(self.kernel_size[0] // 2)
  349. else:
  350. self.PAD = nn.Identity()
  351. self.kernel_decompose = kernel_decompose
  352. self.use_dct = use_dct
  353. if kernel_decompose == 'both':
  354. self.OMNI_ATT1 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
  355. groups=1, reduction=0.0625, kernel_num=1, min_channel=16)
  356. self.OMNI_ATT2 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels,
  357. kernel_size=self.kernel_size[0] if self.use_dct else 1, groups=1,
  358. reduction=0.0625, kernel_num=1, min_channel=16)
  359. elif kernel_decompose == 'high':
  360. self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
  361. groups=1, reduction=0.0625, kernel_num=1, min_channel=16)
  362. elif kernel_decompose == 'low':
  363. self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
  364. groups=1, reduction=0.0625, kernel_num=1, min_channel=16)
  365. self.conv_type = conv_type
  366. if conv_type == 'conv':
  367. self.conv_offset = nn.Conv2d(
  368. self.in_channels,
  369. self.deform_groups * 1,
  370. kernel_size=self.kernel_size,
  371. stride=self.stride,
  372. padding=self.kernel_size[0] // 2 if isinstance(self.PAD, nn.Identity) else 0,
  373. dilation=1,
  374. bias=True)
  375. else:
  376. raise NotImplementedError
  377. pass
  378. # self.conv_offset_low = nn.Sequential(
  379. # nn.AvgPool2d(
  380. # kernel_size=self.kernel_size,
  381. # stride=self.stride,
  382. # padding=1,
  383. # ),
  384. # nn.Conv2d(
  385. # self.in_channels,
  386. # self.deform_groups * 1,
  387. # kernel_size=1,
  388. # stride=1,
  389. # padding=0,
  390. # dilation=1,
  391. # bias=False),
  392. # )
  393. # self.conv_offset_high = nn.Sequential(
  394. # LHPFConv3(channels=self.in_channels, stride=1, padding=1, residual=False),
  395. # nn.Conv2d(
  396. # self.in_channels,
  397. # self.deform_groups * 1,
  398. # kernel_size=1,
  399. # stride=1,
  400. # padding=0,
  401. # dilation=1,
  402. # bias=True),
  403. # )
  404. self.conv_mask = nn.Conv2d(
  405. self.in_channels,
  406. self.deform_groups * 1 * self.kernel_size[0] * self.kernel_size[1],
  407. kernel_size=self.kernel_size,
  408. stride=self.stride,
  409. padding=self.kernel_size[0] // 2 if isinstance(self.PAD, nn.Identity) else 0,
  410. dilation=1,
  411. bias=True)
  412. if sp_att:
  413. self.conv_mask_mean_level = nn.Conv2d(
  414. self.in_channels,
  415. self.deform_groups * 1,
  416. kernel_size=self.kernel_size,
  417. stride=self.stride,
  418. padding=self.kernel_size[0] // 2 if isinstance(self.PAD, nn.Identity) else 0,
  419. dilation=1,
  420. bias=True)
  421. self.offset_freq = offset_freq
  422. # An offset is like [y0, x0, y1, x1, y2, x2, ⋯, y8, x8]
  423. offset = [-1, -1, -1, 0, -1, 1,
  424. 0, -1, 0, 0, 0, 1,
  425. 1, -1, 1, 0, 1, 1]
  426. offset = torch.Tensor(offset)
  427. # offset[0::2] *= self.dilation[0]
  428. # offset[1::2] *= self.dilation[1]
  429. # a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
  430. self.register_buffer('dilated_offset', torch.Tensor(offset[None, None, ..., None, None])) # B, G, 18, 1, 1
  431. if fs_cfg is not None:
  432. if pre_fs:
  433. self.FS = FrequencySelection(self.in_channels, **fs_cfg)
  434. else:
  435. self.FS = FrequencySelection(1, **fs_cfg) # use dilation
  436. self.pre_fs = pre_fs
  437. self.epsilon = epsilon
  438. self.use_zero_dilation = use_zero_dilation
  439. self.init_weights()
  440. def freq_select(self, x):
  441. if self.offset_freq is None:
  442. res = x
  443. elif self.offset_freq in ('FLC_high', 'SLP_high'):
  444. res = x - self.LP(x)
  445. elif self.offset_freq in ('FLC_res', 'SLP_res'):
  446. res = 2 * x - self.LP(x)
  447. else:
  448. raise NotImplementedError
  449. return res
  450. def init_weights(self):
  451. super().init_weights()
  452. if hasattr(self, 'conv_offset'):
  453. # if isinstanace(self.conv_offset, nn.Conv2d):
  454. if self.conv_type == 'conv':
  455. self.conv_offset.weight.data.zero_()
  456. # self.conv_offset.bias.data.fill_((self.dilation[0] - 1) / self.dilation[0] + 1e-4)
  457. self.conv_offset.bias.data.fill_((self.dilation[0] - 1) / self.dilation[0] + self.epsilon)
  458. # self.conv_offset.bias.data.zero_()
  459. if hasattr(self, 'conv_mask'):
  460. self.conv_mask.weight.data.zero_()
  461. self.conv_mask.bias.data.zero_()
  462. if hasattr(self, 'conv_mask_mean_level'):
  463. self.conv_mask.weight.data.zero_()
  464. self.conv_mask.bias.data.zero_()
  465. # @force_fp32(apply_to=('x',))
  466. # @force_fp32
  467. def forward(self, x):
  468. # offset = self.conv_offset(self.freq_select(x)) + self.conv_offset_low(self.freq_select(x))
  469. if hasattr(self, 'FS') and self.pre_fs: x = self.FS(x)
  470. if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
  471. c_att1, f_att1, _, _, = self.OMNI_ATT1(x)
  472. c_att2, f_att2, spatial_att2, _, = self.OMNI_ATT2(x)
  473. elif hasattr(self, 'OMNI_ATT'):
  474. c_att, f_att, _, _, = self.OMNI_ATT(x)
  475. if self.conv_type == 'conv':
  476. offset = self.conv_offset(self.PAD(self.freq_select(x)))
  477. elif self.conv_type == 'multifreqband':
  478. offset = self.conv_offset(self.freq_select(x))
  479. # high_gate = self.conv_offset_high(x)
  480. # high_gate = torch.exp(-0.5 * high_gate ** 2)
  481. # offset = F.relu(offset, inplace=True) * self.dilation[0] - 1 # ensure > 0
  482. if self.use_zero_dilation:
  483. offset = (F.relu(offset + 1, inplace=True) - 1) * self.dilation[0] # ensure > 0
  484. else:
  485. # offset = F.relu(offset, inplace=True) * self.dilation[0] # ensure > 0
  486. offset = offset.abs() * self.dilation[0] # ensure > 0
  487. # offset[offset<0] = offset[offset<0].exp() - 1
  488. # print(offset.mean(), offset.std(), offset.max(), offset.min())
  489. if hasattr(self, 'FS') and (self.pre_fs == False): x = self.FS(x, F.interpolate(offset, x.shape[-2:],
  490. mode='bilinear', align_corners=(
  491. x.shape[
  492. -1] % 2) == 1))
  493. # print(offset.max(), offset.abs().min(), offset.abs().mean())
  494. # offset *= high_gate # ensure > 0
  495. b, _, h, w = offset.shape
  496. offset = offset.reshape(b, self.deform_groups, -1, h, w) * self.dilated_offset
  497. # offset = offset.reshape(b, self.deform_groups, -1, h, w).repeat(1, 1, 9, 1, 1)
  498. # offset[:, :, 0::2, ] *= self.dilated_offset[:, :, 0::2, ]
  499. # offset[:, :, 1::2, ] *= self.dilated_offset[:, :, 1::2, ]
  500. offset = offset.reshape(b, -1, h, w)
  501. x = self.PAD(x)
  502. mask = self.conv_mask(x)
  503. mask = mask.sigmoid()
  504. # print(mask.shape)
  505. # mask = mask.reshape(b, self.deform_groups, -1, h, w).softmax(dim=2)
  506. if hasattr(self, 'conv_mask_mean_level'):
  507. mask_mean_level = torch.sigmoid(self.conv_mask_mean_level(x)).reshape(b, self.deform_groups, -1, h, w)
  508. mask = mask * mask_mean_level
  509. mask = mask.reshape(b, -1, h, w)
  510. if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
  511. offset = offset.reshape(1, -1, h, w)
  512. mask = mask.reshape(1, -1, h, w)
  513. x = x.reshape(1, -1, x.size(-2), x.size(-1))
  514. adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, c_out, c_in, k, k
  515. adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
  516. adaptive_weight_res = adaptive_weight - adaptive_weight_mean
  517. b, c_out, c_in, k, k = adaptive_weight.shape
  518. if self.use_dct:
  519. dct_coefficients = dct.dct_2d(adaptive_weight_res)
  520. # print(adaptive_weight_res.shape, dct_coefficients.shape)
  521. spatial_att2 = spatial_att2.reshape(b, 1, 1, k, k)
  522. dct_coefficients = dct_coefficients * (spatial_att2 * 2)
  523. # print(dct_coefficients.shape)
  524. adaptive_weight_res = dct.idct_2d(dct_coefficients)
  525. # adaptive_weight_res = adaptive_weight_res.reshape(b, c_out, c_in, k, k)
  526. # print(adaptive_weight_res.shape, dct_coefficients.shape)
  527. # adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(1)) * (2 * f_att.unsqueeze(2)) + adaptive_weight - adaptive_weight_mean
  528. # adaptive_weight = adaptive_weight_mean * (c_att1.unsqueeze(1) * 2) * (f_att1.unsqueeze(2) * 2) + (adaptive_weight - adaptive_weight_mean) * (c_att2.unsqueeze(1) * 2) * (f_att2.unsqueeze(2) * 2)
  529. adaptive_weight = adaptive_weight_mean * (c_att1.unsqueeze(1) * 2) * (
  530. f_att1.unsqueeze(2) * 2) + adaptive_weight_res * (c_att2.unsqueeze(1) * 2) * (
  531. f_att2.unsqueeze(2) * 2)
  532. adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
  533. if self.bias is not None:
  534. bias = self.bias.repeat(b)
  535. else:
  536. bias = self.bias
  537. x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
  538. self.stride,
  539. (self.kernel_size[0] // 2, self.kernel_size[1] // 2) if isinstance(self.PAD,
  540. nn.Identity) else (
  541. 0, 0), # padding
  542. (1, 1), # dilation
  543. self.groups * b, self.deform_groups * b)
  544. elif hasattr(self, 'OMNI_ATT'):
  545. offset = offset.reshape(1, -1, h, w)
  546. mask = mask.reshape(1, -1, h, w)
  547. x = x.reshape(1, -1, x.size(-2), x.size(-1))
  548. adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, c_out, c_in, k, k
  549. adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
  550. # adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(1)) * (2 * f_att.unsqueeze(2)) + adaptive_weight - adaptive_weight_mean
  551. if self.kernel_decompose == 'high':
  552. adaptive_weight = adaptive_weight_mean + (adaptive_weight - adaptive_weight_mean) * (
  553. c_att.unsqueeze(1) * 2) * (f_att.unsqueeze(2) * 2)
  554. elif self.kernel_decompose == 'low':
  555. adaptive_weight = adaptive_weight_mean * (c_att.unsqueeze(1) * 2) * (f_att.unsqueeze(2) * 2) + (
  556. adaptive_weight - adaptive_weight_mean)
  557. adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
  558. if self.bias is not None:
  559. bias = self.bias.repeat(b)
  560. else:
  561. bias = self.bias
  562. # adaptive_bias = self.unsqueeze(0).repeat(b, 1, 1, 1, 1)
  563. # print(adaptive_weight.shape)
  564. # print(offset.shape)
  565. # print(mask.shape)
  566. # print(x.shape)
  567. x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
  568. self.stride,
  569. (self.kernel_size[0] // 2, self.kernel_size[1] // 2) if isinstance(self.PAD,
  570. nn.Identity) else (
  571. 0, 0), # padding
  572. (1, 1), # dilation
  573. self.groups * b, self.deform_groups * b)
  574. else:
  575. x = modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
  576. self.stride,
  577. (self.kernel_size[0] // 2, self.kernel_size[1] // 2) if isinstance(self.PAD,
  578. nn.Identity) else (
  579. 0, 0), # padding
  580. (1, 1), # dilation
  581. self.groups, self.deform_groups)
  582. # x = modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
  583. # self.stride, self.padding,
  584. # self.dilation, self.groups,
  585. # self.deform_groups)
  586. # if hasattr(self, 'OMNI_ATT'): x = x * f_att
  587. return x.reshape(b, -1, h, w)
  588. class AdaptiveDilatedDWConv(ModulatedDeformConv2d):
  589. """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
  590. layers.
  591. Args:
  592. in_channels (int): Same as nn.Conv2d.
  593. out_channels (int): Same as nn.Conv2d.
  594. kernel_size (int or tuple[int]): Same as nn.Conv2d.
  595. stride (int): Same as nn.Conv2d, while tuple is not supported.
  596. padding (int): Same as nn.Conv2d, while tuple is not supported.
  597. dilation (int): Same as nn.Conv2d, while tuple is not supported.
  598. groups (int): Same as nn.Conv2d.
  599. bias (bool or str): If specified as `auto`, it will be decided by the
  600. norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
  601. False.
  602. """
  603. _version = 2
  604. def __init__(self, *args,
  605. offset_freq=None,
  606. use_BFM=False,
  607. kernel_decompose='both',
  608. padding_mode='repeat',
  609. # padding_mode='zero',
  610. normal_conv_dim=0,
  611. pre_fs=True, # False, use dilation
  612. fs_cfg={
  613. # 'k_list':[3,5,7,9],
  614. 'k_list': [2, 4, 8],
  615. 'fs_feat': 'feat',
  616. 'lowfreq_att': False,
  617. # 'lp_type':'freq_eca',
  618. # 'lp_type':'freq_channel_att',
  619. # 'lp_type':'freq',
  620. # 'lp_type':'avgpool',
  621. 'lp_type': 'freq',
  622. 'act': 'sigmoid',
  623. 'spatial': 'conv',
  624. 'spatial_group': 1,
  625. },
  626. **kwargs):
  627. super().__init__(*args, **kwargs)
  628. assert self.kernel_size[0] in (3, 7)
  629. assert self.groups == self.in_channels
  630. if kernel_decompose == 'both':
  631. self.OMNI_ATT1 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
  632. groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
  633. self.OMNI_ATT2 = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
  634. groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
  635. elif kernel_decompose == 'high':
  636. self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
  637. groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
  638. elif kernel_decompose == 'low':
  639. self.OMNI_ATT = OmniAttention(in_planes=self.in_channels, out_planes=self.out_channels, kernel_size=1,
  640. groups=self.in_channels, reduction=0.0625, kernel_num=1, min_channel=16)
  641. self.kernel_decompose = kernel_decompose
  642. self.normal_conv_dim = normal_conv_dim
  643. if padding_mode == 'zero':
  644. self.PAD = nn.ZeroPad2d(self.kernel_size[0] // 2)
  645. elif padding_mode == 'repeat':
  646. self.PAD = nn.ReplicationPad2d(self.kernel_size[0] // 2)
  647. else:
  648. self.PAD = nn.Identity()
  649. self.conv_offset = nn.Conv2d(
  650. self.in_channels - self.normal_conv_dim,
  651. self.deform_groups * 1,
  652. # self.groups * 1,
  653. kernel_size=self.kernel_size,
  654. stride=self.stride,
  655. padding=self.padding if isinstance(self.PAD, nn.Identity) else 0,
  656. dilation=1,
  657. bias=True)
  658. # self.conv_offset_low = nn.Sequential(
  659. # nn.AvgPool2d(
  660. # kernel_size=self.kernel_size,
  661. # stride=self.stride,
  662. # padding=1,
  663. # ),
  664. # nn.Conv2d(
  665. # self.in_channels,
  666. # self.deform_groups * 1,
  667. # kernel_size=1,
  668. # stride=1,
  669. # padding=0,
  670. # dilation=1,
  671. # bias=False),
  672. # )
  673. self.conv_mask = nn.Sequential(
  674. nn.Conv2d(
  675. self.in_channels - self.normal_conv_dim,
  676. self.in_channels - self.normal_conv_dim,
  677. kernel_size=self.kernel_size,
  678. stride=self.stride,
  679. padding=self.padding if isinstance(self.PAD, nn.Identity) else 0,
  680. groups=self.in_channels - self.normal_conv_dim,
  681. dilation=1,
  682. bias=False),
  683. nn.Conv2d(
  684. self.in_channels - self.normal_conv_dim,
  685. self.deform_groups * 1 * self.kernel_size[0] * self.kernel_size[1],
  686. kernel_size=1,
  687. stride=1,
  688. padding=0,
  689. groups=1,
  690. dilation=1,
  691. bias=True)
  692. )
  693. self.offset_freq = offset_freq
  694. # An offset is like [y0, x0, y1, x1, y2, x2, ⋯, y8, x8]
  695. if self.kernel_size[0] == 3:
  696. offset = [-1, -1, -1, 0, -1, 1,
  697. 0, -1, 0, 0, 0, 1,
  698. 1, -1, 1, 0, 1, 1]
  699. elif self.kernel_size[0] == 7:
  700. offset = [
  701. -3, -3, -3, -2, -3, -1, -3, 0, -3, 1, -3, 2, -3, 3,
  702. -2, -3, -2, -2, -2, -1, -2, 0, -2, 1, -2, 2, -2, 3,
  703. -1, -3, -1, -2, -1, -1, -1, 0, -1, 1, -1, 2, -1, 3,
  704. 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3,
  705. 1, -3, 1, -2, 1, -1, 1, 0, 1, 1, 1, 2, 1, 3,
  706. 2, -3, 2, -2, 2, -1, 2, 0, 2, 1, 2, 2, 2, 3,
  707. 3, -3, 3, -2, 3, -1, 3, 0, 3, 1, 3, 2, 3, 3,
  708. ]
  709. else:
  710. raise NotImplementedError
  711. offset = torch.Tensor(offset)
  712. # offset[0::2] *= self.dilation[0]
  713. # offset[1::2] *= self.dilation[1]
  714. # a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
  715. self.register_buffer('dilated_offset', torch.Tensor(offset[None, None, ..., None, None])) # B, G, 49, 1, 1
  716. self.init_weights()
  717. self.use_BFM = use_BFM
  718. if use_BFM:
  719. alpha = 8
  720. BFM = np.zeros((self.in_channels, 1, self.kernel_size[0], self.kernel_size[0]))
  721. for i in range(self.kernel_size[0]):
  722. for j in range(self.kernel_size[0]):
  723. point_1 = (i, j)
  724. point_2 = (self.kernel_size[0] // 2, self.kernel_size[0] // 2)
  725. dist = distance.euclidean(point_1, point_2)
  726. BFM[:, :, i, j] = alpha / (dist + alpha)
  727. self.register_buffer('BFM', torch.Tensor(BFM))
  728. if fs_cfg is not None:
  729. if pre_fs:
  730. self.FS = FrequencySelection(self.in_channels - self.normal_conv_dim, **fs_cfg)
  731. else:
  732. self.FS = FrequencySelection(1, **fs_cfg) # use dilation
  733. self.pre_fs = pre_fs
  734. def freq_select(self, x):
  735. if self.offset_freq is None:
  736. pass
  737. elif self.offset_freq in ('FLC_high', 'SLP_high'):
  738. x - self.LP(x)
  739. elif self.offset_freq in ('FLC_res', 'SLP_res'):
  740. 2 * x - self.LP(x)
  741. else:
  742. raise NotImplementedError
  743. return x
  744. def init_weights(self):
  745. super().init_weights()
  746. if hasattr(self, 'conv_offset'):
  747. self.conv_offset.weight.data.zero_()
  748. self.conv_offset.bias.data.fill_((self.dilation[0] - 1) / self.dilation[0] + 1e-4)
  749. # self.conv_offset.bias.data.zero_()
  750. # if hasattr(self, 'conv_offset_low'):
  751. # self.conv_offset_low[1].weight.data.zero_()
  752. if hasattr(self, 'conv_mask'):
  753. self.conv_mask[1].weight.data.zero_()
  754. self.conv_mask[1].bias.data.zero_()
  755. def forward(self, x):
  756. if self.normal_conv_dim > 0:
  757. return self.mix_forward(x)
  758. else:
  759. return self.ad_forward(x)
  760. def ad_forward(self, x):
  761. if hasattr(self, 'FS') and self.pre_fs: x = self.FS(x)
  762. if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
  763. c_att1, _, _, _, = self.OMNI_ATT1(x)
  764. c_att2, _, _, _, = self.OMNI_ATT2(x)
  765. elif hasattr(self, 'OMNI_ATT'):
  766. c_att, _, _, _, = self.OMNI_ATT(x)
  767. x = self.PAD(x)
  768. offset = self.conv_offset(x)
  769. offset = F.relu(offset, inplace=True) * self.dilation[0] # ensure > 0
  770. if hasattr(self, 'FS') and (self.pre_fs == False): x = self.FS(x, offset)
  771. b, _, h, w = offset.shape
  772. offset = offset.reshape(b, self.deform_groups, -1, h, w) * self.dilated_offset
  773. offset = offset.reshape(b, -1, h, w)
  774. mask = self.conv_mask(x)
  775. mask = torch.sigmoid(mask)
  776. if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
  777. offset = offset.reshape(1, -1, h, w)
  778. # print(offset.max(), offset.min(), offset.mean())
  779. mask = mask.reshape(1, -1, h, w)
  780. x = x.reshape(1, -1, x.size(-2), x.size(-1))
  781. adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
  782. adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
  783. adaptive_weight = adaptive_weight_mean * (2 * c_att1.unsqueeze(2)) + (
  784. adaptive_weight - adaptive_weight_mean) * (2 * c_att2.unsqueeze(2))
  785. adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
  786. if self.bias is not None:
  787. bias = self.bias.repeat(b)
  788. else:
  789. bias = self.bias
  790. x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
  791. self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
  792. # padding
  793. (1, 1), # dilation
  794. self.groups * b, self.deform_groups * b)
  795. return x.reshape(b, -1, h, w)
  796. elif hasattr(self, 'OMNI_ATT'):
  797. offset = offset.reshape(1, -1, h, w)
  798. mask = mask.reshape(1, -1, h, w)
  799. x = x.reshape(1, -1, x.size(-2), x.size(-1))
  800. adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
  801. adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
  802. if self.kernel_decompose == 'high':
  803. adaptive_weight = adaptive_weight_mean + (adaptive_weight - adaptive_weight_mean) * (
  804. 2 * c_att.unsqueeze(2))
  805. elif self.kernel_decompose == 'low':
  806. adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(2)) + (
  807. adaptive_weight - adaptive_weight_mean)
  808. adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
  809. if self.bias is not None:
  810. bias = self.bias.repeat(b)
  811. else:
  812. bias = self.bias
  813. x = modulated_deform_conv2d(x, offset, mask, adaptive_weight, bias,
  814. self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
  815. # padding
  816. (1, 1), # dilation
  817. self.groups * b, self.deform_groups * b)
  818. return x.reshape(b, -1, h, w)
  819. else:
  820. return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
  821. self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
  822. # padding
  823. self.dilation, self.groups,
  824. self.deform_groups)
  825. def mix_forward(self, x):
  826. if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
  827. c_att1, _, _, _, = self.OMNI_ATT1(x)
  828. c_att2, _, _, _, = self.OMNI_ATT2(x)
  829. elif hasattr(self, 'OMNI_ATT'):
  830. c_att, _, _, _, = self.OMNI_ATT(x)
  831. ori_x = x
  832. normal_conv_x = ori_x[:, -self.normal_conv_dim:] # ad:normal
  833. x = ori_x[:, :-self.normal_conv_dim]
  834. if hasattr(self, 'FS') and self.pre_fs: x = self.FS(x)
  835. x = self.PAD(x)
  836. offset = self.conv_offset(x)
  837. if hasattr(self, 'FS') and (self.pre_fs == False): x = self.FS(x, F.interpolate(offset, x.shape[-2:],
  838. mode='bilinear', align_corners=(
  839. x.shape[
  840. -1] % 2) == 1))
  841. # if hasattr(self, 'FS') and (self.pre_fs==False): x = self.FS(x, offset)
  842. # offset = F.relu(offset, inplace=True) * self.dilation[0] # ensure > 0
  843. offset[offset < 0] = offset[offset < 0].exp() - 1
  844. b, _, h, w = offset.shape
  845. offset = offset.reshape(b, self.deform_groups, -1, h, w) * self.dilated_offset
  846. offset = offset.reshape(b, -1, h, w)
  847. mask = self.conv_mask(x)
  848. mask = torch.sigmoid(mask)
  849. if hasattr(self, 'OMNI_ATT1') and hasattr(self, 'OMNI_ATT2'):
  850. offset = offset.reshape(1, -1, h, w)
  851. # print(offset.max(), offset.min(), offset.mean())
  852. mask = mask.reshape(1, -1, h, w)
  853. x = x.reshape(1, -1, x.size(-2), x.size(-1))
  854. adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
  855. adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
  856. adaptive_weight = adaptive_weight_mean * (2 * c_att1.unsqueeze(2)) + (
  857. adaptive_weight - adaptive_weight_mean) * (2 * c_att2.unsqueeze(2))
  858. if self.bias is not None:
  859. bias = self.bias.repeat(b)
  860. else:
  861. bias = self.bias
  862. # adaptive_weight = adaptive_weight.reshape(-1, self.in_channels // self.groups, 3, 3)
  863. x = modulated_deform_conv2d(x, offset, mask, adaptive_weight[:, :-self.normal_conv_dim].reshape(-1,
  864. self.in_channels // self.groups,
  865. self.kernel_size[
  866. 0],
  867. self.kernel_size[
  868. 1]),
  869. bias,
  870. self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
  871. # padding
  872. (1, 1), # dilation
  873. (self.in_channels - self.normal_conv_dim) * b, self.deform_groups * b)
  874. x = x.reshape(b, -1, h, w)
  875. normal_conv_x = F.conv2d(normal_conv_x.reshape(1, -1, h, w),
  876. adaptive_weight[:, -self.normal_conv_dim:].reshape(-1,
  877. self.in_channels // self.groups,
  878. self.kernel_size[0],
  879. self.kernel_size[1]),
  880. bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation,
  881. groups=self.normal_conv_dim * b)
  882. normal_conv_x = normal_conv_x.reshape(b, -1, h, w)
  883. # return torch.cat([normal_conv_x, x], dim=1)
  884. return torch.cat([x, normal_conv_x], dim=1)
  885. elif hasattr(self, 'OMNI_ATT'):
  886. offset = offset.reshape(1, -1, h, w)
  887. mask = mask.reshape(1, -1, h, w)
  888. x = x.reshape(1, -1, x.size(-2), x.size(-1))
  889. adaptive_weight = self.weight.unsqueeze(0).repeat(b, 1, 1, 1, 1) # b, out, in, k, k
  890. adaptive_weight_mean = adaptive_weight.mean(dim=(-1, -2), keepdim=True)
  891. if self.bias is not None:
  892. bias = self.bias.repeat(b)
  893. else:
  894. bias = self.bias
  895. if self.kernel_decompose == 'high':
  896. adaptive_weight = adaptive_weight_mean + (adaptive_weight - adaptive_weight_mean) * (
  897. 2 * c_att.unsqueeze(2))
  898. elif self.kernel_decompose == 'low':
  899. adaptive_weight = adaptive_weight_mean * (2 * c_att.unsqueeze(2)) + (
  900. adaptive_weight - adaptive_weight_mean)
  901. x = modulated_deform_conv2d(x, offset, mask, adaptive_weight[:, :-self.normal_conv_dim].reshape(-1,
  902. self.in_channels // self.groups,
  903. self.kernel_size[
  904. 0],
  905. self.kernel_size[
  906. 1]),
  907. bias,
  908. self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
  909. # padding
  910. (1, 1), # dilation
  911. (self.in_channels - self.normal_conv_dim) * b, self.deform_groups * b)
  912. x = x.reshape(b, -1, h, w)
  913. normal_conv_x = F.conv2d(normal_conv_x.reshape(1, -1, h, w),
  914. adaptive_weight[:, -self.normal_conv_dim:].reshape(-1,
  915. self.in_channels // self.groups,
  916. self.kernel_size[0],
  917. self.kernel_size[1]),
  918. bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation,
  919. groups=self.normal_conv_dim * b)
  920. normal_conv_x = normal_conv_x.reshape(b, -1, h, w)
  921. # return torch.cat([normal_conv_x, x], dim=1)
  922. return torch.cat([x, normal_conv_x], dim=1)
  923. else:
  924. return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
  925. self.stride, self.padding if isinstance(self.PAD, nn.Identity) else 0,
  926. # padding
  927. self.dilation, self.groups,
  928. self.deform_groups)
  929. def autopad(k, p=None, d=1): # kernel, padding, dilation
  930. """Pad to 'same' shape outputs."""
  931. if d > 1:
  932. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  933. if p is None:
  934. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  935. return p
  936. class Conv(nn.Module):
  937. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  938. default_act = nn.SiLU() # default activation
  939. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  940. """Initialize Conv layer with given arguments including activation."""
  941. super().__init__()
  942. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  943. self.bn = nn.BatchNorm2d(c2)
  944. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  945. def forward(self, x):
  946. """Apply convolution, batch normalization and activation to input tensor."""
  947. return self.act(self.bn(self.conv(x)))
  948. def forward_fuse(self, x):
  949. """Perform transposed convolution of 2D data."""
  950. return self.act(self.conv(x))
  951. class DFL(nn.Module):
  952. """
  953. Integral module of Distribution Focal Loss (DFL).
  954. Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
  955. """
  956. def __init__(self, c1=16):
  957. """Initialize a convolutional layer with a given number of input channels."""
  958. super().__init__()
  959. self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
  960. x = torch.arange(c1, dtype=torch.float)
  961. self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
  962. self.c1 = c1
  963. def forward(self, x):
  964. """Applies a transformer layer on input tensor 'x' and returns a tensor."""
  965. b, _, a = x.shape # batch, channels, anchors
  966. return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
  967. # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
  968. class ADDWConvHead(nn.Module):
  969. """YOLOv8 Detect head for detection models."""
  970. dynamic = False # force grid reconstruction
  971. export = False # export mode
  972. end2end = False # end2end
  973. max_det = 300 # max_det
  974. shape = None
  975. anchors = torch.empty(0) # init
  976. strides = torch.empty(0) # init
  977. def __init__(self, nc=80, ch=()):
  978. """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
  979. super().__init__()
  980. self.nc = nc # number of classes
  981. self.nl = len(ch) # number of detection layers
  982. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  983. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  984. self.stride = torch.zeros(self.nl) # strides computed during build
  985. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
  986. self.cv2 = nn.ModuleList(
  987. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
  988. )
  989. self.cv3 = nn.ModuleList(
  990. nn.Sequential(
  991. nn.Sequential(AdaptiveDilatedDWConv(x, x, groups=math.gcd(x, x), kernel_size=3, stride=1, dilation=1), Conv(x, c3, 1)),
  992. nn.Sequential(AdaptiveDilatedDWConv(c3, c3, groups=math.gcd(c3, c3), kernel_size=3, stride=1, dilation=1), Conv(c3, c3, 1)),
  993. nn.Conv2d(c3, self.nc, 1),
  994. )
  995. for x in ch
  996. )
  997. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  998. if self.end2end:
  999. self.one2one_cv2 = copy.deepcopy(self.cv2)
  1000. self.one2one_cv3 = copy.deepcopy(self.cv3)
  1001. def forward(self, x):
  1002. """Concatenates and returns predicted bounding boxes and class probabilities."""
  1003. if self.end2end:
  1004. return self.forward_end2end(x)
  1005. for i in range(self.nl):
  1006. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  1007. if self.training: # Training path
  1008. return x
  1009. y = self._inference(x)
  1010. return y if self.export else (y, x)
  1011. def forward_end2end(self, x):
  1012. """
  1013. Performs forward pass of the v10Detect module.
  1014. Args:
  1015. x (tensor): Input tensor.
  1016. Returns:
  1017. (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
  1018. If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
  1019. """
  1020. x_detach = [xi.detach() for xi in x]
  1021. one2one = [
  1022. torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
  1023. ]
  1024. for i in range(self.nl):
  1025. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  1026. if self.training: # Training path
  1027. return {"one2many": x, "one2one": one2one}
  1028. y = self._inference(one2one)
  1029. y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
  1030. return y if self.export else (y, {"one2many": x, "one2one": one2one})
  1031. def _inference(self, x):
  1032. """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
  1033. # Inference path
  1034. shape = x[0].shape # BCHW
  1035. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  1036. if self.dynamic or self.shape != shape:
  1037. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  1038. self.shape = shape
  1039. if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
  1040. box = x_cat[:, : self.reg_max * 4]
  1041. cls = x_cat[:, self.reg_max * 4 :]
  1042. else:
  1043. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  1044. if self.export and self.format in {"tflite", "edgetpu"}:
  1045. # Precompute normalization factor to increase numerical stability
  1046. # See https://github.com/ultralytics/ultralytics/issues/7371
  1047. grid_h = shape[2]
  1048. grid_w = shape[3]
  1049. grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
  1050. norm = self.strides / (self.stride[0] * grid_size)
  1051. dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
  1052. else:
  1053. dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
  1054. return torch.cat((dbox, cls.sigmoid()), 1)
  1055. def bias_init(self):
  1056. """Initialize Detect() biases, WARNING: requires stride availability."""
  1057. m = self # self.model[-1] # Detect() module
  1058. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  1059. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  1060. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  1061. a[-1].bias.data[:] = 1.0 # box
  1062. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  1063. if self.end2end:
  1064. for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
  1065. a[-1].bias.data[:] = 1.0 # box
  1066. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  1067. def decode_bboxes(self, bboxes, anchors):
  1068. """Decode bounding boxes."""
  1069. return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
  1070. @staticmethod
  1071. def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
  1072. """
  1073. Post-processes YOLO model predictions.
  1074. Args:
  1075. preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
  1076. format [x, y, w, h, class_probs].
  1077. max_det (int): Maximum detections per image.
  1078. nc (int, optional): Number of classes. Default: 80.
  1079. Returns:
  1080. (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
  1081. dimension format [x, y, w, h, max_class_prob, class_index].
  1082. """
  1083. batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
  1084. boxes, scores = preds.split([4, nc], dim=-1)
  1085. index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
  1086. boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
  1087. scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
  1088. scores, index = scores.flatten(1).topk(min(max_det, anchors))
  1089. i = torch.arange(batch_size)[..., None] # batch indices
  1090. return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
  1091. if __name__ == "__main__":
  1092. # Generating Sample image
  1093. image1 = (1, 64, 160, 160)
  1094. image2 = (1, 128, 80, 80)
  1095. image3 = (1, 256, 40, 40)
  1096. image1 = torch.rand(image1)
  1097. image2 = torch.rand(image2)
  1098. image3 = torch.rand(image3)
  1099. image = [image1, image2, image3]
  1100. channel = (64, 128, 256)
  1101. # Model
  1102. mobilenet_v1 = ADDWConvHead(nc=80, ch=channel)
  1103. out = mobilenet_v1(image)
  1104. print(out)


四、添加教程

4.1 修改一

首先我们将上面的代码复制粘贴到'ultralytics/nn' 目录下新建一个py文件复制粘贴进去,具体名字自己来定.


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

​​​


4.4 修改四

第四步我门找到如下文件'ultralytics/nn/tasks.py,找到如下的代码进行将检测头添加进去,这里给大家推荐个快速搜索的方法用ctrl+f然后搜索Detect然后就能快速查找了。

​​​​


4.5 修改五

同理


4.6 修改六

同理


4.7 修改七

这里有一些不一样,我们需要加一行代码

  1. else:
  2. return 'detect'

为啥呢不一样,因为这里的m在代码执行过程中会将你的代码自动转换为小写,所以直接else方便一点,以后出现一些其它分割或者其它的教程的时候在提供其它的修改教程。

​​​​


4.8 修改八

同理.

​​​​


到此就修改完成了,大家可以复制下面的yaml文件运行。


五、FADDWConvHead检测头的yaml文件

此版本训练信息:YOLO11-FADDWConvHead summary: 446 layers, 2,667,632 parameters, 2,667,616 gradients, 6.6 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, Conv, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, Conv, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, FADDWConvHead, [nc]] # Detect(P3, P4, P5)


六、完美运行记录

最后提供一下完美运行的图片。

​​

​​


七、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~