Loss function that uses the output to calculate index which is used to get value from an array and then loss is calculated

In my neural network (RNN), I am defining the loss function such that the output of the neural network is used to find the index (binary) and then the index is used to extract the required element from an array which in turn will be used to calculate MSELoss.

However, the program gives parameter().grad = None error which is mostly because the graph is breaking somewhere. What is the problem with the error function defined.

Framework: Pytorch

The codes are as follow:
Neural Network:

class RNN(nn.Module):
  def __init__(self):
    super(RNN, self).__init__()
    self.hidden_size = 8
    # self.input_size = 2
    self.h2o = nn.Linear(self.hidden_size, 1)
    self.h2h = nn.Linear(self.hidden_size, self.hidden_size)
    self.sigmoid = nn.Sigmoid()
  def forward(self,hidden):
    output = self.h2o(hidden)
    output = self.sigmoid(output)
    hidden = self.h2h(hidden)
    return output, hidden
  def init_hidden(self):
    return torch.zeros(1, self.hidden_size)

Loss Function, train step and training

rnn = RNN()
criterion = nn.MSELoss()

def loss_function(previous, output, index):
  code = 2*(output > 0.5).long()
  current = Q_m2[code:code+2, i]
  return criterion(current, previous), current

def train_step():
  hidden = rnn.init_hidden()
  rnn.zero_grad()
  # Q_m2.requires_grad = True
  # Q_m2.create_graph = True 
  loss = 0
  previous = Q_m[0:2, 0]
  for i in range(1, samples):
    output, hidden = rnn(hidden)
    l, previous = loss_function(previous, output, i)
    loss+=l
  loss.backward()
  # Q_m2.retain_grad()
  for p in rnn.parameters():
    p.data.add_(p.grad.data, alpha=-0.05)
  return output, loss.item()/(samples - 1)

def training(epochs):
  running_loss = 0
  for i in range(epochs):
    output, loss = train_step()
    print(f'Epoch Number: {i+1}, Loss: {loss}')
    running_loss +=loss

Q_m2

Q_m = np.zeros((4, samples))
for i in range(samples):
  Q_m[:,i] = q_x(U_m[:,i])
Q_m = torch.FloatTensor(Q_m)
Q_m2 = Q_m
Q_m2.requires_grad = True
Q_m2.create_graph = True

Error:

<ipython-input-36-feefd257c97a> in train_step()
     21   # Q_m2.retain_grad()
     22   for p in rnn.parameters():
---> 23     p.data.add_(p.grad.data, alpha=-0.05)
     24   return output, loss.item()/(samples - 1)
     25 

AttributeError: 'NoneType' object has no attribute 'data'

Hi Aditya!

I haven’t looked at your code or error messages in any detail.

But, yes, your “graph is breaking.”

code, as a function of output, is not (usefully) differentiable, so,
regardless of what you subsequently do with code, you won’t be
able to back propagate through it.

I have no idea whether this would make sense for your use case,
but one possibility would run as follows:

As I read it, code is calculated to be either 0 or 2. You could
instead interpret output (processed appropriately, as necessary)
to be the probability that code should be 0 vs. 2, and then use
that probability to form a weighted average of the 0 and 2 entries
in your Q_m2 array.

This will be differentiable and you will be able to backpropagate (but
I’m not saying that it would make sense …).

Best.

K. Frank

1 Like