学习资源站

YOLOv11改进-特殊场景检测篇-轻量级的低照度图像增强网络IAT改进yolov11暗光检测(全网独家首发改进)

一、本文介绍

本文给大家带来的改进机制是 轻量级 的变换器 模型 Illumination Adaptive Transformer (IAT) ,用于 图像增强 和曝光校正。其基本原理是通过 分解图像信号处理器(ISP)管道到局部和全局图像组件 ,从而 恢复在低光或过/欠曝光条件下的正常光照sRGB图像 。具体来说,IAT使用注意力查询来表示和调整ISP相关参数,例如颜色校正、伽马校正。模型具有约90k参数和约0.004s的处理速度,能够在低光增强和曝光校正的基准数据集上持续实现优于最新技术(State-of-The-Art, SOTA)的性能,我们将其用于YOLOv11上来改进我们模型的暗光检测能力,同时本文的内容不影响其它的模块改进。

欢迎大家订阅我的专栏一起学习YOLO!



二、基本原理

论文地址: 官方论文地址点击此处即可跳转

代码地址: 官方代码地址点击此处即可跳转


2.1 IAT原理

本文提出了一个轻量级的 变换器 模型: Illumination Adaptive Transformer (IAT) ,用于图像增强和曝光校正。其基本原理是通过 分解图像信号处理器(ISP)管道到局部和全局图像组件 ,从而 恢复在低光或过/欠曝光条件下的正常光照sRGB图像 。具体来说,IAT使用注意力查询来表示和调整ISP相关参数,例如颜色校正、伽马校正。模型具有约90k参数和约0.004s的处理速度,能够在低光增强和曝光校正的基准数据集上持续实现优于最新技术(State-of-The-Art, SOTA)的性能。

Illumination Adaptive Transformer (IAT)的 基本原理 如下:

1. 轻量级变换器架构: IAT设计为一个轻量级的模型,具有大约90,000个参数,专注于图像增强和曝光校正任务。这使得它在处理速度和资源消耗上非常高效,适用于实时或资源受限的应用场景。

2. 图像信号处理器(ISP)管道分解: IAT的核心原理是模拟并改进传统的ISP管道。通过分解ISP处理过程中的局部和全局图像成分,IAT能够针对特定的光照条件调整图像的视觉表现。

3. 适应性光照调整: IAT能够根据输入图像的光照条件动态调整处理策略,有效地处理低光、过曝光和欠曝光等情况,恢复正常光照下的sRGB图像。

下面为大家展示 Illumination Adaptive Transformer (IAT)的结构 分为两个主要部分:局部分支和全局分支。

1. 局部分支 (Local Branch): 处理图像的局部特征。这一分支通过多次使用参数增强模块(PEM)来提取局部特征,并通过卷积层来进一步处理这些特征。

2. 全局分支 (Global Branch): 处理图像的全局信息。它同样包含多个PEM和卷积层,不过处理的是全局图像内容。

3. 参数生成 (黑色线条): 黑色线条表示参数生成路径,即如何通过网络生成ISP管道中需要的参数,如颜色矩阵和伽马值。

4. 图像处理 (黄色线条): 黄色线条表示实际的图像处理路径。图像经过局部和全局分支的处理后,获得的特征会被用于调整图像的颜色和曝光。

5. 交叉注意力 (Cross Attention) :这一 组件 在全局分支中,负责整合局部和全局分支的信息,以更准确地调整颜色矩阵和伽马值。

6. 最终输出 :处理过的图像特征通过一个重塑操作和卷积层的处理,将局部和全局的调整应用到原始输入图像上,最终输出增强后的图像。


2.2 IAT的核心模块

下面这张图为大家直观地展示了Illumination Adaptive Transformer (IAT)中的两个核心模块: 像素级增强模块(Pixel-wise Enhancement Module, PEM) 全局预测模块(Global Prediction Module, GPM)

(a)像素级增强模块(PEM):
输入: 大小为 B × C × H × W 的特征图,其中 B 表示批次大小, C 表示通道数, H × W 表示特征图的高和宽。
流程:
1. 通过一系列的1x1卷积层,对特征图进行逐点的线性变换,以增强或调整特定像素点的特性。
2. 每个1x1卷积层之后,进行元素级的相乘(表示为黄色的圆圈和相乘符号)。
3. 操作结束后,特征图被重塑成原始的 B × C × H × W 形状。

(b)全局预测模块(GPM):
流程:
1. 特征图首先经过一个全连接层(FC),产生 V ,代表全局信息的值向量。
2. 另一个全连接层生成 K ,代表键向量。
3. KV 通过交叉注意力机制与查询 Q 相结合,查询 Q 通常来自于局部特征。
4. 结果通过重塑操作,形成颜色校正矩阵和伽马校正值。

两个模块共同工作, PEM负责增强局部特征细节 ,而 GPM则负责生成全局调整参数 ,两者合作为图像增强提供更精细的控制。通过这种方法,IAT能够在处理不同光照条件下的图像时提供细腻的调整,实现出色的图像增强效果。


三、核心代码

核心代码的使用方式看章节四!

  1. import math
  2. import torch
  3. import torch.nn as nn
  4. from timm.models.layers import trunc_normal_, DropPath, to_2tuple
  5. __all__ = ['IAT']
  6. class query_Attention(nn.Module):
  7. def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
  8. super().__init__()
  9. self.num_heads = num_heads
  10. head_dim = dim // num_heads
  11. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  12. self.scale = qk_scale or head_dim ** -0.5
  13. self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
  14. self.k = nn.Linear(dim, dim, bias=qkv_bias)
  15. self.v = nn.Linear(dim, dim, bias=qkv_bias)
  16. self.attn_drop = nn.Dropout(attn_drop)
  17. self.proj = nn.Linear(dim, dim)
  18. self.proj_drop = nn.Dropout(proj_drop)
  19. def forward(self, x):
  20. B, N, C = x.shape
  21. k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  22. v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  23. q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  24. attn = (q @ k.transpose(-2, -1)) * self.scale
  25. attn = attn.softmax(dim=-1)
  26. attn = self.attn_drop(attn)
  27. x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
  28. x = self.proj(x)
  29. x = self.proj_drop(x)
  30. return x
  31. class query_SABlock(nn.Module):
  32. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  33. drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  34. super().__init__()
  35. self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  36. self.norm1 = norm_layer(dim)
  37. self.attn = query_Attention(
  38. dim,
  39. num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  40. attn_drop=attn_drop, proj_drop=drop)
  41. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  42. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  43. self.norm2 = norm_layer(dim)
  44. mlp_hidden_dim = int(dim * mlp_ratio)
  45. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  46. def forward(self, x):
  47. x = x + self.pos_embed(x)
  48. x = x.flatten(2).transpose(1, 2)
  49. x = self.drop_path(self.attn(self.norm1(x)))
  50. x = x + self.drop_path(self.mlp(self.norm2(x)))
  51. return x
  52. class conv_embedding(nn.Module):
  53. def __init__(self, in_channels, out_channels):
  54. super(conv_embedding, self).__init__()
  55. self.proj = nn.Sequential(
  56. nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  57. nn.BatchNorm2d(out_channels // 2),
  58. nn.GELU(),
  59. # nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  60. # nn.BatchNorm2d(out_channels // 2),
  61. # nn.GELU(),
  62. nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  63. nn.BatchNorm2d(out_channels),
  64. )
  65. def forward(self, x):
  66. x = self.proj(x)
  67. return x
  68. class Global_pred(nn.Module):
  69. def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'):
  70. super(Global_pred, self).__init__()
  71. if type == 'exp':
  72. self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction
  73. else:
  74. self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True)
  75. self.color_base = nn.Parameter(torch.eye((3)), requires_grad=True) # basic color matrix
  76. # main blocks
  77. self.conv_large = conv_embedding(in_channels, out_channels)
  78. self.generator = query_SABlock(dim=out_channels, num_heads=num_heads)
  79. self.gamma_linear = nn.Linear(out_channels, 1)
  80. self.color_linear = nn.Linear(out_channels, 1)
  81. self.apply(self._init_weights)
  82. for name, p in self.named_parameters():
  83. if name == 'generator.attn.v.weight':
  84. nn.init.constant_(p, 0)
  85. def _init_weights(self, m):
  86. if isinstance(m, nn.Linear):
  87. trunc_normal_(m.weight, std=.02)
  88. if isinstance(m, nn.Linear) and m.bias is not None:
  89. nn.init.constant_(m.bias, 0)
  90. elif isinstance(m, nn.LayerNorm):
  91. nn.init.constant_(m.bias, 0)
  92. nn.init.constant_(m.weight, 1.0)
  93. def forward(self, x):
  94. #print(self.gamma_base)
  95. x = self.conv_large(x)
  96. x = self.generator(x)
  97. gamma, color = x[:, 0].unsqueeze(1), x[:, 1:]
  98. gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base
  99. #print(self.gamma_base, self.gamma_linear(gamma))
  100. color = self.color_linear(color).squeeze(-1).view(-1, 3, 3) + self.color_base
  101. return gamma, color
  102. # ResMLP's normalization
  103. class Aff(nn.Module):
  104. def __init__(self, dim):
  105. super().__init__()
  106. # learnable
  107. self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
  108. self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
  109. def forward(self, x):
  110. x = x * self.alpha + self.beta
  111. return x
  112. # Color Normalization
  113. class Aff_channel(nn.Module):
  114. def __init__(self, dim, channel_first = True):
  115. super().__init__()
  116. # learnable
  117. self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
  118. self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
  119. self.color = nn.Parameter(torch.eye(dim))
  120. self.channel_first = channel_first
  121. def forward(self, x):
  122. if self.channel_first:
  123. x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
  124. x2 = x1 * self.alpha + self.beta
  125. else:
  126. x1 = x * self.alpha + self.beta
  127. x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
  128. return x2
  129. class Mlp(nn.Module):
  130. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  131. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  132. super().__init__()
  133. out_features = out_features or in_features
  134. hidden_features = hidden_features or in_features
  135. self.fc1 = nn.Linear(in_features, hidden_features)
  136. self.act = act_layer()
  137. self.fc2 = nn.Linear(hidden_features, out_features)
  138. self.drop = nn.Dropout(drop)
  139. def forward(self, x):
  140. x = self.fc1(x)
  141. x = self.act(x)
  142. x = self.drop(x)
  143. x = self.fc2(x)
  144. x = self.drop(x)
  145. return x
  146. class CMlp(nn.Module):
  147. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  148. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  149. super().__init__()
  150. out_features = out_features or in_features
  151. hidden_features = hidden_features or in_features
  152. self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
  153. self.act = act_layer()
  154. self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
  155. self.drop = nn.Dropout(drop)
  156. def forward(self, x):
  157. x = self.fc1(x)
  158. x = self.act(x)
  159. x = self.drop(x)
  160. x = self.fc2(x)
  161. x = self.drop(x)
  162. return x
  163. class CBlock_ln(nn.Module):
  164. def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  165. drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
  166. super().__init__()
  167. self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  168. #self.norm1 = Aff_channel(dim)
  169. self.norm1 = norm_layer(dim)
  170. self.conv1 = nn.Conv2d(dim, dim, 1)
  171. self.conv2 = nn.Conv2d(dim, dim, 1)
  172. self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  173. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  174. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  175. #self.norm2 = Aff_channel(dim)
  176. self.norm2 = norm_layer(dim)
  177. mlp_hidden_dim = int(dim * mlp_ratio)
  178. self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
  179. self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
  180. self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  181. def forward(self, x):
  182. x = x + self.pos_embed(x)
  183. B, C, H, W = x.shape
  184. #print(x.shape)
  185. norm_x = x.flatten(2).transpose(1, 2)
  186. #print(norm_x.shape)
  187. norm_x = self.norm1(norm_x)
  188. norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
  189. x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
  190. norm_x = x.flatten(2).transpose(1, 2)
  191. norm_x = self.norm2(norm_x)
  192. norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
  193. x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
  194. return x
  195. def window_partition(x, window_size):
  196. """
  197. Args:
  198. x: (B, H, W, C)
  199. window_size (int): window size
  200. Returns:
  201. windows: (num_windows*B, window_size, window_size, C)
  202. """
  203. B, H, W, C = x.shape
  204. #print(x.shape)
  205. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  206. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  207. return windows
  208. def window_reverse(windows, window_size, H, W):
  209. """
  210. Args:
  211. windows: (num_windows*B, window_size, window_size, C)
  212. window_size (int): Window size
  213. H (int): Height of image
  214. W (int): Width of image
  215. Returns:
  216. x: (B, H, W, C)
  217. """
  218. B = int(windows.shape[0] / (H * W / window_size / window_size))
  219. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  220. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  221. return x
  222. class WindowAttention(nn.Module):
  223. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  224. It supports both of shifted and non-shifted window.
  225. Args:
  226. dim (int): Number of input channels.
  227. window_size (tuple[int]): The height and width of the window.
  228. num_heads (int): Number of attention heads.
  229. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  230. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  231. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  232. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  233. """
  234. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  235. super().__init__()
  236. self.dim = dim
  237. self.window_size = window_size # Wh, Ww
  238. self.num_heads = num_heads
  239. head_dim = dim // num_heads
  240. self.scale = qk_scale or head_dim ** -0.5
  241. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  242. self.attn_drop = nn.Dropout(attn_drop)
  243. self.proj = nn.Linear(dim, dim)
  244. self.proj_drop = nn.Dropout(proj_drop)
  245. self.softmax = nn.Softmax(dim=-1)
  246. def forward(self, x):
  247. B_, N, C = x.shape
  248. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  249. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  250. q = q * self.scale
  251. attn = (q @ k.transpose(-2, -1))
  252. attn = self.softmax(attn)
  253. attn = self.attn_drop(attn)
  254. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  255. x = self.proj(x)
  256. x = self.proj_drop(x)
  257. return x
  258. ## Layer_norm, Aff_norm, Aff_channel_norm
  259. class SwinTransformerBlock(nn.Module):
  260. r""" Swin Transformer Block.
  261. Args:
  262. dim (int): Number of input channels.
  263. input_resolution (tuple[int]): Input resulotion.
  264. num_heads (int): Number of attention heads.
  265. window_size (int): Window size.
  266. shift_size (int): Shift size for SW-MSA.
  267. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  268. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  269. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  270. drop (float, optional): Dropout rate. Default: 0.0
  271. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  272. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  273. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  274. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  275. """
  276. def __init__(self, dim, num_heads=2, window_size=8, shift_size=0,
  277. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  278. act_layer=nn.GELU, norm_layer=Aff_channel):
  279. super().__init__()
  280. self.dim = dim
  281. self.num_heads = num_heads
  282. self.window_size = window_size
  283. self.shift_size = shift_size
  284. self.mlp_ratio = mlp_ratio
  285. self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  286. #self.norm1 = norm_layer(dim)
  287. self.norm1 = norm_layer(dim)
  288. self.attn = WindowAttention(
  289. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  290. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  291. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  292. #self.norm2 = norm_layer(dim)
  293. self.norm2 = norm_layer(dim)
  294. mlp_hidden_dim = int(dim * mlp_ratio)
  295. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  296. def forward(self, x):
  297. x = x + self.pos_embed(x)
  298. B, C, H, W = x.shape
  299. x = x.flatten(2).transpose(1, 2)
  300. shortcut = x
  301. x = self.norm1(x)
  302. x = x.view(B, H, W, C)
  303. # cyclic shift
  304. if self.shift_size > 0:
  305. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  306. else:
  307. shifted_x = x
  308. # partition windows
  309. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  310. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  311. # W-MSA/SW-MSA
  312. attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
  313. # merge windows
  314. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  315. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  316. x = shifted_x
  317. x = x.view(B, H * W, C)
  318. # FFN
  319. x = shortcut + self.drop_path(x)
  320. x = x + self.drop_path(self.mlp(self.norm2(x)))
  321. x = x.transpose(1, 2).reshape(B, C, H, W)
  322. return x
  323. class Local_pred(nn.Module):
  324. def __init__(self, dim=16, number=4, type='ccc'):
  325. super(Local_pred, self).__init__()
  326. # initial convolution
  327. self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
  328. self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  329. # main blocks
  330. block = CBlock_ln(dim)
  331. block_t = SwinTransformerBlock(dim) # head number
  332. if type == 'ccc':
  333. # blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
  334. blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  335. blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  336. elif type == 'ttt':
  337. blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
  338. elif type == 'cct':
  339. blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
  340. # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
  341. self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
  342. self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
  343. def forward(self, img):
  344. img1 = self.relu(self.conv1(img))
  345. mul = self.mul_blocks(img1)
  346. add = self.add_blocks(img1)
  347. return mul, add
  348. # Short Cut Connection on Final Layer
  349. class Local_pred_S(nn.Module):
  350. def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
  351. super(Local_pred_S, self).__init__()
  352. # initial convolution
  353. self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
  354. self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  355. # main blocks
  356. block = CBlock_ln(dim)
  357. block_t = SwinTransformerBlock(dim) # head number
  358. if type == 'ccc':
  359. blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  360. blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  361. elif type == 'ttt':
  362. blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
  363. elif type == 'cct':
  364. blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
  365. # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
  366. self.mul_blocks = nn.Sequential(*blocks1)
  367. self.add_blocks = nn.Sequential(*blocks2)
  368. self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
  369. self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
  370. self.apply(self._init_weights)
  371. def _init_weights(self, m):
  372. if isinstance(m, nn.Linear):
  373. trunc_normal_(m.weight, std=.02)
  374. if isinstance(m, nn.Linear) and m.bias is not None:
  375. nn.init.constant_(m.bias, 0)
  376. elif isinstance(m, nn.LayerNorm):
  377. nn.init.constant_(m.bias, 0)
  378. nn.init.constant_(m.weight, 1.0)
  379. elif isinstance(m, nn.Conv2d):
  380. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  381. fan_out //= m.groups
  382. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  383. if m.bias is not None:
  384. m.bias.data.zero_()
  385. def forward(self, img):
  386. img1 = self.relu(self.conv1(img))
  387. # short cut connection
  388. mul = self.mul_blocks(img1) + img1
  389. add = self.add_blocks(img1) + img1
  390. mul = self.mul_end(mul)
  391. add = self.add_end(add)
  392. return mul, add
  393. class IAT(nn.Module):
  394. def __init__(self, in_dim=3, with_global=True, type='lol'):
  395. super(IAT, self).__init__()
  396. # self.local_net = Local_pred()
  397. self.local_net = Local_pred_S(in_dim=in_dim)
  398. self.with_global = with_global
  399. if self.with_global:
  400. self.global_net = Global_pred(in_channels=in_dim, type=type)
  401. def apply_color(self, image, ccm):
  402. shape = image.shape
  403. image = image.view(-1, 3)
  404. image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
  405. image = image.view(shape)
  406. return torch.clamp(image, 1e-8, 1.0)
  407. def forward(self, img_low):
  408. # print(self.with_global)
  409. mul, add = self.local_net(img_low)
  410. img_high = (img_low.mul(mul)).add(add)
  411. if not self.with_global:
  412. return img_high
  413. else:
  414. gamma, color = self.global_net(img_low)
  415. b = img_high.shape[0]
  416. img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C)
  417. img_high = torch.stack(
  418. [self.apply_color(img_high[i, :, :, :], color[i, :, :]) ** gamma[i, :] for i in range(b)], dim=0)
  419. img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W)
  420. return img_high
  421. if __name__ == "__main__":
  422. img = torch.Tensor(1, 3, 640, 640)
  423. net = IAT()
  424. imghigh = net(img)
  425. print(imghigh.size())
  426. print('total parameters:', sum(param.numel() for param in net.parameters()))
  427. _, _, high = net(img)


四、手把手教你添加IAT低照度图像增强网络

4.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹( 用群内的文件的话已经有了无需新建) !然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四

按照我的添加在parse_model里添加即可,红框内的添加即可,没有的都是其它文章里的改进机制。

到此就修改完成了,大家可以复制下面的yaml文件运行。


五、IAT的yaml文件和运行记录

5.1 IAT的yaml文件

此版本的训练信息:YOLO11-IAT summary: 435 layers, 2,681,969 parameters, 2,681,953 gradients, 24.4 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, IAT, []] # 0-P1/2
  16. - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2
  17. - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4
  18. - [-1, 2, C3k2, [256, False, 0.25]]
  19. - [-1, 1, Conv, [256, 3, 2]] # 4-P3/8
  20. - [-1, 2, C3k2, [512, False, 0.25]]
  21. - [-1, 1, Conv, [512, 3, 2]] # 6-P4/16
  22. - [-1, 2, C3k2, [512, True]]
  23. - [-1, 1, Conv, [1024, 3, 2]] # 8-P5/32
  24. - [-1, 2, C3k2, [1024, True]]
  25. - [-1, 1, SPPF, [1024, 5]] # 10
  26. - [-1, 2, C2PSA, [1024]] # 11
  27. # YOLO11n head
  28. head:
  29. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  30. - [[-1, 7], 1, Concat, [1]] # cat backbone P4
  31. - [-1, 2, C3k2, [512, False]] # 14
  32. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  33. - [[-1, 5], 1, Concat, [1]] # cat backbone P3
  34. - [-1, 2, C3k2, [256, False]] # 17 (P3/8-small)
  35. - [-1, 1, Conv, [256, 3, 2]]
  36. - [[-1, 14], 1, Concat, [1]] # cat head P4
  37. - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)
  38. - [-1, 1, Conv, [512, 3, 2]]
  39. - [[-1, 11], 1, Concat, [1]] # cat head P5
  40. - [-1, 2, C3k2, [1024, True]] # 23 (P5/32-large)
  41. - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.2 训练代码

大家可以创建一个py文件将我给的代码复制粘贴进去,配置好自己的文件路径即可运行。

  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. from ultralytics import YOLO
  4. if __name__ == '__main__':
  5. model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
  6. # model.load('yolov8n.pt') # loading pretrain weights
  7. model.train(data=r'替换数据集yaml文件地址',
  8. # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
  9. cache=False,
  10. imgsz=640,
  11. epochs=150,
  12. single_cls=False, # 是否是单类别检测
  13. batch=4,
  14. close_mosaic=10,
  15. workers=0,
  16. device='0',
  17. optimizer='SGD', # using SGD
  18. # resume='', # 如过想续训就设置last.pt的地址
  19. amp=False, # 如果出现训练损失为Nan可以关闭amp
  20. project='runs/train',
  21. name='exp',
  22. )


5.3 IAT的训练过程截图


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~