学习资源站

YOLOv11改进-检测头篇-利用动态蛇形卷积DySnakeConv改进yolov11分割检测头(全网独家首发,Segment)

一、本文改进

本文给大家带来的最新改进机制是一种我进行优化的专用于分割的检测头,在分割的过程中,最困难的无非就是边缘的检测, 动态蛇形卷积 (Dynamic Snake Convolution) 通过自适应地聚焦于细长和迂回的局部结构,准确地捕捉管状结构的特征。这种卷积方法的核心思想是, 通过动态形状的卷积核来增强感知能力,针对管状结构的特征提取进行优化, 所以将这个卷积针对于YOLOv11的分割头进行融合是非常合适的,当然本文的检测头也支持用于目标检测,但是我将其设计出来是主要为了分割的读者使用的。

欢迎大家订阅我的专栏一起学习YOLO!



二、DySnakeConv的框架原理

论文代码地址: 动态蛇形卷积官方代码下载地址
论文地址: 【免费】动态蛇形卷积(DynamicSnakeConvolution)资源-CSDN文库

背景-> 动态蛇形卷积(Dynamic Snake Convolution)来源于临床医学,清晰勾画血管是计算流体力学研究的关键前提,并能协助放射科医师进行诊断和定位病变。在遥感应用中,完整的道路分割为路径规划提供了坚实的基础。无论是哪个领域,这些结构都具有细长和曲折的共同特征,使得它们很难在图像中捕捉到,因为它们在图像中的比例很小。因此, 迫切需要提升对细长管状结构的感知能力 ,所以在这一背景下作者提出了动态蛇形卷积(Dynamic Snake Convolution)。

原理-> 上图展示了一个 三维心脏血管数据集 和一个 二维远程道路数据集 。这两个数据集旨在提取管状结构,但由于 脆弱的局部结构和复杂的整体形态 ,这个任务面临着挑战。标准的 卷积核 旨在提取局部特征。基于此,设计了可变形卷积核以丰富它们的应用,并适应不同目标的几何变形。然而,由于前面提到的挑战,有效地聚焦于细小的管状结构是困难的。

由于以下困难,这仍然是一个具有挑战性的任务:

  1. 细小而脆弱的局部结构: 如上面的图所示,细小的结构仅占整体图像的一小部分,并且由于像素组成有限,这些结构容易受到复杂背景的干扰,使模型难以精确地区分目标的细微变化。因此,模型可能难以区分这些结构,导致分割结果出现断裂。

  2. 复杂而多变的整体形态: 上面的图片展示了细小管状结构的复杂和多变形态,即使在同一图像中也如此。不同区域中的目标呈现出形态上的变化,包括分支数量、分叉位置和路径长度等。当数据呈现出前所未见的形态结构时,模型可能会过度拟合已经见过的特征,导致在新的形态结构下泛化能力较弱。

为了应对上述障碍,提出了如下解决方案, 其中包括管状感知卷积核、多视角特征融合策略和拓扑连续性约束损失函数 。具体如下:

1. 针对细小且脆弱的局部结构所占比例小且难以聚焦的挑战 ,提出了动态蛇形卷积,通过自适应地聚焦于管状结构的细长曲线局部特征,增强对几何结构的感知。与可变形卷积不同,DSConv考虑到管状结构的蛇形形态,并通过约束补充自由学习过程,有针对性地增强对管状结构的感知。

2. 针对复杂和多变的整体形态的挑战 ,提出了一种多视角特征融合策略。在该方法中,基于DSConv生成多个形态学卷积核 模板 ,从不同角度观察目标的结构特征,并通过总结典型的重要特征实现高效的特征融合。

3. 针对管状结构分割容易出现断裂的问题 ,提出了基于持久同调(Persistent Homology,PH)的拓扑连续性约束 损失函数 (TCLoss)。PH是一种从出现到消失的拓扑特征响应过程,能够从嘈杂的高维数据中获取足够的拓扑信息。相关的贝蒂数是描述拓扑空间连通性的一种方式。与其他方法不同, TCLoss将PH与点集相似性相结合 ,引导网络关注具有异常像素/体素分布的断裂区域,从拓扑角度实现连续性约束。

总结:为了克服挑战,提出了DSCNet框架,包括管状感知卷积核、多视角特征融合策略和拓扑连续性约束损失函数。DSConv增强了对细长曲线特征的感知,多视角特征融合策略提高了对复杂整体形态的处理能力,而TCLoss基于持久同调实现了从拓扑角度的连续性约束。

三、DySnakerConv的核心代码

下面的检测头可以用于分割和目标检测,但是其修改教程有差别。目标检测的检测头我就不讲了,大家看我之前的检测头如何添加的就可以,都是一样的只是换一个名字,本文主要针对于分割的读者。

  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import math
  5. from ultralytics.utils.checks import check_version
  6. __all__ = ['DSDConvSegment', 'DSDConvHead']
  7. TORCH_1_10 = check_version(torch.__version__, '1.10.0')
  8. def make_anchors(feats, strides, grid_cell_offset=0.5):
  9. """Generate anchors from features."""
  10. anchor_points, stride_tensor = [], []
  11. assert feats is not None
  12. dtype, device = feats[0].dtype, feats[0].device
  13. for i, stride in enumerate(strides):
  14. _, _, h, w = feats[i].shape
  15. sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
  16. sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
  17. sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
  18. anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
  19. stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
  20. return torch.cat(anchor_points), torch.cat(stride_tensor)
  21. def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
  22. """Transform distance(ltrb) to box(xywh or xyxy)."""
  23. lt, rb = distance.chunk(2, dim)
  24. x1y1 = anchor_points - lt
  25. x2y2 = anchor_points + rb
  26. if xywh:
  27. c_xy = (x1y1 + x2y2) / 2
  28. wh = x2y2 - x1y1
  29. return torch.cat((c_xy, wh), dim) # xywh bbox
  30. return torch.cat((x1y1, x2y2), dim) # xyxy bbox
  31. class DFL(nn.Module):
  32. """
  33. Integral module of Distribution Focal Loss (DFL).
  34. Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
  35. """
  36. def __init__(self, c1=16):
  37. """Initialize a convolutional layer with a given number of input channels."""
  38. super().__init__()
  39. self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
  40. x = torch.arange(c1, dtype=torch.float)
  41. self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
  42. self.c1 = c1
  43. def forward(self, x):
  44. """Applies a transformer layer on input tensor 'x' and returns a tensor."""
  45. b, c, a = x.shape # batch, channels, anchors
  46. return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
  47. # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
  48. class Proto(nn.Module):
  49. """YOLOv8 mask Proto module for segmentation models."""
  50. def __init__(self, c1, c_=256, c2=32):
  51. """
  52. Initializes the YOLOv8 mask Proto module with specified number of protos and masks.
  53. Input arguments are ch_in, number of protos, number of masks.
  54. """
  55. super().__init__()
  56. self.cv1 = Conv(c1, c_, k=3)
  57. self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
  58. self.cv2 = Conv(c_, c_, k=3)
  59. self.cv3 = Conv(c_, c2)
  60. def forward(self, x):
  61. """Performs a forward pass through layers using an upsampled input image."""
  62. return self.cv3(self.cv2(self.upsample(self.cv1(x))))
  63. def autopad(k, p=None, d=1): # kernel, padding, dilation
  64. """Pad to 'same' shape outputs."""
  65. if d > 1:
  66. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  67. if p is None:
  68. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  69. return p
  70. class Conv(nn.Module):
  71. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  72. default_act = nn.SiLU() # default activation
  73. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  74. """Initialize Conv layer with given arguments including activation."""
  75. super().__init__()
  76. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  77. self.bn = nn.BatchNorm2d(c2)
  78. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  79. def forward(self, x):
  80. """Apply convolution, batch normalization and activation to input tensor."""
  81. return self.act(self.bn(self.conv(x)))
  82. def forward_fuse(self, x):
  83. """Perform transposed convolution of 2D data."""
  84. return self.act(self.conv(x))
  85. class DSConv(nn.Module):
  86. def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
  87. """
  88. The Dynamic Snake Convolution
  89. :param in_ch: input channel
  90. :param out_ch: output channel
  91. :param kernel_size: the size of kernel
  92. :param extend_scope: the range to expand (default 1 for this method)
  93. :param morph: the morphology of the convolution kernel is mainly divided into two types
  94. along the x-axis (0) and the y-axis (1) (see the paper for details)
  95. :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
  96. """
  97. super(DSConv, self).__init__()
  98. # use the <offset_conv> to learn the deformable offset
  99. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
  100. self.bn = nn.BatchNorm2d(2 * kernel_size)
  101. self.kernel_size = kernel_size
  102. # two types of the DSConv (along x-axis and y-axis)
  103. self.dsc_conv_x = nn.Conv2d(
  104. in_ch,
  105. out_ch,
  106. kernel_size=(kernel_size, 1),
  107. stride=(kernel_size, 1),
  108. padding=0,
  109. )
  110. self.dsc_conv_y = nn.Conv2d(
  111. in_ch,
  112. out_ch,
  113. kernel_size=(1, kernel_size),
  114. stride=(1, kernel_size),
  115. padding=0,
  116. )
  117. self.gn = nn.GroupNorm(out_ch // 4, out_ch)
  118. self.act = Conv.default_act
  119. self.extend_scope = extend_scope
  120. self.morph = morph
  121. self.if_offset = if_offset
  122. def forward(self, f):
  123. offset = self.offset_conv(f)
  124. offset = self.bn(offset)
  125. # We need a range of deformation between -1 and 1 to mimic the snake's swing
  126. offset = torch.tanh(offset)
  127. input_shape = f.shape
  128. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
  129. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
  130. if self.morph == 0:
  131. x = self.dsc_conv_x(deformed_feature.type(f.dtype))
  132. x = self.gn(x)
  133. x = self.act(x)
  134. return x
  135. else:
  136. x = self.dsc_conv_y(deformed_feature.type(f.dtype))
  137. x = self.gn(x)
  138. x = self.act(x)
  139. return x
  140. # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
  141. class DSC(object):
  142. def __init__(self, input_shape, kernel_size, extend_scope, morph):
  143. self.num_points = kernel_size
  144. self.width = input_shape[2]
  145. self.height = input_shape[3]
  146. self.morph = morph
  147. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
  148. # define feature map shape
  149. """
  150. B: Batch size C: Channel W: Width H: Height
  151. """
  152. self.num_batch = input_shape[0]
  153. self.num_channels = input_shape[1]
  154. """
  155. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
  156. output_x: [B,1,W,K*H] coordinate map
  157. output_y: [B,1,K*W,H] coordinate map
  158. """
  159. def _coordinate_map_3D(self, offset, if_offset):
  160. device = offset.device
  161. # offset
  162. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
  163. y_center = torch.arange(0, self.width).repeat([self.height])
  164. y_center = y_center.reshape(self.height, self.width)
  165. y_center = y_center.permute(1, 0)
  166. y_center = y_center.reshape([-1, self.width, self.height])
  167. y_center = y_center.repeat([self.num_points, 1, 1]).float()
  168. y_center = y_center.unsqueeze(0)
  169. x_center = torch.arange(0, self.height).repeat([self.width])
  170. x_center = x_center.reshape(self.width, self.height)
  171. x_center = x_center.permute(0, 1)
  172. x_center = x_center.reshape([-1, self.width, self.height])
  173. x_center = x_center.repeat([self.num_points, 1, 1]).float()
  174. x_center = x_center.unsqueeze(0)
  175. if self.morph == 0:
  176. """
  177. Initialize the kernel and flatten the kernel
  178. y: only need 0
  179. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  180. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
  181. """
  182. y = torch.linspace(0, 0, 1)
  183. x = torch.linspace(
  184. -int(self.num_points // 2),
  185. int(self.num_points // 2),
  186. int(self.num_points),
  187. )
  188. y, x = torch.meshgrid(y, x)
  189. y_spread = y.reshape(-1, 1)
  190. x_spread = x.reshape(-1, 1)
  191. y_grid = y_spread.repeat([1, self.width * self.height])
  192. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  193. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
  194. x_grid = x_spread.repeat([1, self.width * self.height])
  195. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  196. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
  197. y_new = y_center + y_grid
  198. x_new = x_center + x_grid
  199. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
  200. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
  201. y_offset_new = y_offset.detach().clone()
  202. if if_offset:
  203. y_offset = y_offset.permute(1, 0, 2, 3)
  204. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
  205. center = int(self.num_points // 2)
  206. # The center position remains unchanged and the rest of the positions begin to swing
  207. # This part is quite simple. The main idea is that "offset is an iterative process"
  208. y_offset_new[center] = 0
  209. for index in range(1, center):
  210. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
  211. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
  212. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
  213. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
  214. y_new = y_new.reshape(
  215. [self.num_batch, self.num_points, 1, self.width, self.height])
  216. y_new = y_new.permute(0, 3, 1, 4, 2)
  217. y_new = y_new.reshape([
  218. self.num_batch, self.num_points * self.width, 1 * self.height
  219. ])
  220. x_new = x_new.reshape(
  221. [self.num_batch, self.num_points, 1, self.width, self.height])
  222. x_new = x_new.permute(0, 3, 1, 4, 2)
  223. x_new = x_new.reshape([
  224. self.num_batch, self.num_points * self.width, 1 * self.height
  225. ])
  226. return y_new, x_new
  227. else:
  228. """
  229. Initialize the kernel and flatten the kernel
  230. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  231. x: only need 0
  232. """
  233. y = torch.linspace(
  234. -int(self.num_points // 2),
  235. int(self.num_points // 2),
  236. int(self.num_points),
  237. )
  238. x = torch.linspace(0, 0, 1)
  239. y, x = torch.meshgrid(y, x)
  240. y_spread = y.reshape(-1, 1)
  241. x_spread = x.reshape(-1, 1)
  242. y_grid = y_spread.repeat([1, self.width * self.height])
  243. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  244. y_grid = y_grid.unsqueeze(0)
  245. x_grid = x_spread.repeat([1, self.width * self.height])
  246. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  247. x_grid = x_grid.unsqueeze(0)
  248. y_new = y_center + y_grid
  249. x_new = x_center + x_grid
  250. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
  251. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
  252. y_new = y_new.to(device)
  253. x_new = x_new.to(device)
  254. x_offset_new = x_offset.detach().clone()
  255. if if_offset:
  256. x_offset = x_offset.permute(1, 0, 2, 3)
  257. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
  258. center = int(self.num_points // 2)
  259. x_offset_new[center] = 0
  260. for index in range(1, center):
  261. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
  262. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
  263. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
  264. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
  265. y_new = y_new.reshape(
  266. [self.num_batch, 1, self.num_points, self.width, self.height])
  267. y_new = y_new.permute(0, 3, 1, 4, 2)
  268. y_new = y_new.reshape([
  269. self.num_batch, 1 * self.width, self.num_points * self.height
  270. ])
  271. x_new = x_new.reshape(
  272. [self.num_batch, 1, self.num_points, self.width, self.height])
  273. x_new = x_new.permute(0, 3, 1, 4, 2)
  274. x_new = x_new.reshape([
  275. self.num_batch, 1 * self.width, self.num_points * self.height
  276. ])
  277. return y_new, x_new
  278. """
  279. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
  280. output: [N,1,K*D,K*W,K*H] deformed feature map
  281. """
  282. def _bilinear_interpolate_3D(self, input_feature, y, x):
  283. device = input_feature.device
  284. y = y.reshape([-1]).float()
  285. x = x.reshape([-1]).float()
  286. zero = torch.zeros([]).int()
  287. max_y = self.width - 1
  288. max_x = self.height - 1
  289. # find 8 grid locations
  290. y0 = torch.floor(y).int()
  291. y1 = y0 + 1
  292. x0 = torch.floor(x).int()
  293. x1 = x0 + 1
  294. # clip out coordinates exceeding feature map volume
  295. y0 = torch.clamp(y0, zero, max_y)
  296. y1 = torch.clamp(y1, zero, max_y)
  297. x0 = torch.clamp(x0, zero, max_x)
  298. x1 = torch.clamp(x1, zero, max_x)
  299. input_feature_flat = input_feature.flatten()
  300. input_feature_flat = input_feature_flat.reshape(
  301. self.num_batch, self.num_channels, self.width, self.height)
  302. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
  303. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
  304. dimension = self.height * self.width
  305. base = torch.arange(self.num_batch) * dimension
  306. base = base.reshape([-1, 1]).float()
  307. repeat = torch.ones([self.num_points * self.width * self.height
  308. ]).unsqueeze(0)
  309. repeat = repeat.float()
  310. base = torch.matmul(base, repeat)
  311. base = base.reshape([-1])
  312. base = base.to(device)
  313. base_y0 = base + y0 * self.height
  314. base_y1 = base + y1 * self.height
  315. # top rectangle of the neighbourhood volume
  316. index_a0 = base_y0 - base + x0
  317. index_c0 = base_y0 - base + x1
  318. # bottom rectangle of the neighbourhood volume
  319. index_a1 = base_y1 - base + x0
  320. index_c1 = base_y1 - base + x1
  321. # get 8 grid values
  322. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
  323. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
  324. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
  325. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
  326. # find 8 grid locations
  327. y0 = torch.floor(y).int()
  328. y1 = y0 + 1
  329. x0 = torch.floor(x).int()
  330. x1 = x0 + 1
  331. # clip out coordinates exceeding feature map volume
  332. y0 = torch.clamp(y0, zero, max_y + 1)
  333. y1 = torch.clamp(y1, zero, max_y + 1)
  334. x0 = torch.clamp(x0, zero, max_x + 1)
  335. x1 = torch.clamp(x1, zero, max_x + 1)
  336. x0_float = x0.float()
  337. x1_float = x1.float()
  338. y0_float = y0.float()
  339. y1_float = y1.float()
  340. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
  341. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
  342. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
  343. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
  344. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
  345. value_c1 * vol_c1)
  346. if self.morph == 0:
  347. outputs = outputs.reshape([
  348. self.num_batch,
  349. self.num_points * self.width,
  350. 1 * self.height,
  351. self.num_channels,
  352. ])
  353. outputs = outputs.permute(0, 3, 1, 2)
  354. else:
  355. outputs = outputs.reshape([
  356. self.num_batch,
  357. 1 * self.width,
  358. self.num_points * self.height,
  359. self.num_channels,
  360. ])
  361. outputs = outputs.permute(0, 3, 1, 2)
  362. return outputs
  363. def deform_conv(self, input, offset, if_offset):
  364. y, x = self._coordinate_map_3D(offset, if_offset)
  365. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
  366. return deformed_feature
  367. class DWConv(Conv):
  368. """Depth-wise convolution."""
  369. def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
  370. """Initialize Depth-wise convolution with given parameters."""
  371. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
  372. class DSDConvHead(nn.Module):
  373. """YOLOv8 Detect head for detection models. CSDNSnu77"""
  374. dynamic = False # force grid reconstruction
  375. export = False # export mode
  376. end2end = False # end2end
  377. max_det = 300 # max_det
  378. shape = None
  379. anchors = torch.empty(0) # init
  380. strides = torch.empty(0) # init
  381. def __init__(self, nc=80, ch=()):
  382. """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
  383. super().__init__()
  384. self.nc = nc # number of classes
  385. self.nl = len(ch) # number of detection layers
  386. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  387. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  388. self.stride = torch.zeros(self.nl) # strides computed during build
  389. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
  390. # self.DySnakeConv = nn.ModuleList(nn.Sequential(DSConv(x, x, 0), DSConv(x, x, 0)) for x in ch) # DySnakeConv
  391. # morph沿y轴进行更符合绝大多数可能.
  392. self.cv2 = nn.ModuleList(
  393. nn.Sequential(DSConv(x, c2,1, 3), DSConv(c2, c2,1, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
  394. )
  395. # 仅使用一个DSConv辅助边界框回归.
  396. # self.cv2 = nn.ModuleList(
  397. # nn.Sequential(DSConv(x, c2,1, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
  398. # )
  399. self.cv3 = nn.ModuleList(
  400. nn.Sequential(
  401. nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
  402. nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
  403. nn.Conv2d(c3, self.nc, 1),
  404. )
  405. for x in ch
  406. )
  407. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  408. if self.end2end:
  409. self.one2one_cv2 = copy.deepcopy(self.cv2)
  410. self.one2one_cv3 = copy.deepcopy(self.cv3)
  411. def forward(self, x):
  412. """Concatenates and returns predicted bounding boxes and class probabilities."""
  413. if self.end2end:
  414. return self.forward_end2end(x)
  415. for i in range(self.nl):
  416. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  417. if self.training: # Training path
  418. return x
  419. y = self._inference(x)
  420. return y if self.export else (y, x)
  421. def forward_end2end(self, x):
  422. """
  423. Performs forward pass of the v10Detect module.
  424. Args:
  425. x (tensor): Input tensor.
  426. Returns:
  427. (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
  428. If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
  429. """
  430. x_detach = [xi.detach() for xi in x]
  431. one2one = [
  432. torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
  433. ]
  434. for i in range(self.nl):
  435. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  436. if self.training: # Training path
  437. return {"one2many": x, "one2one": one2one}
  438. y = self._inference(one2one)
  439. y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
  440. return y if self.export else (y, {"one2many": x, "one2one": one2one})
  441. def _inference(self, x):
  442. """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
  443. # Inference path
  444. shape = x[0].shape # BCHW
  445. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  446. if self.dynamic or self.shape != shape:
  447. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  448. self.shape = shape
  449. if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
  450. box = x_cat[:, : self.reg_max * 4]
  451. cls = x_cat[:, self.reg_max * 4 :]
  452. else:
  453. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  454. if self.export and self.format in {"tflite", "edgetpu"}:
  455. # Precompute normalization factor to increase numerical stability
  456. # See https://github.com/ultralytics/ultralytics/issues/7371
  457. grid_h = shape[2]
  458. grid_w = shape[3]
  459. grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
  460. norm = self.strides / (self.stride[0] * grid_size)
  461. dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
  462. else:
  463. dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
  464. return torch.cat((dbox, cls.sigmoid()), 1)
  465. def bias_init(self):
  466. """Initialize Detect() biases, WARNING: requires stride availability."""
  467. m = self # self.model[-1] # Detect() module
  468. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  469. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  470. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  471. a[-1].bias.data[:] = 1.0 # box
  472. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  473. if self.end2end:
  474. for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
  475. a[-1].bias.data[:] = 1.0 # box
  476. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  477. def decode_bboxes(self, bboxes, anchors):
  478. """Decode bounding boxes."""
  479. return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
  480. @staticmethod
  481. def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
  482. """
  483. Post-processes YOLO model predictions.
  484. Args:
  485. preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
  486. format [x, y, w, h, class_probs].
  487. max_det (int): Maximum detections per image.
  488. nc (int, optional): Number of classes. Default: 80.
  489. Returns:
  490. (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
  491. dimension format [x, y, w, h, max_class_prob, class_index].
  492. """
  493. batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
  494. boxes, scores = preds.split([4, nc], dim=-1)
  495. index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
  496. boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
  497. scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
  498. scores, index = scores.flatten(1).topk(min(max_det, anchors))
  499. i = torch.arange(batch_size)[..., None] # batch indices
  500. return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
  501. class DSDConvSegment(DSDConvHead):
  502. """YOLOv8 Segment head for segmentation models."""
  503. def __init__(self, nc=80, nm=32, npr=256, ch=()):
  504. """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
  505. super().__init__(nc, ch)
  506. self.nm = nm # number of masks
  507. self.npr = npr # number of protos
  508. self.proto = Proto(ch[0], self.npr, self.nm) # protos
  509. c4 = max(ch[0] // 4, self.nm)
  510. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
  511. def forward(self, x):
  512. """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
  513. p = self.proto(x[0]) # mask protos
  514. bs = p.shape[0] # batch size
  515. mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
  516. x = DSDConvHead.forward(self, x)
  517. if self.training:
  518. return x, mc, p
  519. return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))


四、分割检测头修改教程


4.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py',然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块!

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四

按照我的进行添加即可,当然其中有些检测头你们的文件中可能没有,无需理会,主要看其周围的代码一直来寻找即可!


4.5 修改五

按照我下面的添加,分割的检测头此处添加两个请注意!


4.9 修改九

此处请注意原先是一个==号,然后现在变成in 然后需要额外注意的是此处的m系统会给转化成全小写,所以我们的名字也要变成全是小写的!!!


4.10 修改10

按照我的修改,此处为最后一步后面复制yaml文件运行即可!!!


五、 分割和目标检测的yaml文件

5.1 分割的yaml文件

训练信息:YOLO11-DSConvSegment summary: 374 layers, 2,755,359 parameters, 2,755,343 gradients, 9.5 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, Conv, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, Conv, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, DSDConvSegment, [nc, 32, 256]] # Detect(P3, P4, P5)

分割的训练代码,分割的数据集标注比较特殊目标检测数据集会报错大家注意!

  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. from ultralytics import YOLO
  4. if __name__ == '__main__':
  5. model = YOLO('yolo11-DSConvSegment.yaml') # 续训yaml文件的地方改为lats.pt的地址,需要注意的是如果你设置训练200轮次模型训练了200轮次是没有办法进行续训的.
  6. # 如何切换模型版本, 上面的ymal文件可以改为 yolov11s.yaml就是使用的v11s,
  7. # 类似某个改进的yaml文件名称为yolov11-XXX.yaml那么如果想使用其它版本就把上面的名称改为yolov11l-XXX.yaml即可(改的是上面YOLO中间的名字不是配置文件的)!
  8. # model.load('yolov11n.pt') # 是否加载预训练权重,科研不建议大家加载否则很难提升精度
  9. model.train(data=r"C:\Users\Administrator\Desktop\20240521\YOLOv8.2\SpotGEO2YOLO\data.yaml",
  10. # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
  11. task='segment',
  12. cache=False,
  13. imgsz=640,
  14. epochs=100,
  15. single_cls=False, # 是否是单类别检测
  16. batch=4,
  17. close_mosaic=0,
  18. workers=0,
  19. device='0',
  20. optimizer='SGD', # using SGD 优化器 默认为auto建议大家使用固定的.
  21. # resume=, # 续训的话这里填写True
  22. amp=True, # 如果出现训练损失为Nan可以关闭amp
  23. project='runs/train',
  24. name='exp',
  25. )

5.2 目标检测的yaml文件

目标检测的我上面没有提供教程,之前的检测头提供过很多了,大家直接随便找一个就行就是名字不一样了。

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, Conv, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, Conv, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, DSDConvHead, [nc]] # Detect(P3, P4, P5)


六、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~