@ptrblck @fmassa
Problem: i need to save an “attribution” in tensor, and hope it can transform between forward and backward;
For example, we want save “local” in tensor and want this attribution can be transform both in forward and backward;
The forward works as expected, but the backward is not; the print like:
output local: 1
output local: 2
grad_input local: 1
grad_input local: 1
Another interesting thing is: after add pdb.set_trace(), the backward also work as expected.
import pdb
import torch
from torch import nn
from torch.autograd import Function
class LinearFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
if not hasattr(input, 'local'):
output.local = 1
else:
output.local = input.local + 1
print(f'output local: {output.local}')
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# pdb.set_trace() # NOTE: if add this line, it works as expected
if not hasattr(grad_output, 'local'):
grad_input.local = 1
else:
grad_input.local = grad_output.local + 1
print(f'grad_input local: {grad_input.local}')
return grad_input, grad_weight, grad_bias
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
self.weight.data.uniform_(-0.1, 0.1)
if self.bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
if __name__ == '__main__':
x = torch.randn((2, 4), requires_grad=True)
y1 = Linear(input_features=4, output_features=3)(x)
y2 = Linear(input_features=3, output_features=4)(y1)
out = y2.sum()
out.backward()