I am trying to implement custom function for gaussian CRF layer and weird error of forward() takes exactly 3 arguments (2 given)
during apply()
occurs. As I use new style static method based forward and backward, I have no idea what is basically wrong in current implementation, so can somebody point out the mistake.
The function code:
class G_CRF_layer(torch.autograd.Function):
@staticmethod
def forward(ctx, unary, pairwise):
positive_mul = 1e-3
decomposition = cho_factor(pairwise.numpy() + positive_mul * np.eye(pairwise.size(0)))
x = cho_solve(decomposition, unary.numpy())
x = torch.from_numpy(x).float()
ctx.data_for_backward = pairwise, x, decomposition
return x
@staticmethod
def backward(ctx, grad_output):
positive_mul = 1e-3
pairwise, x, decomposition, = ctx.data_for_backward
unary_grad = torch.from_numpy(cho_solve(decomposition, grad_output.data.numpy())).float()
pairwise_grad = (- unary_grad * x.view(1,-1))
return Variable(unary_grad), Variable(pairwise_grad)
And the apply function call:
A, B = np.array([[1,1],[1,2]]), np.array([3,5])
pairwise = Variable(torch.from_numpy(A).double(), requires_grad=True)
unary = Variable(torch.from_numpy(B).double().view(-1,1), requires_grad=True)
input = unary, pairwise
G_CRF_layer.apply(input)