Tied conv kernels, mixing nn.Module with nn.functional

I’m trying to create a model that has a convolutional and deconvolutional pathway, with tied weights. Essentially, the deconv path is the mirror image of the conv path. Similar to this paper Deconvolutional Feature Stacking….

I like being able to use the nn.Module methods and attributes, but I wanted to create the weights directly, and pass them to the F.conv3d and F.conv3d_transpose functions, just because it seemed less straightforward to keep the weights tied using nn.Conv3d.

this was my attempt, but something is definitely wrong.

class Model3d(nn.Module):

    def __init__(self, input_channels=1, n_classes=2, img_size=96, n_feature_maps=(8, 12, 16, 20, 20)):

        super(Model3d, self).__init__()
        self.input_channels = input_channels
        self.n_feat_map = n_feature_maps
        self.n_classes = n_classes
        self.img_size = img_size

        self.weight1 = self.weight_variable([self.n_feat_map[0], input_channels, 3, 3, 3])
        self.bias_c_1 = self.bias_variable(self.n_feat_map[0])
        self.bias_d_5 = self.bias_variable(input_channels)

        self.weight2 = self.weight_variable([self.n_feat_map[1], self.n_feat_map[0], 3, 3, 3])
        self.bias_c_2 = self.bias_variable(self.n_feat_map[1])
        self.bias_d_4 = self.bias_variable(self.n_feat_map[0])

        self.weight3 = self.weight_variable([self.n_feat_map[2], self.n_feat_map[1], 3, 3, 3])
        self.bias_c_3 = self.bias_variable(self.n_feat_map[2])
        self.bias_d_3 = self.bias_variable(self.n_feat_map[1])

        self.weight4 = self.weight_variable([self.n_feat_map[3], self.n_feat_map[2], 3, 3, 3])
        self.bias_c_4 = self.bias_variable(self.n_feat_map[3])
        self.bias_d_2 = self.bias_variable(self.n_feat_map[2])

        self.weight5 = self.weight_variable([self.n_feat_map[4], self.n_feat_map[3], 3, 3, 3])
        self.bias_c_5 = self.bias_variable(self.n_feat_map[4])
        self.bias_d_1 = self.bias_variable(self.n_feat_map[3])

        self.down_classifier_weight = self.weight_variable([self.n_classes, self.n_feat_map[4], 1, 1, 1])

        self.down_classifier_bias = self.bias_variable(self.n_classes)
        self.up_classifier_weight = self.weight_variable([self.n_classes,
                                                     np.sum(self.n_feat_map)+self.input_channels, 1, 1, 1])
        self.up_classifier_bias = self.bias_variable(self.n_classes)

    def forward(self, x):

        x, ind1 = down_block3d(x, self.weight1, self.bias_c_1)
        x, ind2 = down_block3d(x, self.weight2, self.bias_c_2)
        x, ind3 = down_block3d(x, self.weight3, self.bias_c_3)
        x, ind4 = down_block3d(x, self.weight4, self.bias_c_4)
        down5, ind5 = down_block3d(x, self.weight5, self.bias_c_5)

        up1 = up_block3d(down5, self.weight5, self.bias_d_1, ind5)
        up2 = up_block3d(up1, self.weight4, self.bias_d_2, ind4)
        up3 = up_block3d(up2, self.weight3, self.bias_d_3, ind3)
        up4 = up_block3d(up3, self.weight2, self.bias_d_4, ind2)
        up5 = up_block3d(up4, self.weight1, self.bias_d_5, ind1)

        classifier = classifier3d(down5, self.down_classifier_weight, self.down_classifier_bias)
        to_stack = (down5, up1, up2, up3, up4)

        upsampled = [normalize_and_upsample(i, self.img_size) for i in to_stack] + [up5]
        segmap = classifier3d(torch.cat(upsampled, 1), self.up_classifier_weight, self.up_classifier_bias)

        return classifier, segmap

    def weight_variable(self, shape):
        var =  nn.Parameter(torch.zeros(shape))
        xavier_uniform_(var)
        return var

    def bias_variable(self, size):
        var = nn.Parameter(torch.zeros(size))
        return var

def down_block3d(input, weights, bias):
    x = F.conv3d(input, weights, bias=bias, stride=1, padding=1)
    x = F.relu(x)
    out, indices = F.max_pool3d(x,2,stride=2,return_indices=True)
    return out, indices

def up_block3d(input, weights, bias, indices):
    x = F.max_unpool3d(input, indices, 2, stride=2)
    pad = (-1,-1,-1,-1,-1,-1)
    padded = F.pad(x, pad)
    x = F.conv_transpose3d(padded, weights, bias=bias)
    return F.relu(x)

def classifier3d(input, weights, bias):
    x = F.conv3d(input, weights, bias)
    return F.relu(x)

def normalize_channel_wise(x):
    x = x - torch.mean(x, 1, keepdim=True)
    return x / torch.std(x, 1, keepdim=True)

def normalize_and_upsample(x, new_size):
    x = normalize_channel_wise(x)
    return F.upsample(x, size=new_size)

First of all, in all the examples I have seen, nn.Module subclasses are callable. Mine is not callable, I have been explicitly calling model.forward(). It also seems that gradients are not being propagated, as when I call loss.backward(), all the gradients go to nan, and the GPU is not being utilized at all during the training loop, even though the memory has been allocated, and I have called .cuda() on all the inputs and the model.

I’m wondering, is there an easy to fix my work? or, is there a way to re-use the conv kernels like this with nn.Conv3d and the equivalent Transposed op?

What happens if you call the model directly via output = model(x)?
I tried out your code and it seems to run fine.
What PyTorch version are you using?

I’ll have to get back to you on Monday, but I’ve been getting something like Model3d object is not callable. I’m pretty sure I built the latest stable version of from sources, I’ll update my reply on Monday.
Thanks

UPDATE:
I take back what I said about the model not being callable, I’m calling the model on the inputs now and its fine. I have also found out that the nan value in my gradients is always in model.up_classifier_weight, which means that the loss function I’m using is probably the culprit. The loss values themselves seem normal though.

I solved the GPU utilization issue by calling nn.DataParallel on my model.

There are very large intermediate values computed in my loss function, I wonder if these are leading to gradient explosion… I want to try using batchnorm layers to see if I can mitigate this, but I’m not sure how to use the nn.functional.batch_norm. How do I make sure running_mean, running_var, args are being updated during training / being frozen during evaluation?

F.batch_norm has an argument training, which defines in which mode the function operates.

looking at the docs, it doesn’t seem like the training flag does anything besides check the size of the input. Does one have to manually update the running_mean, running_var?

The training flag get’s passed in this line of code.
You would have to register the running stats as buffers.

1 Like