学习资源站

09-添加SOCA注意力机制_yolov5添加soca注意力机制

YOLOv5改进系列(8)——添加SOCA注意力机制


🚀一、SOCA介绍 

1.1 简介

近年来,深度卷积神经网络(CNN)在单图像超分辨率(SISR)中得到了广泛的研究,并取得了显著的性能。然而,大多数现有的基于CNN的SISR方法主要侧重于更广泛或更深入的架构设计,而忽略了中间层的特征相关性,因此阻碍了CNN的代表能力。为了解决这个问题,论文作者提出了一种二阶注意力网络(SAN),用于更强大的特征表达和特征相关性学习。具体而言,开发了一种新的可训练二阶注意(SOCA)模块通过使用二阶特征统计来自适应地重新调整信道方向特征,以获得更具鉴别性的表示。


1.2 SAN网络

从上图中可以看出SAN的主要由四部分组成:

  • 浅层特征提取(shallow feature extraction)即第一个卷积
  • 非局部增强残差组(NLRG) 提取深度特征(deep feature,DF)
  • 上采样模块(upscale module)
  • 重建模块(reconstruction part)即最后一个卷积

1.3.二阶通道注意力(SOCA) 

以前大多数基于CNN的SR模型都没有考虑功能的相互依赖性。为了利用这些信息,在CNN中引入了SENet,以重新缩放图像SR的信道特征。然而,SENet仅通过全局平均池利用特征的一阶统计,而忽略高于一阶的统计,从而阻碍了网络的辨别能力。另一方面,最近的研究表明特征的二阶统计分布更有利于获得有区分度的表达,如此才诞生了SOCA。

二阶注意力机制(SOCA)能够更好地学习特征之间的联系,此模块通过利用二阶特征的分布自适应的学习特征的内部依赖关系,SOCA的机制是网络能够专注于更有益的信息且能够提高判别学习的能力。此外,原文提出了一种非局部加强残差组结构能进一步结合非局部操作来提取长程的空间上下文信息。通过堆叠非局部残差组,本文的方法能够利用LR图像的信息且能够忽略低频信息。


🚀二、在backbone末端添加SOCA注意力机制方法

2.1 添加顺序 

(1)models/common.py    -->  加入新增的网络结构

(2)     models/yolo.py       -->  设定网络结构的传参细节,将SOCA类名加入其中。(当新的自定义模块中存在输入输出维度时,要使用qw调整输出维度)
(3) models/yolov5*.yaml  -->  新建一个文件夹,如yolov5s_SOCA.yaml,修改现有模型结构配置文件。(当引入新的层时,要修改后续的结构中的from参数)
(4)         train.py                -->  修改‘--cfg’默认参数,训练时指定模型结构配置文件 


2.2 具体添加步骤 

第①步:在common.py中添加SOCA模块

将下面的SOCA代码复制粘贴到common.py文件的末尾

  1. # SOCA moudle 单幅图像超分辨率
  2. from torch.autograd import Function
  3. class Covpool(Function):
  4. @staticmethod
  5. def forward(ctx, input):
  6. x = input
  7. batchSize = x.data.shape[0]
  8. dim = x.data.shape[1]
  9. h = x.data.shape[2]
  10. w = x.data.shape[3]
  11. M = h*w
  12. x = x.reshape(batchSize,dim,M)
  13. I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device)
  14. I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype)
  15. y = x.bmm(I_hat).bmm(x.transpose(1,2))
  16. ctx.save_for_backward(input,I_hat)
  17. return y
  18. @staticmethod
  19. def backward(ctx, grad_output):
  20. input,I_hat = ctx.saved_tensors
  21. x = input
  22. batchSize = x.data.shape[0]
  23. dim = x.data.shape[1]
  24. h = x.data.shape[2]
  25. w = x.data.shape[3]
  26. M = h*w
  27. x = x.reshape(batchSize,dim,M)
  28. grad_input = grad_output + grad_output.transpose(1,2)
  29. grad_input = grad_input.bmm(x).bmm(I_hat)
  30. grad_input = grad_input.reshape(batchSize,dim,h,w)
  31. return grad_input
  32. class Sqrtm(Function):
  33. @staticmethod
  34. def forward(ctx, input, iterN):
  35. x = input
  36. batchSize = x.data.shape[0]
  37. dim = x.data.shape[1]
  38. dtype = x.dtype
  39. I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
  40. normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1)
  41. A = x.div(normA.view(batchSize,1,1).expand_as(x))
  42. Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device)
  43. Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1)
  44. if iterN < 2:
  45. ZY = 0.5*(I3 - A)
  46. Y[:,0,:,:] = A.bmm(ZY)
  47. else:
  48. ZY = 0.5*(I3 - A)
  49. Y[:,0,:,:] = A.bmm(ZY)
  50. Z[:,0,:,:] = ZY
  51. for i in range(1, iterN-1):
  52. ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:]))
  53. Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY)
  54. Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:])
  55. ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))
  56. y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
  57. ctx.save_for_backward(input, A, ZY, normA, Y, Z)
  58. ctx.iterN = iterN
  59. return y
  60. @staticmethod
  61. def backward(ctx, grad_output, der_sacleTrace=None):
  62. input, A, ZY, normA, Y, Z = ctx.saved_tensors
  63. iterN = ctx.iterN
  64. x = input
  65. batchSize = x.data.shape[0]
  66. dim = x.data.shape[1]
  67. dtype = x.dtype
  68. der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
  69. der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA))
  70. I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
  71. if iterN < 2:
  72. der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))
  73. else:
  74. dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) -
  75. Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom))
  76. dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:])
  77. for i in range(iterN-3, -1, -1):
  78. YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:])
  79. ZY = Z[:,i,:,:].bmm(Y[:,i,:,:])
  80. dldY_ = 0.5*(dldY.bmm(YZ) -
  81. Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) -
  82. ZY.bmm(dldY))
  83. dldZ_ = 0.5*(YZ.bmm(dldZ) -
  84. Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) -
  85. dldZ.bmm(ZY))
  86. dldY = dldY_
  87. dldZ = dldZ_
  88. der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))
  89. grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x))
  90. grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)
  91. for i in range(batchSize):
  92. grad_input[i,:,:] += (der_postComAux[i] \
  93. - grad_aux[i] / (normA[i] * normA[i])) \
  94. *torch.ones(dim,device = x.device).diag()
  95. return grad_input, None
  96. def CovpoolLayer(var):
  97. return Covpool.apply(var)
  98. def SqrtmLayer(var, iterN):
  99. return Sqrtm.apply(var, iterN)
  100. class SOCA(nn.Module):
  101. # second-order Channel attention
  102. def __init__(self, channel, reduction=8):
  103. super(SOCA, self).__init__()
  104. self.max_pool = nn.MaxPool2d(kernel_size=2)
  105. self.conv_du = nn.Sequential(
  106. nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
  107. nn.ReLU(inplace=True),
  108. nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
  109. nn.Sigmoid()
  110. )
  111. def forward(self, x):
  112. batch_size, C, h, w = x.shape # x: NxCxHxW
  113. N = int(h * w)
  114. min_h = min(h, w)
  115. h1 = 1000
  116. w1 = 1000
  117. if h < h1 and w < w1:
  118. x_sub = x
  119. elif h < h1 and w > w1:
  120. W = (w - w1) // 2
  121. x_sub = x[:, :, :, W:(W + w1)]
  122. elif w < w1 and h > h1:
  123. H = (h - h1) // 2
  124. x_sub = x[:, :, H:H + h1, :]
  125. else:
  126. H = (h - h1) // 2
  127. W = (w - w1) // 2
  128. x_sub = x[:, :, H:(H + h1), W:(W + w1)]
  129. cov_mat = CovpoolLayer(x_sub) # Global Covariance pooling layer
  130. cov_mat_sqrt = SqrtmLayer(cov_mat,5) # Matrix square root layer( including pre-norm,Newton-Schulz iter. and post-com. with 5 iteration)
  131. cov_mat_sum = torch.mean(cov_mat_sqrt,1)
  132. cov_mat_sum = cov_mat_sum.view(batch_size,C,1,1)
  133. y_cov = self.conv_du(cov_mat_sum)
  134. return y_cov*x

如下图所示:


第②步:在yolo.py文件里的parse_model函数加入类名

首先找到yolo.py里面parse_model函数的这一行

 然后把刚才加入的类SOCA添加到这个注册表里面:

 或者可以在下面的位置这样加,原理和上面是一样的:

  1. elif m is SOCA:
  2. c1, c2 = ch[f], args[0]
  3. if c2 != no:
  4. c2 = make_divisible(c2 * gw, 8)
  5. args = [c1, *args[1:]]

解释一下这段代码:

这段是一个判断语句,如果模块 m 在SOCA中,那么就将模块m对应的输入通道数输出通道数的值分别赋值给 c1c2,然后对 c2进行与之前相同的处理,接下来,将 c1c2 以及 args[1:] 作为元素,组成新的列表,作为更新后的 args


第③步:创建自定义的yaml文件 

首先在models文件夹下复制yolov5s.yaml 文件,粘贴并重命名为 yolov5s_SOCA.yaml 

接着修改  yolov5s_SOCA.yaml ,将SOCA模块加到我们想添加的位置。

这里我先介绍第一种,第一种是将SOCA模块放在backbone部分的最末端这样可以使注意力机制看到整个backbone部分的特征图,将具有全局视野,类似于一个小transformer结构。

[-1,1,SOCA,[1024]]添加到 SPPF 的下一层。即下图中所示位置: 

 同样的下面的head也得修改:

这里我们要把后面两个Concat的from系数分别由[ − 1 , 14 ] , [ − 1 , 10 ]改为[ − 1 , 15 ][ − 1 , 11 ]

然后将Detect原始的from系数[ 17 , 20 , 23 ]要改为[ 18 , 21 , 24 ] 。

  yolov5s_SOCA.yaml 完整代码:

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 20 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.50 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone+SE
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, C3, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 10
  23. [-1, 1, SOCA,[1024]],
  24. ]
  25. # YOLOv5 v6.1 head
  26. head:
  27. [[-1, 1, Conv, [512, 1, 1]],
  28. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  29. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  30. [-1, 3, C3, [512, False]], # 14
  31. [-1, 1, Conv, [256, 1, 1]],
  32. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  33. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  34. [-1, 3, C3, [256, False]], # 18 (P3/8-small)
  35. [-1, 1, Conv, [256, 3, 2]],
  36. [[-1, 15], 1, Concat, [1]], # cat head P4
  37. [-1, 3, C3, [512, False]], # 21 (P4/16-medium)
  38. [-1, 1, Conv, [512, 3, 2]],
  39. [[-1, 11], 1, Concat, [1]], # cat head P5
  40. [-1, 3, C3, [1024, False]], # 24 (P5/32-large)
  41. [[18, 21, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  42. ]

第④步:验证是否加入成功

yolo.py 文件里面配置改为我们刚才自定义的yolov5s_SOCA.yaml

 然后我们运行yolo.py 

 这样就添加成功啦~


第⑤步:修改train.py中 ‘--cfg’默认参数

我们先找到 train.py 文件的parse_opt函数,然后将第二行‘--cfg’的 default改为yolov5s_SOCA.yaml,然后就可以开始训练啦~


🚀三、在C3后添加SOCA注意力机制方法

第二种是将SOCA放在backbone部分每个C3模块的后面,这样可以使注意力机制看到局部的特征,每层进行一次注意力,可以分担学习压力。

步骤和方法1相同,只是yaml文件不同。

所以接下来只放修改yaml文件的部分~


第③步:创建自定义的yaml文件 

SOCA模块放在每个C3模块的后面,要注意通道的变化。

如下图所示:

 同样的,下面的head部分也要做相应的修改:

 第二种方法的 yolov5s_SOCA.yaml 完整代码:

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.50 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 3, SOCA, [128]],
  17. [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
  18. [-1, 6, C3, [256]],
  19. [-1, 3, SOCA, [256]],
  20. [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
  21. [-1, 9, C3, [512]],
  22. [-1, 3, SOCA, [512]],
  23. [-1, 1, Conv, [1024, 3, 2]], # 10-P5/32
  24. [-1, 3, C3, [1024]],
  25. [-1, 3, SOCA, [1024]],
  26. [-1, 1, SPPF, [1024, 5]], # 13
  27. ]
  28. # YOLOv5 v6.0 head
  29. head:
  30. [[-1, 1, Conv, [512, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 9], 1, Concat, [1]], # cat backbone P4
  33. [-1, 3, C3, [512, False]], # 17
  34. [-1, 1, Conv, [256, 1, 1]],
  35. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  36. [[-1, 6], 1, Concat, [1]], # cat backbone P3
  37. [-1, 3, C3, [256, False]], # 21 (P3/8-small)
  38. [-1, 1, Conv, [256, 3, 2]],
  39. [[-1, 18], 1, Concat, [1]], # cat head P4
  40. [-1, 3, C3, [512, False]], # 24 (P4/16-medium)
  41. [-1, 1, Conv, [512, 3, 2]],
  42. [[-1, 14], 1, Concat, [1]], # cat head P5
  43. [-1, 3, C3, [1024, False]], # 27 (P5/32-large)
  44. [[21, 24, 27], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  45. ]

第④步:验证是否加入成功

同样的方法,我们来运行一下yolo.py 

 OK~收工!