学习资源站

YOLOv11改进-特殊场景检测篇-最新低照度增强网络Retinexformer改进黑夜目标检测(全网独家首发)

一、本文介绍

本文给大家带来的改进机制是低照度 图像增强 网络 Retinexformer ,其是针对于黑夜 目标检测 的改进机制(非常适合大家用来发表论文),其主要思想是通过一种新颖的一阶段Retinex-based框架来增强低光图像。这个框架结合了照明信息的估计和损坏恢复,目的是提高低光图像的质量。核心在于照明引导的变换器,这种变换器使用照明信息来引导长期依赖性的建模,从而在不同照明条件下更好地处理图像。 欢迎大家订阅本专栏,本专栏每周更新3-5篇最新机制,更有包含我所有改进的文件和交流群提供给大家。

欢迎大家订阅我的专栏一起学习YOLO!

下图展示了Retinexformer相对于各种图像增强网络的对比效果 ,最新版本的Retinexformer在各种场景都表现的很优秀。



二、 Retinexformer的框架原理

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转


Retinexformer的主要思想是通过一种新颖的一阶段Retinex-based框架来增强低光图像。这个框架结合了照明信息的估计和损坏恢复,目的是提高低光图像的质量。核心在于照明引导的 变换器 ,这种变换器使用照明信息来引导长期依赖性的建模,从而在不同照明条件下更好地处理图像。通过这种方式,Retinexformer能够有效地增强低光图像,同时保持图像的自然外观和细节。

其主要主要创新点如下:

1. 一阶段Retinex-based框架(ORF): 提出了一个简单但原则性的框架,用于估计照明信息以照亮低光图像,然后恢复损坏以产生增强图像。

2. 照明引导的变换器(IGT): 设计了一个照明引导变换器,利用照明表示来指导不同照明条件下区域的非局部相互作用建模。

3. 创新的自注意力机制(IG-MSA): 开发了一种新的自注意力机制,利用照明信息作为关键线索,指导长期依赖性的建模。

这些创新使Retinexformer在多个基准测试上显著优于现有的最先进方法,并在低光物体检测方面显示出其实际应用价值。


上图展示了Retinexformer方法的详细流程:

1. 输入图像与照明先验: 流程以一个低光照输入图像开始,通过某种方法得到照明先验 L_p

2. 照明估计器: 它利用输入图像和照明先验来生成照明图 L ,该照明图用于指导后续图像的照亮过程。

3. 照亮图像和特征提取: 照明图 L 被用来照亮输入图像,生成照亮图像 I_{lu} ,同时会提取照亮特征 F_{lu}

4. 损坏恢复器—照明引导变换器: 包括多个照明引导的注意力块(IGAB),利用照亮特征来指导注意力机制,逐步恢复图像质量。

5. 照明引导的多头自注意力(IG-MSA): 这是IGAB的关键组成部分,通过照明信息引导自注意力计算,以捕获复杂的图像细节。

6. 最终图像输出: 通过层层处理,最终输出增强后的图像,这一图像在质量上有显著提升,色彩失真和噪声得到有效控制。


三、 Retinexformer的核心代码

代码的使用方式看章节四!

  1. import torch.nn as nn
  2. import torch
  3. import torch.nn.functional as F
  4. from einops import rearrange
  5. import math
  6. import warnings
  7. from torch.nn.init import _calculate_fan_in_and_fan_out
  8. __all__ = ['RetinexFormer']
  9. def _no_grad_trunc_normal_(tensor, mean, std, a, b):
  10. def norm_cdf(x):
  11. return (1. + math.erf(x / math.sqrt(2.))) / 2.
  12. if (mean < a - 2 * std) or (mean > b + 2 * std):
  13. warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  14. "The distribution of values may be incorrect.",
  15. stacklevel=2)
  16. with torch.no_grad():
  17. l = norm_cdf((a - mean) / std)
  18. u = norm_cdf((b - mean) / std)
  19. tensor.uniform_(2 * l - 1, 2 * u - 1)
  20. tensor.erfinv_()
  21. tensor.mul_(std * math.sqrt(2.))
  22. tensor.add_(mean)
  23. tensor.clamp_(min=a, max=b)
  24. return tensor
  25. def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  26. # type: (Tensor, float, float, float, float) -> Tensor
  27. return _no_grad_trunc_normal_(tensor, mean, std, a, b)
  28. def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
  29. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  30. if mode == 'fan_in':
  31. denom = fan_in
  32. elif mode == 'fan_out':
  33. denom = fan_out
  34. elif mode == 'fan_avg':
  35. denom = (fan_in + fan_out) / 2
  36. variance = scale / denom
  37. if distribution == "truncated_normal":
  38. trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
  39. elif distribution == "normal":
  40. tensor.normal_(std=math.sqrt(variance))
  41. elif distribution == "uniform":
  42. bound = math.sqrt(3 * variance)
  43. tensor.uniform_(-bound, bound)
  44. else:
  45. raise ValueError(f"invalid distribution {distribution}")
  46. def lecun_normal_(tensor):
  47. variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
  48. class PreNorm(nn.Module):
  49. def __init__(self, dim, fn):
  50. super().__init__()
  51. self.fn = fn
  52. self.norm = nn.LayerNorm(dim)
  53. def forward(self, x, *args, **kwargs):
  54. x = self.norm(x)
  55. return self.fn(x, *args, **kwargs)
  56. class GELU(nn.Module):
  57. def forward(self, x):
  58. return F.gelu(x)
  59. def conv(in_channels, out_channels, kernel_size, bias=False, padding=1, stride=1):
  60. return nn.Conv2d(
  61. in_channels, out_channels, kernel_size,
  62. padding=(kernel_size // 2), bias=bias, stride=stride)
  63. # input [bs,28,256,310] output [bs, 28, 256, 256]
  64. def shift_back(inputs, step=2):
  65. [bs, nC, row, col] = inputs.shape
  66. down_sample = 256 // row
  67. step = float(step) / float(down_sample * down_sample)
  68. out_col = row
  69. for i in range(nC):
  70. inputs[:, i, :, :out_col] = \
  71. inputs[:, i, :, int(step * i):int(step * i) + out_col]
  72. return inputs[:, :, :, :out_col]
  73. class Illumination_Estimator(nn.Module):
  74. def __init__(
  75. self, n_fea_middle, n_fea_in=4, n_fea_out=3): # __init__部分是内部属性,而forward的输入才是外部输入
  76. super(Illumination_Estimator, self).__init__()
  77. self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)
  78. self.depth_conv = nn.Conv2d(
  79. n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)
  80. self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)
  81. def forward(self, img):
  82. # img: b,c=3,h,w
  83. # mean_c: b,c=1,h,w
  84. # illu_fea: b,c,h,w
  85. # illu_map: b,c=3,h,w
  86. mean_c = img.mean(dim=1).unsqueeze(1)
  87. # stx()
  88. input = torch.cat([img, mean_c], dim=1)
  89. x_1 = self.conv1(input)
  90. illu_fea = self.depth_conv(x_1)
  91. illu_map = self.conv2(illu_fea)
  92. return illu_fea, illu_map
  93. class IG_MSA(nn.Module):
  94. def __init__(
  95. self,
  96. dim,
  97. dim_head=64,
  98. heads=8,
  99. ):
  100. super().__init__()
  101. self.num_heads = heads
  102. self.dim_head = dim_head
  103. self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
  104. self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
  105. self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
  106. self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
  107. self.proj = nn.Linear(dim_head * heads, dim, bias=True)
  108. self.pos_emb = nn.Sequential(
  109. nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
  110. GELU(),
  111. nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
  112. )
  113. self.dim = dim
  114. def forward(self, x_in, illu_fea_trans):
  115. """
  116. x_in: [b,h,w,c] # input_feature
  117. illu_fea: [b,h,w,c] # mask shift? 为什么是 b, h, w, c?
  118. return out: [b,h,w,c]
  119. """
  120. b, h, w, c = x_in.shape
  121. x = x_in.reshape(b, h * w, c)
  122. q_inp = self.to_q(x)
  123. k_inp = self.to_k(x)
  124. v_inp = self.to_v(x)
  125. illu_attn = illu_fea_trans # illu_fea: b,c,h,w -> b,h,w,c
  126. q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
  127. (q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))
  128. v = v * illu_attn
  129. # q: b,heads,hw,c
  130. q = q.transpose(-2, -1)
  131. k = k.transpose(-2, -1)
  132. v = v.transpose(-2, -1)
  133. q = F.normalize(q, dim=-1, p=2)
  134. k = F.normalize(k, dim=-1, p=2)
  135. attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
  136. attn = attn * self.rescale
  137. attn = attn.softmax(dim=-1)
  138. x = attn @ v # b,heads,d,hw
  139. x = x.permute(0, 3, 1, 2) # Transpose
  140. x = x.reshape(b, h * w, self.num_heads * self.dim_head)
  141. out_c = self.proj(x).view(b, h, w, c)
  142. out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
  143. 0, 3, 1, 2)).permute(0, 2, 3, 1)
  144. out = out_c + out_p
  145. return out
  146. class FeedForward(nn.Module):
  147. def __init__(self, dim, mult=4):
  148. super().__init__()
  149. self.net = nn.Sequential(
  150. nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
  151. GELU(),
  152. nn.Conv2d(dim * mult, dim * mult, 3, 1, 1,
  153. bias=False, groups=dim * mult),
  154. GELU(),
  155. nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
  156. )
  157. def forward(self, x):
  158. """
  159. x: [b,h,w,c]
  160. return out: [b,h,w,c]
  161. """
  162. out = self.net(x.permute(0, 3, 1, 2))
  163. return out.permute(0, 2, 3, 1)
  164. class IGAB(nn.Module):
  165. def __init__(
  166. self,
  167. dim,
  168. dim_head=64,
  169. heads=8,
  170. num_blocks=2,
  171. ):
  172. super().__init__()
  173. self.blocks = nn.ModuleList([])
  174. for _ in range(num_blocks):
  175. self.blocks.append(nn.ModuleList([
  176. IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
  177. PreNorm(dim, FeedForward(dim=dim))
  178. ]))
  179. def forward(self, x, illu_fea):
  180. """
  181. x: [b,c,h,w]
  182. illu_fea: [b,c,h,w]
  183. return out: [b,c,h,w]
  184. """
  185. x = x.permute(0, 2, 3, 1)
  186. for (attn, ff) in self.blocks:
  187. x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
  188. x = ff(x) + x
  189. out = x.permute(0, 3, 1, 2)
  190. return out
  191. class Denoiser(nn.Module):
  192. def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
  193. super(Denoiser, self).__init__()
  194. self.dim = dim
  195. self.level = level
  196. # Input projection
  197. self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
  198. # Encoder
  199. self.encoder_layers = nn.ModuleList([])
  200. dim_level = dim
  201. for i in range(level):
  202. self.encoder_layers.append(nn.ModuleList([
  203. IGAB(
  204. dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
  205. nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
  206. nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
  207. ]))
  208. dim_level *= 2
  209. # Bottleneck
  210. self.bottleneck = IGAB(
  211. dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])
  212. # Decoder
  213. self.decoder_layers = nn.ModuleList([])
  214. for i in range(level):
  215. self.decoder_layers.append(nn.ModuleList([
  216. nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
  217. kernel_size=2, padding=0, output_padding=0),
  218. nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
  219. IGAB(
  220. dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
  221. heads=(dim_level // 2) // dim),
  222. ]))
  223. dim_level //= 2
  224. # Output projection
  225. self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
  226. # activation function
  227. self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  228. self.apply(self._init_weights)
  229. def _init_weights(self, m):
  230. if isinstance(m, nn.Linear):
  231. trunc_normal_(m.weight, std=.02)
  232. if isinstance(m, nn.Linear) and m.bias is not None:
  233. nn.init.constant_(m.bias, 0)
  234. elif isinstance(m, nn.LayerNorm):
  235. nn.init.constant_(m.bias, 0)
  236. nn.init.constant_(m.weight, 1.0)
  237. def forward(self, x, illu_fea):
  238. """
  239. x: [b,c,h,w] x是feature, 不是image
  240. illu_fea: [b,c,h,w]
  241. return out: [b,c,h,w]
  242. """
  243. # Embedding
  244. fea = self.embedding(x)
  245. # Encoder
  246. fea_encoder = []
  247. illu_fea_list = []
  248. for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
  249. fea = IGAB(fea, illu_fea) # bchw
  250. illu_fea_list.append(illu_fea)
  251. fea_encoder.append(fea)
  252. fea = FeaDownSample(fea)
  253. illu_fea = IlluFeaDownsample(illu_fea)
  254. # Bottleneck
  255. fea = self.bottleneck(fea, illu_fea)
  256. # Decoder
  257. for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
  258. fea = FeaUpSample(fea)
  259. fea = Fution(
  260. torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
  261. illu_fea = illu_fea_list[self.level - 1 - i]
  262. fea = LeWinBlcok(fea, illu_fea)
  263. # Mapping
  264. out = self.mapping(fea) + x
  265. return out
  266. class RetinexFormer_Single_Stage(nn.Module):
  267. def __init__(self, in_channels=3, out_channels=3, n_feat=31, level=2, num_blocks=[1, 1, 1]):
  268. super(RetinexFormer_Single_Stage, self).__init__()
  269. self.estimator = Illumination_Estimator(n_feat)
  270. self.denoiser = Denoiser(in_dim=in_channels, out_dim=out_channels, dim=n_feat, level=level,
  271. num_blocks=num_blocks) #### 将 Denoiser 改为 img2img
  272. def forward(self, img):
  273. # img: b,c=3,h,w
  274. # illu_fea: b,c,h,w
  275. # illu_map: b,c=3,h,w
  276. illu_fea, illu_map = self.estimator(img)
  277. input_img = img * illu_map + img
  278. output_img = self.denoiser(input_img, illu_fea)
  279. return output_img
  280. class RetinexFormer(nn.Module):
  281. def __init__(self, in_channels=3, out_channels=3, n_feat=8, stage=1, num_blocks=[1,2,2]):
  282. super(RetinexFormer, self).__init__()
  283. self.stage = stage
  284. modules_body = [
  285. RetinexFormer_Single_Stage(in_channels=in_channels, out_channels=out_channels, n_feat=n_feat, level=2,
  286. num_blocks=num_blocks)
  287. for _ in range(stage)]
  288. self.body = nn.Sequential(*modules_body)
  289. def forward(self, x):
  290. """
  291. x: [b,c,h,w]
  292. return out:[b,c,h,w]
  293. """
  294. out = self.body(x)
  295. return out
  296. if __name__ == '__main__':
  297. # from fvcore.nn import FlopCountAnalysis
  298. model = RetinexFormer(stage=1,n_feat=40,num_blocks=[1,2,2]).cuda()
  299. inputs = torch.randn((1, 3, 256, 256)).cuda()
  300. out = model(inputs)
  301. print(out.size())

四、 Retinexformer 的添加方式

这个添加方式和之前的变了一下,以后的添加方法都按照这个来了,是为了和群内的文件适配。


4.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹( 用群内的文件的话已经有了无需新建) !然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四

按照我的添加在parse_model里添加即可。

到此就完事了注册的工作,该模型无需添加任何参数是一种无参的机制,所以导入进来即可。


关闭混合精度验证!

找到'ultralytics/engine/validator.py'文件找到 'class BaseValidator:' 然后在其'__call__'中 self.args.half = self.device.type != 'cpu' # force FP16 val during training的一行代码下面加上self.args.half = False

打印计算量的问题!

计算的GFLOPs计算 异常 不打印,所以需要额外修改一处, 我们找到如下文件'ultralytics/utils/torch_utils.py'文件内有如下的代码按照如下的图片进行修改,大家看好函数就行,其中红框的640可能和你的不一样, 然后用我给的代码替换掉整个代码即可。

  1. def get_flops(model, imgsz=640):
  2. """Return a YOLO model's FLOPs."""
  3. try:
  4. model = de_parallel(model)
  5. p = next(model.parameters())
  6. # stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
  7. stride = 640
  8. im = torch.empty((1, 3, stride, stride), device=p.device) # input image in BCHW format
  9. flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
  10. imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
  11. return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
  12. except Exception:
  13. return 0


五、 Retinexformer 的yaml文件和运行记录

5.1 Retinexformer 的yaml文件

训练信息:YOLO11-Retinexformer summary: 385 layers, 2,624,436 parameters, 2,624,420 gradients, 34.4 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, RetinexFormer, []] # 0-P1/2
  16. - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2
  17. - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4
  18. - [-1, 2, C3k2, [256, False, 0.25]]
  19. - [-1, 1, Conv, [256, 3, 2]] # 4-P3/8
  20. - [-1, 2, C3k2, [512, False, 0.25]]
  21. - [-1, 1, Conv, [512, 3, 2]] # 6-P4/16
  22. - [-1, 2, C3k2, [512, True]]
  23. - [-1, 1, Conv, [1024, 3, 2]] # 8-P5/32
  24. - [-1, 2, C3k2, [1024, True]]
  25. - [-1, 1, SPPF, [1024, 5]] # 10
  26. - [-1, 2, C2PSA, [1024]] # 11
  27. # YOLO11n head
  28. head:
  29. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  30. - [[-1, 7], 1, Concat, [1]] # cat backbone P4
  31. - [-1, 2, C3k2, [512, False]] # 14
  32. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  33. - [[-1, 5], 1, Concat, [1]] # cat backbone P3
  34. - [-1, 2, C3k2, [256, False]] # 17 (P3/8-small)
  35. - [-1, 1, Conv, [256, 3, 2]]
  36. - [[-1, 14], 1, Concat, [1]] # cat head P4
  37. - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)
  38. - [-1, 1, Conv, [512, 3, 2]]
  39. - [[-1, 11], 1, Concat, [1]] # cat head P5
  40. - [-1, 2, C3k2, [1024, True]] # 23 (P5/32-large)
  41. - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

5.2 Retinexformer 的训练过程截图


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~