I used to adopt a self-defined Function using PyTorch 1.1.0.
Now a newer PyTorch (1.10.0) said the function must be used with a static method.
I learn the tutorial but still don’t know how to make the function maintain a variable (here is self.V) during the forward and backward. Anyone could help?
class Exclusive(autograd.Function):
def __init__(self, V):
super(Exclusive, self).__init__()
self.V = V
def forward(self, inputs, targets):
self.save_for_backward(inputs, targets)
outputs = inputs.mm(self.V.t())
return outputs
def backward(self, grad_outputs):
inputs, targets = self.saved_tensors
grad_inputs = grad_outputs.mm(self.V) if self.needs_input_grad[0] else None
for x, y in zip(inputs, targets):
self.V[y] = F.normalize( (self.V[y] + x) / 2, p=2, dim=0)
return grad_inputs, None
As you’ve recognized, because pytorch now requires the forward() and backward() functions of an autograd.Function to be static, you can no
longer store any state (your self.V) in an instance of your Exclusive.
For this reason, pytorch has added a context argument (ctx) to forward()
and backward() that is used to save information from the forward() call
for use in the backward() call.