How do i reimplement the following logic in torch 1.7+?

The logic is flawed and PyTorch now properly raises an error, since you are using stale forward activations in the second backward call as described here.