RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn (But I already caculate the grad_fn!))

When I run:

import torch
import torch.nn as nn
from torch.autograd import Variable
class ParetoSetModel(torch.nn.Module):
def init(self, n_dim, n_obj, problem):
super(ParetoSetModel, self).init()
self.n_dim = n_dim
self.n_obj = n_obj
self.problem = problem
self.fc1 = nn.Linear(self.n_obj, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, self.n_dim*5)
self.shape = (10, n_dim, 5)
# self.fc3 = nn.Linear(256, self.n_dim)
def forward(self, pref):

    x = torch.relu(self.fc1(pref))
    x = torch.relu(self.fc2(x))
    x = self.fc3(x)
    x = x.view(self.shape)
    x = torch.softmax(x, dim=2)
    x = torch.argmax(x,dim=2)
    return x

psmodel(pref_vec).backward(tch_grad)

I get the error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In my code, the psmodel(pref_vec) variable is :slight_smile:
Out[1]:
tensor([[3, 0, 4, 1, 0, 4],
[4, 4, 4, 0, 0, 4],
[3, 0, 4, 1, 0, 4],
[4, 4, 4, 0, 0, 4],
[4, 4, 4, 0, 0, 4],
[4, 4, 4, 0, 0, 4],
[4, 0, 4, 0, 0, 4],
[4, 4, 4, 0, 0, 4],
[3, 0, 4, 1, 0, 4],
[3, 0, 4, 1, 0, 4]])

tch_grad variable is:
Out[2]:
tensor([[-5.7283e-01, 7.1886e-01, -1.8480e-02, -3.8009e-01, -1.0146e-01,
-8.2423e-04],
[-2.0862e-01, 6.5151e-01, -7.2875e-01, -3.0513e-02, -1.4981e-03,
-3.9053e-04],
[-5.7287e-01, 7.1879e-01, -1.8521e-02, -3.8016e-01, -1.0147e-01,
-8.8013e-04],
[-2.0861e-01, 6.5153e-01, -7.2873e-01, -3.0524e-02, -1.4975e-03,
-4.4674e-04],
[-2.0861e-01, 6.5153e-01, -7.2873e-01, -3.0526e-02, -1.4974e-03,
-4.5370e-04],
[-2.0861e-01, 6.5154e-01, -7.2873e-01, -3.0530e-02, -1.4972e-03,
-4.7680e-04],
[ 2.9420e-01, 2.7046e-01, -8.6619e-01, -3.0001e-01, 9.6886e-04,
-1.8034e-03],
[-2.0863e-01, 6.5151e-01, -7.2875e-01, -3.0512e-02, -1.4981e-03,
-3.8361e-04],
[-5.7287e-01, 7.1880e-01, -1.8519e-02, -3.8015e-01, -1.0147e-01,
-8.7734e-04],
[-5.7288e-01, 7.1877e-01, -1.8532e-02, -3.8017e-01, -1.0147e-01,
-8.9424e-04]], dtype=torch.float64, grad_fn=)

torch.argmax is not differentiable so remove it from the model.