Calculating gradients w.r.t inputs & backward

Hi, I’m trying to implement a contractive Autoencoder with Pytorch and I’m having trouble calculating gradients w.r.t inputs and calculate backward on top of that.
My output ‘hn_enc’ is a 10 dimensional tensor and my input ‘x’ is a 24 dimensional tensor.
I have a model called ‘csi_enc’ used for forward pass which parameters are also trained using an additional loss, but for the moment I want to focus on gradients w.r.t ‘x’.
The goal is to compute the squared sum of the gradients of ‘hn_enc’ w.r.t ‘x’ and perform backward pass on this as a loss function.
Note that this is essentially computing the Jacobian matrix:
[ dhn_enc(1)/dx(1), dhn_enc(1)/dx(2),…, dhn_enc(1)/dx(24) ;

dhn_enc(10)/dx(1), dhn_enc(10)/dx(2),…, dhn_enc(10)/dx(24) ]
and calculating this matrix squared sum as a loss function.

So, I’m going in a foor loop and calculating the gradients w.r.t ‘x’ for each ‘hn_enc’ entry as follows:

enc_optimizer = optim.Adam(lr=opt_lr_adam, params=csi_enc.parameters(), eps=1e-8)
x.requires_grad = True
z_enc, hn_enc, cn_enc = csi_enc.forward(x, x.shape[0], 1, False)  # forward pass
for q in range(0,10):
      x.grad.requires_grad = True
      contractive_loss = contractive_loss + torch.mean(x.grad**2)
      enc_optimizer.zero_grad()    # with this, all model gradients are 0 again for next x grad compute

Note that I’m using ‘retain_graph=True’ so the graph won’t disappear.
Note that I’m calling ‘enc_optimizer.zero_grad()’ such that the model weights gradients won’t get accumulated.
Now, I want to update ‘csi_enc’ parameters w.r.t ‘contractive_loss’. But calling:


results in:

“one of the variables needed for gradient computation has been modified by an inplace operation”

When I’m calling ‘contractive_loss.backward()’ using only a single iteration in the loop (i.e., calculate gradients w.r.t x[0] only) then it works great. So I’m doing something wrong with the iterative summation over ‘x.grad’ but I can’t figure out how to do it properly.
Any help please?