一、本文改进
本文给大家带来的最新改进机制是一种我进行优化的专用于分割的检测头,在分割的过程中,最困难的无非就是边缘的检测, 动态蛇形卷积 (Dynamic Snake Convolution) 通过自适应地聚焦于细长和迂回的局部结构,准确地捕捉管状结构的特征。这种卷积方法的核心思想是, 通过动态形状的卷积核来增强感知能力,针对管状结构的特征提取进行优化, 所以将这个卷积针对于YOLOv11的分割头进行融合是非常合适的,当然本文的检测头也支持用于目标检测,但是我将其设计出来是主要为了分割的读者使用的。
欢迎大家订阅我的专栏一起学习YOLO!
二、DySnakeConv的框架原理
论文代码地址:
动态蛇形卷积官方代码下载地址
论文地址:
【免费】动态蛇形卷积(DynamicSnakeConvolution)资源-CSDN文库
背景-> 动态蛇形卷积(Dynamic Snake Convolution)来源于临床医学,清晰勾画血管是计算流体力学研究的关键前提,并能协助放射科医师进行诊断和定位病变。在遥感应用中,完整的道路分割为路径规划提供了坚实的基础。无论是哪个领域,这些结构都具有细长和曲折的共同特征,使得它们很难在图像中捕捉到,因为它们在图像中的比例很小。因此, 迫切需要提升对细长管状结构的感知能力 ,所以在这一背景下作者提出了动态蛇形卷积(Dynamic Snake Convolution)。
原理-> 上图展示了一个 三维心脏血管数据集 和一个 二维远程道路数据集 。这两个数据集旨在提取管状结构,但由于 脆弱的局部结构和复杂的整体形态 ,这个任务面临着挑战。标准的 卷积核 旨在提取局部特征。基于此,设计了可变形卷积核以丰富它们的应用,并适应不同目标的几何变形。然而,由于前面提到的挑战,有效地聚焦于细小的管状结构是困难的。
由于以下困难,这仍然是一个具有挑战性的任务:
-
细小而脆弱的局部结构: 如上面的图所示,细小的结构仅占整体图像的一小部分,并且由于像素组成有限,这些结构容易受到复杂背景的干扰,使模型难以精确地区分目标的细微变化。因此,模型可能难以区分这些结构,导致分割结果出现断裂。
-
复杂而多变的整体形态: 上面的图片展示了细小管状结构的复杂和多变形态,即使在同一图像中也如此。不同区域中的目标呈现出形态上的变化,包括分支数量、分叉位置和路径长度等。当数据呈现出前所未见的形态结构时,模型可能会过度拟合已经见过的特征,导致在新的形态结构下泛化能力较弱。
为了应对上述障碍,提出了如下解决方案, 其中包括管状感知卷积核、多视角特征融合策略和拓扑连续性约束损失函数 。具体如下:
1. 针对细小且脆弱的局部结构所占比例小且难以聚焦的挑战 ,提出了动态蛇形卷积,通过自适应地聚焦于管状结构的细长曲线局部特征,增强对几何结构的感知。与可变形卷积不同,DSConv考虑到管状结构的蛇形形态,并通过约束补充自由学习过程,有针对性地增强对管状结构的感知。
2. 针对复杂和多变的整体形态的挑战 ,提出了一种多视角特征融合策略。在该方法中,基于DSConv生成多个形态学卷积核 模板 ,从不同角度观察目标的结构特征,并通过总结典型的重要特征实现高效的特征融合。
3. 针对管状结构分割容易出现断裂的问题 ,提出了基于持久同调(Persistent Homology,PH)的拓扑连续性约束 损失函数 (TCLoss)。PH是一种从出现到消失的拓扑特征响应过程,能够从嘈杂的高维数据中获取足够的拓扑信息。相关的贝蒂数是描述拓扑空间连通性的一种方式。与其他方法不同, TCLoss将PH与点集相似性相结合 ,引导网络关注具有异常像素/体素分布的断裂区域,从拓扑角度实现连续性约束。
总结:为了克服挑战,提出了DSCNet框架,包括管状感知卷积核、多视角特征融合策略和拓扑连续性约束损失函数。DSConv增强了对细长曲线特征的感知,多视角特征融合策略提高了对复杂整体形态的处理能力,而TCLoss基于持久同调实现了从拓扑角度的连续性约束。
三、DySnakerConv的核心代码
下面的检测头可以用于分割和目标检测,但是其修改教程有差别。目标检测的检测头我就不讲了,大家看我之前的检测头如何添加的就可以,都是一样的只是换一个名字,本文主要针对于分割的读者。
- import copy
- import torch
- import torch.nn as nn
- import math
- from ultralytics.utils.checks import check_version
- __all__ = ['DSDConvSegment', 'DSDConvHead']
- TORCH_1_10 = check_version(torch.__version__, '1.10.0')
- def make_anchors(feats, strides, grid_cell_offset=0.5):
- """Generate anchors from features."""
- anchor_points, stride_tensor = [], []
- assert feats is not None
- dtype, device = feats[0].dtype, feats[0].device
- for i, stride in enumerate(strides):
- _, _, h, w = feats[i].shape
- sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
- sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
- sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
- anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
- stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
- return torch.cat(anchor_points), torch.cat(stride_tensor)
- def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
- """Transform distance(ltrb) to box(xywh or xyxy)."""
- lt, rb = distance.chunk(2, dim)
- x1y1 = anchor_points - lt
- x2y2 = anchor_points + rb
- if xywh:
- c_xy = (x1y1 + x2y2) / 2
- wh = x2y2 - x1y1
- return torch.cat((c_xy, wh), dim) # xywh bbox
- return torch.cat((x1y1, x2y2), dim) # xyxy bbox
- class DFL(nn.Module):
- """
- Integral module of Distribution Focal Loss (DFL).
- Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
- """
- def __init__(self, c1=16):
- """Initialize a convolutional layer with a given number of input channels."""
- super().__init__()
- self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
- x = torch.arange(c1, dtype=torch.float)
- self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
- self.c1 = c1
- def forward(self, x):
- """Applies a transformer layer on input tensor 'x' and returns a tensor."""
- b, c, a = x.shape # batch, channels, anchors
- return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
- # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
- class Proto(nn.Module):
- """YOLOv8 mask Proto module for segmentation models."""
- def __init__(self, c1, c_=256, c2=32):
- """
- Initializes the YOLOv8 mask Proto module with specified number of protos and masks.
- Input arguments are ch_in, number of protos, number of masks.
- """
- super().__init__()
- self.cv1 = Conv(c1, c_, k=3)
- self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
- self.cv2 = Conv(c_, c_, k=3)
- self.cv3 = Conv(c_, c2)
- def forward(self, x):
- """Performs a forward pass through layers using an upsampled input image."""
- return self.cv3(self.cv2(self.upsample(self.cv1(x))))
- def autopad(k, p=None, d=1): # kernel, padding, dilation
- """Pad to 'same' shape outputs."""
- if d > 1:
- k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
- if p is None:
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
- return p
- class Conv(nn.Module):
- """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
- default_act = nn.SiLU() # default activation
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
- """Initialize Conv layer with given arguments including activation."""
- super().__init__()
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
- self.bn = nn.BatchNorm2d(c2)
- self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
- def forward(self, x):
- """Apply convolution, batch normalization and activation to input tensor."""
- return self.act(self.bn(self.conv(x)))
- def forward_fuse(self, x):
- """Perform transposed convolution of 2D data."""
- return self.act(self.conv(x))
- class DSConv(nn.Module):
- def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
- """
- The Dynamic Snake Convolution
- :param in_ch: input channel
- :param out_ch: output channel
- :param kernel_size: the size of kernel
- :param extend_scope: the range to expand (default 1 for this method)
- :param morph: the morphology of the convolution kernel is mainly divided into two types
- along the x-axis (0) and the y-axis (1) (see the paper for details)
- :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
- """
- super(DSConv, self).__init__()
- # use the <offset_conv> to learn the deformable offset
- self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
- self.bn = nn.BatchNorm2d(2 * kernel_size)
- self.kernel_size = kernel_size
- # two types of the DSConv (along x-axis and y-axis)
- self.dsc_conv_x = nn.Conv2d(
- in_ch,
- out_ch,
- kernel_size=(kernel_size, 1),
- stride=(kernel_size, 1),
- padding=0,
- )
- self.dsc_conv_y = nn.Conv2d(
- in_ch,
- out_ch,
- kernel_size=(1, kernel_size),
- stride=(1, kernel_size),
- padding=0,
- )
- self.gn = nn.GroupNorm(out_ch // 4, out_ch)
- self.act = Conv.default_act
- self.extend_scope = extend_scope
- self.morph = morph
- self.if_offset = if_offset
- def forward(self, f):
- offset = self.offset_conv(f)
- offset = self.bn(offset)
- # We need a range of deformation between -1 and 1 to mimic the snake's swing
- offset = torch.tanh(offset)
- input_shape = f.shape
- dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
- deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
- if self.morph == 0:
- x = self.dsc_conv_x(deformed_feature.type(f.dtype))
- x = self.gn(x)
- x = self.act(x)
- return x
- else:
- x = self.dsc_conv_y(deformed_feature.type(f.dtype))
- x = self.gn(x)
- x = self.act(x)
- return x
- # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
- class DSC(object):
- def __init__(self, input_shape, kernel_size, extend_scope, morph):
- self.num_points = kernel_size
- self.width = input_shape[2]
- self.height = input_shape[3]
- self.morph = morph
- self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
- # define feature map shape
- """
- B: Batch size C: Channel W: Width H: Height
- """
- self.num_batch = input_shape[0]
- self.num_channels = input_shape[1]
- """
- input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
- output_x: [B,1,W,K*H] coordinate map
- output_y: [B,1,K*W,H] coordinate map
- """
- def _coordinate_map_3D(self, offset, if_offset):
- device = offset.device
- # offset
- y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
- y_center = torch.arange(0, self.width).repeat([self.height])
- y_center = y_center.reshape(self.height, self.width)
- y_center = y_center.permute(1, 0)
- y_center = y_center.reshape([-1, self.width, self.height])
- y_center = y_center.repeat([self.num_points, 1, 1]).float()
- y_center = y_center.unsqueeze(0)
- x_center = torch.arange(0, self.height).repeat([self.width])
- x_center = x_center.reshape(self.width, self.height)
- x_center = x_center.permute(0, 1)
- x_center = x_center.reshape([-1, self.width, self.height])
- x_center = x_center.repeat([self.num_points, 1, 1]).float()
- x_center = x_center.unsqueeze(0)
- if self.morph == 0:
- """
- Initialize the kernel and flatten the kernel
- y: only need 0
- x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
- !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
- """
- y = torch.linspace(0, 0, 1)
- x = torch.linspace(
- -int(self.num_points // 2),
- int(self.num_points // 2),
- int(self.num_points),
- )
- y, x = torch.meshgrid(y, x)
- y_spread = y.reshape(-1, 1)
- x_spread = x.reshape(-1, 1)
- y_grid = y_spread.repeat([1, self.width * self.height])
- y_grid = y_grid.reshape([self.num_points, self.width, self.height])
- y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
- x_grid = x_spread.repeat([1, self.width * self.height])
- x_grid = x_grid.reshape([self.num_points, self.width, self.height])
- x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
- y_new = y_center + y_grid
- x_new = x_center + x_grid
- y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
- x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
- y_offset_new = y_offset.detach().clone()
- if if_offset:
- y_offset = y_offset.permute(1, 0, 2, 3)
- y_offset_new = y_offset_new.permute(1, 0, 2, 3)
- center = int(self.num_points // 2)
- # The center position remains unchanged and the rest of the positions begin to swing
- # This part is quite simple. The main idea is that "offset is an iterative process"
- y_offset_new[center] = 0
- for index in range(1, center):
- y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
- y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
- y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
- y_new = y_new.add(y_offset_new.mul(self.extend_scope))
- y_new = y_new.reshape(
- [self.num_batch, self.num_points, 1, self.width, self.height])
- y_new = y_new.permute(0, 3, 1, 4, 2)
- y_new = y_new.reshape([
- self.num_batch, self.num_points * self.width, 1 * self.height
- ])
- x_new = x_new.reshape(
- [self.num_batch, self.num_points, 1, self.width, self.height])
- x_new = x_new.permute(0, 3, 1, 4, 2)
- x_new = x_new.reshape([
- self.num_batch, self.num_points * self.width, 1 * self.height
- ])
- return y_new, x_new
- else:
- """
- Initialize the kernel and flatten the kernel
- y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
- x: only need 0
- """
- y = torch.linspace(
- -int(self.num_points // 2),
- int(self.num_points // 2),
- int(self.num_points),
- )
- x = torch.linspace(0, 0, 1)
- y, x = torch.meshgrid(y, x)
- y_spread = y.reshape(-1, 1)
- x_spread = x.reshape(-1, 1)
- y_grid = y_spread.repeat([1, self.width * self.height])
- y_grid = y_grid.reshape([self.num_points, self.width, self.height])
- y_grid = y_grid.unsqueeze(0)
- x_grid = x_spread.repeat([1, self.width * self.height])
- x_grid = x_grid.reshape([self.num_points, self.width, self.height])
- x_grid = x_grid.unsqueeze(0)
- y_new = y_center + y_grid
- x_new = x_center + x_grid
- y_new = y_new.repeat(self.num_batch, 1, 1, 1)
- x_new = x_new.repeat(self.num_batch, 1, 1, 1)
- y_new = y_new.to(device)
- x_new = x_new.to(device)
- x_offset_new = x_offset.detach().clone()
- if if_offset:
- x_offset = x_offset.permute(1, 0, 2, 3)
- x_offset_new = x_offset_new.permute(1, 0, 2, 3)
- center = int(self.num_points // 2)
- x_offset_new[center] = 0
- for index in range(1, center):
- x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
- x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
- x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
- x_new = x_new.add(x_offset_new.mul(self.extend_scope))
- y_new = y_new.reshape(
- [self.num_batch, 1, self.num_points, self.width, self.height])
- y_new = y_new.permute(0, 3, 1, 4, 2)
- y_new = y_new.reshape([
- self.num_batch, 1 * self.width, self.num_points * self.height
- ])
- x_new = x_new.reshape(
- [self.num_batch, 1, self.num_points, self.width, self.height])
- x_new = x_new.permute(0, 3, 1, 4, 2)
- x_new = x_new.reshape([
- self.num_batch, 1 * self.width, self.num_points * self.height
- ])
- return y_new, x_new
- """
- input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
- output: [N,1,K*D,K*W,K*H] deformed feature map
- """
- def _bilinear_interpolate_3D(self, input_feature, y, x):
- device = input_feature.device
- y = y.reshape([-1]).float()
- x = x.reshape([-1]).float()
- zero = torch.zeros([]).int()
- max_y = self.width - 1
- max_x = self.height - 1
- # find 8 grid locations
- y0 = torch.floor(y).int()
- y1 = y0 + 1
- x0 = torch.floor(x).int()
- x1 = x0 + 1
- # clip out coordinates exceeding feature map volume
- y0 = torch.clamp(y0, zero, max_y)
- y1 = torch.clamp(y1, zero, max_y)
- x0 = torch.clamp(x0, zero, max_x)
- x1 = torch.clamp(x1, zero, max_x)
- input_feature_flat = input_feature.flatten()
- input_feature_flat = input_feature_flat.reshape(
- self.num_batch, self.num_channels, self.width, self.height)
- input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
- input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
- dimension = self.height * self.width
- base = torch.arange(self.num_batch) * dimension
- base = base.reshape([-1, 1]).float()
- repeat = torch.ones([self.num_points * self.width * self.height
- ]).unsqueeze(0)
- repeat = repeat.float()
- base = torch.matmul(base, repeat)
- base = base.reshape([-1])
- base = base.to(device)
- base_y0 = base + y0 * self.height
- base_y1 = base + y1 * self.height
- # top rectangle of the neighbourhood volume
- index_a0 = base_y0 - base + x0
- index_c0 = base_y0 - base + x1
- # bottom rectangle of the neighbourhood volume
- index_a1 = base_y1 - base + x0
- index_c1 = base_y1 - base + x1
- # get 8 grid values
- value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
- value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
- value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
- value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
- # find 8 grid locations
- y0 = torch.floor(y).int()
- y1 = y0 + 1
- x0 = torch.floor(x).int()
- x1 = x0 + 1
- # clip out coordinates exceeding feature map volume
- y0 = torch.clamp(y0, zero, max_y + 1)
- y1 = torch.clamp(y1, zero, max_y + 1)
- x0 = torch.clamp(x0, zero, max_x + 1)
- x1 = torch.clamp(x1, zero, max_x + 1)
- x0_float = x0.float()
- x1_float = x1.float()
- y0_float = y0.float()
- y1_float = y1.float()
- vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
- vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
- vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
- vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
- outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
- value_c1 * vol_c1)
- if self.morph == 0:
- outputs = outputs.reshape([
- self.num_batch,
- self.num_points * self.width,
- 1 * self.height,
- self.num_channels,
- ])
- outputs = outputs.permute(0, 3, 1, 2)
- else:
- outputs = outputs.reshape([
- self.num_batch,
- 1 * self.width,
- self.num_points * self.height,
- self.num_channels,
- ])
- outputs = outputs.permute(0, 3, 1, 2)
- return outputs
- def deform_conv(self, input, offset, if_offset):
- y, x = self._coordinate_map_3D(offset, if_offset)
- deformed_feature = self._bilinear_interpolate_3D(input, y, x)
- return deformed_feature
- class DWConv(Conv):
- """Depth-wise convolution."""
- def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
- """Initialize Depth-wise convolution with given parameters."""
- super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
- class DSDConvHead(nn.Module):
- """YOLOv8 Detect head for detection models. CSDNSnu77"""
- dynamic = False # force grid reconstruction
- export = False # export mode
- end2end = False # end2end
- max_det = 300 # max_det
- shape = None
- anchors = torch.empty(0) # init
- strides = torch.empty(0) # init
- def __init__(self, nc=80, ch=()):
- """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
- super().__init__()
- self.nc = nc # number of classes
- self.nl = len(ch) # number of detection layers
- self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
- self.no = nc + self.reg_max * 4 # number of outputs per anchor
- self.stride = torch.zeros(self.nl) # strides computed during build
- c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
- # self.DySnakeConv = nn.ModuleList(nn.Sequential(DSConv(x, x, 0), DSConv(x, x, 0)) for x in ch) # DySnakeConv
- # morph沿y轴进行更符合绝大多数可能.
- self.cv2 = nn.ModuleList(
- nn.Sequential(DSConv(x, c2,1, 3), DSConv(c2, c2,1, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
- )
- # 仅使用一个DSConv辅助边界框回归.
- # self.cv2 = nn.ModuleList(
- # nn.Sequential(DSConv(x, c2,1, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
- # )
- self.cv3 = nn.ModuleList(
- nn.Sequential(
- nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
- nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
- nn.Conv2d(c3, self.nc, 1),
- )
- for x in ch
- )
- self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
- if self.end2end:
- self.one2one_cv2 = copy.deepcopy(self.cv2)
- self.one2one_cv3 = copy.deepcopy(self.cv3)
- def forward(self, x):
- """Concatenates and returns predicted bounding boxes and class probabilities."""
- if self.end2end:
- return self.forward_end2end(x)
- for i in range(self.nl):
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
- if self.training: # Training path
- return x
- y = self._inference(x)
- return y if self.export else (y, x)
- def forward_end2end(self, x):
- """
- Performs forward pass of the v10Detect module.
- Args:
- x (tensor): Input tensor.
- Returns:
- (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
- If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
- """
- x_detach = [xi.detach() for xi in x]
- one2one = [
- torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
- ]
- for i in range(self.nl):
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
- if self.training: # Training path
- return {"one2many": x, "one2one": one2one}
- y = self._inference(one2one)
- y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
- return y if self.export else (y, {"one2many": x, "one2one": one2one})
- def _inference(self, x):
- """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
- # Inference path
- shape = x[0].shape # BCHW
- x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
- if self.dynamic or self.shape != shape:
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
- self.shape = shape
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
- box = x_cat[:, : self.reg_max * 4]
- cls = x_cat[:, self.reg_max * 4 :]
- else:
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
- if self.export and self.format in {"tflite", "edgetpu"}:
- # Precompute normalization factor to increase numerical stability
- # See https://github.com/ultralytics/ultralytics/issues/7371
- grid_h = shape[2]
- grid_w = shape[3]
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
- norm = self.strides / (self.stride[0] * grid_size)
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
- else:
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
- return torch.cat((dbox, cls.sigmoid()), 1)
- def bias_init(self):
- """Initialize Detect() biases, WARNING: requires stride availability."""
- m = self # self.model[-1] # Detect() module
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
- for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
- a[-1].bias.data[:] = 1.0 # box
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
- if self.end2end:
- for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
- a[-1].bias.data[:] = 1.0 # box
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
- def decode_bboxes(self, bboxes, anchors):
- """Decode bounding boxes."""
- return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
- @staticmethod
- def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
- """
- Post-processes YOLO model predictions.
- Args:
- preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
- format [x, y, w, h, class_probs].
- max_det (int): Maximum detections per image.
- nc (int, optional): Number of classes. Default: 80.
- Returns:
- (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
- dimension format [x, y, w, h, max_class_prob, class_index].
- """
- batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
- boxes, scores = preds.split([4, nc], dim=-1)
- index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
- boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
- scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
- scores, index = scores.flatten(1).topk(min(max_det, anchors))
- i = torch.arange(batch_size)[..., None] # batch indices
- return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
- class DSDConvSegment(DSDConvHead):
- """YOLOv8 Segment head for segmentation models."""
- def __init__(self, nc=80, nm=32, npr=256, ch=()):
- """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
- super().__init__(nc, ch)
- self.nm = nm # number of masks
- self.npr = npr # number of protos
- self.proto = Proto(ch[0], self.npr, self.nm) # protos
- c4 = max(ch[0] // 4, self.nm)
- self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
- def forward(self, x):
- """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
- p = self.proto(x[0]) # mask protos
- bs = p.shape[0] # batch size
- mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
- x = DSDConvHead.forward(self, x)
- if self.training:
- return x, mc, p
- return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
四、分割检测头修改教程
4.1 修改一
第一还是建立文件,我们找到如下 ultralytics /nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。
4.2 修改二
第二步我们在该目录下创建一个新的py文件名字为'__init__.py',然后在其内部导入我们的检测头如下图所示。
4.3 修改三
第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块!
从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!
4.4 修改四
按照我的进行添加即可,当然其中有些检测头你们的文件中可能没有,无需理会,主要看其周围的代码一直来寻找即可!
4.5 修改五
按照我下面的添加,分割的检测头此处添加两个请注意!
4.9 修改九
此处请注意原先是一个==号,然后现在变成in 然后需要额外注意的是此处的m系统会给转化成全小写,所以我们的名字也要变成全是小写的!!!
4.10 修改10
按照我的修改,此处为最后一步后面复制yaml文件运行即可!!!
五、 分割和目标检测的yaml文件
5.1 分割的yaml文件
训练信息:YOLO11-DSConvSegment summary: 374 layers, 2,755,359 parameters, 2,755,343 gradients, 9.5 GFLOPs
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
- # Parameters
- nc: 80 # number of classes
- scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
- # [depth, width, max_channels]
- n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
- s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
- m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
- l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
- x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
- # YOLO11n backbone
- backbone:
- # [from, repeats, module, args]
- - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- - [-1, 2, C3k2, [256, False, 0.25]]
- - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- - [-1, 2, C3k2, [512, False, 0.25]]
- - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- - [-1, 2, C3k2, [512, True]]
- - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- - [-1, 2, C3k2, [1024, True]]
- - [-1, 1, SPPF, [1024, 5]] # 9
- - [-1, 2, C2PSA, [1024]] # 10
- # YOLO11n head
- head:
- - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- - [[-1, 6], 1, Concat, [1]] # cat backbone P4
- - [-1, 2, C3k2, [512, False]] # 13
- - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- - [[-1, 4], 1, Concat, [1]] # cat backbone P3
- - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- - [-1, 1, Conv, [256, 3, 2]]
- - [[-1, 13], 1, Concat, [1]] # cat head P4
- - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- - [-1, 1, Conv, [512, 3, 2]]
- - [[-1, 10], 1, Concat, [1]] # cat head P5
- - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- - [[16, 19, 22], 1, DSDConvSegment, [nc, 32, 256]] # Detect(P3, P4, P5)
分割的训练代码,分割的数据集标注比较特殊目标检测数据集会报错大家注意!
- import warnings
- warnings.filterwarnings('ignore')
- from ultralytics import YOLO
- if __name__ == '__main__':
- model = YOLO('yolo11-DSConvSegment.yaml') # 续训yaml文件的地方改为lats.pt的地址,需要注意的是如果你设置训练200轮次模型训练了200轮次是没有办法进行续训的.
- # 如何切换模型版本, 上面的ymal文件可以改为 yolov11s.yaml就是使用的v11s,
- # 类似某个改进的yaml文件名称为yolov11-XXX.yaml那么如果想使用其它版本就把上面的名称改为yolov11l-XXX.yaml即可(改的是上面YOLO中间的名字不是配置文件的)!
- # model.load('yolov11n.pt') # 是否加载预训练权重,科研不建议大家加载否则很难提升精度
- model.train(data=r"C:\Users\Administrator\Desktop\20240521\YOLOv8.2\SpotGEO2YOLO\data.yaml",
- # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
- task='segment',
- cache=False,
- imgsz=640,
- epochs=100,
- single_cls=False, # 是否是单类别检测
- batch=4,
- close_mosaic=0,
- workers=0,
- device='0',
- optimizer='SGD', # using SGD 优化器 默认为auto建议大家使用固定的.
- # resume=, # 续训的话这里填写True
- amp=True, # 如果出现训练损失为Nan可以关闭amp
- project='runs/train',
- name='exp',
- )
5.2 目标检测的yaml文件
目标检测的我上面没有提供教程,之前的检测头提供过很多了,大家直接随便找一个就行就是名字不一样了。
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
- # Parameters
- nc: 80 # number of classes
- scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
- # [depth, width, max_channels]
- n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
- s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
- m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
- l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
- x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
- # YOLO11n backbone
- backbone:
- # [from, repeats, module, args]
- - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- - [-1, 2, C3k2, [256, False, 0.25]]
- - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- - [-1, 2, C3k2, [512, False, 0.25]]
- - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- - [-1, 2, C3k2, [512, True]]
- - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- - [-1, 2, C3k2, [1024, True]]
- - [-1, 1, SPPF, [1024, 5]] # 9
- - [-1, 2, C2PSA, [1024]] # 10
- # YOLO11n head
- head:
- - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- - [[-1, 6], 1, Concat, [1]] # cat backbone P4
- - [-1, 2, C3k2, [512, False]] # 13
- - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- - [[-1, 4], 1, Concat, [1]] # cat backbone P3
- - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- - [-1, 1, Conv, [256, 3, 2]]
- - [[-1, 13], 1, Concat, [1]] # cat head P4
- - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- - [-1, 1, Conv, [512, 3, 2]]
- - [[-1, 10], 1, Concat, [1]] # cat head P5
- - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- - [[16, 19, 22], 1, DSDConvHead, [nc]] # Detect(P3, P4, P5)
六、本文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~