学习资源站

21-添加BiFormer注意力机制(CVPR2023_小目标涨点神器)_yolov5 biformer

YOLOv5改进系列(20)——添加BiFormer注意力机制(CVPR2023|小目标涨点神器)

🚀一、BiFormer介绍    


 1.1 简介

BiFomer模块 提出了 一种动态的稀疏注意力机制 。首先在宏观上 过滤掉大部分不相关的 key 与 value ,只保留下一小部分相关的 key 和 value。然后 在这一小部分相关 key 与 value 中使用 token -to-token 注意力

1.2 模型结构

具体结构:

  1. 第一阶段使用重叠的patch embedding(i=1时)嵌入(降低输入空间分辨率)
  2. 第二到第四阶段使用patch merging(i=2,3,4时)模块增加通道数量)
  3. 然后使用Ni个相连的BiFormer块来对输入特征进行transformer操作

🚀二、具体添加方法

第①步:在common.py中添加BiFormer模块

将以下代码复制粘贴到common.py文件的末尾

  1. # ------------------代码转自@迪菲赫尔曼------------------------------
  2. from torch import Tensor
  3. from typing import Tuple
  4. import torch.nn.functional as F
  5. from einops import rearrange
  6. class TopkRouting(nn.Module):
  7. """
  8. differentiable topk routing with scaling
  9. Args:
  10. qk_dim: int, feature dimension of query and key
  11. topk: int, the 'topk'
  12. qk_scale: int or None, temperature (multiply) of softmax activation
  13. with_param: bool, wether inorporate learnable params in routing unit
  14. diff_routing: bool, wether make routing differentiable
  15. soft_routing: bool, wether make output value multiplied by routing weights
  16. """
  17. def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
  18. super().__init__()
  19. self.topk = topk
  20. self.qk_dim = qk_dim
  21. self.scale = qk_scale or qk_dim ** -0.5
  22. self.diff_routing = diff_routing
  23. self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
  24. self.routing_act = nn.Softmax(dim=-1)
  25. def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor]:
  26. if not self.diff_routing:
  27. query, key = query.detach(), key.detach()
  28. query_hat, key_hat = self.emb(query), self.emb(key)
  29. attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1)
  30. topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)
  31. r_weight = self.routing_act(topk_attn_logit)
  32. return r_weight, topk_index
  33. class KVGather(nn.Module):
  34. def __init__(self, mul_weight='none'):
  35. super().__init__()
  36. assert mul_weight in ['none', 'soft', 'hard']
  37. self.mul_weight = mul_weight
  38. def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor):
  39. n, p2, w2, c_kv = kv.size()
  40. topk = r_idx.size(-1)
  41. topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
  42. dim=2,
  43. index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv)
  44. )
  45. if self.mul_weight == 'soft':
  46. topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv
  47. elif self.mul_weight == 'hard':
  48. raise NotImplementedError('differentiable hard routing TBA')
  49. return topk_kv
  50. class QKVLinear(nn.Module):
  51. def __init__(self, dim, qk_dim, bias=True):
  52. super().__init__()
  53. self.dim = dim
  54. self.qk_dim = qk_dim
  55. self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
  56. def forward(self, x):
  57. q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1)
  58. return q, kv
  59. class BiLevelRoutingAttention(nn.Module):
  60. """
  61. n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
  62. kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
  63. topk: topk for window filtering
  64. param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
  65. param_routing: extra linear for routing
  66. diff_routing: wether to set routing differentiable
  67. soft_routing: wether to multiply soft routing weights
  68. """
  69. def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None,
  70. kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
  71. topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False,
  72. side_dwconv=3,
  73. auto_pad=True):
  74. super().__init__()
  75. self.dim = dim
  76. self.n_win = n_win # Wh, Ww
  77. self.num_heads = num_heads
  78. self.qk_dim = qk_dim or dim
  79. assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, 'qk_dim and dim must be divisible by num_heads!'
  80. self.scale = qk_scale or self.qk_dim ** -0.5
  81. self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2,
  82. groups=dim) if side_dwconv > 0 else \
  83. lambda x: torch.zeros_like(x)
  84. self.topk = topk
  85. self.param_routing = param_routing
  86. self.diff_routing = diff_routing
  87. self.soft_routing = soft_routing
  88. # router
  89. assert not (self.param_routing and not self.diff_routing)
  90. self.router = TopkRouting(qk_dim=self.qk_dim,
  91. qk_scale=self.scale,
  92. topk=self.topk,
  93. diff_routing=self.diff_routing,
  94. param_routing=self.param_routing)
  95. if self.soft_routing: # soft routing, always diffrentiable (if no detach)
  96. mul_weight = 'soft'
  97. elif self.diff_routing: # hard differentiable routing
  98. mul_weight = 'hard'
  99. else: # hard non-differentiable routing
  100. mul_weight = 'none'
  101. self.kv_gather = KVGather(mul_weight=mul_weight)
  102. # qkv mapping (shared by both global routing and local attention)
  103. self.param_attention = param_attention
  104. if self.param_attention == 'qkvo':
  105. self.qkv = QKVLinear(self.dim, self.qk_dim)
  106. self.wo = nn.Linear(dim, dim)
  107. elif self.param_attention == 'qkv':
  108. self.qkv = QKVLinear(self.dim, self.qk_dim)
  109. self.wo = nn.Identity()
  110. else:
  111. raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
  112. self.kv_downsample_mode = kv_downsample_mode
  113. self.kv_per_win = kv_per_win
  114. self.kv_downsample_ratio = kv_downsample_ratio
  115. self.kv_downsample_kenel = kv_downsample_kernel
  116. if self.kv_downsample_mode == 'ada_avgpool':
  117. assert self.kv_per_win is not None
  118. self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
  119. elif self.kv_downsample_mode == 'ada_maxpool':
  120. assert self.kv_per_win is not None
  121. self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
  122. elif self.kv_downsample_mode == 'maxpool':
  123. assert self.kv_downsample_ratio is not None
  124. self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
  125. elif self.kv_downsample_mode == 'avgpool':
  126. assert self.kv_downsample_ratio is not None
  127. self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
  128. elif self.kv_downsample_mode == 'identity': # no kv downsampling
  129. self.kv_down = nn.Identity()
  130. elif self.kv_downsample_mode == 'fracpool':
  131. raise NotImplementedError('fracpool policy is not implemented yet!')
  132. elif kv_downsample_mode == 'conv':
  133. raise NotImplementedError('conv policy is not implemented yet!')
  134. else:
  135. raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')
  136. self.attn_act = nn.Softmax(dim=-1)
  137. self.auto_pad = auto_pad
  138. def forward(self, x, ret_attn_mask=False):
  139. """
  140. x: NHWC tensor
  141. Return:
  142. NHWC tensor
  143. """
  144. x = rearrange(x, "n c h w -> n h w c")
  145. if self.auto_pad:
  146. N, H_in, W_in, C = x.size()
  147. pad_l = pad_t = 0
  148. pad_r = (self.n_win - W_in % self.n_win) % self.n_win
  149. pad_b = (self.n_win - H_in % self.n_win) % self.n_win
  150. x = F.pad(x, (0, 0, # dim=-1
  151. pad_l, pad_r, # dim=-2
  152. pad_t, pad_b)) # dim=-3
  153. _, H, W, _ = x.size() # padded size
  154. else:
  155. N, H, W, C = x.size()
  156. assert H % self.n_win == 0 and W % self.n_win == 0 #
  157. x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
  158. q, kv = self.qkv(x)
  159. q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
  160. kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
  161. kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
  162. q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean(
  163. [2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)
  164. lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win,
  165. i=self.n_win).contiguous())
  166. lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
  167. r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors
  168. kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) # (n, p^2, topk, h_kv*w_kv, c_qk+c_v)
  169. k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
  170. k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)',
  171. m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
  172. v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c',
  173. m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
  174. q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c',
  175. m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)
  176. attn_weight = (
  177. q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
  178. attn_weight = self.attn_act(attn_weight)
  179. out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
  180. out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
  181. h=H // self.n_win, w=W // self.n_win)
  182. out = out + lepe
  183. out = self.wo(out)
  184. if self.auto_pad and (pad_r > 0 or pad_b > 0):
  185. out = out[:, :H_in, :W_in, :].contiguous()
  186. if ret_attn_mask:
  187. return out, r_weight, r_idx, attn_weight
  188. else:
  189. return rearrange(out, "n h w c -> n c h w")

之前没有导入einops包的看这篇→

如下图所示:


第②步:修改yolo.py文件

再来修改yolo.py,在parse_model函数中找到 elif m is Concat: 语句,在其后面加上下面代码:

  1. elif m in [BiLevelRoutingAttention]:
  2. c2 = ch[f]
  3. args = [c2, *args[0:]]

如下图所示:


第③步:创建自定义的yaml文件  

yaml文件配置完整代码如下:

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # Parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 0.33 # model depth multiple
  5. width_multiple: 0.50 # layer channel multiple
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, C3, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 9
  23. ]
  24. # YOLOv5 v6.0 head
  25. head:
  26. [[-1, 1, Conv, [512, 1, 1]],
  27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  29. [-1, 3, C3, [512, False]], # 13
  30. [-1, 1, Conv, [256, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  33. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  34. [-1, 1, BiLevelRoutingAttention, []],
  35. [-1, 1, Conv, [256, 3, 2]],
  36. [[-1, 14], 1, Concat, [1]], # cat head P4
  37. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
  38. [-1, 1, Conv, [512, 3, 2]],
  39. [[-1, 10], 1, Concat, [1]], # cat head P5
  40. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  41. [[17, 21, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  42. ]

第④步:验证是否加入成功

运行yolo.py

这样就OK啦!


PS:

(1)这个注意力机制在我同学的数据集上高效涨点,我加入后试了不同的位置,在小目标层相比原始yolov5s涨了0.2,其余位置都掉点了。可以证明的确是对小目标检测更加友好。

引用原文一句话:

“由于 BiFormer 以查询自适应的方式关注一小部分相关标记,而不会分散其他不相关标记的注意力,因此它具有良好的性能和高计算效率。”

(2)通过不止一个人的实验发现,加入这个模块后没有增加太多参数,但电脑就莫名其妙跑不动了。在云端跑还是OK的。