本文介绍在nnunet中引入PolarizedSelfAttention注意力机制模块,PolarizedSelfAttention包含了Channel Attention和Spatial Attention机制,该注意力机制有并行和串行两种方式,本文介绍在nnunet中引入并行PolarizedSelfAttention模块。
一、PolarizedSelfAttention
PolarizedSelfAttention论文地址:Polarized Self-Attention: Towards High-quality Pixel-wise Regression
PolarizedSelfAttention的结构很简单,如下图:
Channel Attention的输入经过2个1x1的卷积层得到大小为C/2xHxW和1xHxW的2个特征图,特征图经过维度转换后变为C/2xHxW和HWx1x1,HWx1x1的特征图经过softmax后得到HWx1x1的权重,C/2xHxW的特征图和HWx1x1的权重特征图经过矩阵乘法得到了大小为C/2x1x1的特征图,C/2x1x1的特征图经过卷积层Conv、归一化层LayerNorm和Sigmoid层,变为Cx1x1的权重特征图,将Cx1x1的权重特征图与Channel Attention的输入相乘即可得到Channel Attention的输出。Spatial Attention的输入经过2个1x1的卷积层得到大小为C/2xHxW和C/2xHxW的2个特征图,第一个C/2xHxW特征图经过全局池化层后变为C/2x1x1的特征图,第二个C/2xHxW特征图和C/2x1x1的特征图经过维度转换后变为1xC/2和C/2xHxW,1xC/2的特征图经过softmax后得到1xC/2的权重,C/2xHxW的特征图和1xC/2的权重特征图经过矩阵乘法得到了大小为1xHW的特征图,1xHW的特征图经过维度变换Reshape和Sigmoid层,变为1xHxW的权重特征图,将1xHxW的权重特征图与Spatial Attention的输入相乘即可得到Spatial Attention的输出。将Channel Attention和Spatial Attention的输出相加即可得到PSA的输出。
class PSA_p(nn.Module): def __init__(self, inplanes, planes, kernel_size=1, stride=1): super(PSA_p, self).__init__() self.inplanes = inplanes self.inter_planes = planes // 2 self.planes = planes self.kernel_size = kernel_size self.stride = stride self.padding = (kernel_size-1)//2 self.conv_q_right = nn.Conv2d(self.inplanes, 1, kernel_size=1, stride=stride, padding=0, bias=False) self.conv_v_right = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) self.conv_up = nn.Conv2d(self.inter_planes, self.planes, kernel_size=1, stride=1, padding=0, bias=False) self.softmax_right = nn.Softmax(dim=2) self.sigmoid = nn.Sigmoid() self.conv_q_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #g self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_v_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #theta self.softmax_left = nn.Softmax(dim=2) self.reset_parameters() def reset_parameters(self): kaiming_init(self.conv_q_right, mode='fan_in') kaiming_init(self.conv_v_right, mode='fan_in') kaiming_init(self.conv_q_left, mode='fan_in') kaiming_init(self.conv_v_left, mode='fan_in') self.conv_q_right.inited = True self.conv_v_right.inited = True self.conv_q_left.inited = True self.conv_v_left.inited = True def spatial_pool(self, x): input_x = self.conv_v_right(x) batch, channel, height, width = input_x.size() # [N, IC, H*W] input_x = input_x.view(batch, channel, height * width) # [N, 1, H, W] context_mask = self.conv_q_right(x) # [N, 1, H*W] context_mask = context_mask.view(batch, 1, height * width) # [N, 1, H*W] context_mask = self.softmax_right(context_mask) # [N, IC, 1] # context = torch.einsum('ndw,new->nde', input_x, context_mask) context = torch.matmul(input_x, context_mask.transpose(1,2)) # [N, IC, 1, 1] context = context.unsqueeze(-1) # [N, OC, 1, 1] context = self.conv_up(context) # [N, OC, 1, 1] mask_ch = self.sigmoid(context) out = x * mask_ch return out def channel_pool(self, x): # [N, IC, H, W] g_x = self.conv_q_left(x) batch, channel, height, width = g_x.size() # [N, IC, 1, 1] avg_x = self.avg_pool(g_x) batch, channel, avg_x_h, avg_x_w = avg_x.size() # [N, 1, IC] avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1) # [N, IC, H*W] theta_x = self.conv_v_left(x).view(batch, self.inter_planes, height * width) # [N, 1, H*W] # context = torch.einsum('nde,new->ndw', avg_x, theta_x) context = torch.matmul(avg_x, theta_x) # [N, 1, H*W] context = self.softmax_left(context) # [N, 1, H, W] context = context.view(batch, 1, height, width) # [N, 1, H, W] mask_sp = self.sigmoid(context) out = x * mask_sp return out def forward(self, x): # [N, C, H, W] context_channel = self.spatial_pool(x) # [N, C, H, W] context_spatial = self.channel_pool(x) # [N, C, H, W] out = context_spatial + context_channel return out class PSA_s(nn.Module): def __init__(self, inplanes, planes, kernel_size=1, stride=1): super(PSA_s, self).__init__() self.inplanes = inplanes self.inter_planes = planes // 2 self.planes = planes self.kernel_size = kernel_size self.stride = stride self.padding = (kernel_size - 1) // 2 ratio = 4 self.conv_q_right = nn.Conv2d(self.inplanes, 1, kernel_size=1, stride=stride, padding=0, bias=False) self.conv_v_right = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) # self.conv_up = nn.Conv2d(self.inter_planes, self.planes, kernel_size=1, stride=1, padding=0, bias=False) self.conv_up = nn.Sequential( nn.Conv2d(self.inter_planes, self.inter_planes // ratio, kernel_size=1), nn.LayerNorm([self.inter_planes // ratio, 1, 1]), nn.ReLU(inplace=True), nn.Conv2d(self.inter_planes // ratio, self.planes, kernel_size=1) ) self.softmax_right = nn.Softmax(dim=2) self.sigmoid = nn.Sigmoid() self.conv_q_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) # g self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_v_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) # theta self.softmax_left = nn.Softmax(dim=2) self.reset_parameters() def reset_parameters(self): kaiming_init(self.conv_q_right, mode='fan_in') kaiming_init(self.conv_v_right, mode='fan_in') kaiming_init(self.conv_q_left, mode='fan_in') kaiming_init(self.conv_v_left, mode='fan_in') self.conv_q_right.inited = True self.conv_v_right.inited = True self.conv_q_left.inited = True self.conv_v_left.inited = True def spatial_pool(self, x): input_x = self.conv_v_right(x) batch, channel, height, width = input_x.size() # [N, IC, H*W] input_x = input_x.view(batch, channel, height * width) # [N, 1, H, W] context_mask = self.conv_q_right(x) # [N, 1, H*W] context_mask = context_mask.view(batch, 1, height * width) # [N, 1, H*W] context_mask = self.softmax_right(context_mask) # [N, IC, 1] # context = torch.einsum('ndw,new->nde', input_x, context_mask) context = torch.matmul(input_x, context_mask.transpose(1, 2)) # [N, IC, 1, 1] context = context.unsqueeze(-1) # [N, OC, 1, 1] context = self.conv_up(context) # [N, OC, 1, 1] mask_ch = self.sigmoid(context) out = x * mask_ch return out def channel_pool(self, x): # [N, IC, H, W] g_x = self.conv_q_left(x) batch, channel, height, width = g_x.size() # [N, IC, 1, 1] avg_x = self.avg_pool(g_x) batch, channel, avg_x_h, avg_x_w = avg_x.size() # [N, 1, IC] avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1) # [N, IC, H*W] theta_x = self.conv_v_left(x).view(batch, self.inter_planes, height * width) # [N, IC, H*W] theta_x = self.softmax_left(theta_x) # [N, 1, H*W] # context = torch.einsum('nde,new->ndw', avg_x, theta_x) context = torch.matmul(avg_x, theta_x) # [N, 1, H, W] context = context.view(batch, 1, height, width) # [N, 1, H, W] mask_sp = self.sigmoid(context) out = x * mask_sp return out def forward(self, x): # [N, C, H, W] out = self.spatial_pool(x) # [N, C, H, W] out = self.channel_pool(out) # [N, C, H, W] # out = context_spatial + context_channel return out二、nnunet加入PSA
之前的教程已经提到过,nnunet的网络需要在dynamic-network-architectures中修改,并在数据集的plan中修改来实现自己的网络训练。1、网络结构修改
在dynamic-network-architectures的architectures目录下新建psaunet.py,如下图:
代码内容如下:
from typing import Union, Type, List, Tuple import torch from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim from dynamic_network_architectures.initialization.weight_init import InitWeights_He from torch import nn from torch.nn.modules.conv import _ConvNd from torch.nn.modules.dropout import _DropoutNd from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op import numpy as np from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp class PSAPlainConvUNet(nn.Module): def __init__(self, input_channels: int, n_stages: int, features_per_stage: Union[int, List[int], Tuple[int, ...]], conv_op: Type[_ConvNd], kernel_sizes: Union[int, List[int], Tuple[int, ...]], strides: Union[int, List[int], Tuple[int, ...]], n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], num_classes: int, n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], conv_bias: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, dropout_op: Union[None, Type[_DropoutNd]] = None, dropout_op_kwargs: dict = None, nonlin: Union[None, Type[torch.nn.Module]] = None, nonlin_kwargs: dict = None, deep_supervision: bool = False, nonlin_first: bool = False ): """ nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin """ super().__init__() if isinstance(n_conv_per_stage, int): n_conv_per_stage = [n_conv_per_stage] * n_stages if isinstance(n_conv_per_stage_decoder, int): n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \ f"resolution stages. here: {n_stages}. " \ f"n_conv_per_stage: {n_conv_per_stage}" assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ f"as we have resolution stages. here: {n_stages} " \ f"stages, so it should have {n_stages - 1} entries. " \ f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" self.encoder = PSAPlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True, nonlin_first=nonlin_first) self.decoder = PSAUNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision, nonlin_first=nonlin_first) print('using psa unet...') def forward(self, x): skips = self.encoder(x) return self.decoder(skips) def compute_conv_feature_map_size(self, input_size): assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ "Give input_size=(x, y(, z))!" return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) @staticmethod def initialize(module): InitWeights_He(1e-2)(module) class PSAPlainConvEncoder(nn.Module): def __init__(self, input_channels: int, n_stages: int, features_per_stage: Union[int, List[int], Tuple[int, ...]], conv_op: Type[_ConvNd], kernel_sizes: Union[int, List[int], Tuple[int, ...]], strides: Union[int, List[int], Tuple[int, ...]], n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], conv_bias: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, dropout_op: Union[None, Type[_DropoutNd]] = None, dropout_op_kwargs: dict = None, nonlin: Union[None, Type[torch.nn.Module]] = None, nonlin_kwargs: dict = None, return_skips: bool = False, nonlin_first: bool = False, pool: str = 'conv' ): super().__init__() if isinstance(kernel_sizes, int): kernel_sizes = [kernel_sizes] * n_stages if isinstance(features_per_stage, int): features_per_stage = [features_per_stage] * n_stages if isinstance(n_conv_per_stage, int): n_conv_per_stage = [n_conv_per_stage] * n_stages if isinstance(strides, int): strides = [strides] * n_stages assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" stages = [] for s in range(n_stages): stage_modules = [] if pool == 'max' or pool == 'avg': if (isinstance(strides[s], int) and strides[s] != 1) or \ isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]): stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s])) conv_stride = 1 elif pool == 'conv': conv_stride = strides[s] else: raise RuntimeError() stage_modules.append(PSAStackedConvBlocks( n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first )) stages.append(nn.Sequential(*stage_modules)) input_channels = features_per_stage[s] self.stages = nn.Sequential(*stages) self.output_channels = features_per_stage self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] self.return_skips = return_skips # we store some things that a potential decoder needs self.conv_op = conv_op self.norm_op = norm_op self.norm_op_kwargs = norm_op_kwargs self.nonlin = nonlin self.nonlin_kwargs = nonlin_kwargs self.dropout_op = dropout_op self.dropout_op_kwargs = dropout_op_kwargs self.conv_bias = conv_bias self.kernel_sizes = kernel_sizes def forward(self, x): ret = [] for s in self.stages: x = s(x) ret.append(x) if self.return_skips: return ret else: return ret[-1] def compute_conv_feature_map_size(self, input_size): output = np.int64(0) for s in range(len(self.stages)): if isinstance(self.stages[s], nn.Sequential): for sq in self.stages[s]: if hasattr(sq, 'compute_conv_feature_map_size'): output += self.stages[s][-1].compute_conv_feature_map_size(input_size) else: output += self.stages[s].compute_conv_feature_map_size(input_size) input_size = [i // j for i, j in zip(input_size, self.strides[s])] return output class PSAUNetDecoder(nn.Module): def __init__(self, encoder: Union[PSAPlainConvEncoder], num_classes: int, n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], deep_supervision, nonlin_first: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, dropout_op: Union[None, Type[_DropoutNd]] = None, dropout_op_kwargs: dict = None, nonlin: Union[None, Type[torch.nn.Module]] = None, nonlin_kwargs: dict = None, conv_bias: bool = None ): """ This class needs the skips of the encoder as input in its forward. the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck features and the lowest skip as inputs the decoder has two (three) parts in each stage: 1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) 2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge 3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? :param encoder: :param num_classes: :param n_conv_per_stage: :param deep_supervision: """ super().__init__() self.deep_supervision = deep_supervision self.encoder = encoder self.num_classes = num_classes n_stages_encoder = len(encoder.output_channels) if isinstance(n_conv_per_stage, int): n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ "resolution stages - 1 (n_stages in encoder - 1), " \ "here: %d" % n_stages_encoder transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) conv_bias = encoder.conv_bias if conv_bias is None else conv_bias norm_op = encoder.norm_op if norm_op is None else norm_op norm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargs dropout_op = encoder.dropout_op if dropout_op is None else dropout_op dropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargs nonlin = encoder.nonlin if nonlin is None else nonlin nonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs # we start with the bottleneck and work out way up stages = [] transpconvs = [] seg_layers = [] for s in range(1, n_stages_encoder): input_features_below = encoder.output_channels[-s] input_features_skip = encoder.output_channels[-(s + 1)] stride_for_transpconv = encoder.strides[-s] transpconvs.append(transpconv_op( input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, bias=conv_bias )) # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) stages.append(PSAStackedConvBlocks( n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip, encoder.kernel_sizes[-(s + 1)], 1, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first )) # we always build the deep supervision outputs so that we can always load parameters. If we don't do this # then a model trained with deep_supervision=True could not easily be loaded at inference time where # deep supervision is not needed. It's just a convenience thing seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) self.stages = nn.ModuleList(stages) self.transpconvs = nn.ModuleList(transpconvs) self.seg_layers = nn.ModuleList(seg_layers) def forward(self, skips): """ we expect to get the skips in the order they were computed, so the bottleneck should be the last entry :param skips: :return: """ lres_input = skips[-1] seg_outputs = [] for s in range(len(self.stages)): x = self.transpconvs[s](lres_input) x = torch.cat((x, skips[-(s+2)]), 1) x = self.stages[s](x) if self.deep_supervision: seg_outputs.append(self.seg_layers[s](x)) elif s == (len(self.stages) - 1): seg_outputs.append(self.seg_layers[-1](x)) lres_input = x # invert seg outputs so that the largest segmentation prediction is returned first seg_outputs = seg_outputs[::-1] if not self.deep_supervision: r = seg_outputs[0] else: r = seg_outputs return r def compute_conv_feature_map_size(self, input_size): """ IMPORTANT: input_size is the input_size of the encoder! :param input_size: :return: """ # first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at # least have the size of the skip above that (therefore -1) skip_sizes = [] for s in range(len(self.encoder.strides) - 1): skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) input_size = skip_sizes[-1] # print(skip_sizes) assert len(skip_sizes) == len(self.stages) # our ops are the other way around, so let's match things up output = np.int64(0) for s in range(len(self.stages)): # print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)]) # conv blocks output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) # trans conv output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) # segmentation if self.deep_supervision or (s == (len(self.stages) - 1)): output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) return output class PSAStackedConvBlocks(nn.Module): def __init__(self, num_convs: int, conv_op: Type[_ConvNd], input_channels: int, output_channels: Union[int, List[int], Tuple[int, ...]], kernel_size: Union[int, List[int], Tuple[int, ...]], initial_stride: Union[int, List[int], Tuple[int, ...]], conv_bias: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, dropout_op: Union[None, Type[_DropoutNd]] = None, dropout_op_kwargs: dict = None, nonlin: Union[None, Type[torch.nn.Module]] = None, nonlin_kwargs: dict = None, nonlin_first: bool = False ): """ :param conv_op: :param num_convs: :param input_channels: :param output_channels: can be int or a list/tuple of int. If list/tuple are provided, each entry is for one conv. The length of the list/tuple must then naturally be num_convs :param kernel_size: :param initial_stride: :param conv_bias: :param norm_op: :param norm_op_kwargs: :param dropout_op: :param dropout_op_kwargs: :param nonlin: :param nonlin_kwargs: """ super().__init__() if not isinstance(output_channels, (tuple, list)): output_channels = [output_channels] * num_convs self.convs = nn.Sequential( ConvDropoutNormReLU( conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first ), *[ ConvDropoutNormReLU( conv_op, output_channels[i - 1], output_channels[i], kernel_size, 1, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first ) for i in range(1, num_convs-1) ], PSA( conv_op, output_channels[-2], output_channels[-1], kernel_size, 1, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first ) ) self.act = nonlin(**nonlin_kwargs) self.output_channels = output_channels[-1] self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride) def forward(self, x): out = self.convs(x) out = self.act(out) return out def compute_conv_feature_map_size(self, input_size): assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ "Give input_size=(x, y(, z))!" output = self.convs[0].compute_conv_feature_map_size(input_size) size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)] for b in self.convs[1:]: output += b.compute_conv_feature_map_size(size_after_stride) return output class ConvDropoutNormReLU(nn.Module): def __init__(self, conv_op: Type[_ConvNd], input_channels: int, output_channels: int, kernel_size: Union[int, List[int], Tuple[int, ...]], stride: Union[int, List[int], Tuple[int, ...]], conv_bias: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, dropout_op: Union[None, Type[_DropoutNd]] = None, dropout_op_kwargs: dict = None, nonlin: Union[None, Type[torch.nn.Module]] = None, nonlin_kwargs: dict = None, nonlin_first: bool = False ): super(ConvDropoutNormReLU, self).__init__() self.input_channels = input_channels self.output_channels = output_channels stride = maybe_convert_scalar_to_list(conv_op, stride) self.stride = stride kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size) if norm_op_kwargs is None: norm_op_kwargs = {} if nonlin_kwargs is None: nonlin_kwargs = {} ops = [] self.conv = conv_op( input_channels, output_channels, kernel_size, stride, padding=[(i - 1) // 2 for i in kernel_size], dilation=1, bias=conv_bias, ) ops.append(self.conv) if dropout_op is not None: self.dropout = dropout_op(**dropout_op_kwargs) ops.append(self.dropout) if norm_op is not None: self.norm = norm_op(output_channels, **norm_op_kwargs) ops.append(self.norm) if nonlin is not None: self.nonlin = nonlin(**nonlin_kwargs) ops.append(self.nonlin) if nonlin_first and (norm_op is not None and nonlin is not None): ops[-1], ops[-2] = ops[-2], ops[-1] self.all_modules = nn.Sequential(*ops) def forward(self, x): return self.all_modules(x) def compute_conv_feature_map_size(self, input_size): assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ "Give input_size=(x, y(, z))!" output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding return np.prod([self.output_channels, *output_size], dtype=np.int64) class ConvDropoutNorm(nn.Module): def __init__(self, conv_op: Type[_ConvNd], input_channels: int, output_channels: int, kernel_size: Union[int, List[int], Tuple[int, ...]], stride: Union[int, List[int], Tuple[int, ...]], conv_bias: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, dropout_op: Union[None, Type[_DropoutNd]] = None, dropout_op_kwargs: dict = None, nonlin: Union[None, Type[torch.nn.Module]] = None, nonlin_kwargs: dict = None, nonlin_first: bool = False ): super(ConvDropoutNorm, self).__init__() self.input_channels = input_channels self.output_channels = output_channels stride = maybe_convert_scalar_to_list(conv_op, stride) self.stride = stride kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size) if norm_op_kwargs is None: norm_op_kwargs = {} if nonlin_kwargs is None: nonlin_kwargs = {} ops = [] self.conv = conv_op( input_channels, output_channels, kernel_size, stride, padding=[(i - 1) // 2 for i in kernel_size], dilation=1, bias=conv_bias, ) ops.append(self.conv) if dropout_op is not None: self.dropout = dropout_op(**dropout_op_kwargs) ops.append(self.dropout) if norm_op is not None: self.norm = norm_op(output_channels, **norm_op_kwargs) ops.append(self.norm) self.all_modules = nn.Sequential(*ops) def forward(self, x): return self.all_modules(x) def compute_conv_feature_map_size(self, input_size): assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ "Give input_size=(x, y(, z))!" output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding return np.prod([self.output_channels, *output_size], dtype=np.int64) class PSA(nn.Module): def __init__(self, conv_op: Type[_ConvNd], input_channels: int, output_channels: int, kernel_size: Union[int, List[int], Tuple[int, ...]], stride: Union[int, List[int], Tuple[int, ...]], conv_bias: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, dropout_op: Union[None, Type[_DropoutNd]] = None, dropout_op_kwargs: dict = None, nonlin: Union[None, Type[torch.nn.Module]] = None, nonlin_kwargs: dict = None, nonlin_first: bool = False ): super(PSA, self).__init__() self.input_channels = input_channels self.output_channels = output_channels stride = maybe_convert_scalar_to_list(conv_op, stride) self.stride = stride kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size) if norm_op_kwargs is None: norm_op_kwargs = {} if nonlin_kwargs is None: nonlin_kwargs = {} ops = [] self.conv = conv_op( input_channels, output_channels, kernel_size, stride, padding=[(i - 1) // 2 for i in kernel_size], dilation=1, bias=conv_bias, ) ops.append(self.conv) if dropout_op is not None: self.dropout = dropout_op(**dropout_op_kwargs) ops.append(self.dropout) if norm_op is not None: self.norm = norm_op(output_channels, **norm_op_kwargs) ops.append(self.norm) self.all_modules = nn.Sequential(*ops) self.ca = ChannelAttention(conv_op=conv_op, channels=output_channels) def forward(self, x): x = self.all_modules(x) x = self.ca(x) * x # x = self.sa(x) * x return x def compute_conv_feature_map_size(self, input_size): assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ "Give input_size=(x, y(, z))!" output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding return np.prod([self.output_channels, *output_size], dtype=np.int64) class ChannelAttention(nn.Module): """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet.""" def __init__(self, conv_op, channels: int) -> None: """Initializes the class and sets the basic configurations and instance variables required.""" super().__init__() if conv_op == torch.nn.modules.conv.Conv2d: self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) elif conv_op == torch.nn.modules.conv.Conv3d: self.pool = nn.AdaptiveAvgPool3d(1) self.fc = nn.Conv3d(channels, channels, 1, 1, 0, bias=True) self.act = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies forward pass using activation on convolutions of the input, optionally using batch normalization.""" return x * self.act(self.fc(self.pool(x))) class PolarizedSelfAttention(nn.Module): def __init__(self, conv_op, channels: int): super().__init__() self.is_3d = True if conv_op == torch.nn.modules.conv.Conv2d: self.is_3d = False self.ch_wv=nn.Conv2d(channels, channels // 2, kernel_size=(1, 1)) self.ch_wq=nn.Conv2d(channels, 1, kernel_size=(1, 1)) # self.softmax_channel=nn.Softmax(1) # self.softmax_spatial=nn.Softmax(-1) self.ch_wz=nn.Conv2d(channels//2, channels, kernel_size=(1, 1)) # self.ln=nn.LayerNorm(channels) # self.sigmoid=nn.Sigmoid() self.sp_wv=nn.Conv2d(channels, channels // 2, kernel_size=(1, 1)) self.sp_wq=nn.Conv2d(channels, channels // 2, kernel_size=(1, 1)) self.agp=nn.AdaptiveAvgPool2d((1, 1)) elif conv_op == torch.nn.modules.conv.Conv3d: self.ch_wv=nn.Conv3d(channels, channels // 2, kernel_size=(1, 1, 1)) self.ch_wq=nn.Conv3d(channels, 1, kernel_size=(1, 1, 1)) self.ch_wz=nn.Conv3d(channels//2, channels, kernel_size=(1, 1, 1)) self.sp_wv=nn.Conv3d(channels, channels // 2, kernel_size=(1, 1, 1)) self.sp_wq=nn.Conv3d(channels, channels // 2, kernel_size=(1, 1, 1)) self.agp=nn.AdaptiveAvgPool3d((1, 1, 1)) self.softmax_channel=nn.Softmax(1) self.softmax_spatial=nn.Softmax(-1) self.ln=nn.LayerNorm(channels) self.sigmoid=nn.Sigmoid() def forward(self, x): if not self.is_3d: b, c, h, w = x.size() #Channel-only Self-Attention channel_wv=self.ch_wv(x) #bs,c//2,h,w channel_wq=self.ch_wq(x) #bs,1,h,w channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1 channel_wq=self.softmax_channel(channel_wq) channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1 channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1 channel_out=channel_weight*x #Spatial-only Self-Attention spatial_wv=self.sp_wv(x) #bs,c//2,h,w spatial_wq=self.sp_wq(x) #bs,c//2,h,w spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1 spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2 spatial_wq=self.softmax_spatial(spatial_wq) spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w spatial_out=spatial_weight*x out=spatial_out+channel_out else: b, c, h, w, d = x.size() #Channel-only Self-Attention channel_wv=self.ch_wv(x) #bs,c//2,h,w channel_wq=self.ch_wq(x) #bs,1,h,w channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1 channel_wq=self.softmax_channel(channel_wq) channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1).unsqueeze(-1) #bs,c//2,1,1, 1 channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1, 1) #bs,c,1,1, 1 channel_out=channel_weight*x #Spatial-only Self-Attention spatial_wv=self.sp_wv(x) #bs,c//2,h,w spatial_wq=self.sp_wq(x) #bs,c//2,h,w spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1, 1 spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w*d spatial_wq=spatial_wq.permute(0,2,3, 4,1).reshape(b,1,c//2) #bs,1,c//2 spatial_wq=self.softmax_spatial(spatial_wq) spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w*d spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w, d)) #bs,1,h,w spatial_out=spatial_weight*x out=spatial_out+channel_out return out2、配置文件修改
在完成了模型修改后,还是用上个教程的Task04_Hippocampus数据集来验证(如果没做上个教程的,自行完成数据处理),编辑nnUNet\nnUNet_preprocessed\Dataset004_Hippocampus\nnUNetPlans.json这个配置文件,进行以下改动,把network_class_name改成dynamic_network_architectures.architectures.psaunet.PSAPlainConvUNet,如下图:
三、模型训练
完成了模型和数据集配置文件的修改后,开始训练模型,使用的数据集还是Task04_Hippocampus,以上的代码支持2d和3d模型,可以使用以下的训练命令:
nnUNetv2_train 4 2d 0 nnUNetv2_train 4 2d 1 nnUNetv2_train 4 2d 2 nnUNetv2_train 4 2d 3 nnUNetv2_train 4 2d 4 nnUNetv2_train 4 3d_fullres 0 nnUNetv2_train 4 3d_fullres 1 nnUNetv2_train 4 3d_fullres 2 nnUNetv2_train 4 3d_fullres 3 nnUNetv2_train 4 3d_fullres 4可以看到,2d模型训练起来了:
3d_fullres也训练一下:
因为nnunet训练非常的久,实验资源有限,没有完成全部训练,只完成了代码修改及跑通。




