学习资源站

YOLOv11改进-特殊场景检测篇-利用图像去雾网络UnfogNet改进yolov11图像雾天检测能力(全网独家首发改进)

一、本文介绍

本文给大家带来的改进机制是利用 UnfogNet超轻量化图像去雾网络 ,我将该网络结合YOLOv11针对图像进行去雾检测(也适用于一些模糊场景),我将该网络结构和YOLOv11的网络进行结合同时该网络的结构的参数量非常的小,我们将其添加到 模型 里增加的计算量和参数量基本可以忽略不计这是非常难得的,因为其也算是一种图像增强算法, 同时本文的内容不影响其它的模块改进可以作为工作量凑近大家的论文里,非常的适用,图像去雾检测为群友最近提出的需要的改进

欢迎大家订阅我的专栏一起学习YOLO!

37b7a5829b134f58843683922e978d45.png



二、原理介绍

87c8c6fc14e74e77a80422130004af69.png

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转

本文的 图像去雾 算法 的原理很简单,下面为其代码,大家可以看到其主要有Conv组成,并没有涉及到过多的复杂结构这里不再赘述了。

  1. import torch
  2. import torch.nn as nn
  3. import math
  4. class unfog_net(nn.Module):
  5. def __init__(self):
  6. super(unfog_net, self).__init__()
  7. self.relu = nn.ReLU(inplace=True)
  8. self.e_conv1 = nn.Conv2d(3, 3, 1, 1, 0, bias=True)
  9. self.e_conv2 = nn.Conv2d(3, 3, 3, 1, 1, bias=True)
  10. self.e_conv3 = nn.Conv2d(6, 3, 5, 1, 2, bias=True)
  11. self.e_conv4 = nn.Conv2d(6, 3, 7, 1, 3, bias=True)
  12. self.e_conv5 = nn.Conv2d(12, 3, 3, 1, 1, bias=True)
  13. def forward(self, x):
  14. x1 = self.relu(self.e_conv1(x))
  15. x2 = self.relu(self.e_conv2(x1))
  16. concat1 = torch.cat((x1, x2), 1)
  17. x3 = self.relu(self.e_conv3(concat1))
  18. concat2 = torch.cat((x2, x3), 1)
  19. x4 = self.relu(self.e_conv4(concat2))
  20. concat3 = torch.cat((x1, x2, x3, x4), 1)
  21. x5 = self.relu(self.e_conv5(concat3))
  22. clean_image = self.relu((x5 * x) - x5 + 1)
  23. return clean_image
  24. if __name__ == "__main__":
  25. # Generating Sample image
  26. image_size = (1, 3, 640, 640)
  27. image = torch.rand(*image_size)
  28. out = unfog_net()
  29. out = out(image)
  30. print(out.size())

这段代码定义了一个 神经网络模型 unfog_net,用于图像去雾。这个模型是基于PyTorch框架的nn.Module构建的。下面是对代码的逐行分析:

1. 初始化(__init__方法):

  • 通过调用super,这个类继承了nn.Module的所有属性和方法。
  • 初始化了ReLU激活函数。
  • 定义了五个卷积层e_conv1到e_conv5。每个卷积层的参数配置了输入通道数、输出通道数、卷积核大小、步长(stride)、填充(padding),以及是否使用偏置(bias)。

2. 前向传播(forward方法):

  • 输入图像x首先通过第一个卷积层e_conv1和ReLU激活函数,产生特征图x1。
  • x1然后通过第二个卷积层e_conv2和ReLU激活函数,产生特征图x2。
  • 特征图x1和x2被拼接(concatenate)成concat1,然后通过第三个卷积层e_conv3和ReLU激活函数,产生特征图x3。
  • 类似地,x2和x3被拼接成concat2,通过第四个卷积层e_conv4和ReLU激活函数,产生特征图x4。
  • 最后,x1、x2、x3和x4被一起拼接成concat3,通过第五个卷积层e_conv5和ReLU激活函数,产生特征图x5。
  • 最终输出的清晰图像是通过一个操作计算得出的:(x5 * x) - x5 + 1,这里x5与输入图像x相乘,然后从结果中减去x5,最后加上1。这个操作意在利用卷积层提取的特征来恢复清晰图像。

这个网络利用了多尺度的特征提取策略,通过不同大小的卷积核来捕捉图像的不同细节,并通过特征拼接来增强模型的表达能力。最终的输出操作是一个简化的形式,目的是结合输入图像和网络提取的特征来去除雾霾。


三、核心代码

核心代码的使用方式看章节四!

  1. import torch
  2. import torch.nn as nn
  3. __all__ = ['unfog_net']
  4. class unfog_net(nn.Module):
  5. def __init__(self, args):
  6. super(unfog_net, self).__init__()
  7. self.relu = nn.ReLU(inplace=True)
  8. self.e_conv1 = nn.Conv2d(3, 3, 1, 1, 0, bias=True)
  9. self.e_conv2 = nn.Conv2d(3, 3, 3, 1, 1, bias=True)
  10. self.e_conv3 = nn.Conv2d(6, 3, 5, 1, 2, bias=True)
  11. self.e_conv4 = nn.Conv2d(6, 3, 7, 1, 3, bias=True)
  12. self.e_conv5 = nn.Conv2d(12, 3, 3, 1, 1, bias=True)
  13. def forward(self, x):
  14. x1 = self.relu(self.e_conv1(x))
  15. x2 = self.relu(self.e_conv2(x1))
  16. concat1 = torch.cat((x1, x2), 1)
  17. x3 = self.relu(self.e_conv3(concat1))
  18. concat2 = torch.cat((x2, x3), 1)
  19. x4 = self.relu(self.e_conv4(concat2))
  20. concat3 = torch.cat((x1, x2, x3, x4), 1)
  21. x5 = self.relu(self.e_conv5(concat3))
  22. clean_image = self.relu((x5 * x) - x5 + 1)
  23. return clean_image
  24. if __name__ == "__main__":
  25. # Generating Sample image
  26. image_size = (1, 3, 640, 640)
  27. image = torch.rand(*image_size)
  28. out = unfog_net()
  29. out = out(image)
  30. print(out.size())

四、添加方式

本文的网络结构为无参数网络结构所以我们的注册方式很简单。

4.1 修改一

第一还是建立文件,我们找到如下 ultralytics /nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹( 用群内的文件的话已经有了无需新建) !然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。

c7e83550cf094b1fa4757a7860dab47c.png


4.2 修改二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。

c2e78a0c394349c9b8fa552f9890497b.png


4.3 修改三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!

67b28bda87e44d3285f0241acd165256.png


打印计算量的问题!

计算的GFLOPs计算异常不打印,所以需要额外修改一处, 我们找到如下文件'ultralytics/utils/torch_utils.py'文件内有如下的代码按照如下的图片进行修改,有一个get_flops的 函数 我们直接用我给的代码全部替换!

  1. def get_flops(model, imgsz=640):
  2. """Return a YOLO model's FLOPs."""
  3. if not thop:
  4. return 0.0 # if not installed return 0.0 GFLOPs
  5. try:
  6. model = de_parallel(model)
  7. p = next(model.parameters())
  8. if not isinstance(imgsz, list):
  9. imgsz = [imgsz, imgsz] # expand if int/float
  10. try:
  11. # Use stride size for input tensor
  12. stride = 640
  13. im = torch.empty((1, 3, stride, stride), device=p.device) # input image in BCHW format
  14. flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
  15. return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
  16. except Exception:
  17. # Use actual image size for input tensor (i.e. required for RTDETR models)
  18. im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
  19. return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
  20. except Exception:
  21. return 0.0

到此就修改完成了,大家可以复制下面的yaml文件运行,无需修改parse_model方法。


五、 UnfogNet 的yaml文件和运行记录

5.1 UnfogNet 的yaml文件

此版本训练信息:YOLO11-UnfogNet summary: 326 layers, 2,592,576 parameters, 2,592,560 gradients, 7.9 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, unfog_net, []] # 0-P1/2
  16. - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2
  17. - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4
  18. - [-1, 2, C3k2, [256, False, 0.25]]
  19. - [-1, 1, Conv, [256, 3, 2]] # 4-P3/8
  20. - [-1, 2, C3k2, [512, False, 0.25]]
  21. - [-1, 1, Conv, [512, 3, 2]] # 6-P4/16
  22. - [-1, 2, C3k2, [512, True]]
  23. - [-1, 1, Conv, [1024, 3, 2]] # 8-P5/32
  24. - [-1, 2, C3k2, [1024, True]]
  25. - [-1, 1, SPPF, [1024, 5]] # 10
  26. - [-1, 2, C2PSA, [1024]] # 11
  27. # YOLO11n head
  28. head:
  29. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  30. - [[-1, 7], 1, Concat, [1]] # cat backbone P4
  31. - [-1, 2, C3k2, [512, False]] # 14
  32. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  33. - [[-1, 5], 1, Concat, [1]] # cat backbone P3
  34. - [-1, 2, C3k2, [256, False]] # 17 (P3/8-small)
  35. - [-1, 1, Conv, [256, 3, 2]]
  36. - [[-1, 14], 1, Concat, [1]] # cat head P4
  37. - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)
  38. - [-1, 1, Conv, [512, 3, 2]]
  39. - [[-1, 11], 1, Concat, [1]] # cat head P5
  40. - [-1, 2, C3k2, [1024, True]] # 23 (P5/32-large)
  41. - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.2 训练代码

大家可以创建一个py文件将我给的代码复制粘贴进去,配置好自己的文件路径即可运行。

  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. from ultralytics import YOLO
  4. if __name__ == '__main__':
  5. model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
  6. # model.load('yolov8n.pt') # loading pretrain weights
  7. model.train(data=r'替换数据集yaml文件地址',
  8. # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
  9. cache=False,
  10. imgsz=640,
  11. epochs=150,
  12. single_cls=False, # 是否是单类别检测
  13. batch=4,
  14. close_mosaic=10,
  15. workers=0,
  16. device='0',
  17. optimizer='SGD', # using SGD
  18. # resume='', # 如过想续训就设置last.pt的地址
  19. amp=False, # 如果出现训练损失为Nan可以关闭amp
  20. project='runs/train',
  21. name='exp',
  22. )


5.3 训练过程截图


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~