I’m trying to build a custom module layer which itself uses a custom function. Then, inside this function it would be nice, if I could use existing functions. As a simplified example I wrapped a Linear Layer inside my function and try to pass its weights as a parameter from the “surrounding” module.
Originally, I asked this as a follow up question, but I think it’s easier to find this topic for related issues, when it is posted as a “stand-alone-problem”.
In the forward pass everything seems to work out fine, but when it comes to the backward computations, the backward computation inside the linFct.backward() method never seems to terminate. More precisely, after passing the gradient value via
tmpLin(tmpDataVar).backward(grad_output) nothing more seems happens when the Variables backward method calls the execution_engine.
import torch class linFct(torch.autograd.Function): def forward(self, fctDataIn, fctWeight): self.save_for_backward(fctDataIn, fctWeight) tmpDataVar = torch.autograd.Variable(fctDataIn) tmpWeightParam = torch.nn.Parameter(fctWeight) tmpLin = torch.nn.Linear(3, 2, bias=False) tmpLin.weight = tmpWeightParam outFct = tmpLin(tmpDataVar) return outFct.data def backward(self, grad_output): fctDataIn, fctWeight = self.saved_tensors tmpDataVar = torch.autograd.Variable(fctDataIn, requires_grad=True) tmpWeightParam = torch.nn.Parameter(fctWeight) tmpLin = torch.nn.Linear(3, 2, bias=False) tmpLin.weight = tmpWeightParam tmpLin.zero_grad() print(tmpDataVar.data) print(tmpWeightParam.data) print(grad_output) print("still here...") tmpLin(tmpDataVar).backward(grad_output) print("cannot reach this :( ") grad_fctDataIn = tmpDataVar.grad.data grad_fctWeight = tmpWeightParam.grad.data print(grad_fctDataIn) print(grad_fctWeight) return grad_fctDataIn, grad_fctWeight class linLayer(torch.nn.Module): def __init__(self): super(linLayer, self).__init__() self.wParam = torch.nn.Parameter(torch.randn(2, 3)) self.fct = linFct() def forward(self, layerIn): return self.fct(layerIn, self.wParam) x = torch.autograd.Variable(torch.randn(2, 3), requires_grad=True) fct = linLayer() print("forward...") y = fct(x) fct.zero_grad() print("backward...") fct(x).backward(torch.randn(2, 2)) print(x.grad.data) print(fct.wParam.grad.data)
I assume, that it’s not the way, how autograd should be used and that maybe it’s not allowed to run “backward” computations while autograd traverses the backward graph. I would be grateful for any advice how to implement the use of a backward function call inside a self-implemented backward method. Thank you!