Hello,
Here is where I encounter the problem, basically I tried to use torch.autograd.grad
to compute gradient but it returned None
.
class BatchNorm2dMul(nn.Module):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
super(BatchNorm2dMul, self).__init__()
self.bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=False, track_running_stats=track_running_stats)
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
self.affine = affine
def forward(self, x):
bn_out = self.bn(x)
if self.affine:
out = self.gamma[None, :, None, None] * bn_out + self.beta[None, :, None, None]
return out, bn_out
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout_p):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = BatchNorm2dMul(num_features=out_channels)
self.lReLu = nn.LeakyReLU()
self.dp = nn.Dropout(dropout_p)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = BatchNorm2dMul(num_features=out_channels)
def forward(self, x):
bn_outputs = []
out, bn_out = self.bn1(self.conv1(x))
bn_outputs.append(bn_out)
out = self.lReLu(out)
out = self.dp(out)
out, bn_out = self.bn2(self.conv2(out))
bn_outputs.append(bn_out)
out = self.lReLu(out)
return out, bn_outputs
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout_p):
self.maxpool = nn.MaxPool2d(2)
self.convBlock = ConvBlock(in_channels, out_channels, dropout_p)
def forward(self, x):
x = self.maxpool(x)
bn_outputs = []
out, bn_output = self.convBlock(x)
bn_outputs.extend(bn_output)
return out, bn_outputs
class UpBlock(nn.Module):
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
bilinear=True):
super(UpBlock, self).__init__()
self.bilinear = bilinear
if bilinear:
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(
in_channels1, in_channels2, kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)
def forward(self, x1, x2):
if self.bilinear:
x1 = self.conv1x1(x1)
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
bn_outputs = []
out, bn_output = self.conv(x)
bn_outputs.extend(bn_output)
return out, bn_outputs
class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5)
self.in_conv = ConvBlock(
self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
self.down4 = DownBlock(
self.ft_chns[3], self.ft_chns[4], self.dropout[4])
def forward(self, x):
all_bn_outputs = []
x0, bn_outputs0 = self.in_conv(x)
all_bn_outputs.extend(bn_outputs0)
x1, bn_outputs1 = self.down1(x0)
all_bn_outputs.extend(bn_outputs1)
x2, bn_outputs2 = self.down2(x1)
all_bn_outputs.extend(bn_outputs2)
x3, bn_outputs3 = self.down3(x2)
all_bn_outputs.extend(bn_outputs3)
x4, bn_outputs4 = self.down4(x3)
all_bn_outputs.extend(bn_outputs4)
return [x0, x1, x2, x3, x4], all_bn_outputs
class Decoder(nn.Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
assert (len(self.ft_chns) == 5)
self.up1 = UpBlock(
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
self.up2 = UpBlock(
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
self.up3 = UpBlock(
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
self.up4 = UpBlock(
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1)
def forward(self, feature, all_bn_outputs):
x0 = feature[0]
x1 = feature[1]
x2 = feature[2]
x3 = feature[3]
x4 = feature[4]
feature_map = [x4]
x, output = self.up1(x4, x3)
all_bn_outputs.extend(output)
feature_map.append(x)
x, output = self.up2(x, x2)
all_bn_outputs.extend(output)
feature_map.append(x)
x, output = self.up3(x, x1)
all_bn_outputs.extend(output)
feature_map.append(x)
x, output = self.up4(x, x0)
all_bn_outputs.extend(output)
feature_map.append(x)
output = self.out_conv(x)
return output, feature_map, all_bn_outputs
class UNet(nn.Module):
def __init__(self, in_chns, class_num, train_encoder=True, train_decoder=True, unfreeze_seg=True):
super(UNet, self).__init__()
params = {'in_chns': in_chns,
'feature_chns': [16, 32, 64, 128, 256],
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
'class_num': class_num,
'bilinear': False,
'acti_func': 'relu'}
self.encoder = Encoder(params)
self.decoder = Decoder(params)
self.train_encoder = train_encoder
self.train_decoder = train_decoder
if not (train_encoder):
for params in self.encoder.parameters():
params.requires_grad = False
params = params.detach_()
if not (train_decoder):
for params in self.decoder.parameters():
params.requires_grad = False
params = params.detach_()
if not(unfreeze_seg):
for params in self.encoder.parameters():
params.requires_grad = False
params = params.detach_()
for params in self.decoder.parameters():
if params not in self.decoder.out_conv.parameters():
params.requires_grad = False
params = params.detach_()
def forward(self, x):
feature, all_bn_outputs = self.encoder(x)
output, feature_map, all_bn_outputs = self.decoder(feature, all_bn_outputs)
return output, feature[-1], feature_map, all_bn_outputs
basically I am adding a BatchNorm2dMul
class to substitude the normal batch_normal layer, and when I use this class to build a Unet model and try to find the grad with respect to all_bn_outputs
as shown below:
model = UNet(in_chns=1, class_num=4, \
train_encoder=True, train_decoder=True, unfreeze_seg=True).cuda()
pred_l, _, _, all_bn_outputs= model(train_l_data)
loss_ce = CrossEntropyLoss()(pred_l, train_l_label.long())
loss_grads = torch.autograd.grad(outputs=loss_ce, inputs=all_bn_outputs, create_graph=True, allow_unused=True)
for grd in loss_grads:
print(grd)
and the result is 18 None
, but as I consider, the all_bn_outputs
is related to loss_ce
, so I am confused here. Please take some time to look, thank you for your help!