学习资源站

26-添加LSKNet注意力机制(大选择性卷积核的领域首次探索)

YOLOv5改进系列(25)——添加LSKNet注意力机制(大选择性卷积核的领域首次探索)

🚀 一、LSKNet介绍

1.1 LSKNet简介

LSKNet是指Large Selective Kernel Network的缩写,翻译为大型选择性核网络。这是南开大学在ICCV2023会议上新提出的目标检测旋转算法,基本原理就是通过一系列Depth-wise 卷积核和空间选择机制来动态调整目标的感受野,从而允许模型适应不同背景的目标检测

这篇论文着重于遥感目标的特征提取,原始遥感目标检测的不足:

  1. 精确识别遥感影像中的目标通常需要充分的背景信息支持。有限的背景区域可能会对模型的识别效果产生影响,例如在背景信息稀缺的情况下,模型可能会错误地将十字路口误认为是道路。
  2. 不同类型的目标对背景信息的需求是多样的。以足球场为例,其可通过明显的球场边界线进行区分,因此所需的背景信息较少。然而,十字路口与道路相似,容易受到树木和其他遮挡物的影响,因此为了进行准确的识别,需要提供足够广泛的背景信息范围。

经验证,LSKNet虽然结构简单,但能够获得优异的检测性能,在HRSC2016、DOTA-v1.0、FAIR1M-v1.0三个典型数据集上都取得了SOTA。 


1.2 LSKNet网络结构

本文改进方法: 

  1. 通过使用卷积核大小为5*5的普通卷积卷积核大小为7*7的膨胀卷积,相结合来达到卷积核大小为23*23的大核卷积的效果,从而使网络在进行特征提取的时候获得超大的感受野,进而,获得丰富的上下文信息,也就是背景信息。这一点是在目前较为流行的主干特征提取网络(resnet系列,swin-transform)中是没有的。 
  2. 通过使用普通二维卷积Sigmoid函数设计了一种模块选择网络,使得网络可以针对不同的检测目标动态的选取不同感受野大小的卷积从而获得最好的检测效果。例如,针对自身特征就比较明显的对象(足球场)就选择感受野较小的卷积来提取特征。针对自身特征不太明显的对象(船只,交叉路口)则选取感觉野较大的卷积来提取特征,利用目标周围的背景信息来提高这类目标识别的准确率。

具体方法:

(1)首先,LSKNet通过使用普通卷积膨胀卷积,分别生成两个不同的特征图。

(2)接着,通过应用卷积核大小为1*1的卷积,将两者的通道数调整为相同的大小。再将它们叠加在一起,形成对应的特征图。

(3)然后,对该特征图进行平均池化和最大池化操作,将两者堆叠在一起。通过卷积和Sigmoid操作,获得针对不同卷积核大小的选择权重

(4)最后,将得到的权重与之前提到的特征图相乘和相加,再与最初的输入X相乘,从而得到最终的输出Y


🚀二、LSKNet源码讲解 

LSKNet的代码如下: 

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.modules.utils import _pair as to_2tuple
  4. from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
  5. trunc_normal_init)
  6. from ..builder import ROTATED_BACKBONES
  7. from mmcv.runner import BaseModule
  8. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  9. import math
  10. from functools import partial
  11. import warnings
  12. from mmcv.cnn import build_norm_layer
  13. # 1. Mlp 模块
  14. class Mlp(nn.Module):
  15. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  16. super().__init__()
  17. out_features = out_features or in_features
  18. hidden_features = hidden_features or in_features
  19. self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
  20. self.dwconv = DWConv(hidden_features)
  21. self.act = act_layer()
  22. self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
  23. self.drop = nn.Dropout(drop)
  24. def forward(self, x):
  25. x = self.fc1(x)
  26. x = self.dwconv(x)
  27. x = self.act(x)
  28. x = self.drop(x)
  29. x = self.fc2(x)
  30. x = self.drop(x)
  31. return x
  32. # 2. LSKblock 模块
  33. class LSKblock(nn.Module):
  34. def __init__(self, dim):
  35. super().__init__()
  36. self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  37. self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
  38. self.conv1 = nn.Conv2d(dim, dim//2, 1)
  39. self.conv2 = nn.Conv2d(dim, dim//2, 1)
  40. self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
  41. self.conv = nn.Conv2d(dim//2, dim, 1)
  42. def forward(self, x):
  43. attn1 = self.conv0(x)
  44. attn2 = self.conv_spatial(attn1)
  45. attn1 = self.conv1(attn1)
  46. attn2 = self.conv2(attn2)
  47. attn = torch.cat([attn1, attn2], dim=1)
  48. avg_attn = torch.mean(attn, dim=1, keepdim=True)
  49. max_attn, _ = torch.max(attn, dim=1, keepdim=True)
  50. agg = torch.cat([avg_attn, max_attn], dim=1)
  51. sig = self.conv_squeeze(agg).sigmoid()
  52. attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1)
  53. attn = self.conv(attn)
  54. return x * attn
  55. # 3. Attention 模块
  56. class Attention(nn.Module):
  57. def __init__(self, d_model):
  58. super().__init__()
  59. self.proj_1 = nn.Conv2d(d_model, d_model, 1)
  60. self.activation = nn.GELU()
  61. self.spatial_gating_unit = LSKblock(d_model)
  62. self.proj_2 = nn.Conv2d(d_model, d_model, 1)
  63. def forward(self, x):
  64. shorcut = x.clone()
  65. x = self.proj_1(x)
  66. x = self.activation(x)
  67. x = self.spatial_gating_unit(x)
  68. x = self.proj_2(x)
  69. x = x + shorcut
  70. return x
  71. # 4. Block 模块
  72. class Block(nn.Module):
  73. def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU, norm_cfg=None):
  74. super().__init__()
  75. if norm_cfg:
  76. self.norm1 = build_norm_layer(norm_cfg, dim)[1]
  77. self.norm2 = build_norm_layer(norm_cfg, dim)[1]
  78. else:
  79. self.norm1 = nn.BatchNorm2d(dim)
  80. self.norm2 = nn.BatchNorm2d(dim)
  81. self.attn = Attention(dim)
  82. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  83. mlp_hidden_dim = int(dim * mlp_ratio)
  84. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  85. layer_scale_init_value = 1e-2
  86. self.layer_scale_1 = nn.Parameter(
  87. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  88. self.layer_scale_2 = nn.Parameter(
  89. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  90. def forward(self, x):
  91. x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
  92. x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
  93. return x
  94. # 5. OverlapPatchEmbed 模块
  95. class OverlapPatchEmbed(nn.Module):
  96. """ Image to Patch Embedding
  97. """
  98. def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=None):
  99. super().__init__()
  100. patch_size = to_2tuple(patch_size)
  101. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
  102. padding=(patch_size[0] // 2, patch_size[1] // 2))
  103. if norm_cfg:
  104. self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
  105. else:
  106. self.norm = nn.BatchNorm2d(embed_dim)
  107. def forward(self, x):
  108. x = self.proj(x)
  109. _, _, H, W = x.shape
  110. x = self.norm(x)
  111. return x, H, W
  112. # 6. LSKNet 模块
  113. @ROTATED_BACKBONES.register_module()
  114. class LSKNet(BaseModule):
  115. def __init__(self, img_size=224, in_chans=3, embed_dims=[64, 128, 256, 512],
  116. mlp_ratios=[8, 8, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
  117. depths=[3, 4, 6, 3], num_stages=4,
  118. pretrained=None,
  119. init_cfg=None,
  120. norm_cfg=None):
  121. super().__init__(init_cfg=init_cfg)
  122. assert not (init_cfg and pretrained), \
  123. 'init_cfg and pretrained cannot be set at the same time'
  124. if isinstance(pretrained, str):
  125. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  126. 'please use "init_cfg" instead')
  127. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  128. elif pretrained is not None:
  129. raise TypeError('pretrained must be a str or None')
  130. self.depths = depths
  131. self.num_stages = num_stages
  132. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  133. cur = 0
  134. for i in range(num_stages):
  135. patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
  136. patch_size=7 if i == 0 else 3,
  137. stride=4 if i == 0 else 2,
  138. in_chans=in_chans if i == 0 else embed_dims[i - 1],
  139. embed_dim=embed_dims[i], norm_cfg=norm_cfg)
  140. block = nn.ModuleList([Block(
  141. dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j],norm_cfg=norm_cfg)
  142. for j in range(depths[i])])
  143. norm = norm_layer(embed_dims[i])
  144. cur += depths[i]
  145. setattr(self, f"patch_embed{i + 1}", patch_embed)
  146. setattr(self, f"block{i + 1}", block)
  147. setattr(self, f"norm{i + 1}", norm)
  148. def init_weights(self):
  149. print('init cfg', self.init_cfg)
  150. if self.init_cfg is None:
  151. for m in self.modules():
  152. if isinstance(m, nn.Linear):
  153. trunc_normal_init(m, std=.02, bias=0.)
  154. elif isinstance(m, nn.LayerNorm):
  155. constant_init(m, val=1.0, bias=0.)
  156. elif isinstance(m, nn.Conv2d):
  157. fan_out = m.kernel_size[0] * m.kernel_size[
  158. 1] * m.out_channels
  159. fan_out //= m.groups
  160. normal_init(
  161. m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
  162. else:
  163. super(LSKNet, self).init_weights()
  164. def freeze_patch_emb(self):
  165. self.patch_embed1.requires_grad = False
  166. @torch.jit.ignore
  167. def no_weight_decay(self):
  168. return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
  169. def get_classifier(self):
  170. return self.head
  171. def reset_classifier(self, num_classes, global_pool=''):
  172. self.num_classes = num_classes
  173. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  174. def forward_features(self, x):
  175. B = x.shape[0]
  176. outs = []
  177. for i in range(self.num_stages):
  178. patch_embed = getattr(self, f"patch_embed{i + 1}")
  179. block = getattr(self, f"block{i + 1}")
  180. norm = getattr(self, f"norm{i + 1}")
  181. x, H, W = patch_embed(x)
  182. for blk in block:
  183. x = blk(x)
  184. x = x.flatten(2).transpose(1, 2)
  185. x = norm(x)
  186. x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  187. outs.append(x)
  188. return outs
  189. def forward(self, x):
  190. x = self.forward_features(x)
  191. # x = self.head(x)
  192. return x
  193. # 7. DWConv 模块
  194. class DWConv(nn.Module):
  195. def __init__(self, dim=768):
  196. super(DWConv, self).__init__()
  197. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  198. def forward(self, x):
  199. x = self.dwconv(x)
  200. return x
  201. # 8. _conv_filter 函数
  202. def _conv_filter(state_dict, patch_size=16):
  203. """ convert patch embedding weight from manual patchify + linear proj to conv"""
  204. out_dict = {}
  205. for k, v in state_dict.items():
  206. if 'patch_embed.proj.weight' in k:
  207. v = v.reshape((v.shape[0], 3, patch_size, patch_size))
  208. out_dict[k] = v
  209. return out_dict

一共有七个模块和一个函数:

  1. Mlp 模块

  2. LSKblock 模块

  3. Attention 模块

  4. Block 模块

  5. OverlapPatchEmbed 模块

  6. LSKNet 模块

  7. DWConv 模块

  8. _conv_filter 函数

这里我们重点讲一下 LSKblock 模块:

 主要包括了:

  1. self.conv0: 一个5x5的卷积层,使用depthwise卷积(groups=dim),用于从输入特征图中提取信息。

  2. self.conv_spatial: 一个7x7的卷积层,用于处理输入特征图的空间信息,具有较大的padding和dilation,dilation=3 表示膨胀卷积。

  3. self.conv1和self.conv2 一个1x1的卷积层,分别用于处理conv0和conv_spatial的输出,将通道数减半。

  4. self.conv_squeeze 一个7x7的卷积层,用于对avg_attn和max_attn进行通道方向的融合,通过sigmoid激活函数得到注意力权重。

  5. self.conv: 一个1x1的卷积层,用于将融合后的注意力权重与输入特征图进行融合。


🚀三、具体添加方法

 3.1 添加顺序 

(1)models/common.py    -->  加入新增的网络结构

(2)     models/yolo.py       -->  设定网络结构的传参细节,将LSKblock类名加入其中。(当新的自定义模块中存在输入输出维度时,要使用qw调整输出维度)
(3) models/yolov5*.yaml  -->  新建一个文件夹,如yolov5s_LSK.yaml,修改现有模型结构配置文件。(当引入新的层时,要修改后续的结构中的from参数)
(4)         train.py                -->  修改‘--cfg’默认参数,训练时指定模型结构配置文件 


3.2 具体添加步骤  

第①步:在common.py中添加LSK模块 

将下面的LSK代码复制粘贴到common.py文件的末尾 

  1. # LSKNet
  2. class LSKblock(nn.Module):
  3. def __init__(self, dim):
  4. super().__init__()
  5. # 一个5x5的卷积层,groups=dim 表示使用分组卷积,其中 dim 是输入通道数。
  6. self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  7. # 一个7x7的卷积层,使用了步幅1、padding为9、groups=dim 表示分组卷积、dilation=3 表示膨胀卷积。
  8. self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
  9. # 一个1x1的卷积层,将 dim 通道减半。
  10. self.conv1 = nn.Conv2d(dim, dim // 2, 1)
  11. # 一个1x1的卷积层,将 dim 通道减半。
  12. self.conv2 = nn.Conv2d(dim, dim // 2, 1)
  13. # 一个7x7的卷积层,padding为3,用于对两个路径的信息进行压缩。
  14. self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
  15. # 一个1x1的卷积层,将 dim//2 通道的信息还原为 dim。
  16. self.conv = nn.Conv2d(dim // 2, dim, 1)
  17. def forward(self, x):
  18. # 使用 self.conv0 和 self.conv_spatial 对输入特征图进行不同的卷积操作,得到 attn1 和 attn2。
  19. attn1 = self.conv0(x)
  20. attn2 = self.conv_spatial(attn1)
  21. # 使用 self.conv1 和 self.conv2 对 attn1 和 attn2 进行进一步的卷积操作。
  22. attn1 = self.conv1(attn1)
  23. attn2 = self.conv2(attn2)
  24. # 将 attn1 和 attn2 沿着通道维度拼接在一起,得到 attn。
  25. attn = torch.cat([attn1, attn2], dim=1)
  26. # 计算 attn 在通道维度上的平均值和最大值,得到 avg_attn 和 max_attn。
  27. avg_attn = torch.mean(attn, dim=1, keepdim=True)
  28. max_attn, _ = torch.max(attn, dim=1, keepdim=True)
  29. # 将 avg_attn 和 max_attn 沿着通道维度拼接在一起,得到 agg。
  30. agg = torch.cat([avg_attn, max_attn], dim=1)
  31. # 对 agg 进行压缩,通过 self.conv_squeeze 进行sigmoid激活,得到注意力权重 sig。
  32. sig = self.conv_squeeze(agg).sigmoid()
  33. # 使用 sig 对 attn1 和 attn2 进行加权融合,得到最终的注意力结果 attn。
  34. attn = attn1 * sig[:, 0, :, :].unsqueeze(1) + attn2 * sig[:, 1, :, :].unsqueeze(1)
  35. # 将 attn 经过 self.conv 进行进一步处理,得到最终的输出,返回 x * attn。
  36. attn = self.conv(attn)
  37. return x * attn

第②步:修改yolo.py文件

再来修改yolo.py,在parse_model函数中找到 elif m is Concat: 语句,在其后面加上下面代码:

  1. elif m is LSKblock:
  2. c1 = ch[f]
  3. args = [c1, *args[0:]]

如下图所示: 


第③步:创建自定义的yaml文件  

yaml文件配置完整代码如下:

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 1 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.1 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, C3, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3, [1024]],
  22. [-1, 1, LSKblock, []],
  23. [-1, 1, SPPF, [1024, 5]], # 10
  24. ]
  25. # YOLOv5 v6.1 head
  26. head:
  27. [[-1, 1, Conv, [512, 1, 1]],
  28. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  29. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  30. [-1, 3, C3, [512, False]], # 14
  31. [-1, 1, Conv, [256, 1, 1]],
  32. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  33. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  34. [-1, 3, C3, [256, False]], # 18 (P3/8-small)
  35. [-1, 1, Conv, [256, 3, 2]],
  36. [[-1, 15], 1, Concat, [1]], # cat head P4
  37. [-1, 3, C3, [512, False]], # 21 (P4/16-medium)
  38. [-1, 1, Conv, [512, 3, 2]],
  39. [[-1, 11], 1, Concat, [1]], # cat head P5
  40. [-1, 3, C3, [1024, False]], # 24 (P5/32-large)
  41. [[18, 21, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  42. ]

第④步:验证是否加入成功

运行yolo.py

这样就OK啦!