How to preserve gradient using external libraries?

I have a GAN network that returns a predicted torch.tensor. To guide this network, I have a loss function which is a summation of binary cross entropy loss and Wasserstein distance. However, in order to calculate Wasserstein distance, I am using scipy.stats.wasserstein_distance function from SciPy library. As you might know, this function requires two NumPy arrays as input. So, to use this function, I am converting my predicted tensor and ground-truth tensor to NumPy arrays as follows

pred_np = pred_tensor.detach().cpu().clone().numpy().ravel()
target_np = target_tensor.detach().cpu().clone().numpy().ravel()

W_loss = wasserstein_distance(pred_np, target_np)

Then, total loss is obtained by adding W_loss to BCELoss. I am now showing this part because it is a bit unnecessary and not related to my question.

My concern is I am detaching gradient so I suppose that while optimizing and updating model parameters it will not consider W_loss. I am a bit newbie so I hope my question is clear and appreciate for answers in advance.

Why not look for implementations of the wasserstein distance in pytorch instead?

1 Like