Hi all,
I’m quite new to torch and am struggling to understand why my code here is giving an error. I’m working on a seq2seq type model, here is the relevant code
def forward(self, x, y, is_eval=False, weighted=False):
batch_size = x.size(0)
n = x.size(1)
y_pred = torch.zeros_like(y).cuda()
encoder_outputs, _ = self.encoder(self.embed(x))
decoder_outputs = torch.zeros(batch_size, n, self.decoder.output_size).cuda()
dec_h = None
for i in range(n):
# Update the decoder
embedded = self.embed(y_pred[:, i-1:i]) if i > 0 else torch.ones_like(y_pred[:, 0:1])
decoder_outputs[:, i:i+1, :], dec_h = self.decoder(embedded, dec_h)
prev_edge = torch.ones((batch_size, 1, 1)).cuda() # This tracks s^(t)_i,j-1
edge_h = None
for j in range(i+1):
theta, edge_h = self.edge_level(prev_edge, edge_h)
edge_mlp_input = torch.cat((x[:, i:i+1, j:j+1],
theta,
encoder_outputs[:, i:i+1, :],
decoder_outputs[:, i:i+1, :]), dim=2)
edge_prob = self.activation(self.edge_mlp(edge_mlp_input))
if is_eval: # sample the sigmoid
y_pred[:, i:i+1, j:j+1] = sample_vector(edge_prob) if not weighted else edge_prob
prev_edge = y_pred[:, i:i+1, j:j+1]
else:
y_pred[:, i:i+1, j:j+1] = edge_prob
prev_edge = edge_prob
return y_pred
The Encoder and Decoder are just standard GRU RNNs from torch.
I am training with a very standard training loop you would find on any of the torch examples (hence the error is not coming from the is_eval section of the fwd pass), and I get the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [180, 1, 14]], which is output 0 of SliceBackward, is at version 106; expected version 92 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
The Tensor it is referring to based on the sizes is the y_pred[:, i-1] tensor, but I can’t understand why what I am doing is causing an issue as it doesn’t appear to me that I am modifying it, but maybe there’s something I haven’t grasped here?
Many thanks,
Jase