How can I optimise the Gradient? Something like grad.backward()

Currently not suppoted. Check here for more discussion How to implement gradient penalty in PyTorch