Suppose I define a new function called foo . This function has some new operations, so I define the autograd.Function myself. In the backward, I use autograd to compute the gradient automatically.
class foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, params):
out = torch.sum( x0 * params)
ctx.save_for_backward(x0, params)
return out
@staticmethod
def backward(ctx, grad_output):
x0, params = ctx.saved_tensors
with torch.set_grad_enabled(True):
out = torch.sum( x0 * params)
gradient = torch.autograd.grad(outputs=out,
inputs=params,
grad_outputs= grad_output,
allow_unused=True,
retain_graph=True,
only_inputs = True)
dfdx = torch.autograd.grad(outputs=out,
inputs=x0,
grad_outputs= grad_output,
allow_unused=True,
retain_graph=True,
only_inputs = True)
return dfdx, gradient
Now I want to use the foo
class model(nn.Module):
def forward(self, x):
parameters = self.parameters()
new_parameters = process_parameters(parameters)
func = foo.apply
out = func (x, new_parameters)
return out
in the model class, I have a process_parameters function to reshape the weights to different dimension.
Here is my question: how does PyTorch update weights by gradient descent method? How do I know the model class has new weights? Does PyTorch realize that new_parameters is actually self.parameters() and update self.parameters() according to the new value from new_parameters ?