Is there a way I can call a function in the middle of a series of operations but ignore that function's gradient?

For example,

output_1 = model(inputs)
# ignore the grads of torch.exp()
middle_func_output = torch.exp(output_1)
loss = loss_func(middle_func_output)
loss.backward()

I would like to ignore the gradients from the middle_func_output and treat the backprop as if torch.exp() didn’t change the gradients in the backwards pass even though it’s still applied in the forwards pass.

I tried:

with torch.no_grad():
     middle_func_output = torch.exp(output_1)

But this returns the error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn.

My actual middle_func is not torch.exp(), but is multivariable.

You could use a similar approach as this one posted by @tom, where x_backward would correspond to the tensor, which should be used during the backprobagation and x_forward only in the forward pass.