For example,
output_1 = model(inputs)
# ignore the grads of torch.exp()
middle_func_output = torch.exp(output_1)
loss = loss_func(middle_func_output)
loss.backward()
I would like to ignore the gradients from the middle_func_output and treat the backprop as if torch.exp()
didn’t change the gradients in the backwards pass even though it’s still applied in the forwards pass.
I tried:
with torch.no_grad():
middle_func_output = torch.exp(output_1)
But this returns the error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
.
My actual middle_func is not torch.exp()
, but is multivariable.