学习资源站

YOLOv11改进-Conv_卷积篇-2024最新Kolmogorov-Arnold网络架构下的KANConv(包含九种不同类型激活函数的KANConv2d)

一、本文介绍

本文给大家带来的改进机制是 2024最新的 ,Kolmogorov-Arnold 网络(Convolutional KANs),这种架构旨在将 Kolmogorov-Arnold 网络(KANs)的非线性激活 函数 整合到 卷积层 中,从而替代传统 卷积神经网络 (CNNs)的线性变换。与标准的卷积神经网络(CNN)相比,KANConv 层引入了更多的参数,因为每个卷积核元素都需要额外的可学习函数。这使得它能够更好地捕捉数据中的空间关系。在实验中,KANConv 层在图像识别等任务中常常表现出比传统卷积层更高的精度,特别是当网络架构经过精心优化时。 同时博主帮大家整理了多大九种的不同类型激活函数 KANConv2d,总有一种适合你的数据集 包含:



二、原理介绍

官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转


文章提出了卷积柯尔莫哥洛夫-阿诺德网络(Convolutional KANs) ,这是结合了柯尔莫哥洛夫-阿诺德网络(KANs)的非线性 激活函数 与卷积神经网络(CNNs)的创新架构。卷积KANs通过使用可学习的样条函数替代传统CNN中的固定激活函数和线性卷积操作,在保持 模型 准确率的同时大幅减少了参数数量。这种方法为优化 神经网络 架构、提升参数效率提供了新途径。

1. 背景和动机

  • 深度学习尤其是在计算机视觉领域的进展依赖于越来越复杂的神经网络架构。传统的CNNs通过线性卷积操作处理图像数据,但这些方法在处理复杂非线性数据时有一定的局限性。
  • 柯尔莫哥洛夫-阿诺德网络(KANs)是一种新的架构,它利用柯尔莫哥洛夫-阿诺德定理,将多变量函数表示为单变量函数的组合,通过可学习的样条来增强模型的表达能力。

2. KANs的应用

  • KANs通过用可学习的样条函数代替传统神经网络的线性权重矩阵,使网络在保持或提高准确率的同时减少参数数量。
  • 这些样条函数的可学习性让KANs在复杂数据上具有更好的灵活性和表达能力。

3. 卷积KANs的实现

  • 卷积KANs结合了KANs的优势与卷积神经网络的架构,将KANs的样条函数应用于卷积操作中。
  • 卷积KANs的卷积核由可学习的非线性函数(例如B样条)组成,而不是传统CNN中固定的权重。这些函数根据训练数据动态调整形状,以更好地适应数据的非线性特征。

4. 参数效率

  • 卷积KANs显著减少了所需的参数数量 (我实验之后基本都是增加,和文章描述的不符) 。例如,在MNIST和Fashion-MNIST数据集上 (看到这两个数据集的时候我对文章的结果和这个结构得有效性产生了一定的质疑) ,卷积KANs使用的参数数量比传统CNN减少了一半,而准确率基本相同。
  • 减少的参数数量不仅有助于提高训练速度和内存效率,还可能提升模型的泛化能力。

5. 实验结果

  • 通过对MNIST和Fashion-MNIST数据集的实验,卷积KANs显示出与传统CNN相似的准确率,但参数数量显著减少。
  • 当使用更少的参数时,卷积KANs在某些情况下甚至超过了传统卷积模型的性能,证明了其在处理非线性关系方面的优势。

总结: 卷积 Kolmogorov-Arnold 网络(Convolutional KANs),这是一种对标准卷积神经网络(CNNs)的创新替代方法,后者已经彻底改变了 计算机视觉 领域。将 Kolmogorov-Arnold 网络(KANs)中提出的非线性激活函数整合到卷积操作中,构建了一个新的网络层。在整篇论文中,作者通过 MNIST 和 Fashion-MNIST 基准数据集对卷积 KANs 的性能进行了实证验证,结果显示,这种新方法在保持类似准确率的同时,参数数量减少了一半。这种参数的大幅减少为神经网络架构的优化开辟了一种新的途径。


三、核心代码

核心代码涉及到1200行,包含九个版本的KANConv2d,其中有两个存在一定问题'WavKANConv2d' 需要大量显存(我个人电脑8g显存运行不了) 'KANConv2d'需要torch 1.9以上。

  1. from functools import lru_cache
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from einops import einsum
  6. import numpy as np
  7. import math
  8. from torch.autograd import Function
  9. __all__ = ['RBFKANConv2d', 'ReLUKANConv2d', 'KANConv2d', 'FasterKANConv2d', 'WavKANConv2d', 'ChebyKANConv2d', 'JacobiKANConv2d', 'FastKANConv2d', 'GRAMKANConv2d']
  10. # ################################################各种工具###########################################################
  11. class RadialBasisFunction(nn.Module):
  12. def __init__(
  13. self,
  14. grid_min: float = -2.,
  15. grid_max: float = 2.,
  16. num_grids: int = 8,
  17. denominator: float = None, # larger denominators lead to smoother basis
  18. ):
  19. super().__init__()
  20. grid = torch.linspace(grid_min, grid_max, num_grids)
  21. self.grid = torch.nn.Parameter(grid, requires_grad=False)
  22. self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
  23. def forward(self, x):
  24. return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
  25. class RSWAFFunction(Function):
  26. @staticmethod
  27. def forward(ctx, input, grid, inv_denominator, train_grid, train_inv_denominator):
  28. # Compute the forward pass
  29. # print('\n')
  30. # print(f"Forward pass - grid: {(grid[0].item(),grid[-1].item())}, inv_denominator: {inv_denominator.item()}")
  31. # print(f"grid.shape: {grid.shape }")
  32. # print(f"grid: {(grid[0],grid[-1]) }")
  33. # print(f"inv_denominator.shape: {inv_denominator.shape }")
  34. # print(f"inv_denominator: {inv_denominator }")
  35. diff = (input[..., None] - grid)
  36. diff_mul = diff.mul(inv_denominator)
  37. tanh_diff = torch.tanh(diff)
  38. tanh_diff_deriviative = -tanh_diff.mul(tanh_diff) + 1 # sech^2(x) = 1 - tanh^2(x)
  39. # Save tensors for backward pass
  40. ctx.save_for_backward(input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator)
  41. ctx.train_grid = train_grid
  42. ctx.train_inv_denominator = train_inv_denominator
  43. return tanh_diff_deriviative
  44. @staticmethod
  45. def backward(ctx, grad_output):
  46. # Retrieve saved tensors
  47. input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator = ctx.saved_tensors
  48. grad_grid = None
  49. grad_inv_denominator = None
  50. # print(f"tanh_diff_deriviative shape: {tanh_diff_deriviative.shape }")
  51. # print(f"tanh_diff shape: {tanh_diff.shape }")
  52. # print(f"grad_output shape: {grad_output.shape }")
  53. # Compute the backward pass for the input
  54. grad_input = -2 * tanh_diff * tanh_diff_deriviative * grad_output
  55. # print(f"Backward pass 1 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
  56. # print(f"grad_input shape: {grad_input.shape }")
  57. # print(f"grad_input.sum(dim=-1): {grad_input.sum(dim=-1).shape}")
  58. grad_input = grad_input.sum(dim=-1).mul(inv_denominator)
  59. # print(f"Backward pass 2 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
  60. # print(f"grad_input: {grad_input}")
  61. # print(f"grad_input shape: {grad_input.shape }")
  62. # Compute the backward pass for grid
  63. if ctx.train_grid:
  64. # print('\n')
  65. # print(f"grad_grid shape: {grad_grid.shape }")
  66. grad_grid = -inv_denominator * grad_output.sum(dim=0).sum(
  67. dim=0) # -(inv_denominator * grad_output * tanh_diff_deriviative).sum(dim=0) #-inv_denominator * grad_output.sum(dim=0).sum(dim=0)
  68. # print(f"Backward pass - grad_grid: {(grad_grid[0].item(),grad_grid[-1].item())}")
  69. # print(f"grad_grid.shape: {grad_grid.shape }")
  70. # print(f"grad_grid: {(grad_grid[0],grad_grid[-1]) }")
  71. # print(f"inv_denominator shape: {inv_denominator.shape }")
  72. # print(f"grad_grid shape: {grad_grid.shape }")
  73. # Compute the backward pass for inv_denominator
  74. if ctx.train_inv_denominator:
  75. grad_inv_denominator = (
  76. grad_output * diff).sum() # (grad_output * diff * tanh_diff_deriviative).sum() #(grad_output* diff).sum()
  77. # print(f"Backward pass - grad_inv_denominator: {grad_inv_denominator.item()}")
  78. # print(f"diff shape: {diff.shape }")
  79. # print(f"grad_inv_denominator shape: {grad_inv_denominator.shape }")
  80. # print(f"grad_inv_denominator : {grad_inv_denominator }")
  81. return grad_input, grad_grid, grad_inv_denominator, None, None # same number as tensors or parameters
  82. class ReflectionalSwitchFunction(nn.Module):
  83. def __init__(
  84. self,
  85. grid_min: float = -1.2,
  86. grid_max: float = 0.2,
  87. num_grids: int = 8,
  88. exponent: int = 2,
  89. inv_denominator: float = 0.5,
  90. train_grid: bool = False,
  91. train_inv_denominator: bool = False,
  92. ):
  93. super().__init__()
  94. grid = torch.linspace(grid_min, grid_max, num_grids)
  95. self.train_grid = torch.tensor(train_grid, dtype=torch.bool)
  96. self.train_inv_denominator = torch.tensor(train_inv_denominator, dtype=torch.bool)
  97. self.grid = torch.nn.Parameter(grid, requires_grad=train_grid)
  98. # print(f"grid initial shape: {self.grid.shape }")
  99. self.inv_denominator = torch.nn.Parameter(torch.tensor(inv_denominator, dtype=torch.float32),
  100. requires_grad=train_inv_denominator) # Cache the inverse of the denominator
  101. def forward(self, x):
  102. return RSWAFFunction.apply(x, self.grid, self.inv_denominator, self.train_grid, self.train_inv_denominator)
  103. # ####################################各种激活函数##############################################################
  104. class KANLayer(nn.Module):
  105. def __init__(self, input_features, output_features, grid_size=5, spline_order=3, base_activation=nn.GELU,
  106. grid_range=[-1, 1]):
  107. super(KANLayer, self).__init__()
  108. self.input_features = input_features
  109. self.output_features = output_features
  110. # The number of points in the grid for the spline interpolation.
  111. self.grid_size = grid_size
  112. # The order of the spline used in the interpolation.
  113. self.spline_order = spline_order
  114. # Activation function used for the initial transformation of the input.
  115. self.base_activation = base_activation()
  116. # The range of values over which the grid for spline interpolation is defined.
  117. self.grid_range = grid_range
  118. # Initialize the base weights with random values for the linear transformation.
  119. self.base_weight = nn.Parameter(torch.randn(output_features, input_features))
  120. # Initialize the spline weights with random values for the spline transformation.
  121. self.spline_weight = nn.Parameter(torch.randn(output_features, input_features, grid_size + spline_order))
  122. # Add a layer normalization for stabilizing the output of this layer.
  123. self.layer_norm = nn.LayerNorm(output_features)
  124. # Add a PReLU activation for this layer to provide a learnable non-linearity.
  125. self.prelu = nn.PReLU()
  126. # Compute the grid values based on the specified range and grid size.
  127. h = (self.grid_range[1] - self.grid_range[0]) / grid_size
  128. self.grid = torch.linspace(
  129. self.grid_range[0] - h * spline_order,
  130. self.grid_range[1] + h * spline_order,
  131. grid_size + 2 * spline_order + 1,
  132. dtype=torch.float32
  133. ).expand(input_features, -1).contiguous()
  134. # Initialize the weights using Kaiming uniform distribution for better initial values.
  135. nn.init.kaiming_uniform_(self.base_weight, nonlinearity='linear')
  136. nn.init.kaiming_uniform_(self.spline_weight, nonlinearity='linear')
  137. def forward(self, x):
  138. # Process each layer using the defined base weights, spline weights, norms, and activations.
  139. grid = self.grid.to(x.device)
  140. # Move the input tensor to the device where the weights are located.
  141. # Perform the base linear transformation followed by the activation function.
  142. base_output = F.linear(self.base_activation(x), self.base_weight)
  143. x_uns = x.unsqueeze(-1) # Expand dimensions for spline operations.
  144. # Compute the basis for the spline using intervals and input values.
  145. bases = ((x_uns >= grid[:, :-1]) & (x_uns < grid[:, 1:])).to(x.dtype).to(x.device)
  146. # Compute the spline basis over multiple orders.
  147. for k in range(1, self.spline_order + 1):
  148. left_intervals = grid[:, :-(k + 1)]
  149. right_intervals = grid[:, k:-1]
  150. delta = torch.where(right_intervals == left_intervals, torch.ones_like(right_intervals),
  151. right_intervals - left_intervals)
  152. bases = ((x_uns - left_intervals) / delta * bases[:, :, :-1]) + \
  153. ((grid[:, k + 1:] - x_uns) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * bases[:, :, 1:])
  154. bases = bases.contiguous()
  155. # Compute the spline transformation and combine it with the base transformation.
  156. spline_output = F.linear(bases.view(x.size(0), -1), self.spline_weight.view(self.spline_weight.size(0), -1))
  157. # Apply layer normalization and PReLU activation to the combined output.
  158. x = self.prelu(self.layer_norm(base_output + spline_output))
  159. return x
  160. class SplineLinear(nn.Linear):
  161. def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
  162. self.init_scale = init_scale
  163. super().__init__(in_features, out_features, bias=False, **kw)
  164. def reset_parameters(self) -> None:
  165. nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
  166. class FastKANLayer(nn.Module):
  167. def __init__(
  168. self,
  169. input_dim: int,
  170. output_dim: int,
  171. grid_min: float = -2.,
  172. grid_max: float = 2.,
  173. num_grids: int = 8,
  174. use_base_update: bool = True,
  175. base_activation=nn.SiLU,
  176. spline_weight_init_scale: float = 0.1,
  177. ) -> None:
  178. super().__init__()
  179. self.layernorm = nn.LayerNorm(input_dim)
  180. self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
  181. self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
  182. self.use_base_update = use_base_update
  183. if use_base_update:
  184. self.base_activation = base_activation()
  185. self.base_linear = nn.Linear(input_dim, output_dim)
  186. def forward(self, x, time_benchmark=False):
  187. if not time_benchmark:
  188. spline_basis = self.rbf(self.layernorm(x))
  189. else:
  190. spline_basis = self.rbf(x)
  191. ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
  192. if self.use_base_update:
  193. base = self.base_linear(self.base_activation(x))
  194. ret = ret + base
  195. return ret
  196. # This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
  197. class ChebyKANLayer(nn.Module):
  198. def __init__(self, input_dim, output_dim, degree):
  199. super(ChebyKANLayer, self).__init__()
  200. self.inputdim = input_dim
  201. self.outdim = output_dim
  202. self.degree = degree
  203. self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
  204. nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
  205. self.register_buffer("arange", torch.arange(0, degree + 1, 1))
  206. def forward(self, x):
  207. # Since Chebyshev polynomial is defined in [-1, 1]
  208. # We need to normalize x to [-1, 1] using tanh
  209. # x = torch.tanh(x)
  210. x = torch.clamp(x, -1.0, 1.0)
  211. # View and repeat input degree + 1 times
  212. x = x.view((-1, self.inputdim, 1)).expand(
  213. -1, -1, self.degree + 1
  214. ) # shape = (batch_size, inputdim, self.degree + 1)
  215. # Apply acos
  216. x = x.acos()
  217. # Multiply by arange [0 .. degree]
  218. x *= self.arange
  219. # Apply cos
  220. x = x.cos()
  221. # Compute the Chebyshev interpolation
  222. y = torch.einsum(
  223. "bid,iod->bo", x, self.cheby_coeffs
  224. ) # shape = (batch_size, outdim)
  225. y = y.view(-1, self.outdim)
  226. return y
  227. class GRAMLayer(nn.Module):
  228. def __init__(self, in_channels, out_channels, degree=3, act=nn.SiLU):
  229. super(GRAMLayer, self).__init__()
  230. self.in_channels = in_channels
  231. self.out_channels = out_channels
  232. self.degrees = degree
  233. self.act = act()
  234. self.norm = nn.LayerNorm(out_channels).to(dtype=torch.float32)
  235. self.beta_weights = nn.Parameter(torch.zeros(degree + 1, dtype=torch.float32))
  236. self.grams_basis_weights = nn.Parameter(
  237. torch.zeros(in_channels, out_channels, degree + 1, dtype=torch.float32)
  238. )
  239. self.base_weights = nn.Parameter(
  240. torch.zeros(out_channels, in_channels, dtype=torch.float32)
  241. )
  242. self.init_weights()
  243. def init_weights(self):
  244. nn.init.normal_(
  245. self.beta_weights,
  246. mean=0.0,
  247. std=1.0 / (self.in_channels * (self.degrees + 1.0)),
  248. )
  249. nn.init.xavier_uniform_(self.grams_basis_weights)
  250. nn.init.xavier_uniform_(self.base_weights)
  251. def beta(self, n, m):
  252. return (
  253. ((m + n) * (m - n) * n ** 2) / (m ** 2 / (4.0 * n ** 2 - 1.0))
  254. ) * self.beta_weights[n]
  255. @lru_cache(maxsize=128)
  256. def gram_poly(self, x, degree):
  257. p0 = x.new_ones(x.size())
  258. if degree == 0:
  259. return p0.unsqueeze(-1)
  260. p1 = x
  261. grams_basis = [p0, p1]
  262. for i in range(2, degree + 1):
  263. p2 = x * p1 - self.beta(i - 1, i) * p0
  264. grams_basis.append(p2)
  265. p0, p1 = p1, p2
  266. return torch.stack(grams_basis, dim=-1)
  267. def forward(self, x):
  268. basis = F.linear(self.act(x), self.base_weights)
  269. x = torch.tanh(x).contiguous()
  270. grams_basis = self.act(self.gram_poly(x, self.degrees))
  271. y = einsum(
  272. grams_basis,
  273. self.grams_basis_weights,
  274. "b l d, l o d -> b o",
  275. )
  276. y = self.act(self.norm(y + basis))
  277. y = y.view(-1, self.out_channels)
  278. return y
  279. class WavKANLayer(nn.Module):
  280. def __init__(self, in_features, out_features, wavelet_type='mexican_hat'):
  281. super(WavKANLayer, self).__init__()
  282. self.in_features = in_features
  283. self.out_features = out_features
  284. self.wavelet_type = wavelet_type
  285. # Parameters for wavelet transformation
  286. self.scale = nn.Parameter(torch.ones(out_features, in_features))
  287. self.translation = nn.Parameter(torch.zeros(out_features, in_features))
  288. # Linear weights for combining outputs
  289. # self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
  290. self.weight1 = nn.Parameter(torch.Tensor(out_features,
  291. in_features)) # not used; you may like to use it for wieghting base activation and adding it like Spl-KAN paper
  292. self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))
  293. nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))
  294. nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
  295. # Base activation function #not used for this experiment
  296. self.base_activation = nn.SiLU()
  297. # Batch normalization
  298. self.bn = nn.BatchNorm1d(out_features)
  299. def wavelet_transform(self, x):
  300. if x.dim() == 2:
  301. x_expanded = x.unsqueeze(1)
  302. else:
  303. x_expanded = x
  304. translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)
  305. scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)
  306. x_scaled = (x_expanded - translation_expanded) / scale_expanded
  307. # Implementation of different wavelet types
  308. if self.wavelet_type == 'mexican_hat':
  309. term1 = ((x_scaled ** 2) - 1)
  310. term2 = torch.exp(-0.5 * x_scaled ** 2)
  311. wavelet = (2 / (math.sqrt(3) * math.pi ** 0.25)) * term1 * term2
  312. elif self.wavelet_type == 'morlet':
  313. omega0 = 5.0 # Central frequency
  314. real = torch.cos(omega0 * x_scaled)
  315. envelope = torch.exp(-0.5 * x_scaled ** 2)
  316. wavelet = envelope * real
  317. elif self.wavelet_type == 'dog':
  318. # Implementing Derivative of Gaussian Wavelet
  319. wavelet = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)
  320. elif self.wavelet_type == 'meyer':
  321. # Implement Meyer Wavelet here
  322. # Constants for the Meyer wavelet transition boundaries
  323. v = torch.abs(x_scaled)
  324. pi = math.pi
  325. def meyer_aux(v):
  326. return torch.where(v <= 1 / 2, torch.ones_like(v),
  327. torch.where(v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * nu(2 * v - 1))))
  328. def nu(t):
  329. return t ** 4 * (35 - 84 * t + 70 * t ** 2 - 20 * t ** 3)
  330. # Meyer wavelet calculation using the auxiliary function
  331. wavelet = torch.sin(pi * v) * meyer_aux(v)
  332. elif self.wavelet_type == 'shannon':
  333. # Windowing the sinc function to limit its support
  334. pi = math.pi
  335. sinc = torch.sinc(x_scaled / pi) # sinc(x) = sin(pi*x) / (pi*x)
  336. # Applying a Hamming window to limit the infinite support of the sinc function
  337. window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype,
  338. device=x_scaled.device)
  339. # Shannon wavelet is the product of the sinc function and the window
  340. wavelet = sinc * window
  341. # You can try many more wavelet types ...
  342. else:
  343. raise ValueError("Unsupported wavelet type")
  344. wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
  345. wavelet_output = wavelet_weighted.sum(dim=2)
  346. return wavelet_output
  347. def forward(self, x):
  348. wavelet_output = self.wavelet_transform(x)
  349. # You may like test the cases like Spl-KAN
  350. # wav_output = F.linear(wavelet_output, self.weight)
  351. base_output = F.linear(self.base_activation(x), self.weight1)
  352. # base_output = F.linear(x, self.weight1)
  353. combined_output = wavelet_output + base_output
  354. # Apply batch normalization
  355. return self.bn(combined_output)
  356. class JacobiKANLayer(nn.Module):
  357. def __init__(self, input_dim, output_dim, degree, a=1.0, b=1.0, act=nn.SiLU):
  358. super(JacobiKANLayer, self).__init__()
  359. self.inputdim = input_dim
  360. self.outdim = output_dim
  361. self.a = a
  362. self.b = b
  363. self.degree = degree
  364. self.act = act()
  365. self.norm = nn.LayerNorm(output_dim,).to(dtype=torch.float32)
  366. self.base_weights = nn.Parameter(
  367. torch.zeros(output_dim, input_dim, dtype=torch.float32)
  368. )
  369. self.jacobi_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
  370. nn.init.normal_(self.jacobi_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
  371. nn.init.xavier_uniform_(self.base_weights)
  372. def forward(self, x):
  373. x = torch.reshape(x, (-1, self.inputdim)) # shape = (batch_size, inputdim)
  374. basis = F.linear(self.act(x), self.base_weights)
  375. # Since Jacobian polynomial is defined in [-1, 1]
  376. # We need to normalize x to [-1, 1] using tanh
  377. x = torch.tanh(x)
  378. # Initialize Jacobian polynomial tensors
  379. jacobi = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device)
  380. if self.degree > 0: ## degree = 0: jacobi[:, :, 0] = 1 (already initialized) ; degree = 1: jacobi[:, :, 1] = x ; d
  381. jacobi[:, :, 1] = ((self.a - self.b) + (self.a + self.b + 2) * x) / 2
  382. for i in range(2, self.degree + 1):
  383. theta_k = (2 * i + self.a + self.b) * (2 * i + self.a + self.b - 1) / (2 * i * (i + self.a + self.b))
  384. theta_k1 = (2 * i + self.a + self.b - 1) * (self.a * self.a - self.b * self.b) / (
  385. 2 * i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))
  386. theta_k2 = (i + self.a - 1) * (i + self.b - 1) * (2 * i + self.a + self.b) / (
  387. i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))
  388. jacobi[:, :, i] = (theta_k * x + theta_k1) * jacobi[:, :, i - 1].clone() - theta_k2 * jacobi[:, :,
  389. i - 2].clone()
  390. # Compute the Jacobian interpolation
  391. y = torch.einsum('bid,iod->bo', jacobi, self.jacobi_coeffs) # shape = (batch_size, outdim)
  392. y = y.view(-1, self.outdim)
  393. y = self.act(self.norm(y + basis))
  394. return y
  395. class ReLUKANLayer(nn.Module):
  396. def __init__(self, input_size: int, g: int, k: int, output_size: int, train_ab: bool = True):
  397. super().__init__()
  398. self.g, self.k, self.r = g, k, 4 * g * g / ((k + 1) * (k + 1))
  399. self.input_size, self.output_size = input_size, output_size
  400. phase_low = np.arange(-k, g) / g
  401. phase_height = phase_low + (k + 1) / g
  402. self.phase_low = nn.Parameter(torch.Tensor(np.array([phase_low for i in range(input_size)])),
  403. requires_grad=train_ab)
  404. self.phase_height = nn.Parameter(torch.Tensor(np.array([phase_height for i in range(input_size)])),
  405. requires_grad=train_ab)
  406. self.equal_size_conv = nn.Conv2d(1, output_size, (g + k, input_size))
  407. def forward(self, x):
  408. # Expand dimensions of x to match the shape of self.phase_low
  409. x_expanded = x.unsqueeze(2).expand(-1, -1, self.phase_low.size(1))
  410. # Perform the subtraction with broadcasting
  411. x1 = torch.relu(x_expanded - self.phase_low)
  412. x2 = torch.relu(self.phase_height - x_expanded)
  413. # Continue with the rest of the operations
  414. x = x1 * x2 * self.r
  415. x = x * x
  416. x = x.reshape((len(x), 1, self.g + self.k, self.input_size))
  417. x = self.equal_size_conv(x)
  418. # x = x.reshape((len(x), self.output_size, 1))
  419. x = x.reshape((len(x), self.output_size))
  420. return x
  421. class SplineLinear_fstr(nn.Linear):
  422. def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
  423. self.init_scale = init_scale
  424. super().__init__(in_features, out_features, bias=False, **kw)
  425. def reset_parameters(self) -> None:
  426. nn.init.xavier_uniform_(self.weight) # Using Xavier Uniform initialization
  427. class FasterKANLayer(nn.Module):
  428. def __init__(
  429. self,
  430. input_dim: int,
  431. output_dim: int,
  432. grid_min: float = -1.2,
  433. grid_max: float = 0.2,
  434. num_grids: int = 8,
  435. exponent: int = 2,
  436. inv_denominator: float = 0.5,
  437. train_grid: bool = False,
  438. train_inv_denominator: bool = False,
  439. # use_base_update: bool = True,
  440. base_activation=F.silu,
  441. spline_weight_init_scale: float = 0.667,
  442. ) -> None:
  443. super().__init__()
  444. self.layernorm = nn.LayerNorm(input_dim)
  445. self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, inv_denominator, train_grid,
  446. train_inv_denominator)
  447. self.spline_linear = SplineLinear_fstr(input_dim * num_grids, output_dim, spline_weight_init_scale)
  448. # self.use_base_update = use_base_update
  449. # if use_base_update:
  450. # self.base_activation = base_activation
  451. # self.base_linear = nn.Linear(input_dim, output_dim)
  452. def forward(self, x):
  453. # print("Shape before LayerNorm:", x.shape) # Debugging line to check the input shape
  454. x = self.layernorm(x)
  455. # print("Shape After LayerNorm:", x.shape)
  456. spline_basis = self.rbf(x).view(x.shape[0], -1)
  457. # print("spline_basis:", spline_basis.shape)
  458. # print("-------------------------")
  459. # ret = 0
  460. ret = self.spline_linear(spline_basis)
  461. # print("spline_basis.shape[:-2]:", spline_basis.shape[:-2])
  462. # print("*spline_basis.shape[:-2]:", *spline_basis.shape[:-2])
  463. # print("spline_basis.view(*spline_basis.shape[:-2], -1):", spline_basis.view(*spline_basis.shape[:-2], -1).shape)
  464. # print("ret:", ret.shape)
  465. # print("-------------------------")
  466. # if self.use_base_update:
  467. # base = self.base_linear(self.base_activation(x))
  468. # print("self.base_activation(x):", self.base_activation(x).shape)
  469. # print("base:", base.shape)
  470. # print("@@@@@@@@@")
  471. # ret += base
  472. return ret
  473. class RBFLinear(nn.Module):
  474. def __init__(self, in_features, out_features, grid_min=-2., grid_max=2., num_grids=8, spline_weight_init_scale=0.1):
  475. super().__init__()
  476. self.grid_min = grid_min
  477. self.grid_max = grid_max
  478. self.num_grids = num_grids
  479. self.grid = nn.Parameter(torch.linspace(grid_min, grid_max, num_grids), requires_grad=False)
  480. self.spline_weight = nn.Parameter(torch.randn(in_features * num_grids, out_features) * spline_weight_init_scale)
  481. def forward(self, x):
  482. x = x.unsqueeze(-1)
  483. basis = torch.exp(-((x - self.grid) / ((self.grid_max - self.grid_min) / (self.num_grids - 1))) ** 2)
  484. return basis.reshape(basis.size(0), -1).matmul(self.spline_weight)
  485. class RBFKANLayer(nn.Module):
  486. def __init__(self, input_dim, output_dim, grid_min=-2., grid_max=2., num_grids=8, use_base_update=True,
  487. base_activation=nn.SiLU(), spline_weight_init_scale=0.1):
  488. super().__init__()
  489. self.input_dim = input_dim
  490. self.output_dim = output_dim
  491. self.use_base_update = use_base_update
  492. self.base_activation = base_activation
  493. self.spline_weight_init_scale = spline_weight_init_scale
  494. self.rbf_linear = RBFLinear(input_dim, output_dim, grid_min, grid_max, num_grids, spline_weight_init_scale)
  495. self.base_linear = nn.Linear(input_dim, output_dim) if use_base_update else None
  496. def forward(self, x):
  497. ret = self.rbf_linear(x)
  498. if self.use_base_update:
  499. base = self.base_linear(self.base_activation(x))
  500. ret = ret + base
  501. return ret
  502. class KANLinear(torch.nn.Module):
  503. def __init__(
  504. self,
  505. in_features,
  506. out_features,
  507. grid_size=5,
  508. spline_order=3,
  509. scale_noise=0.1,
  510. scale_base=1.0,
  511. scale_spline=1.0,
  512. enable_standalone_scale_spline=True,
  513. base_activation=torch.nn.SiLU,
  514. grid_eps=0.02,
  515. grid_range=[-1, 1],
  516. ):
  517. super(KANLinear, self).__init__()
  518. self.in_features = in_features
  519. self.out_features = out_features
  520. self.grid_size = grid_size
  521. self.spline_order = spline_order
  522. h = (grid_range[1] - grid_range[0]) / grid_size
  523. grid = (
  524. (
  525. torch.arange(-spline_order, grid_size + spline_order + 1) * h
  526. + grid_range[0]
  527. )
  528. .expand(in_features, -1)
  529. .contiguous()
  530. )
  531. self.register_buffer("grid", grid)
  532. self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
  533. self.spline_weight = torch.nn.Parameter(
  534. torch.Tensor(out_features, in_features, grid_size + spline_order)
  535. )
  536. if enable_standalone_scale_spline:
  537. self.spline_scaler = torch.nn.Parameter(
  538. torch.Tensor(out_features, in_features)
  539. )
  540. self.scale_noise = scale_noise
  541. self.scale_base = scale_base
  542. self.scale_spline = scale_spline
  543. self.enable_standalone_scale_spline = enable_standalone_scale_spline
  544. self.base_activation = base_activation()
  545. self.grid_eps = grid_eps
  546. self.reset_parameters()
  547. def reset_parameters(self):
  548. torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
  549. with torch.no_grad():
  550. noise = (
  551. (
  552. torch.rand(self.grid_size + 1, self.in_features, self.out_features)
  553. - 1 / 2
  554. )
  555. * self.scale_noise
  556. / self.grid_size
  557. )
  558. self.spline_weight.data.copy_(
  559. (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
  560. * self.curve2coeff(
  561. self.grid.T[self.spline_order: -self.spline_order],
  562. noise,
  563. )
  564. )
  565. if self.enable_standalone_scale_spline:
  566. # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
  567. torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
  568. def b_splines(self, x: torch.Tensor):
  569. """
  570. Compute the B-spline bases for the given input tensor.
  571. Args:
  572. x (torch.Tensor): Input tensor of shape (batch_size, in_features).
  573. Returns:
  574. torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
  575. """
  576. assert x.dim() == 2 and x.size(1) == self.in_features
  577. grid: torch.Tensor = (
  578. self.grid
  579. ) # (in_features, grid_size + 2 * spline_order + 1)
  580. x = x.unsqueeze(-1)
  581. bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
  582. for k in range(1, self.spline_order + 1):
  583. bases = (
  584. (x - grid[:, : -(k + 1)])
  585. / (grid[:, k:-1] - grid[:, : -(k + 1)])
  586. * bases[:, :, :-1]
  587. ) + (
  588. (grid[:, k + 1:] - x)
  589. / (grid[:, k + 1:] - grid[:, 1:(-k)])
  590. * bases[:, :, 1:]
  591. )
  592. assert bases.size() == (
  593. x.size(0),
  594. self.in_features,
  595. self.grid_size + self.spline_order,
  596. )
  597. return bases.contiguous()
  598. def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
  599. """
  600. Compute the coefficients of the curve that interpolates the given points.
  601. Args:
  602. x (torch.Tensor): Input tensor of shape (batch_size, in_features).
  603. y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
  604. Returns:
  605. torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
  606. """
  607. assert x.dim() == 2 and x.size(1) == self.in_features
  608. assert y.size() == (x.size(0), self.in_features, self.out_features)
  609. A = self.b_splines(x).transpose(
  610. 0, 1
  611. ) # (in_features, batch_size, grid_size + spline_order)
  612. B = y.transpose(0, 1) # (in_features, batch_size, out_features)
  613. # CSDN作者Snu77注明 此处torch.linalg.lstsq需要torch1.9及以上才可以.
  614. solution = torch.linalg.lstsq(
  615. A, B
  616. ).solution # (in_features, grid_size + spline_order, out_features)
  617. result = solution.permute(
  618. 2, 0, 1
  619. ) # (out_features, in_features, grid_size + spline_order)
  620. assert result.size() == (
  621. self.out_features,
  622. self.in_features,
  623. self.grid_size + self.spline_order,
  624. )
  625. return result.contiguous()
  626. @property
  627. def scaled_spline_weight(self):
  628. return self.spline_weight * (
  629. self.spline_scaler.unsqueeze(-1)
  630. if self.enable_standalone_scale_spline
  631. else 1.0
  632. )
  633. def forward(self, x: torch.Tensor):
  634. assert x.dim() == 2 and x.size(1) == self.in_features
  635. base_output = F.linear(self.base_activation(x), self.base_weight)
  636. spline_output = F.linear(
  637. self.b_splines(x).view(x.size(0), -1),
  638. self.scaled_spline_weight.view(self.out_features, -1),
  639. )
  640. return base_output + spline_output
  641. @torch.no_grad()
  642. def update_grid(self, x: torch.Tensor, margin=0.01):
  643. assert x.dim() == 2 and x.size(1) == self.in_features
  644. batch = x.size(0)
  645. splines = self.b_splines(x) # (batch, in, coeff)
  646. splines = splines.permute(1, 0, 2) # (in, batch, coeff)
  647. orig_coeff = self.scaled_spline_weight # (out, in, coeff)
  648. orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
  649. unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
  650. unreduced_spline_output = unreduced_spline_output.permute(
  651. 1, 0, 2
  652. ) # (batch, in, out)
  653. # sort each channel individually to collect data distribution
  654. x_sorted = torch.sort(x, dim=0)[0]
  655. grid_adaptive = x_sorted[
  656. torch.linspace(
  657. 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
  658. )
  659. ]
  660. uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
  661. grid_uniform = (
  662. torch.arange(
  663. self.grid_size + 1, dtype=torch.float32, device=x.device
  664. ).unsqueeze(1)
  665. * uniform_step
  666. + x_sorted[0]
  667. - margin
  668. )
  669. grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
  670. grid = torch.concatenate(
  671. [
  672. grid[:1]
  673. - uniform_step
  674. * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
  675. grid,
  676. grid[-1:]
  677. + uniform_step
  678. * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
  679. ],
  680. dim=0,
  681. )
  682. self.grid.copy_(grid.T)
  683. self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
  684. def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
  685. """
  686. Compute the regularization loss.
  687. This is a dumb simulation of the original L1 regularization as stated in the
  688. paper, since the original one requires computing absolutes and entropy from the
  689. expanded (batch, in_features, out_features) intermediate tensor, which is hidden
  690. behind the F.linear function if we want an memory efficient implementation.
  691. The L1 regularization is now computed as mean absolute value of the spline
  692. weights. The authors implementation also includes this term in addition to the
  693. sample-based regularization.
  694. """
  695. l1_fake = self.spline_weight.abs().mean(-1)
  696. regularization_loss_activation = l1_fake.sum()
  697. p = l1_fake / regularization_loss_activation
  698. regularization_loss_entropy = -torch.sum(p * p.log())
  699. return (
  700. regularize_activation * regularization_loss_activation
  701. + regularize_entropy * regularization_loss_entropy
  702. )
  703. # """""""""""""""""""""""""""""""""""""""""""""正式代码"""""""""""""""""""""""""""""""""""""""""""
  704. class KANConv2d(nn.Module):
  705. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
  706. super(KANConv2d, self).__init__()
  707. self.in_channels = in_channels
  708. self.out_channels = out_channels
  709. self.kernel_size = kernel_size
  710. self.stride = stride
  711. self.padding = padding
  712. self.kanlayer = KANLinear(in_channels * kernel_size * kernel_size, out_channels)
  713. def forward(self, x):
  714. batch_size, in_channels, height, width = x.size()
  715. assert in_channels == self.in_channels
  716. # Apply unfold to get sliding local blocks
  717. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  718. x_unfold = x_unfold.transpose(1, 2)
  719. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  720. out_unfold = self.kanlayer(x_unfold)
  721. # Reshape and transpose to get the final output
  722. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  723. out = out_unfold.transpose(1, 2)
  724. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  725. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  726. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  727. return out
  728. class ChebyKANConv2d(nn.Module):
  729. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, degree=4):
  730. super(ChebyKANConv2d, self).__init__()
  731. self.in_channels = in_channels
  732. self.out_channels = out_channels
  733. self.kernel_size = kernel_size
  734. self.stride = stride
  735. self.padding = padding
  736. self.kanlayer = ChebyKANLayer(in_channels * kernel_size * kernel_size, out_channels, degree=degree)
  737. def forward(self, x):
  738. batch_size, in_channels, height, width = x.size()
  739. assert in_channels == self.in_channels
  740. # Apply unfold to get sliding local blocks
  741. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  742. x_unfold = x_unfold.transpose(1, 2)
  743. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  744. out_unfold = self.kanlayer(x_unfold)
  745. # Reshape and transpose to get the final output
  746. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  747. out = out_unfold.transpose(1, 2)
  748. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  749. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  750. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  751. return out
  752. class FastKANConv2d(nn.Module):
  753. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
  754. super(FastKANConv2d, self).__init__()
  755. self.in_channels = in_channels
  756. self.out_channels = out_channels
  757. self.kernel_size = kernel_size
  758. self.stride = stride
  759. self.padding = padding
  760. self.kanlayer = FastKANLayer(in_channels * kernel_size * kernel_size, out_channels)
  761. def forward(self, x):
  762. batch_size, in_channels, height, width = x.size()
  763. assert in_channels == self.in_channels
  764. # Apply unfold to get sliding local blocks
  765. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  766. x_unfold = x_unfold.transpose(1, 2)
  767. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  768. out_unfold = self.kanlayer(x_unfold)
  769. # Reshape and transpose to get the final output
  770. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  771. out = out_unfold.transpose(1, 2)
  772. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  773. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  774. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  775. return out
  776. class GRAMKANConv2d(nn.Module):
  777. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
  778. super(GRAMKANConv2d, self).__init__()
  779. self.in_channels = in_channels
  780. self.out_channels = out_channels
  781. self.kernel_size = kernel_size
  782. self.stride = stride
  783. self.padding = padding
  784. self.kanlayer = GRAMLayer(in_channels * kernel_size * kernel_size, out_channels)
  785. def forward(self, x):
  786. batch_size, in_channels, height, width = x.size()
  787. assert in_channels == self.in_channels
  788. # Apply unfold to get sliding local blocks
  789. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  790. x_unfold = x_unfold.transpose(1, 2)
  791. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  792. out_unfold = self.kanlayer(x_unfold)
  793. # Reshape and transpose to get the final output
  794. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  795. out = out_unfold.transpose(1, 2)
  796. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  797. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  798. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  799. return out
  800. class WavKANConv2d(nn.Module):
  801. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, wavelet_type='mexican_hat'):
  802. super(WavKANConv2d, self).__init__()
  803. self.in_channels = in_channels
  804. self.out_channels = out_channels
  805. self.kernel_size = kernel_size
  806. self.stride = stride
  807. self.padding = padding
  808. self.kanlayer = WavKANLayer(in_channels * kernel_size * kernel_size, out_channels, wavelet_type=wavelet_type)
  809. def forward(self, x):
  810. batch_size, in_channels, height, width = x.size()
  811. assert in_channels == self.in_channels
  812. # Apply unfold to get sliding local blocks
  813. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  814. x_unfold = x_unfold.transpose(1, 2)
  815. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  816. out_unfold = self.kanlayer(x_unfold)
  817. # Reshape and transpose to get the final output
  818. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  819. out = out_unfold.transpose(1, 2)
  820. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  821. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  822. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  823. return out
  824. class JacobiKANConv2d(nn.Module):
  825. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1,degree=4):
  826. super(JacobiKANConv2d, self).__init__()
  827. self.in_channels = in_channels
  828. self.out_channels = out_channels
  829. self.kernel_size = kernel_size
  830. self.stride = stride
  831. self.padding = padding
  832. self.kanlayer = JacobiKANLayer(in_channels * kernel_size * kernel_size, out_channels, degree=degree)
  833. def forward(self, x):
  834. batch_size, in_channels, height, width = x.size()
  835. assert in_channels == self.in_channels
  836. # Apply unfold to get sliding local blocks
  837. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  838. x_unfold = x_unfold.transpose(1, 2)
  839. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  840. out_unfold = self.kanlayer(x_unfold)
  841. # Reshape and transpose to get the final output
  842. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  843. out = out_unfold.transpose(1, 2)
  844. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  845. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  846. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  847. return out
  848. class ReLUKANConv2d(nn.Module):
  849. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
  850. super(ReLUKANConv2d, self).__init__()
  851. self.in_channels = in_channels
  852. self.out_channels = out_channels
  853. self.kernel_size = kernel_size
  854. self.stride = stride
  855. self.padding = padding
  856. self.kanlayer = ReLUKANLayer(in_channels * kernel_size * kernel_size, 5, 3, out_channels)
  857. def forward(self, x):
  858. batch_size, in_channels, height, width = x.size()
  859. assert in_channels == self.in_channels
  860. # Apply unfold to get sliding local blocks
  861. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  862. x_unfold = x_unfold.transpose(1, 2)
  863. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  864. out_unfold = self.kanlayer(x_unfold)
  865. # Reshape and transpose to get the final output
  866. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  867. out = out_unfold.transpose(1, 2)
  868. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  869. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  870. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  871. return out
  872. class FasterKANConv2d(nn.Module):
  873. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
  874. super(FasterKANConv2d, self).__init__()
  875. self.in_channels = in_channels
  876. self.out_channels = out_channels
  877. self.kernel_size = kernel_size
  878. self.stride = stride
  879. self.padding = padding
  880. self.kanlayer = FasterKANLayer(in_channels * kernel_size * kernel_size, out_channels)
  881. def forward(self, x):
  882. batch_size, in_channels, height, width = x.size()
  883. assert in_channels == self.in_channels
  884. # Apply unfold to get sliding local blocks
  885. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  886. x_unfold = x_unfold.transpose(1, 2)
  887. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  888. out_unfold = self.kanlayer(x_unfold)
  889. # Reshape and transpose to get the final output
  890. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  891. out = out_unfold.transpose(1, 2)
  892. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  893. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  894. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  895. return out
  896. class RBFKANConv2d(nn.Module):
  897. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
  898. super(RBFKANConv2d, self).__init__()
  899. self.in_channels = in_channels
  900. self.out_channels = out_channels
  901. self.kernel_size = kernel_size
  902. self.stride = stride
  903. self.padding = padding
  904. self.kanlayer = RBFKANLayer(in_channels * kernel_size * kernel_size, out_channels)
  905. def forward(self, x):
  906. batch_size, in_channels, height, width = x.size()
  907. assert in_channels == self.in_channels
  908. # Apply unfold to get sliding local blocks
  909. x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  910. x_unfold = x_unfold.transpose(1, 2)
  911. x_unfold = x_unfold.reshape(batch_size * x_unfold.size(1), -1)
  912. out_unfold = self.kanlayer(x_unfold)
  913. # Reshape and transpose to get the final output
  914. out_unfold = out_unfold.reshape(batch_size, -1, out_unfold.size(1))
  915. out = out_unfold.transpose(1, 2)
  916. out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1
  917. out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1
  918. out = out.reshape(batch_size, self.out_channels, out_height, out_width)
  919. return out
  920. if __name__ == "__main__":
  921. # Generating Sample image
  922. image_size = (1, 64, 240, 240)
  923. image = torch.rand(*image_size)
  924. # KANConv2d需要torch1.9以上才可以.
  925. Convs = ['RBFKANConv2d', 'ReLUKANConv2d','FasterKANConv2d', 'ChebyKANConv2d', 'JacobiKANConv2d', 'FastKANConv2d', 'GRAMKANConv2d']
  926. qu = ['WavKANConv2d'] # 需要大量显存
  927. e = ['KANConv2d'] # 需要torch 1.9以上
  928. with torch.no_grad():
  929. for i in range(len(Convs)):
  930. model = eval(Convs[i])
  931. # Model
  932. mobilenet_v1 = model(64, 64, kernel_size=3, stride=1, padding=1)
  933. out = mobilenet_v1(image)
  934. print(out.size())


四、手把手教你添加九种KANCon2d

4.1 步骤一

首先我们找到如下的目录' ultralytics /nn',然后在这个目录下创建一个py文件,名字可以根据你自己的习惯起,然后将核心代码复制进去。


4.2 步骤二

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'( 用群内的文件的话已经有了无需新建) ,然后在其内部导入我们的检测头如下图所示。


4.3 步骤三

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块( 用群内的文件的话已经有了无需重新导入直接开始第四步即可)


4.4 步骤四

我们找到parse_model这个方法,可以用搜索 (Ctrl + F) 也可以自己手动找, 我们找到如下的地方,模仿我添加即可。

​​​


到此我们就注册成功了,接下里的是复制yaml文件,然后运行文件。


五、九种KANCon2d的yaml文件

九种KANConv2d分别是:WavKANConv2d, RBFKANConv2d, KANConv2d, ReLUKANConv2d, FasterKANConv2d, ChebyKANConv2d, JacobiKANConv2d, FastKANConv2d, GRAMKANConv2d.

5.1 RBFKANConv2d 的yaml文件

此版本训练信息:YOLO11-RBFKANConv2d summary: 326 layers, 7,935,471 parameters, 7,935,407 gradients, 6.4 GFLOP

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, RBFKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, RBFKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, RBFKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, RBFKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, RBFKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, RBFKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.2 KANConv2d 的yaml文件

此版本因为我的电脑torch版本较低,但是这个Conv需要torch>=1.9所以我这里没有运行信息。

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, KANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, KANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, KANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, KANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, KANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, KANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.3 ReLUKANConv2d的yaml文件

此版本的训练信息:YOLO11-ReLUKANConv2d summary: 319 layers, 7,343,295 parameters, 7,343,279 gradients, 18.0 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, ReLUKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, ReLUKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, ReLUKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, ReLUKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, ReLUKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, ReLUKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.4 FasterKANConv2d的yaml文件

此版本的训练信息:YOLO11-FasterKANConv2d summary: 331 layers, 7,276,149 parameters, 7,276,079 gradients, 4.9 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, FasterKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, FasterKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, FasterKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, FasterKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, FasterKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, FasterKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.5 ChebyKANConv2d的yaml文件

此版本训练信息:YOLO11-ChebyKANConv2d summary: 313 layers, 5,262,111 parameters, 5,262,095 gradients, 4.8 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, ChebyKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, ChebyKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, ChebyKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, ChebyKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, ChebyKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, ChebyKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.6 JacobiKANConv2d的yaml文件

此版本训练信息:YOLO11-JacobiKANConv2d summary: 325 layers, 5,931,615 parameters, 5,931,599 gradients, 4.8 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, JacobiKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, JacobiKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, JacobiKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, JacobiKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, JacobiKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, JacobiKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.7  FastKANConv2d的yaml文件

此版本的训练信息:YOLO11-FastKANConv2d summary: 343 layers, 7,944,975 parameters, 7,944,911 gradients, 6.5 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, FastKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, FastKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, FastKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, FastKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, FastKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, FastKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.8 GRAMKANConv2d的yaml文件

此版本训练信息:YOLO11-GRAMKANConv2d summary: 325 layers, 5,263,479 parameters, 5,263,463 gradients, 4.8 GFLOPs

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, GRAMKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, GRAMKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, GRAMKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, GRAMKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, GRAMKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, GRAMKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.9 WavKANConv2d的yaml文件

这个是小波Conv配合激活函数,但是需要显存特别大,大家根据自己需求来决定是否尝试吧。

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
  3. # Parameters
  4. nc: 80 # number of classes
  5. scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  6. # [depth, width, max_channels]
  7. n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  8. s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  9. m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  10. l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  11. x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
  12. # YOLO11n backbone
  13. backbone:
  14. # [from, repeats, module, args]
  15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  16. - [-1, 1, WavKANConv2d, [128, 3, 2]] # 1-P2/4
  17. - [-1, 2, C3k2, [256, False, 0.25]]
  18. - [-1, 1, WavKANConv2d, [256, 3, 2]] # 3-P3/8
  19. - [-1, 2, C3k2, [512, False, 0.25]]
  20. - [-1, 1, WavKANConv2d, [512, 3, 2]] # 5-P4/16
  21. - [-1, 2, C3k2, [512, True]]
  22. - [-1, 1, WavKANConv2d, [1024, 3, 2]] # 7-P5/32
  23. - [-1, 2, C3k2, [1024, True]]
  24. - [-1, 1, SPPF, [1024, 5]] # 9
  25. - [-1, 2, C2PSA, [1024]] # 10
  26. # YOLO11n head
  27. head:
  28. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  29. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  30. - [-1, 2, C3k2, [512, False]] # 13
  31. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  32. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  33. - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  34. - [-1, 1, WavKANConv2d, [256, 3, 2]]
  35. - [[-1, 13], 1, Concat, [1]] # cat head P4
  36. - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  37. - [-1, 1, WavKANConv2d, [512, 3, 2]]
  38. - [[-1, 10], 1, Concat, [1]] # cat head P5
  39. - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  40. - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)


5.10 训练代码

大家可以创建一个py文件将我给的代码复制粘贴进去,配置好自己的文件路径即可运行。

  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. from ultralytics import YOLO
  4. if __name__ == '__main__':
  5. model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
  6. # model.load('yolov8n.pt') # loading pretrain weights
  7. model.train(data=r'替换数据集yaml文件地址',
  8. # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
  9. cache=False,
  10. imgsz=640,
  11. epochs=150,
  12. single_cls=False, # 是否是单类别检测
  13. batch=4,
  14. close_mosaic=10,
  15. workers=0,
  16. device='0',
  17. optimizer='SGD', # using SGD
  18. # resume='', # 如过想续训就设置last.pt的地址
  19. amp=False, # 如果出现训练损失为Nan可以关闭amp
  20. project='runs/train',
  21. name='exp',
  22. )


5.11 训练过程截图


六、本文总结

到此本文的正式分享内容就结束了, 在这里给大家推荐我的YOLOv11改进有效涨点专栏, 本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~