ronda
June 5, 2021, 3:12pm
1
Hi, I have a question,
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# return input.detach()
return torch.tensor([1.0], requires_grad=False)
@staticmethod
def backward(ctx, grad_input):
return grad_input
input = torch.tensor([2.], requires_grad=True)
output = MyFunc.apply(input)
print(output.requires_grad) # True
I don’t understand why output.requires_grad is True while the result of forward is False
I am new to pytorch and hope I don’t bother you.
Thanks
This is pretty tricky indeed, but this behavior is noted in the docs:
By default, all the output Tensors that are of differentiable type will be set to require gradient and have all autograd metadata set for them. If you don’t want them to require gradients, you can use the mark_non_differentiable method mentioned above. For output Tensors that are not of differentiable type (integer types for example), they won’t be marked as requiring gradients.
See: Extending PyTorch — PyTorch 2.1 documentation
ronda
June 9, 2021, 4:07pm
3
Thank you very much for your answer, may I ask where to view its source code, I would like to know some details of it.
Check here for the entry point of Python custom functions once you call apply
:
TORCH_CHECK(cdata,
"Attribute 'name' is invalid for this instance of _C._FunctionBase. "
"Accessing this attribute directly on an instance of autograd.Function is a legacy "
"access pattern that is no longer supported. For examples on how to use new-style "
"autograd functions, see "
"https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
return THPUtils_packString(cdata->name());
END_HANDLE_TH_ERRORS
}
PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
{
HANDLE_TH_ERRORS
RECORD_FUNCTION(
((PyTypeObject*)cls)->tp_name,
std::vector<c10::IValue>(),
at::sequence_number::peek());
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
if (!backward_cls) return nullptr;
THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
This function is where the grad_fn is set for the output. You can see that set_gradient_edge
is called as long as the output is differentiable (i.e., its scalar type is floating or complex).
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
if (is_empty) {
// Return undefined tensor.
return at::Tensor();
} else {
return at::zeros(
size, at::TensorOptions(scalar_type).device(device).layout(layout));
}
}
std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_vars,
const std::unordered_set<at::TensorImpl*> &non_differentiable,
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node> &cdata) {
std::unordered_set<at::TensorImpl*> inputs;
inputs.reserve(input_vars.size());
for (auto& var : input_vars) {
inputs.emplace(var.unsafeGetTensorImpl());
}