Retain_graph and training with symmetric matricies

I’m trying to train a model that uses a symmetric matrix – I implemented it via:

T = nn.Parameter(torch.Tensor(STATE_DIM, STATE_DIM))
self.Q =, T.t())

but when trying to train, I am getting an error about needing retain_graph=True but then my loss value doesn’t decrease during optimizer steps

Is there a way to enforce symmetry in a way that doesn’t require retain_graph? if not, how do I get my loss to go down when retain_graph=True?

1 Like

The error about needing retain_graph=True is likely unrelated to your symmetric matrix.

I tried the following without getting any error.

T = nn.Parameter(torch.Tensor(5,5))
Q =, T.t())
out = torch.nn.functional.linear(torch.rand(2,5), Q)

If your model is a recurrent model, then you should probably detach the hidden state between batches. See Adding new hidden layer to LSTM

If your model is not a recurrent model, then something else weird is happening, can you provide a small example that produces the error?

It looks like it’s the for loop causing the problem – the first time through things are fine, but the second time causes an issue.

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.T = nn.Parameter(torch.Tensor(5, 5))        
    self.Q =, self.T.t())
  def forward(self, x0):
    rhat = x0.unsqueeze(1).matmul(self.Q).matmul(x0.unsqueeze(2)).squeeze(2)
    return rhat
loop = tqdm.tqdm(range(EPOCHS))
model = Net()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

for i in range(2):
  r = Variable(torch.from_numpy(np.random.randn(32, 1).astype(np.float32)))
  x0 = Variable(torch.from_numpy(np.random.randn(32, 5).astype(np.float32)))
  rhat = model(x0)
  loss = F.mse_loss(rhat, r)

You should recalculate Q for each forward pass. Otherwise the computation graph for Q gets freed when you run loss.backward() and the second run can’t backpropagate all the way back to T.

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.T = nn.Parameter(torch.Tensor(5, 5))        
  def forward(self, x0):
    Q =, self.T.t())
    rhat = x0.unsqueeze(1).matmul(Q).matmul(x0.unsqueeze(2)).squeeze(2)
    return rhat

Doh! Absolutely. Thank you!