本文介绍在nnunet中引入SPPF模块,增加全局语义信息。
一、SPPF
SPPF结构很简单,采用YOLOv8的结构:如下图:

本文把他加到encoder的后面。
代码如下:
class SPPF(nn.Module):
"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
def __init__(self, c1, c2, k=5):
"""
Initializes the SPPF layer with given input/output channels and kernel size.
This module is equivalent to SPP(k=(5, 9, 13)).
"""
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
def forward(self, x):
"""Forward pass through Ghost Convolution block."""
y = [self.cv1(x)]
y.extend(self.m(y[-1]) for _ in range(3))
return self.cv2(torch.cat(y, 1))二、nnunet加入SPPF
之前的教程已经提到过,nnunet的网络需要在dynamic-network-architectures中修改,并在数据集的plan中修改来实现自己的网络训练。
1、网络结构修改
在dynamic-network-architectures的architectures目录下新建sppfunet.py,如下图:

代码内容如下:
from typing import Union, Type, List, Tuple
import torch
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder
from dynamic_network_architectures.building_blocks.unet_decoder import UNetDecoder
from dynamic_network_architectures.building_blocks.unet_residual_decoder import UNetResDecoder
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
from torch import nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
import torch.nn.functional as F
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
class SPPFPlainConvUNet(nn.Module):
def __init__(self,
input_channels: int,
n_stages: int,
features_per_stage: Union[int, List[int], Tuple[int, ...]],
conv_op: Type[_ConvNd],
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
strides: Union[int, List[int], Tuple[int, ...]],
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
num_classes: int,
n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
deep_supervision: bool = False,
nonlin_first: bool = False
):
"""
nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin
"""
super().__init__()
if isinstance(n_conv_per_stage, int):
n_conv_per_stage = [n_conv_per_stage] * n_stages
if isinstance(n_conv_per_stage_decoder, int):
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \
f"resolution stages. here: {n_stages}. " \
f"n_conv_per_stage: {n_conv_per_stage}"
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
f"as we have resolution stages. here: {n_stages} " \
f"stages, so it should have {n_stages - 1} entries. " \
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,
nonlin_first=nonlin_first)
self.decoder = UNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,
nonlin_first=nonlin_first)
self.sppf = SPPF(
conv_op=conv_op,
input_channels=features_per_stage[-1],
output_channels=features_per_stage[-1],
kernel_size=3,
stride=1,
conv_bias=False,
norm_op=norm_op,
norm_op_kwargs=norm_op_kwargs,
dropout_op=dropout_op,
dropout_op_kwargs=dropout_op_kwargs,
nonlin=nonlin,
nonlin_kwargs=nonlin_kwargs,
nonlin_first=nonlin_first
)
def forward(self, x):
skips = self.encoder(x)
skips[-1] = self.sppf(skips[-1])
return self.decoder(skips)
def compute_conv_feature_map_size(self, input_size):
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
"Give input_size=(x, y(, z))!"
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
@staticmethod
def initialize(module):
InitWeights_He(1e-2)(module)
class ConvDropoutNormReLU(nn.Module):
def __init__(self,
conv_op: Type[_ConvNd],
input_channels: int,
output_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]],
stride: Union[int, List[int], Tuple[int, ...]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
nonlin_first: bool = False
):
super(ConvDropoutNormReLU, self).__init__()
self.input_channels = input_channels
self.output_channels = output_channels
stride = maybe_convert_scalar_to_list(conv_op, stride)
self.stride = stride
kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
if norm_op_kwargs is None:
norm_op_kwargs = {}
if nonlin_kwargs is None:
nonlin_kwargs = {}
ops = []
self.conv = conv_op(
input_channels,
output_channels,
kernel_size,
stride,
padding=[(i - 1) // 2 for i in kernel_size],
dilation=1,
bias=conv_bias,
)
ops.append(self.conv)
if dropout_op is not None:
self.dropout = dropout_op(**dropout_op_kwargs)
ops.append(self.dropout)
if norm_op is not None:
self.norm = norm_op(output_channels, **norm_op_kwargs)
ops.append(self.norm)
if nonlin is not None:
self.nonlin = nonlin(**nonlin_kwargs)
ops.append(self.nonlin)
if nonlin_first and (norm_op is not None and nonlin is not None):
ops[-1], ops[-2] = ops[-2], ops[-1]
self.all_modules = nn.Sequential(*ops)
def forward(self, x):
return self.all_modules(x)
def compute_conv_feature_map_size(self, input_size):
assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
"Give input_size=(x, y(, z))!"
output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding
return np.prod([self.output_channels, *output_size], dtype=np.int64)
class SPPF(nn.Module):
"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
def __init__(self, conv_op: Type[_ConvNd],
input_channels: int,
output_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]],
stride: Union[int, List[int], Tuple[int, ...]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
nonlin_first: bool = False):
"""
Initializes the SPPF layer with given input/output channels and kernel size.
This module is equivalent to SPP(k=(5, 9, 13)).
"""
super().__init__()
self.cv1 = ConvDropoutNormReLU(conv_op,
input_channels,
input_channels,
1,
1,
conv_bias,
norm_op,
norm_op_kwargs,
dropout_op,
dropout_op_kwargs,
nonlin,
nonlin_kwargs,
nonlin_first)
self.cv2 = ConvDropoutNormReLU(conv_op,
input_channels*4,
output_channels,
1,
1,
conv_bias,
norm_op,
norm_op_kwargs,
dropout_op,
dropout_op_kwargs,
nonlin,
nonlin_kwargs,
nonlin_first)
self.is_3d = True
if conv_op == torch.nn.modules.conv.Conv2d:
self.m = nn.MaxPool2d(kernel_size, stride=1, padding=kernel_size // 2)
self.is_3d = False
elif conv_op == torch.nn.modules.conv.Conv3d:
self.m = nn.MaxPool3d(kernel_size, stride=1, padding=kernel_size // 2)
self.kernel_size = kernel_size
print('using sppf ')
def forward(self, x):
"""Forward pass through Ghost Convolution block."""
y = [self.cv1(x)]
for i in range(3):
in_x = y[-1]
t = self.m(in_x)
y.append(t)
x = torch.cat(y, 1)
return self.cv2(x)
2、配置文件修改
在完成了模型修改后,还是用上个教程的Task04_Hippocampus数据集来验证(如果没做上个教程的,自行完成数据处理),编辑nnUNet\nnUNet_preprocessed\Dataset004_Hippocampus\nnUNetPlans.json这个配置文件,进行以下改动,把network_class_name改成dynamic_network_architectures.architectures.sppfunet.SPPFPlainConvUNet,如下图:

三、模型训练
完成了模型和数据集配置文件的修改后,开始训练模型,使用的数据集还是Task04_Hippocampus,以上的代码支持2d和3d模型,可以使用以下的训练命令:
nnUNetv2_train 4 2d 0
nnUNetv2_train 4 2d 1
nnUNetv2_train 4 2d 2
nnUNetv2_train 4 2d 3
nnUNetv2_train 4 2d 4
nnUNetv2_train 4 3d_fullres 0
nnUNetv2_train 4 3d_fullres 1
nnUNetv2_train 4 3d_fullres 2
nnUNetv2_train 4 3d_fullres 3
nnUNetv2_train 4 3d_fullres 4 可以看到,2d模型训练起来了:

3d_fullres也训练一下:

因为nnunet训练非常的久,实验资源有限,没有完成全部训练,只完成了代码修改及跑通。