本文介绍在nnunet中引入ASPP模块,增加全局语义信息。
一、ASPP
DeepLabv3提出了ASPP模块,本文在nnunet中引入ASPP模块。
论文地址:Rethinking Atrous Convolution for Semantic Image Segmentation
代码参考(PaddleSeg):https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.9/paddleseg/models/layers/pyramid_pool.py
ASPP模块非常简单多个池化卷积并行结构:

参考代码如下:
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddleseg.models import layers
class ASPPModule(nn.Layer):
"""
Atrous Spatial Pyramid Pooling.
Args:
aspp_ratios (tuple): The dilation rate using in ASSP module.
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
use_sep_conv (bool, optional): If using separable conv in ASPP module. Default: False.
image_pooling (bool, optional): If augmented with image-level features. Default: False
"""
def __init__(self,
aspp_ratios,
in_channels,
out_channels,
align_corners,
use_sep_conv=False,
image_pooling=False,
data_format='NCHW'):
super().__init__()
self.align_corners = align_corners
self.data_format = data_format
self.aspp_blocks = nn.LayerList()
for ratio in aspp_ratios:
if use_sep_conv and ratio > 1:
conv_func = layers.SeparableConvBNReLU
else:
conv_func = layers.ConvBNReLU
block = conv_func(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1 if ratio == 1 else 3,
dilation=ratio,
padding=0 if ratio == 1 else ratio,
data_format=data_format)
self.aspp_blocks.append(block)
out_size = len(self.aspp_blocks)
if image_pooling:
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2D(
output_size=(1, 1), data_format=data_format),
layers.ConvBNReLU(
in_channels,
out_channels,
kernel_size=1,
bias_attr=False,
data_format=data_format))
out_size += 1
self.image_pooling = image_pooling
self.conv_bn_relu = layers.ConvBNReLU(
in_channels=out_channels * out_size,
out_channels=out_channels,
kernel_size=1,
data_format=data_format)
self.dropout = nn.Dropout(p=0.1) # drop rate
def forward(self, x):
outputs = []
if self.data_format == 'NCHW':
interpolate_shape = paddle.shape(x)[2:]
axis = 1
else:
interpolate_shape = paddle.shape(x)[1:3]
axis = -1
for block in self.aspp_blocks:
y = block(x)
outputs.append(y)
if self.image_pooling:
img_avg = self.global_avg_pool(x)
img_avg = F.interpolate(
img_avg,
interpolate_shape,
mode='bilinear',
align_corners=self.align_corners,
data_format=self.data_format)
outputs.append(img_avg)
x = paddle.concat(outputs, axis=axis)
x = self.conv_bn_relu(x)
x = self.dropout(x)
return x
二、nnunet加入ASPP
之前的教程已经提到过,nnunet的网络需要在dynamic-network-architectures中修改,并在数据集的plan中修改来实现自己的网络训练。
1、网络结构修改
在dynamic-network-architectures的architectures目录下新建asppunet.py,如下图:

代码内容如下(请注意:由于nnunet会使得最小特征图的分辨率在8x8左右,所以dilation系数要改的小一些,可自行设置,在代码第65行):
from typing import Union, Type, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder
from dynamic_network_architectures.building_blocks.unet_decoder import UNetDecoder
from dynamic_network_architectures.building_blocks.unet_residual_decoder import UNetResDecoder
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
from torch import nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
class ASPPPlainConvUNet(nn.Module):
def __init__(self,
input_channels: int,
n_stages: int,
features_per_stage: Union[int, List[int], Tuple[int, ...]],
conv_op: Type[_ConvNd],
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
strides: Union[int, List[int], Tuple[int, ...]],
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
num_classes: int,
n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
deep_supervision: bool = False,
nonlin_first: bool = False
):
"""
nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin
"""
super().__init__()
if isinstance(n_conv_per_stage, int):
n_conv_per_stage = [n_conv_per_stage] * n_stages
if isinstance(n_conv_per_stage_decoder, int):
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \
f"resolution stages. here: {n_stages}. " \
f"n_conv_per_stage: {n_conv_per_stage}"
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
f"as we have resolution stages. here: {n_stages} " \
f"stages, so it should have {n_stages - 1} entries. " \
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,
nonlin_first=nonlin_first)
self.decoder = UNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,
nonlin_first=nonlin_first)
self.aspp = ASPPModule(
conv_op=conv_op,
input_channels=features_per_stage[-1],
output_channels=features_per_stage[-1],
sizes=(1, 2, 3, 6),
)
print('............using asppunet......................')
def forward(self, x):
skips = self.encoder(x)
skips[-1] = self.aspp(skips[-1])
return self.decoder(skips)
def compute_conv_feature_map_size(self, input_size):
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
"Give input_size=(x, y(, z))!"
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
@staticmethod
def initialize(module):
InitWeights_He(1e-2)(module)
class ConvDropoutNormReLU(nn.Module):
def __init__(self,
conv_op: Type[_ConvNd],
input_channels: int,
output_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]],
stride: Union[int, List[int], Tuple[int, ...]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
nonlin_first: bool = False
):
super(ConvDropoutNormReLU, self).__init__()
self.input_channels = input_channels
self.output_channels = output_channels
stride = maybe_convert_scalar_to_list(conv_op, stride)
self.stride = stride
kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
if norm_op_kwargs is None:
norm_op_kwargs = {}
if nonlin_kwargs is None:
nonlin_kwargs = {}
ops = []
self.conv = conv_op(
input_channels,
output_channels,
kernel_size,
stride,
padding=[(i - 1) // 2 for i in kernel_size],
dilation=1,
bias=conv_bias,
)
ops.append(self.conv)
if dropout_op is not None:
self.dropout = dropout_op(**dropout_op_kwargs)
ops.append(self.dropout)
if norm_op is not None:
self.norm = norm_op(output_channels, **norm_op_kwargs)
ops.append(self.norm)
if nonlin is not None:
self.nonlin = nonlin(**nonlin_kwargs)
ops.append(self.nonlin)
if nonlin_first and (norm_op is not None and nonlin is not None):
ops[-1], ops[-2] = ops[-2], ops[-1]
self.all_modules = nn.Sequential(*ops)
def forward(self, x):
return self.all_modules(x)
def compute_conv_feature_map_size(self, input_size):
assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
"Give input_size=(x, y(, z))!"
output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding
return np.prod([self.output_channels, *output_size], dtype=np.int64)
class ASPPModule(nn.Module):
def __init__(self, conv_op, input_channels, output_channels=1024, aspp_ratios=(1, 2, 3, 4, 5)):
super().__init__()
self.is_3d = True
if conv_op == torch.nn.modules.conv.Conv2d:
self.is_3d = False
self.aspp_blocks = nn.ModuleList()
for ratio in aspp_ratios:
if self.is_3d:
block = nn.Sequential(
nn.Conv3d(input_channels, out_channels=output_channels, kernel_size=1 if ratio == 1 else 3, dilation=ratio, padding=0 if ratio == 1 else ratio),
nn.BatchNorm3d(output_channels),
nn.ReLU(),
)
else:
block = nn.Sequential(
nn.Conv2d(input_channels, out_channels=output_channels, kernel_size=1 if ratio == 1 else 3, dilation=ratio, padding=0 if ratio == 1 else ratio),
nn.BatchNorm2d(output_channels),
nn.ReLU(),
)
self.aspp_blocks.append(block)
if self.is_3d:
self.conv_bn_relu = nn.Sequential(
nn.Conv3d(output_channels * (len(aspp_ratios) + 1), out_channels=output_channels, kernel_size=1 if ratio == 1 else 3, dilation=ratio, padding=0 if ratio == 1 else ratio),
nn.BatchNorm3d(output_channels),
nn.ReLU(),
)
self.avg_pool = nn.Sequential(
nn.AdaptiveAvgPool3d(output_size=(1, 1)),
nn.Conv3d(input_channels, out_channels=output_channels, kernel_size=1),
nn.BatchNorm3d(output_channels),
nn.ReLU(),
)
else:
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(output_channels * (len(aspp_ratios) + 1), out_channels=output_channels, kernel_size=1 if ratio == 1 else 3, dilation=ratio, padding=0 if ratio == 1 else ratio),
nn.BatchNorm2d(output_channels),
nn.ReLU(),
)
self.avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Conv2d(input_channels, out_channels=output_channels, kernel_size=1),
nn.BatchNorm2d(output_channels),
nn.ReLU(),
)
def forward(self, x):
outputs = []
if self.is_3d:
b, c, h, w, d = x.shape
p = self.avg_pool(x)
pool = F.upsample(input=p, size=(h, w, d), mode='trilinear')
else:
b, c, h, w = x.shape
p = self.avg_pool(x)
pool = F.upsample(input=p, size=(h, w), mode='bilinear')
outputs.append(pool)
for block in self.aspp_blocks:
outputs.append(block(x))
x = torch.cat(outputs, dim=1)
x = self.conv_bn_relu(x)
return x2、配置文件修改
在完成了模型修改后,还是用上个教程的Task04_Hippocampus数据集来验证(如果没做上个教程的,自行完成数据处理),编辑nnUNet\nnUNet_preprocessed\Dataset004_Hippocampus\nnUNetPlans.json这个配置文件,进行以下改动,把network_class_name改成dynamic_network_architectures.architectures.asppunet.ASPPPlainConvUNet,如下图:

三、模型训练
完成了模型和数据集配置文件的修改后,开始训练模型,使用的数据集还是Task04_Hippocampus,以上的代码支持2d和3d模型,可以使用以下的训练命令:
nnUNetv2_train 4 2d 0
nnUNetv2_train 4 2d 1
nnUNetv2_train 4 2d 2
nnUNetv2_train 4 2d 3
nnUNetv2_train 4 2d 4
nnUNetv2_train 4 3d_fullres 0
nnUNetv2_train 4 3d_fullres 1
nnUNetv2_train 4 3d_fullres 2
nnUNetv2_train 4 3d_fullres 3
nnUNetv2_train 4 3d_fullres 4 可以看到,2d模型训练起来了:

3d_fullres也训练一下:

因为nnunet训练非常的久,实验资源有限,没有完成全部训练,只完成了代码修改及跑通。