Learning with externally consructed parameters

I have an external function that constructs a tensor from a set of parameters. I then pass this tensor and the parameters to a torch.nn module during initilzation. The problem is that the parameters are not being learned. I have included below a minimal test case where the module is a softmax classifier. The (externally) constructed tensor is the parameter tensor of the softmax.

import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from collections import OrderedDict


train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=False)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor())


batch_size = 100
n_iters = 3000
epochs = n_iters / (len(train_dataset) / batch_size)
input_dim = 784
output_dim = 10
lr_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


def create_parameters(input_dim, output_dim):
    parameters = OrderedDict()
    for i in range(input_dim):
        parameters[f"parameter_{i}"] = torch.nn.Parameter(torch.rand(  size=(output_dim,1), requires_grad=True, device=device))
    return parameters

def create_bias(output_dim):
    bias  = torch.nn.Parameter(torch.rand(  size=(output_dim,), requires_grad=True, device=device))
    return bias 
    
def construct_logit_parameter(parameters):
    p = torch.hstack(tuple(parameters.values()))
    return p

ps = create_parameters(input_dim, output_dim)
bias = create_bias(output_dim)
p = construct_logit_parameter(ps)

class SoftMax(torch.nn.Module):
    def __init__(self, parameters, bias, logit_parameter):
        super(SoftMax, self).__init__()
        self.params = torch.nn.ParameterDict(parameters)
        self.logit_parameter = logit_parameter
        self.bias = bias

    def forward(self, x):
        outputs = x @ self.logit_parameter.T + bias
        return outputs


model = SoftMax(ps, bias, p)
criterion = torch.nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate)


iter = 0
for epoch in range(int(epochs)):
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, 28 * 28).to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        iter+=1
        if iter%500==0:
            ps  = list(model.parameters())
            print(ps[1].grad)
            print(ps[0].grad)
            # calculate Accuracy
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.view(-1, 28*28).to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total+= labels.size(0)
                correct+= (predicted == labels.to(device)).sum()
            accuracy = 100 * correct/total
            print("Iteration: {}. Loss: {}. Accuracy: {}.".format(iter, loss.item(), accuracy))

Inspecting the gradient (simply checking .grad for the parameters in model.parameters()) shows me that the bias parameter has a non-zero gradient but all the other ones don’t. How do I fix this? My guess is that something is detached from the computation graph somewhere but I am not sure where and what.

Note, yes I know I could just construct the parameters in the nn.Module but the code base I am working on makes this a little bit tricky and I figured it might be easier to solve it this way.

in the module, you’re registering a tensor logit_parameter, i.e. a dependent entity, parameters that create it are invisible to the optimizer (not in model.parameters())

Aren’t the parameters that construct self.logit_parameter being registered with self.params = torch.nn.ParameterDict(parameters) ?

But if I understand you correctly this is not enough as the module does not know that the registered self.params are the same parameters that constructed self.logit_parameter. Basically, what I got now is that the module knows about the self.logit_parameter and the self.params that constructed self.logit_parameter but it does not know anything about how self.logit_parameter was constructed (the computation graph is not visible for the module).

Maybe, to formulate my initial question more clearly:
Is there a way to register the computation graph that constructs self.logit_parameter without wrapping the computation graph construction in module of its own?

Sorry, I missed the use of ParameterDict. The problem then is that you’re doing hstack once, and not in forward(), so it is not in the graph (zero_grad() before the first iteration resets logit_parameter.grad_fn, and it is never re-linked to hstack after that).

So, you’ll have to either do the stacking in forward() or use a stacked parameter, without dictionaries.

Yes, indeed that’s what I thought too, that the construction of the tensor has to happen in the forward pass. However, what I don’t understand yet is where exactly the . grad_fn gets “lost”. You mentioned calling zero_grad() resets the logit_parameter.grad_fn but having a look at the documentation only mentions that zero_grad() sets the gradients to zero and does not touch the .grad_fn. I guess it’s happening somewhere else?

Yeah, I’m wrong again, it is not reset. And your approach seems valid then, seems like some zero gradient values are mathematically correct there, try with this:

npa=list(model.named_parameters())
for k,v in npa:
	if v.grad.mean() != 0: print(k)

So, within the leanring loop I did something like this:

count=0
for k,v in model.named_parameters():
    if not torch.isclose(v.grad.mean(),torch.tensor(0.0, device=device)):
        count +=1
print(count)

count=0
for k,v in model.named_parameters():
    if v.grad.mean() != 0:
        count +=1
print(count)

Turns out none of the means are zero but all of them are super close to zero.

At this stage, I somehow got the feeling that I am forgetting about something silly that is happening under the hood. I mean there has to be a reason why constructing tensors in the forward function makes it work and doing the same externally or in the init results in near zero gradients.

despite such a mean, values are dispersed fine (except for always black pixels). some statistical property gives this effect (CLT I guess)

ah damn it, I was completely stuck in the TensorFlow mindset and thinking in terms of computation graphs that are being passed around. Basically, what happens if one passes the pre-constucted tensor to a pytorch module, this tensor is never re-evaluated with the updated parameter. Obviously, in pytorch one has to re-construct the parameter tensor in every forward pass, with the updated inputs, otherwise you will always have the same one. Kind of obvious, now.

Anyways, thanks for the help @googlebot !

oh and what does CLT mean?

lol, right, I initially had the impression that this wont work, but skipped the obvious reason - logit_parameter not being re-evaluated. And I was refering to Central Limit Theorem.