Hello. I’m trying to use torch.autograd.Function to define a function with a custom forward and backward pass. I’m following the guide in the pytorch documentation which gives the following example -
>>> class Exp(Function):
>>> @staticmethod
>>> def forward(ctx, i):
>>> result = i.exp()
>>> ctx.save_for_backward(result)
>>> return result
>>>
>>> @staticmethod
>>> def backward(ctx, grad_output):
>>> result, = ctx.saved_tensors
>>> return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)
my code looks something like this -
class Exp(Function):
@staticmethod
def forward(ctx, i,):
#I do some changes to the input i which works as intended
return i
@staticmethod
def backward(ctx,grad_output):
grad_output = torch.relu(grad_output)
print(grad_output.shape)
return grad_output
I expected this to print grad_output shape everytime the function has been called. This function has called in the forward method in the torch.nn.Module subclass (which looks really similar to the Linear layer). Let’s called this class Cats. Cats has been instantiated in the init method of a model.
But the code does not print the grad_output. What could be the reason?
Hello! Thanks for the reply.
I did not explicitly call the .backward() , however i did apply it according to the pytorch documentation by typing Exp.apply(input)
The following is the code for reproducing the error(provided the proper imports have made) -
class autograd_func(torch.autograd.Function):
def forward(self,layer):
alpha = torch.ones(layer.shape[0])
Alpha = (alpha)
return layer,Alpha
def backward(self,grad_output):
print("checking backward call for autograd_func")
return grad_output
class custom_linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super(custom_linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(self,input):
A,B = autograd_func.apply(self.weight.data)
self.weight.data = A
#B does get used in the complete implementation but it does not make a difference
out = F.linear(input, self.weight, self.bias)
return out
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None)
x = torch.randn([1,100])
y_star = torch.ones([1,10])
cust_layer = custom_linear(100,10)
y = cust_layer(x)
# print(y.shape,y_star.shape)
criterion = torch.nn.MSELoss()
loss = criterion(y,y_star)
loss.backward()
Ok, so the comments about @staticmethod were correct, it’s just that you need to add a few more things. For example change self to ctx, and remove the calls to .data, for example A,B = autograd_func.apply(self.weight.data) as .data is deprecated. Make sure to follow the exact layout of the arguments for the forward and backward in this example below. Do make sure to read through the tutorial (Extending PyTorch — PyTorch 1.11.0 documentation) too.
# Inherit from Function
class LinearFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
Thanks for the reply and the links!
Changing it to the ctx did not work.I was unaware of the fact that .data is depreceated. I will rewrite that part of code and try it again, as autograd does not track changed by making a backwards graph when .data is used(correct me if I’m wrong here)
I will try out your suggestions and update you.
Thank you for the reply!
This did work, yes. The torch.randn is a tensor whose computations are tracked by autograd. Using backward on .data is depreceated and changes on it are not tracked by it. Ill make the necessary changes and update you.