学习资源站

YOLOv11改进-Conv篇-手把手教你添加动态蛇形卷积DynamicSnakeConvolution(辅助C3k2进行特征提取)

一、本文介绍

动态蛇形卷积的灵感来源于对管状结构的特殊性的观察和理解,在分割拓扑管状结构、血管和道路等类型的管状结构时,任务的复杂性增加,因为这些结构的局部结构可能非常细长和迂回,而整体形态也可能多变。
因此为了应对这个挑战,作者研究团队注意到了 管状结构的特殊性 ,并提出了动态蛇形卷积(Dynamic Snake Convolution)这个方法。动态蛇形卷积通过自适应地聚焦于细长和迂回的局部结构,准确地捕捉管状结构的特征。这种卷积方法的核心思想是, 通过动态形状的卷积核来增强感知能力,针对管状结构的特征提取进行优化。

总之动态蛇形卷积是一种针对管状结构分割任务的创新方法, 在许多模型上添加针对一些数据集都能够有效的涨点 其具有重要性和广泛的应用领域。


二、动态蛇形卷积背景和原理

论文代码地址: 动态蛇形卷积官方代码下载地址
论文地址: 【免费】动态蛇形卷积(DynamicSnakeConvolution)资源-CSDN文库

背景-> 动态蛇形卷积(Dynamic Snake Convolution)来源于临床医学,清晰勾画血管是计算流体力学研究的关键前提,并能协助放射科医师进行诊断和定位病变。在遥感应用中,完整的道路分割为路径规划提供了坚实的基础。无论是哪个领域,这些结构都具有细长和曲折的共同特征,使得它们很难在图像中捕捉到,因为它们在图像中的比例很小。因此, 迫切需要提升对细长管状结构的感知能力 所以在这一背景下作者提出了动态蛇形卷积(Dynamic Snake Convolution)。

原理-> 上图展示了一个 三维心脏血管数据集 和一个 二维远程道路数据集 。这两个数据集旨在提取管状结构,但由于 脆弱的局部结构和复杂的整体形态 ,这个任务面临着挑战。标准的 卷积核 旨在提取局部特征。基于此,设计了可变形卷积核以丰富它们的应用,并适应不同目标的几何变形。然而,由于前面提到的挑战,有效地聚焦于细小的管状结构是困难的。

由于以下困难,这仍然是一个具有挑战性的任务:

  1. 细小而脆弱的局部结构: 如上面的图所示,细小的结构仅占整体图像的一小部分,并且由于像素组成有限,这些结构容易受到复杂背景的干扰,使模型难以精确地区分目标的细微变化。因此,模型可能难以区分这些结构,导致分割结果出现断裂。

  2. 复杂而多变的整体形态: 上面的图片展示了细小管状结构的复杂和多变形态,即使在同一图像中也如此。不同区域中的目标呈现出形态上的变化,包括分支数量、分叉位置和路径长度等。当数据呈现出前所未见的形态结构时,模型可能会过度拟合已经见过的特征,导致在新的形态结构下泛化能力较弱。

为了应对上述障碍,提出了如下解决方案, 其中包括管状感知卷积核、多视角特征融合策略和拓扑连续性约束损失函数 。具体如下:

1. 针对细小且脆弱的局部结构所占比例小且难以聚焦的挑战 ,提出了动态蛇形卷积,通过自适应地聚焦于管状结构的细长曲线局部特征,增强对几何结构的感知。与可变形卷积不同,DSConv考虑到管状结构的蛇形形态,并通过约束补充自由学习过程,有针对性地增强对管状结构的感知。

2. 针对复杂和多变的整体形态的挑战 ,提出了一种多视角特征融合策略。在该方法中,基于DSConv生成多个形态学卷积核模板,从不同角度观察目标的结构特征,并通过总结典型的重要特征实现高效的特征融合。

3. 针对管状结构分割容易出现断裂的问题 ,提出了基于持久同调(Persistent Homology,PH)的拓扑连续性约束损失函数(TCLoss)。PH是一种从出现到消失的拓扑特征响应过程,能够从嘈杂的高维数据中获取足够的拓扑信息。相关的贝蒂数是描述拓扑空间连通性的一种方式。与其他方法不同, TCLoss将PH与点集相似性相结合 ,引导网络关注具有 异常 像素/体素分布的断裂区域,从拓扑角度实现连续性约束。

总结: 为了克服挑战,提出了DSCNet框架,包括管状感知卷积核、多视角特征融合策略和拓扑连续性约束损失函数。DSConv增强了对细长曲线特征的感知,多视角特征融合策略提高了对复杂整体形态的处理能力,而TCLoss基于持久同调实现了从拓扑角度的连续性约束。


三、动态蛇形卷积的优势

为了提高对管状结构的 性能 ,已经提出了各种方法,根据管状结构的形态设计了特定的网络架构和模块。具体如下:

1. 基于卷积核设计的方法: 著名的扩张卷积(dilated convolution)和可变形卷积(deformable convolution)等方法被提出来处理 卷积神经网络 中固有的几何变换限制,并在复杂的检测和分割任务中取得了出色的表现。这些方法还被设计用于动态感知对象的几何特征,以适应具有可变形态的结构。例如,DUNet。

2. 基于形态学的方法: 一些方法专注于利用形态学信息来处理管状结构。例如,形态学重建网络(Morphological Reconstruction Network)利用形态学重建操作来重建管状结构,从而实现更准确的分割。另外,形态学操作如开运算和闭运算也被广泛应用于处理管状结构。

3. 基于拓扑学的方法: 拓扑学方法被用来处理管状结构的拓扑特征。例如,基于持久同调(Persistent Homology)的方法可以从高维数据中获取拓扑信息,并用于分析管状结构的连通性和形态特征。

总结: 为了处理管状结构,已经提出了多种方法。这些方法包括基于卷积核设计的方法、基于形态学的方法和基于拓扑学的方法。这些方法的目标是通过设计适应管状结构形态的网络架构和模块,提高对管状结构的检测和分割性能。

优势-> 以上所述的方法都只是从单一的角度去分析,DSConv提出了一种多角度特征融合策略,从多个角度补充对重要特征的关注。在这个策略中,基于动态蛇形卷积(DSConv)生成多个形态学卷积核模板,从多个角度观察目标的结构特征,并通过总结关键的标准特征实现特征融合,从而提高我们模型的性能。


四、实验和结果

4.1 数据集

使用了三个数据集来验证我们的框架,其中包括两个公开数据集和一个内部数据集。在2D方面,评估了DRIVE视网膜数据集和马萨诸塞道路数据集。在3D方面,使用了一个名为Cardiac CCTA Data的数据集。


4.2 实验

进行了比较实验和消融研究,以证明DSCN的优势。与经典的分割网络U-Net 和2021年提出的用于血管分割的CS2-Net 进行比较,以验证准确性。为了验证网络设计性能,将2022年提出的用于视网膜血管分割的DCU-Net 进行了比较。为了验证特征融合的优势,将2021年提出的用于医学 图像分割 的Transunet 进行了比较。为了验证损失函数约束,将2021年提出的clDice和基于Wasserstein距离的TCLoss LWTC进行了比较。这些模型在相同的数据集上进行训练,并进行了精确的实现,通过以下指标进行评估。所有指标都是针对每个图像进行计算并求平均。

1. 体积得分: 使用平均Dice系数(Dice)、相对Dice系数(RDice)、中心线Dice(clDice)、准确度(ACC)和AUC来评估结果的性能。
2. 拓扑错误: 计算基于拓扑的得分,包括Betti数β0和β1的Betti错误。同时,为了客观验证冠状动脉分割的连续性,使用直到第一个错误的重叠(OF)来评估提取的中心线的完整性。
3. 距离错误: Hausdorff距离(HD)也被广泛用于描述两组点之间的相似性,推荐用于评估薄管状结构的相似性。


4.3 实验结果

在下面的表格中展示了DSCNet方法在每个指标上的优势,结果表明提出的DSCNet在2D和3D数据集上取得了更好的结果。

在DRIVE数据集上的评估中,DSCNet在分割准确性和拓扑连续性方面均优于其他模型。在下面的表格中,与其他方法相比,DSCNet在体积准确性方面取得了最佳的分割结果,Dice系数为82.06%,RDice系数为90.17%,clDice系数为82.07%,准确度为96.87%,AUC为90.27%。同时,从拓扑的角度来看,与其他方法相比,DSCNet在拓扑连续性上取得了最好的结果,β0错误为0.998,β1错误为0.803。结果显示,DSCNet方法更好地捕捉了薄管状结构的特征,并展现出更准确的分割性能和更连续的 拓扑结构 。正如表格1中第6行到第12行所示,在引入TCLoss后,不同的模型在分割的拓扑连续性方面均有所改善。结果表明,TCLoss能够准确地约束模型关注失去拓扑连续性的薄管状结构。在ROADS数据集上的评估中,DSCNet同样取得了最佳结果。如表格1所示,与其他方法相比,提出的带有TCLoss的DSCNet在分割结果上取得了最佳的效果,Dice系数为78.21%,RDice系数为85.85%,clDice系数为87.64%。与经典的分割网络UNet的结果相比,DSCNet的方法在Dice系数、RDice系数和clDice系数上分别改善了最多1.31%、1.78%和0.77%。结果显示,与其他模型相比,DSCNet的模型在结构复杂且形态多变的道路数据集上也表现良好。

在CORONARY数据集上的评估中,验证了DSCNet在3D薄管状结构分割任务上仍然取得了最佳结果。如下面的表格所示,与其他方法相比,提出的DSCNet在分割结果上取得了最佳的效果,Dice系数为80.27%,RDice系数为86.37%,clDice系数为85.26%。与经典的分割网络UNet的结果相比,DSCNet方法在Dice系数、RDice系数和clDice系数上分别改善了最多3.40%、1.89%和3.83%。同时,使用OF指标来评估分割的连续性。使用DSCNet的方法,LAD的OF指标提升了6.00%,LCX的OF指标提升了3.78%,而RCA的OF指标提升了3.30%


4.4 有效性展示

DSCNet和TCLoss在各个方面都具有决定性的视觉优势。

(1) 为了展示DSCNet的有效性下面的图片中。从左到右,第三到第五列展示了不同网络在分割准确性方面的表现。由于DSConv能够自适应地感知关键特征,DSCNet的方法在分割结果上表现出色。与其他方法相比,DSCNet的方法能够更好地捕捉和保留薄管状结构的细节。

(2) 为了验证DSCNet的TCLoss的有效性,第六列展示了在没有使用TCLoss的情况下的分割结果。可以看出,没有TCLoss的方法在拓扑连续性方面存在明显的问题,而DSCNet的方法能够通过TCLoss准确地约束分割结果的拓扑结构,使得分割结果更加连续。

(3) 在第七列和第八列中,展示了DSCNet在不同数据集上的分割结果。可以看到,DSCNet在DRIVE和ROADS数据集上都能取得准确且连续的分割结果,进一步证明了DSCNet的通用性和鲁棒性。

总的来说,从图6可以清楚地看到我们的DSCNet和TCLoss在分割准确性和拓扑连续性方面的显著优势,这进一步证明了我们方法的有效性和优越性。

DSConv能够动态地适应管状结构的形状,并且注意力能够很好地适配目标。

(1) 适应管状结构的形状。下面的图片中的顶部显示了卷积核的位置和形状。可视化结果显示,DSConv能够很好地适应管状结构并保持形状,而可变形卷积则在目标外部游走。

(2) 关注管状结构的位置。下面的图片的底部显示了给定点的注意力热力图。结果显示,DSConv最亮的区域集中在管状结构上,这表示DSConv对管状结构更加敏感。

这些结果表明,我们的DSConv能够有效地适应和关注管状结构,从而使得DSCNet能够更好地捕捉和分割这些结构。


五、核心代码

使用方式看章节六


  1. import torch
  2. import torch.nn as nn
  3. __all__ = ['C3k2_DSConv']
  4. def autopad(k, p=None, d=1): # kernel, padding, dilation
  5. """Pad to 'same' shape outputs."""
  6. if d > 1:
  7. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  8. if p is None:
  9. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  10. return p
  11. class Conv(nn.Module):
  12. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  13. default_act = nn.SiLU() # default activation
  14. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  15. """Initialize Conv layer with given arguments including activation."""
  16. super().__init__()
  17. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  18. self.bn = nn.BatchNorm2d(c2)
  19. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  20. def forward(self, x):
  21. """Apply convolution, batch normalization and activation to input tensor."""
  22. return self.act(self.bn(self.conv(x)))
  23. def forward_fuse(self, x):
  24. """Perform transposed convolution of 2D data."""
  25. return self.act(self.conv(x))
  26. class DySnakeConv(nn.Module):
  27. def __init__(self, inc, ouc, k=3) -> None:
  28. super().__init__()
  29. self.conv_0 = Conv(inc, ouc, k)
  30. self.conv_x = DSConv(inc, ouc, 0, k)
  31. self.conv_y = DSConv(inc, ouc, 1, k)
  32. def forward(self, x):
  33. return torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1)
  34. class DSConv(nn.Module):
  35. def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
  36. """
  37. The Dynamic Snake Convolution
  38. :param in_ch: input channel
  39. :param out_ch: output channel
  40. :param kernel_size: the size of kernel
  41. :param extend_scope: the range to expand (default 1 for this method)
  42. :param morph: the morphology of the convolution kernel is mainly divided into two types
  43. along the x-axis (0) and the y-axis (1) (see the paper for details)
  44. :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
  45. """
  46. super(DSConv, self).__init__()
  47. # use the <offset_conv> to learn the deformable offset
  48. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
  49. self.bn = nn.BatchNorm2d(2 * kernel_size)
  50. self.kernel_size = kernel_size
  51. # two types of the DSConv (along x-axis and y-axis)
  52. self.dsc_conv_x = nn.Conv2d(
  53. in_ch,
  54. out_ch,
  55. kernel_size=(kernel_size, 1),
  56. stride=(kernel_size, 1),
  57. padding=0,
  58. )
  59. self.dsc_conv_y = nn.Conv2d(
  60. in_ch,
  61. out_ch,
  62. kernel_size=(1, kernel_size),
  63. stride=(1, kernel_size),
  64. padding=0,
  65. )
  66. self.gn = nn.GroupNorm(out_ch // 4, out_ch)
  67. self.act = Conv.default_act
  68. self.extend_scope = extend_scope
  69. self.morph = morph
  70. self.if_offset = if_offset
  71. def forward(self, f):
  72. offset = self.offset_conv(f)
  73. offset = self.bn(offset)
  74. # We need a range of deformation between -1 and 1 to mimic the snake's swing
  75. offset = torch.tanh(offset)
  76. input_shape = f.shape
  77. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
  78. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
  79. if self.morph == 0:
  80. x = self.dsc_conv_x(deformed_feature.type(f.dtype))
  81. x = self.gn(x)
  82. x = self.act(x)
  83. return x
  84. else:
  85. x = self.dsc_conv_y(deformed_feature.type(f.dtype))
  86. x = self.gn(x)
  87. x = self.act(x)
  88. return x
  89. # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
  90. class DSC(object):
  91. def __init__(self, input_shape, kernel_size, extend_scope, morph):
  92. self.num_points = kernel_size
  93. self.width = input_shape[2]
  94. self.height = input_shape[3]
  95. self.morph = morph
  96. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
  97. # define feature map shape
  98. """
  99. B: Batch size C: Channel W: Width H: Height
  100. """
  101. self.num_batch = input_shape[0]
  102. self.num_channels = input_shape[1]
  103. """
  104. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
  105. output_x: [B,1,W,K*H] coordinate map
  106. output_y: [B,1,K*W,H] coordinate map
  107. """
  108. def _coordinate_map_3D(self, offset, if_offset):
  109. device = offset.device
  110. # offset
  111. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
  112. y_center = torch.arange(0, self.width).repeat([self.height])
  113. y_center = y_center.reshape(self.height, self.width)
  114. y_center = y_center.permute(1, 0)
  115. y_center = y_center.reshape([-1, self.width, self.height])
  116. y_center = y_center.repeat([self.num_points, 1, 1]).float()
  117. y_center = y_center.unsqueeze(0)
  118. x_center = torch.arange(0, self.height).repeat([self.width])
  119. x_center = x_center.reshape(self.width, self.height)
  120. x_center = x_center.permute(0, 1)
  121. x_center = x_center.reshape([-1, self.width, self.height])
  122. x_center = x_center.repeat([self.num_points, 1, 1]).float()
  123. x_center = x_center.unsqueeze(0)
  124. if self.morph == 0:
  125. """
  126. Initialize the kernel and flatten the kernel
  127. y: only need 0
  128. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  129. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
  130. """
  131. y = torch.linspace(0, 0, 1)
  132. x = torch.linspace(
  133. -int(self.num_points // 2),
  134. int(self.num_points // 2),
  135. int(self.num_points),
  136. )
  137. y, x = torch.meshgrid(y, x)
  138. y_spread = y.reshape(-1, 1)
  139. x_spread = x.reshape(-1, 1)
  140. y_grid = y_spread.repeat([1, self.width * self.height])
  141. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  142. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
  143. x_grid = x_spread.repeat([1, self.width * self.height])
  144. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  145. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
  146. y_new = y_center + y_grid
  147. x_new = x_center + x_grid
  148. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
  149. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
  150. y_offset_new = y_offset.detach().clone()
  151. if if_offset:
  152. y_offset = y_offset.permute(1, 0, 2, 3)
  153. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
  154. center = int(self.num_points // 2)
  155. # The center position remains unchanged and the rest of the positions begin to swing
  156. # This part is quite simple. The main idea is that "offset is an iterative process"
  157. y_offset_new[center] = 0
  158. for index in range(1, center):
  159. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
  160. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
  161. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
  162. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
  163. y_new = y_new.reshape(
  164. [self.num_batch, self.num_points, 1, self.width, self.height])
  165. y_new = y_new.permute(0, 3, 1, 4, 2)
  166. y_new = y_new.reshape([
  167. self.num_batch, self.num_points * self.width, 1 * self.height
  168. ])
  169. x_new = x_new.reshape(
  170. [self.num_batch, self.num_points, 1, self.width, self.height])
  171. x_new = x_new.permute(0, 3, 1, 4, 2)
  172. x_new = x_new.reshape([
  173. self.num_batch, self.num_points * self.width, 1 * self.height
  174. ])
  175. return y_new, x_new
  176. else:
  177. """
  178. Initialize the kernel and flatten the kernel
  179. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  180. x: only need 0
  181. """
  182. y = torch.linspace(
  183. -int(self.num_points // 2),
  184. int(self.num_points // 2),
  185. int(self.num_points),
  186. )
  187. x = torch.linspace(0, 0, 1)
  188. y, x = torch.meshgrid(y, x)
  189. y_spread = y.reshape(-1, 1)
  190. x_spread = x.reshape(-1, 1)
  191. y_grid = y_spread.repeat([1, self.width * self.height])
  192. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  193. y_grid = y_grid.unsqueeze(0)
  194. x_grid = x_spread.repeat([1, self.width * self.height])
  195. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  196. x_grid = x_grid.unsqueeze(0)
  197. y_new = y_center + y_grid
  198. x_new = x_center + x_grid
  199. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
  200. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
  201. y_new = y_new.to(device)
  202. x_new = x_new.to(device)
  203. x_offset_new = x_offset.detach().clone()
  204. if if_offset:
  205. x_offset = x_offset.permute(1, 0, 2, 3)
  206. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
  207. center = int(self.num_points // 2)
  208. x_offset_new[center] = 0
  209. for index in range(1, center):
  210. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
  211. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
  212. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
  213. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
  214. y_new = y_new.reshape(
  215. [self.num_batch, 1, self.num_points, self.width, self.height])
  216. y_new = y_new.permute(0, 3, 1, 4, 2)
  217. y_new = y_new.reshape([
  218. self.num_batch, 1 * self.width, self.num_points * self.height
  219. ])
  220. x_new = x_new.reshape(
  221. [self.num_batch, 1, self.num_points, self.width, self.height])
  222. x_new = x_new.permute(0, 3, 1, 4, 2)
  223. x_new = x_new.reshape([
  224. self.num_batch, 1 * self.width, self.num_points * self.height
  225. ])
  226. return y_new, x_new
  227. """
  228. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
  229. output: [N,1,K*D,K*W,K*H] deformed feature map
  230. """
  231. def _bilinear_interpolate_3D(self, input_feature, y, x):
  232. device = input_feature.device
  233. y = y.reshape([-1]).float()
  234. x = x.reshape([-1]).float()
  235. zero = torch.zeros([]).int()
  236. max_y = self.width - 1
  237. max_x = self.height - 1
  238. # find 8 grid locations
  239. y0 = torch.floor(y).int()
  240. y1 = y0 + 1
  241. x0 = torch.floor(x).int()
  242. x1 = x0 + 1
  243. # clip out coordinates exceeding feature map volume
  244. y0 = torch.clamp(y0, zero, max_y)
  245. y1 = torch.clamp(y1, zero, max_y)
  246. x0 = torch.clamp(x0, zero, max_x)
  247. x1 = torch.clamp(x1, zero, max_x)
  248. input_feature_flat = input_feature.flatten()
  249. input_feature_flat = input_feature_flat.reshape(
  250. self.num_batch, self.num_channels, self.width, self.height)
  251. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
  252. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
  253. dimension = self.height * self.width
  254. base = torch.arange(self.num_batch) * dimension
  255. base = base.reshape([-1, 1]).float()
  256. repeat = torch.ones([self.num_points * self.width * self.height
  257. ]).unsqueeze(0)
  258. repeat = repeat.float()
  259. base = torch.matmul(base, repeat)
  260. base = base.reshape([-1])
  261. base = base.to(device)
  262. base_y0 = base + y0 * self.height
  263. base_y1 = base + y1 * self.height
  264. # top rectangle of the neighbourhood volume
  265. index_a0 = base_y0 - base + x0
  266. index_c0 = base_y0 - base + x1
  267. # bottom rectangle of the neighbourhood volume
  268. index_a1 = base_y1 - base + x0
  269. index_c1 = base_y1 - base + x1
  270. # get 8 grid values
  271. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
  272. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
  273. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
  274. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
  275. # find 8 grid locations
  276. y0 = torch.floor(y).int()
  277. y1 = y0 + 1
  278. x0 = torch.floor(x).int()
  279. x1 = x0 + 1
  280. # clip out coordinates exceeding feature map volume
  281. y0 = torch.clamp(y0, zero, max_y + 1)
  282. y1 = torch.clamp(y1, zero, max_y + 1)
  283. x0 = torch.clamp(x0, zero, max_x + 1)
  284. x1 = torch.clamp(x1, zero, max_x + 1)
  285. x0_float = x0.float()
  286. x1_float = x1.float()
  287. y0_float = y0.float()
  288. y1_float = y1.float()
  289. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
  290. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
  291. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
  292. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
  293. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
  294. value_c1 * vol_c1)
  295. if self.morph == 0:
  296. outputs = outputs.reshape([
  297. self.num_batch,
  298. self.num_points * self.width,
  299. 1 * self.height,
  300. self.num_channels,
  301. ])
  302. outputs = outputs.permute(0, 3, 1, 2)
  303. else:
  304. outputs = outputs.reshape([
  305. self.num_batch,
  306. 1 * self.width,
  307. self.num_points * self.height,
  308. self.num_channels,
  309. ])
  310. outputs = outputs.permute(0, 3, 1, 2)
  311. return outputs
  312. def deform_conv(self, input, offset, if_offset):
  313. y, x = self._coordinate_map_3D(offset, if_offset)
  314. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
  315. return deformed_feature
  316. class Bottleneck(nn.Module):
  317. """Standard bottleneck."""
  318. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
  319. """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
  320. expansion.
  321. """
  322. super().__init__()
  323. c_ = int(c2 * e) # hidden channels
  324. self.cv1 = Conv(c1, c_, k[0], 1)
  325. self.cv2 = Conv(c_, c2, k[1], 1, g=g)
  326. self.add = shortcut and c1 == c2
  327. def forward(self, x):
  328. """'forward()' applies the YOLO FPN to input data."""
  329. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  330. class Bottleneck_DySnakeConv(Bottleneck):
  331. """Standard bottleneck with DySnakeConv."""
  332. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
  333. super().__init__(c1, c2, shortcut, g, k, e)
  334. c_ = int(c2 * e) # hidden channels
  335. self.cv2 = DySnakeConv(c_, c2, k[1])
  336. self.cv3 = Conv(c2 * 3, c2, k=1)
  337. def forward(self, x):
  338. """'forward()' applies the YOLOv5 FPN to input data."""
  339. return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
  340. class C2f(nn.Module):
  341. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  342. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
  343. """Initializes a CSP bottleneck with 2 convolutions and n Bottleneck blocks for faster processing."""
  344. super().__init__()
  345. self.c = int(c2 * e) # hidden channels
  346. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  347. self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  348. self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  349. def forward(self, x):
  350. """Forward pass through C2f layer."""
  351. y = list(self.cv1(x).chunk(2, 1))
  352. y.extend(m(y[-1]) for m in self.m)
  353. return self.cv2(torch.cat(y, 1))
  354. def forward_split(self, x):
  355. """Forward pass using split() instead of chunk()."""
  356. y = list(self.cv1(x).split((self.c, self.c), 1))
  357. y.extend(m(y[-1]) for m in self.m)
  358. return self.cv2(torch.cat(y, 1))
  359. class C3(nn.Module):
  360. """CSP Bottleneck with 3 convolutions."""
  361. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  362. """Initialize the CSP Bottleneck with given channels, number, shortcut, groups, and expansion values."""
  363. super().__init__()
  364. c_ = int(c2 * e) # hidden channels
  365. self.cv1 = Conv(c1, c_, 1, 1)
  366. self.cv2 = Conv(c1, c_, 1, 1)
  367. self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
  368. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
  369. def forward(self, x):
  370. """Forward pass through the CSP bottleneck with 2 convolutions."""
  371. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  372. class C3k_DSConv(C3):
  373. """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
  374. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
  375. """Initializes the C3k module with specified channels, number of layers, and configurations."""
  376. super().__init__(c1, c2, n, shortcut, g, e)
  377. c_ = int(c2 * e) # hidden channels
  378. # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
  379. self.m = nn.Sequential(*(Bottleneck_DySnakeConv(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
  380. class C3k2_DSConv(C2f):
  381. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  382. def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True):
  383. """Initializes the C3k2 module, a faster CSP Bottleneck with 2 convolutions and optional C3k blocks."""
  384. super().__init__(c1, c2, n, shortcut, g, e)
  385. self.m = nn.ModuleList(
  386. C3k_DSConv(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
  387. )
  388. # 在特征提取时用DSConv,在辅助特征融合时换回原先的Bottleneck
  389. if __name__ == "__main__":
  390. # Generating Sample image
  391. image_size = (1, 64, 240, 240)
  392. image = torch.rand(*image_size)
  393. # Model
  394. mobilenet_v1 = C3k2_DSConv(64, 64, c3k=True)
  395. out = mobilenet_v1(image)
  396. print(out.size())

六、需要改动代码的地方


6.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹( 用群内的文件的话已经有了无需新建) !然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


6.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。


6.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


6.4 修改四

按照我的添加在parse_model里添加即可。


七、DSConv的yaml文件和运行记录

7.1 DSConv的yaml文件

此版本的训练信息:YOLO11-C3k2-DSConv summary: 416 layers, 2,905,271 parameters, 2,905,255 gradients, 6.7 GFLOPs

改进说明在特征提取时用DSConv,在辅助特征融合时换回原先的Bottleneck

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
  8. s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
  9. m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
  10. l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  11. x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
  12. # YOLOv8.0n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 3, C2f_DSConv, [128, True]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 6, C2f_DSConv, [256, True]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 6, C2f_DSConv, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 3, C2f_DSConv, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. # YOLOv8.0n head
  26. head:
  27. - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  28. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  29. - [-1, 3, C2f, [512]] # 12
  30. - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  31. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  32. - [-1, 3, C2f_DSConv, [256]] # 15 (P3/8-small)
  33. - [-1, 1, Conv, [256, 3, 2]]
  34. - [[-1, 12], 1, Concat, [1]] # cat head P4
  35. - [-1, 3, C2f_DSConv, [512]] # 18 (P4/16-medium)
  36. - [-1, 1, Conv, [512, 3, 2]]
  37. - [[-1, 9], 1, Concat, [1]] # cat head P5
  38. - [-1, 3, C2f_DSConv, [1024]] # 21 (P5/32-large)
  39. - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)


7.2 DSConv的训练过程截图

下面是添加了 DSConv 的训练截图。


八、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~