Model parameters are not being updated?

I made a simple example of a cnn layer where convolutional weights are defined as linear combination of predefined filters. The goal is to train the coefficients of linear combination while keeping predefined filters fixed. This can be easily achieved in tensorflow using tf.nn.conv2d. However here model.weight is not changing and model.weight.grad=None?
Can anyone please have a look? Thanks.

from __future__ import print_function
import torch.nn.functional as F
# from torch.autograd import Variable
# import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class ConvNet(nn.Module):
    def __init__(self, device='cuda'):

        super(ConvNet, self).__init__()
        self.device = device
        self.filter = torch.Tensor(
            [[[0.06,  0,      0],
              [0.1,   0,    0.2],
              [0.06,  0.1,   0]],
             [[0.1,   0,      0],
             [0.2,    0,      0],
             [0.1,    0,      0]]]).to(self.device)
        # print(self.filter.shape)
        self.weight = nn.Parameter(torch.Tensor(1, 1, 2).to(self.device))
        self.bias = nn.Parameter(torch.Tensor(1)).to(self.device)
        nn.init.xavier_uniform_(self.weight)
        self.bias.data.uniform_(-1, 1)
        self.kernel = nn.Parameter(
            torch.einsum("ijk, klm -> ijlm", 
            self.weight, self.filter).to(self.device), 
            requires_grad=False)
        # print(self.kernel.shape)
        self.bn1 = nn.BatchNorm2d(1)
        self.fc = nn.Linear(1*13*13, 10)

    def forward(self, x):
        out = F.conv2d(input=x, weight=self.kernel, bias=self.bias)
        out = self.bn1(out)
        out = nn.ReLU()(out)
        out = nn.MaxPool2d(kernel_size=2, stride=2)(out)
        # print("out.shape: ", out.shape)
        out = out.reshape(out.size(0), -1)
        # print("out.shape: ", out.shape)
        out = self.fc(out)
        return out


def main(argv=None):
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # Hyper parameters
    num_epochs = 1
    batch_size = 1024
    learning_rate = 0.001
    log_interval = 10

    # MNIST dataset
    train_dataset = datasets.MNIST(
    root='../../data/', 
    train=True, 
    transform=transforms.ToTensor(), 
    download=True)
    train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, 
    batch_size=batch_size, 
    shuffle=True)

    model = ConvNet(device=device)
    model.to(device)
    for name, param in model.named_parameters():
        print(name, '\t\t', param.shape)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(1, num_epochs + 1):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            # print(model.weight)
            # print(model.kernel)
            # print(model.weight.grad)
            a = list(model.parameters())[0].clone()
            loss.backward()
            optimizer.step()
            b = list(model.parameters())[0].clone()
            print(torch.equal(a.data, b.data))
            if batch_idx % log_interval == 0:
                _, predicted = torch.max(output.data, 1)
                total = target.size(0)
                correct = (predicted == target).sum().item()
                print('batch Accuracy: {} %'.format(100 * correct / total))


if __name__ == '__main__':
    main()


I tried to play around with your code to see, what might be wrong, but get an error running it:

RuntimeError: CUDA error: an illegal memory access was encountered

for PyTorch 1.0.0.dev20190207.
Is the code running fine on your machine?

Thank you so much for your reply. Yes it runs fine on my ubuntu 16.04 with pytorch 1.0.0 cuda version, python 3.5. Also there is really no need to use the gpu (the device). Since even on cpu it was having the same issue.

Your code is working on the CPU, if you specify requires_grad=True for self.kernel or just remove the requires_grad argument, since the default is set to True for nn.Parameter.
It’s still strange, that your code threw an error on my machine using the GPU.

I understand but the problem is when i set requires_grad=True for self.kernel the algorithm just initializes the weights based on the linear combination. Afterwards, all the coordinate of self.kernel start changing. That is not the behavior i want. I want only the coefficients to be trained. So self.filter has to stay fixed while self.weight is changing. Therefore, it has to be the case that for self.kernel requires_grad=False.
An easy way to observe the right behavior during training is the following: coordinates 0x1, 0x2, 1x1, and 2x2 of the self.kernel has to stay zero during the entire training process. Hope that helps.

Now, my question is, is it an internal limitation of pytorch that by design you cannot define the weight of nn.conv2d or F.conv2d to be non-trainable? Do i need to change the F.conv2d function internally in pytorch?
Thank you for your help.

I’ve had no problem implementing this idea in tensorflow. The reason i want to move to pytorch is:
It’s awesome dataset api and easy multi-gpu training.
I can also share a tensorflow example if that helps you understand the algorithm better.
I would be most grateful if you could help me solve this issue.

  1. Remember that pytorch is based on dynamic computation graphs, where as tensorflow is based on static computation graphs. Hence, you need to calculate filter in forward() rather than __init__().

  2. You do not need to move individual parameters to cuda. Rather try to move whole model to cuda.

1 Like

Hi InnovArul! Thank you for your response. Well, it looks like your example is working:

self.weight is changing.
self.weight.grad is not None and changing as well.
self.kernel itself is changing as expected as well upon closer inspection.

Thank you so much.

InnovArul and ptrblck.
Thank you so much for your help.
I just wanted to ask why should I use this line self.register_buffer("filter", filter)? and not something like self.filter = nn.Parameter(filter, requires_grad=False)? Why is the latter not correct?
since I looked at a bunch of posts to learn about register_buffer usage but I am new to pytorch and still want to understand deeper:
https://discuss.pytorch.org/t/solved-register-parameter-vs-register-buffer-vs-nn-parameter/31953
https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/2
https://discuss.pytorch.org/t/use-and-abuse-of-register-buffer/4128

requires_grad=False by default, and not saved to checkpoint:

self.filter = torch.tensor(...)

requires_grad=True by default, and saved to checkpoint, and shows up in module.parameters():

self.filter = nn.Parameter(torch.tensor(...))

requires_grad=False by default, and saved to checkpoint:

self.filter = self.register_buffer("filter", torch.tensor(...))

If you’re dealing with a constant tensor, you don’t want it showing up in model.parameters() since that makes the following include the constant tensor in the optimizer:

optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

…this is perhaps not so problematic if you manually also set requires_grad=False, since parameters that have had requires_grad=False set since the beginning of time will be skipped. But it’s nicer to keep buffers separated from parameters.