MSE loss backward fails while numpy-processed tensor is treated as a prediction tensor in MSE loss function

I’m sorry,I’m new to pytorch.maybe it’s a stupid question :joy:

when I wanna train a model,whose output is a spherical tensor named y_pred. The shape is [B,C,H,W] and I need to transform this spherical tensor to cartesian tensor,but numpy has a convenient function named np.apply_along_axis which can map my own coordinate transformation function( my own function named : np_spherical2cartesian ) to every element along a specific dimension.

so my process is


# 1. change pytorch tensor to numpy array
# 2. transform coordinate in numpy array
pred_car = [np.apply_along_axis(np_spherical2cartesian, dim=0, y_pred[i].cpu().detach().numpy()) for i in range(bs)] # bs is batch-size

#3. numpy to pytorch tensor for mse calculation
pred_car = [torch.from_numpy(i) for i in pred_car]
pred_car = torch.stack(pred_car)

criterion = nn.MSELoss()
loss = criterion(pred_car,target_car)

I can print the loss value but I got this error when run loss.backward():

RuntimeError: 0 of tensors does not require grad and does not have a grad_fn

It seems that the gradient of post-processed tensor have disappeared, is there any solution to solve this?

Thank you guys very much!!

Hi Ko!

This is to be expected. Calling y_pred[i].cpu().detach()
“breaks the computation graph” and pytorch can no longer
backpropagate through this part of the processing.

When you run your forward pass through your model and loss
function, pytorch’s autograd facility remembers what functions
were called (the “computation graph”) so that when you then
backpropagate, it can calculate the gradients of the loss function.

But the functions you call in your forward pass must be
autograd-aware, and numpy is not.

You either have to rewrite your numpy processing using pytorch
tensor operators (in which case you get autograd “for free”), or
you have to package this processing (and you can use numpy
in it) as a pytorch torch.autograd.Function, and provide your own
backward() method for your Function.

Good luck.

K. Frank