学习资源站

29-添加DSConv注意力卷积(ICCV 2023_用于管状结构分割的动态蛇形卷积)_dsconv是哪篇论文提出的

YOLOv5改进系列(28)——添加DSConv注意力卷积(ICCV 2023|用于管状结构分割的动态蛇形卷积)

🚀 一、DSConv介绍 

学习资料:

1.1 DSConv简介

背景

管状结构(例如血管、道路)是临床、自然界等各领域场景中十分重要的一种结构,其精确分割可以保证下游任务的准确性与效率。但管状结构的精确提取仍然面临着众多挑战:

  • 细长且脆弱的局部结构。如下图所示,细长的结构仅占整个图像的一小部分,像素的组成有限。此外,这些结构容易受到复杂背景的干扰,因此模型很难精确分辨目标的细微变化,从而导致分割出现破碎与断裂。
  • 复杂且多变的全局形态。如下图所示,我们可以看出细长管状结构复杂多变的形态,即使在同一张图像中也是如此。位于不同区域的目标的形态变化取决于分支的数量、分叉的位置,路径长度以及其在图像中的位置。因此当数据表现出未曾见过的形态特征时,模型倾向于过拟合到已见过的特征,无法识别未见过的特征形态,从而导致泛化性较弱。

本文主要工作 

本文关注到管状结构细长连续的特点,并利用这一信息在神经网络以下三个阶段同时增强感知:特征提取特征融合损失约束。分别设计了动态蛇形卷积(Dynamic Snake Convolution)多视角特征融合策略连续性拓扑约束损失。我们同时给出了基于 2D 和 3D 的方法设计,通过实验证明了本文所提出的 DSCNet 在管状结构分割任务上提供了更好的精度和连续性。 


 1.2 动态蛇形卷积

目的:

  • 希望卷积核一方面能够自由地贴合结构学习特征
  • 另一方面能够在约束条件下不偏离目标结构太远

可变形卷积:

  • 操控单个卷积核形变的所有偏置(offset),是在网络中一次性全部学到的
  • 对于这一个偏置只有一个范围的约束,即感受野范围(extend)
  • 控制所有的卷积发生形变,是依赖于整个网络最终的损失约束回传,这个变化过程是相当自由的。 

1.3  多视角特征融合策略

目的:

  • 管状结构的走向视角不是单一的,因此在设计中融合多视角特征也是必然的选择。

挑战:

  • 融合更多的特征会导致更大的网络负载以及出现冗余

方法:

  • 在特征融合的训练过程中加入了分组dropout的策略,一定程度上缓解了网络内内存的压力并避免模型陷入过拟合。 

 1.4 连续性拓扑约束损失

目的:

  • 构建数据的拓扑结构,并提取复杂管状结构中的高维关系,也就是持续同源性(Persistence Homology, PH)。

启发:

  • 假设 PO 的上端存在着一个异常的离散点(横坐标表示出现的时间,纵坐标表示消失的时间),这表明存在一个构件直到最后才与其他构件获得连接从而消失。

方法: 

  • 本文中采用的是豪斯多夫距离(Hausdorff Distance, HD),HD 也是用于衡量点集相似度的一个重要算法,对离散点也非常敏感。
  1. # -*- coding: utf-8 -*-
  2. import torch
  3. from torch import nn
  4. from torch.nn.functional import max_pool3d
  5. class crossentry(nn.Module):
  6. def __init__(self):
  7. super().__init__()
  8. def forward(self, y_true, y_pred):
  9. smooth = 1e-6
  10. return -torch.mean(y_true * torch.log(y_pred + smooth))
  11. class cross_loss(nn.Module):
  12. def __init__(self):
  13. super().__init__()
  14. def forward(self, y_true, y_pred):
  15. smooth = 1e-6
  16. return -torch.mean(y_true * torch.log(y_pred + smooth) +
  17. (1 - y_true) * torch.log(1 - y_pred + smooth))
  18. '''
  19. Another Loss Function proposed by us in IEEE transactions on Image Precessing:
  20. Paper: https://ieeexplore.ieee.org/abstract/document/9611074
  21. Code: https://github.com/YaoleiQi/Examinee-Examiner-Network
  22. '''
  23. class Dropoutput_Layer(nn.Module):
  24. def __init__(self):
  25. super().__init__()
  26. def forward(self, y_true, y_pred, alpha=0.4):
  27. smooth = 1e-6
  28. w = torch.abs(y_true - y_pred)
  29. w = torch.round(w + alpha)
  30. loss_ce = (
  31. -((torch.sum(w * y_true * torch.log(y_pred + smooth)) /
  32. torch.sum(w * y_true + smooth)) +
  33. (torch.sum(w * (1 - y_true) * torch.log(1 - y_pred + smooth)) /
  34. torch.sum(w * (1 - y_true) + smooth))) / 2)
  35. return loss_ce

🚀二、具体添加方法

2.1 添加顺序 

(1)models/common.py    -->  加入新增的网络结构

(2)     models/yolo.py       -->  设定网络结构的传参细节,将DSConv类名加入其中。(当新的自定义模块中存在输入输出维度时,要使用qw调整输出维度)
(3) models/yolov5*.yaml  -->  新建一个文件夹,如yolov5s_DSConv.yaml,修改现有模型结构配置文件。(当引入新的层时,要修改后续的结构中的from参数)
(4)         train.py                -->  修改‘--cfg’默认参数,训练时指定模型结构配置文件 


2.2 具体添加步骤 

第①步:在common.py中添加DCConv模块 

将下面的DSConv代码复制粘贴到common.py文件的末尾。

  1. # by:迪菲赫尔曼
  2. import warnings
  3. import torch
  4. from torch import nn
  5. warnings.filterwarnings("ignore")
  6. """
  7. This code is mainly the deformation process of our DSConv
  8. """
  9. class DSConv(nn.Module):
  10. def __init__(self, in_ch, out_ch, kernel_size, extend_scope, morph,
  11. if_offset):
  12. """
  13. 动态蛇形卷积
  14. :param in_ch: 输入通道
  15. :param out_ch: 输出通道
  16. :param kernel_size: 卷积核的大小
  17. :param extend_scope: 扩展范围(默认为此方法的1)
  18. :param morph: 卷积核的形态主要分为两种类型,沿x轴(0)和沿y轴(1)(详细信息请参阅论文)
  19. :param if_offset: 是否需要变形,如果为False,则是标准卷积核
  20. """
  21. super(DSConv, self).__init__()
  22. # use the <offset_conv> to learn the deformable offset
  23. # offset_conv: 学习可变形偏移的卷积层
  24. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
  25. self.bn = nn.BatchNorm2d(2 * kernel_size)
  26. self.kernel_size = kernel_size
  27. # two types of the DSConv (along x-axis and y-axis)
  28. # dsc_conv_x 和 dsc_conv_y:两种动态蛇形卷积层,分别沿x轴和y轴。
  29. self.dsc_conv_x = nn.Conv2d(
  30. in_ch,
  31. out_ch,
  32. kernel_size=(kernel_size, 1),
  33. stride=(kernel_size, 1),
  34. padding=0,
  35. )
  36. self.dsc_conv_y = nn.Conv2d(
  37. in_ch,
  38. out_ch,
  39. kernel_size=(1, kernel_size),
  40. stride=(1, kernel_size),
  41. padding=0,
  42. )
  43. # gn:组归一化层
  44. self.gn = nn.GroupNorm(out_ch // 4, out_ch)
  45. self.relu = nn.ReLU(inplace=True)
  46. # extend_scope:扩展范围
  47. self.extend_scope = extend_scope
  48. # morph:卷积核形态的类型
  49. self.morph = morph
  50. # if_offset:指示是否需要变形的布尔值
  51. self.if_offset = if_offset
  52. def forward(self, f):
  53. offset = self.offset_conv(f)
  54. offset = self.bn(offset)
  55. # We need a range of deformation between -1 and 1 to mimic the snake's swing
  56. offset = torch.tanh(offset)
  57. input_shape = f.shape
  58. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
  59. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
  60. if self.morph == 0:
  61. x = self.dsc_conv_x(deformed_feature.type(f.dtype))
  62. x = self.gn(x)
  63. x = self.relu(x)
  64. return x
  65. else:
  66. x = self.dsc_conv_y(deformed_feature.type(f.dtype))
  67. x = self.gn(x)
  68. x = self.relu(x)
  69. return x
  70. # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
  71. class DSC(object):
  72. def __init__(self, input_shape, kernel_size, extend_scope, morph):
  73. self.num_points = kernel_size
  74. self.width = input_shape[2]
  75. self.height = input_shape[3]
  76. self.morph = morph
  77. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
  78. # define feature map shape
  79. """
  80. B: Batch size C: Channel W: Width H: Height
  81. """
  82. self.num_batch = input_shape[0]
  83. self.num_channels = input_shape[1]
  84. """
  85. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
  86. output_x: [B,1,W,K*H] coordinate map
  87. output_y: [B,1,K*W,H] coordinate map
  88. """
  89. def _coordinate_map_3D(self, offset, if_offset):
  90. """
  91. 1.输入为偏移 (offset) 和是否需要偏移 (if_offset)。
  92. 2.根据输入特征图的形状、卷积核大小、扩展范围以及形态类型,生成二维坐标映射。
  93. 3.如果形态类型为0,表示沿x轴,生成y坐标映射;如果形态类型为1,表示沿y轴,生成x坐标映射。
  94. 4.根据偏移和扩展范围调整坐标映射。
  95. 5.返回生成的坐标映射。
  96. """
  97. device = offset.device
  98. # offset
  99. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
  100. y_center = torch.arange(0, self.width).repeat([self.height])
  101. y_center = y_center.reshape(self.height, self.width)
  102. y_center = y_center.permute(1, 0)
  103. y_center = y_center.reshape([-1, self.width, self.height])
  104. y_center = y_center.repeat([self.num_points, 1, 1]).float()
  105. y_center = y_center.unsqueeze(0)
  106. x_center = torch.arange(0, self.height).repeat([self.width])
  107. x_center = x_center.reshape(self.width, self.height)
  108. x_center = x_center.permute(0, 1)
  109. x_center = x_center.reshape([-1, self.width, self.height])
  110. x_center = x_center.repeat([self.num_points, 1, 1]).float()
  111. x_center = x_center.unsqueeze(0)
  112. if self.morph == 0:
  113. """
  114. Initialize the kernel and flatten the kernel
  115. y: only need 0
  116. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  117. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
  118. """
  119. y = torch.linspace(0, 0, 1)
  120. x = torch.linspace(
  121. -int(self.num_points // 2),
  122. int(self.num_points // 2),
  123. int(self.num_points),
  124. )
  125. y, x = torch.meshgrid(y, x)
  126. y_spread = y.reshape(-1, 1)
  127. x_spread = x.reshape(-1, 1)
  128. y_grid = y_spread.repeat([1, self.width * self.height])
  129. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  130. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
  131. x_grid = x_spread.repeat([1, self.width * self.height])
  132. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  133. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
  134. y_new = y_center + y_grid
  135. x_new = x_center + x_grid
  136. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
  137. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
  138. y_offset_new = y_offset.detach().clone()
  139. if if_offset:
  140. y_offset = y_offset.permute(1, 0, 2, 3)
  141. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
  142. center = int(self.num_points // 2)
  143. # The center position remains unchanged and the rest of the positions begin to swing
  144. # This part is quite simple. The main idea is that "offset is an iterative process"
  145. y_offset_new[center] = 0
  146. for index in range(1, center):
  147. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
  148. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
  149. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
  150. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
  151. y_new = y_new.reshape(
  152. [self.num_batch, self.num_points, 1, self.width, self.height])
  153. y_new = y_new.permute(0, 3, 1, 4, 2)
  154. y_new = y_new.reshape([
  155. self.num_batch, self.num_points * self.width, 1 * self.height
  156. ])
  157. x_new = x_new.reshape(
  158. [self.num_batch, self.num_points, 1, self.width, self.height])
  159. x_new = x_new.permute(0, 3, 1, 4, 2)
  160. x_new = x_new.reshape([
  161. self.num_batch, self.num_points * self.width, 1 * self.height
  162. ])
  163. return y_new, x_new
  164. else:
  165. """
  166. Initialize the kernel and flatten the kernel
  167. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  168. x: only need 0
  169. """
  170. y = torch.linspace(
  171. -int(self.num_points // 2),
  172. int(self.num_points // 2),
  173. int(self.num_points),
  174. )
  175. x = torch.linspace(0, 0, 1)
  176. y, x = torch.meshgrid(y, x)
  177. y_spread = y.reshape(-1, 1)
  178. x_spread = x.reshape(-1, 1)
  179. y_grid = y_spread.repeat([1, self.width * self.height])
  180. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  181. y_grid = y_grid.unsqueeze(0)
  182. x_grid = x_spread.repeat([1, self.width * self.height])
  183. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  184. x_grid = x_grid.unsqueeze(0)
  185. y_new = y_center + y_grid
  186. x_new = x_center + x_grid
  187. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
  188. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
  189. y_new = y_new.to(device)
  190. x_new = x_new.to(device)
  191. x_offset_new = x_offset.detach().clone()
  192. if if_offset:
  193. x_offset = x_offset.permute(1, 0, 2, 3)
  194. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
  195. center = int(self.num_points // 2)
  196. x_offset_new[center] = 0
  197. for index in range(1, center):
  198. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
  199. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
  200. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
  201. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
  202. y_new = y_new.reshape(
  203. [self.num_batch, 1, self.num_points, self.width, self.height])
  204. y_new = y_new.permute(0, 3, 1, 4, 2)
  205. y_new = y_new.reshape([
  206. self.num_batch, 1 * self.width, self.num_points * self.height
  207. ])
  208. x_new = x_new.reshape(
  209. [self.num_batch, 1, self.num_points, self.width, self.height])
  210. x_new = x_new.permute(0, 3, 1, 4, 2)
  211. x_new = x_new.reshape([
  212. self.num_batch, 1 * self.width, self.num_points * self.height
  213. ])
  214. return y_new, x_new
  215. """
  216. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
  217. output: [N,1,K*D,K*W,K*H] deformed feature map
  218. """
  219. def _bilinear_interpolate_3D(self, input_feature, y, x):
  220. """
  221. 1.输入为输入特征图 (input_feature)、y坐标映射 (y) 和x坐标映射 (x)。
  222. 2.进行三维双线性插值,获取变形后的特征。
  223. 3.返回插值得到的变形特征。
  224. """
  225. device = input_feature.device
  226. y = y.reshape([-1]).float()
  227. x = x.reshape([-1]).float()
  228. zero = torch.zeros([]).int()
  229. max_y = self.width - 1
  230. max_x = self.height - 1
  231. # find 8 grid locations
  232. y0 = torch.floor(y).int()
  233. y1 = y0 + 1
  234. x0 = torch.floor(x).int()
  235. x1 = x0 + 1
  236. # clip out coordinates exceeding feature map volume
  237. y0 = torch.clamp(y0, zero, max_y)
  238. y1 = torch.clamp(y1, zero, max_y)
  239. x0 = torch.clamp(x0, zero, max_x)
  240. x1 = torch.clamp(x1, zero, max_x)
  241. input_feature_flat = input_feature.flatten()
  242. input_feature_flat = input_feature_flat.reshape(
  243. self.num_batch, self.num_channels, self.width, self.height)
  244. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
  245. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
  246. dimension = self.height * self.width
  247. base = torch.arange(self.num_batch) * dimension
  248. base = base.reshape([-1, 1]).float()
  249. repeat = torch.ones([self.num_points * self.width * self.height
  250. ]).unsqueeze(0)
  251. repeat = repeat.float()
  252. base = torch.matmul(base, repeat)
  253. base = base.reshape([-1])
  254. base = base.to(device)
  255. base_y0 = base + y0 * self.height
  256. base_y1 = base + y1 * self.height
  257. # top rectangle of the neighbourhood volume
  258. index_a0 = base_y0 - base + x0
  259. index_c0 = base_y0 - base + x1
  260. # bottom rectangle of the neighbourhood volume
  261. index_a1 = base_y1 - base + x0
  262. index_c1 = base_y1 - base + x1
  263. # get 8 grid values
  264. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
  265. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
  266. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
  267. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
  268. # find 8 grid locations
  269. y0 = torch.floor(y).int()
  270. y1 = y0 + 1
  271. x0 = torch.floor(x).int()
  272. x1 = x0 + 1
  273. # clip out coordinates exceeding feature map volume
  274. y0 = torch.clamp(y0, zero, max_y + 1)
  275. y1 = torch.clamp(y1, zero, max_y + 1)
  276. x0 = torch.clamp(x0, zero, max_x + 1)
  277. x1 = torch.clamp(x1, zero, max_x + 1)
  278. x0_float = x0.float()
  279. x1_float = x1.float()
  280. y0_float = y0.float()
  281. y1_float = y1.float()
  282. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
  283. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
  284. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
  285. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
  286. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
  287. value_c1 * vol_c1)
  288. if self.morph == 0:
  289. outputs = outputs.reshape([
  290. self.num_batch,
  291. self.num_points * self.width,
  292. 1 * self.height,
  293. self.num_channels,
  294. ])
  295. outputs = outputs.permute(0, 3, 1, 2)
  296. else:
  297. outputs = outputs.reshape([
  298. self.num_batch,
  299. 1 * self.width,
  300. self.num_points * self.height,
  301. self.num_channels,
  302. ])
  303. outputs = outputs.permute(0, 3, 1, 2)
  304. return outputs
  305. def deform_conv(self, input, offset, if_offset):
  306. """
  307. 1.输入为原始特征图 (input)、偏移 (offset) 和是否需要偏移 (if_offset)。
  308. 2.调用 _coordinate_map_3D 方法获取坐标映射。
  309. 3.调用 _bilinear_interpolate_3D 方法进行双线性插值,得到变形后的特征。
  310. 4.返回变形后的特征。
  311. """
  312. y, x = self._coordinate_map_3D(offset, if_offset)
  313. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
  314. return deformed_feature
  315. #---------------------------------YOLOv5 专用部分↓---------------------------------
  316. class DSConv_Bottleneck(nn.Module):
  317. # DSConv bottleneck
  318. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  319. super().__init__()
  320. c_ = int(c2 * e) # hidden channels
  321. self.cv1 = Conv(c1, c_, 1, 1)
  322. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  323. self.add = shortcut and c1 == c2
  324. self.snc = DSConv(c2, c2, 3, 1, 1, True)
  325. def forward(self, x):
  326. return x + self.snc(self.cv2(self.cv1(x))) if self.add else self.snc(self.cv2(self.cv1(x)))
  327. class DSConv_C3(nn.Module):
  328. # DSConv Bottleneck with 3 convolutions
  329. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  330. super().__init__()
  331. c_ = int(c2 * e) # hidden channels
  332. self.cv1 = Conv(c1, c_, 1, 1)
  333. self.cv2 = Conv(c1, c_, 1, 1)
  334. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  335. self.m = nn.Sequential(*(DSConv_Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  336. def forward(self, x):
  337. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  338. #---------------------------------YOLOv5 专用部分↑---------------------------------

第②步:修改yolo.py文件  

再来修改yolo.py,在parse_model函数中找到 elif m is nn.BatchNorm2d:语句,在其后面加上下面代码: 

  1. elif m in (DSConv, DSConv_C3):
  2. c1, c2 = ch[f], args[0]
  3. if c2 != nc:
  4. c2 = make_divisible(c2 * gw, 8)
  5. args = [c1, c2, *args[1:]]
  6. if m is DSConv_C3:
  7. args.insert(2, n) # number of repeats
  8. n = 1

 如下图所示:


第③步:创建自定义的yaml文件  

第1种,替换conv结构

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.5 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, C3, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 9
  23. ]
  24. # YOLOv5 v6.0 head
  25. head:
  26. [[-1, 1, Conv, [512, 1, 1]],
  27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  29. [-1, 3, C3, [512]], # 13
  30. [-1, 1, Conv, [256, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  33. [-1, 3, DSConv, [256, 3,1,1,True]], # 17 (P3/8-small)
  34. [-1, 1, Conv, [256, 3, 2]],
  35. [[-1, 14], 1, Concat, [1]], # cat head P4
  36. [-1, 3, DSConv, [512, 3,1,1,True]], # 20 (P4/16-medium)
  37. [-1, 1, Conv, [512, 3, 2]],
  38. [[-1, 10], 1, Concat, [1]], # cat head P5
  39. [-1, 3, DSConv, [1024, 3,1,1,True]], # 23 (P5/32-large)
  40. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  41. ]

这里要注意一个问题,替换时DSConv参数是需要做对应修改:

如下图栗子所示: 

如果直接改模块名会出现缺参报错:

TypeError: __init__() missing 2 required positional arguments: 'morph' and 'if_offset'


第2种,替换C3模块 

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.5 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, DSConv_C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, DSConv_C3, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, DSConv_C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, DSConv_C3, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 9
  23. ]
  24. # YOLOv5 v6.0 head
  25. head:
  26. [[-1, 1, Conv, [512, 1, 1]],
  27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  29. [-1, 3, C3, [512]], # 13
  30. [-1, 1, Conv, [256, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  33. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  34. [-1, 1, Conv, [256, 3, 2]],
  35. [[-1, 14], 1, Concat, [1]], # cat head P4
  36. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
  37. [-1, 1, Conv, [512, 3, 2]],
  38. [[-1, 10], 1, Concat, [1]], # cat head P5
  39. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  40. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  41. ]

替换C3模块直接改模块名字就行。 


第④步:验证是否加入成功

运行yolo.py

第1种 

第2种  

 这样就OK啦!