Hi,
I think using torch.tensor
would break it because torch.tensor
has no grad_fn
. What you could do is replace,
with,
new_output = torch.stack([tensor.sum() for tensor in torch.split(output, 10, dim=1)])
That should keep the gradients flowing as torch.stack
has a grad_fn
whereas torch.tensor
doesn’t!