Weights cannot be updated

I use “torch.autograd.Function” to construct my own autograd function. “eta” is my weights. But when I print all the parameters, I found eta is not updated at all. And if I did

print(list(model.parameters())[0].grad)

I will get None.
The sample code is below:

class ANN_Pytorch(torch.nn.Module):
    def __init__(self, D_in, H1, H2, D_out, cutoff, beta):
        super(ANN_Pytorch, self).__init__()
        self.w0 = Variable(((15 - 2.5) * torch.rand(1, D_in, device=device, requires_grad=True) + 2.5)).double().unsqueeze(0).permute(2, 0, 1)
        self.eta = nn.Parameter(self.w0)
        self.linear1 = torch.nn.Linear(D_in, H1)
        self.linear2 = torch.nn.Linear(H1, H2)
        self.linear3 = torch.nn.Linear(H2, D_out)


    def forward(self, X):
        output = torch.zeros(X.shape[0], X.shape[1])
        for i in range(X.shape[0]):
            output[i,:] = RigidityExp.apply(sparse2tensor(X[i,:].reshape(1,-1)).double(), self.eta, beta, cutoff)
        z1 = self.linear1(output)
        f1 = F.sigmoid(z1)
        z2 = self.linear2(f1)
        f2 = F.sigmoid(z2)
        y_hat = self.linear3(f2)
        return y_hat
#================================Construct Model================================
D_in, H1, H2, D_out, cutoff, beta = X_train_Distance.shape[1], 40, 20, ytrain.shape[1], 12, 2.5

model = ANN_Pytorch(D_in, H1, H2, D_out, cutoff, beta).cuda()

criterion = torch.nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = 1e-2, momentum = 0.9, weight_decay = 1e-2)
epochs = 300

#===================================Trianing Model==============================
for epoch in range(epochs):
    model.train()
    y_pred  = model(X_train_Distance)
    loss = criterion(y_pred, ytrain)
    print(epoch, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(list(model.parameters())[0].grad)

Could you post the code of RigidityExp, since self.eta is being used only there?
I guess you might detach it somewhere in your function.
A simple dummy model without your custom function works fine:

class ANN_Pytorch(nn.Module):
    def __init__(self, D_in, H1, H2, D_out, cutoff, beta):
        super(ANN_Pytorch, self).__init__()
        self.w0 = torch.rand(1, D_in, requires_grad=True)
        self.eta = nn.Parameter(self.w0)
        self.linear1 = nn.Linear(D_in, H1)
        self.linear2 = nn.Linear(H1, H2)
        self.linear3 = nn.Linear(H2, D_out)


    def forward(self, x):
        x = x + self.eta
        z1 = self.linear1(x)
        f1 = torch.sigmoid(z1)
        z2 = self.linear2(f1)
        f2 = torch.sigmoid(z2)
        y_hat = self.linear3(f2)
        return y_hat


model = ANN_Pytorch(5, 10, 10, 10, 0, 0)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = 1e-2, momentum = 0.9, weight_decay = 1e-2)
epochs = 300

data = torch.randn(1, 5)
ytrain = torch.randn(1, 10)
#===================================Trianing Model==============================
for epoch in range(epochs):
    model.train()
    y_pred  = model(data)
    loss = criterion(y_pred, ytrain)
    optimizer.zero_grad()
    loss.backward()
    print(model.eta.grad)
    optimizer.step()
1 Like

Thanks for your reply! Here is my RigidityExp function:

class RigidityExp(torch.autograd.Function):
 
    @staticmethod
    def forward(ctx, input, eta, beta, cutoff):
        """
        Input:
            - input  : (D_in, N, M) matrix, which will be the output of sparse2tensor functionself.
            - eta    : (D_in, 1, 1) matrix.
            - beta   :  constant
            - cutoff :  constant
        Output:
            - output : A (1, D_in) matrix
        """
        ctx.save_for_backward(input, eta)
        ctx.beta =  beta
        ctx.cutoff = cutoff
        X = torch.exp(-(input/eta).pow(beta))
        M = torch.zeros_like(X)
        Y = torch.where(input < cutoff, X, M)
        output = torch.sum(torch.sum(Y, dim =-1), dim=-1).unsqueeze(0)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, eta = ctx.saved_tensors
        exponent = (input/eta).pow(ctx.beta)
        X = torch.exp(-exponent)
        M = torch.zeros_like(X)
        M0 = torch.ones_like(X)
        Y = torch.where(input < ctx.cutoff, X, M)
        input0 = torch.where(input != 0, input, M0)
        Z = (ctx.beta * Y * exponent) / eta
        Z0 = -(ctx.beta * Y * exponent)/ input0

        matrix_eta = torch.sum(torch.sum(Z, dim =-1), dim=-1).unsqueeze(0)
        matrix_eta0 = torch.sum(torch.sum(Z0, dim =-1), dim=-1).unsqueeze(0)

        grad_input = (grad_output.unsqueeze(0).permute(2,0,1) * matrix_eta0)
        grad_eta = (grad_output * matrix_eta).unsqueeze(0).permute(2,0,1)
        return grad_input, grad_eta, None, None

from torch.autograd import gradcheck
rigidity = RigidityExp.apply
a = ((15 - 2.5) * torch.rand(1, 36, device = device,requires_grad=True) + 2.5).double().unsqueeze(0).permute(2, 1, 0)
input = (sparse2tensor(X_test_Distance[0,:].reshape(1,-1)).double(), a, 2.5, 12)
test = gradcheck(rigidity, input, eps=1e-3, atol=1e-3)
print(test)

Thanks for the code.
Note that you are currently using a global beta and cutoff in ANN_Pytorch.forward, since you didn’t set them as attributes.
Also Variables are deprecated since PyTorch 0.4.0, so you can just use torch.tensors instead.
After fixing these small issues and using some random arguments, I get valid gradients for eta:

class ANN_Pytorch(torch.nn.Module):
    def __init__(self, D_in, H1, H2, D_out, cutoff, beta, device):
        super(ANN_Pytorch, self).__init__()
        self.w0 = ((15 - 2.5) * torch.rand(1, D_in, device=device, requires_grad=True) + 2.5).double().unsqueeze(0).permute(2, 0, 1)
        self.eta = nn.Parameter(self.w0)
        self.linear1 = torch.nn.Linear(D_in, H1)
        self.linear2 = torch.nn.Linear(H1, H2)
        self.linear3 = torch.nn.Linear(H2, D_out)
        self.beta = beta
        self.cutoff = cutoff


    def forward(self, X):
        output = torch.zeros(X.shape[0], X.shape[1])
        for i in range(X.shape[0]):
            output[i,:] = RigidityExp.apply(X[i,:].reshape(1,-1).double(), self.eta, self.beta, self.cutoff)
        z1 = self.linear1(output)
        f1 = F.sigmoid(z1)
        z2 = self.linear2(f1)
        f2 = F.sigmoid(z2)
        y_hat = self.linear3(f2)
        return y_hat

model = ANN_Pytorch(5, 10, 10, 10, 1, 1, 'cpu')
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = 100, momentum = 0.9, weight_decay = 1e-2)
data = torch.randn(1, 5)
ytrain = torch.randn(1, 10)

model.train()
y_pred  = model(data)
loss = criterion(y_pred, ytrain)
optimizer.zero_grad()
loss.backward()
print(model.eta.grad)
> tensor([[[ 7.7592511500e-04]],

        [[-7.3158718795e-05]],

        [[-1.6567345016e-04]],

        [[ 3.0475339896e-04]],

        [[-1.9967394943e-04]]], dtype=torch.float64)
optimizer.step()
1 Like

Thank you so much for your reply! I change beta and cutoff to self.beta and self.cutoff, but I still got “None” for model.eta.grad. I think the reason is in the forward process, I use a function for X:

output[i,:] = RigidityExp.apply(sparse2tensor(X[i,:].reshape(1,-1)).double(), self.eta, self.beta, self.cutoff)

and the sparse2tensor function is:

def sparse2tensor(X):
    '''
    Input: X has shape(1, D)
           For Example, input X should be X_test_Distance[0,:].reshape(1,-1)
    '''
    D_in = X.shape[1]
    sparse =  stack_coo_matrix(X)
    matrix = sparse.toarray()
    D_total = matrix.shape[1]
    matrix = matrix.reshape(int(D_in), -1, int(D_total/D_in))
    tensor = torch.tensor(matrix, device=device, dtype=dtype)
    return tensor

Do you think this is the reason I cannot get the correct answer? Or do you have other idea?

Ah yeah, I missed that function. Thanks for pointing it out!
Could you also post the definition of stack_coo_matrix and if possible sample values for X?

Sure.

def pad_matrix(X):
    '''
    Input: X has shape(1, D)
           For Example, input X should be X_test_Distance[0,:].reshape(1,-1)
    '''
    width = []
    height = []
    for j in range(X.shape[1]):
        h = X[0,j].shape[0]
        w = X[0,j].shape[1]
        width.append(w)
        height.append(h)

    w_max_idx = np.argmax(width)
    h_max_idx = np.argmax(height)
    w_max = width[w_max_idx]
    h_max = height[h_max_idx]

    for j in range(X.shape[1]):
        X[0,j] = np.pad(X[0,j], ((0,h_max - X[0,j].shape[0]),(0, w_max - X[0,j].shape[1])), 'constant')
    return X

def stack_coo_matrix(X):
    '''
    Input: X has shape(1, D)
           For Example, input X should be X_test_Distance[0,:].reshape(1,-1)
           Example code: D = stack_coo_matrix(X_test_Distance[2,:].reshape(1,-1))
    '''
    Temp = pad_matrix(X)
    output = coo_matrix(Temp[0,0])
    # print(output)
    for i in range(X.shape[1] - 1):
        output = hstack([output, coo_matrix(Temp[0,i+1])])
    return output

def sparse2tensor(X):
    '''
    Input: X has shape(1, D)
           For Example, input X should be X_test_Distance[0,:].reshape(1,-1)
    '''
    D_in = X.shape[1]
    sparse =  stack_coo_matrix(X)
    matrix = sparse.toarray()
    D_total = matrix.shape[1]
    matrix = matrix.reshape(int(D_in), -1, int(D_total/D_in))
    tensor = torch.tensor(matrix, device=device, dtype=dtype)
    return tensor

And I shared the sample example for X on Google Dirve.
https://drive.google.com/file/d/1oF7vr7WnOf6Fu9tlrsEyK4zPG4QrWOVg/view?usp=sharing
Thank you so much!

Thanks for the code. Unfortunately, I cannot debug it, as stack_coo_matrix is undefined and the npy file cannot de decoded.

Anyway, I assume the sparse2tensor method just creates a dense tensor using some kind of sparse format.
Could you post the shapes of this dense tensor and pass it for debugging purposes to the method.
I double that sparse2tensor can somehow make the updates of self.eta fail, as it should be just a transformation on your input as far as I see.

Sorry, I reattaced the .npy file.
https://drive.google.com/open?id=1Vfi_3Pg46jPxTw1RpxJnbbqKjd-oZOLP

'''
X_test_Distance is a matrix, and each entry is still a matrix with a different shape. (Please see the picture I attached.)
'''
X_test_Distance = np.load('X_test_Distance.npy')[0:5]
A = sparse2tensor(X_test_Distance[0,:].reshape(1,-1))
print(X_test_Distance)
print(A.shape)
>>(5,36)
>>torch.Size([36, 1804, 5])

And for auto.grad part if use .double() for input, then I will get True when I check grad. And I will get an error if I do not double input and param.

from torch.autograd import gradcheck
Rigidity = RigidityExp.apply
a = ((15 - 2.5) * torch.rand(1, 36,device = device,requires_grad=True) + 2.5).double().unsqueeze(0).permute(2, 1, 0)
input = (sparse2tensor(X_test_Distance[0,:].reshape(1,-1)).double(), a, 2.5, 12)
test = gradcheck(Rigidity, input, eps=1e-3, atol=1e-3)
print(test)
>>True

And

from torch.autograd import gradcheck
Rigidity = RigidityExp.apply
a = ((15 - 2.5) * torch.rand(1, 36,device = device,requires_grad=True) + 2.5).unsqueeze(0).permute(2, 1, 0)
input = (sparse2tensor(X_test_Distance[0,:].reshape(1,-1)), a, 2.5, 12)
test = gradcheck(Rigidity, input, eps=1e-3, atol=1e-3)
print(test)
>>/home/name/anaconda3/lib/python3.6/site-packages/torch/autograd/gradcheck.py:170: UserWarning: At least one of the inputs that requires gradient is not of double precision floating point. This check will likely fail if all the inputs are not of double precision floating point. 
  'At least one of the inputs that requires gradient '
Traceback (most recent call last):
  File "check_grad.py", line 229, in <module>
    test = gradcheck(linear, input, eps=1e-3, atol=1e-3)
  File "/home/name/anaconda3/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 205, in gradcheck
    'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
  File "/home/name/anaconda3/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 185, in fail_test
    raise RuntimeError(msg)
RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0.2441, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7324, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2441,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
analytical:tensor([[4.5538e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 7.2516e-01, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.0459e-01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 1.8696e-04,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])

gradcheck is supposed to work with doubles and will fail for types of less precision.
I’m not sure how to debug your code further, as it seems the gradients are correct.
What are your doubts about your code?

1 Like

Hi ptrblck,

Thanks for your reply!

I think maybe one reason is that pytorch cannot track the gradient process once I use a function in the “for loop”? For example, I use RigidityExp function in the for loop as below:

for i in range(X.shape[0]):
            output[i,:] = RigidityExp.apply(X[i,:].reshape(1,-1).double(), self.eta, self.beta, self.cutoff)

I’m new to pytorch so that’s just my guess. My datasets do not have a decent shape so I will try to custom my dataset to see if it works.

Anyway, I really appreciate your help! And also I will let you know if custom dataset works.

Best,
qflm