Hi,
I’m trying to customize layer following the tutorial Extending Pytorch. I created a similar customized function and nn.Module like LinearFunction and Linear in the example. However, when I use the my customized layer to replace the FC layer in the official mnist example, it shows no parameters for this layers. (While the testing step will work, the training process shows an error: ValueError: optimizer got an empty parameter list).
My customized function is:
class LinearDecomp(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
def forward(ctx, input, coefs, dictionary, bias=None):
weight = torch.mm(dictionary, coefs).cuda() # reconstruct the weight
ctx.save_for_backward(input, weight, dictionary, coefs, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
input, weight, coefs, dictionary, bias = ctx.saved_variables
grad_input = grad_input = grad_coefs = grad_bias = None
grad_weight = grad_output.t().mm(input) # do not output
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
# if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input) # do not output grad_weight
if ctx.needs_input_grad[2]:
grad_coefs = dictionary.t().mm(grad_weight)
if ctx.needs_input_grad[3]:
grad_dictionary = grad_weight.t().mm(grad_coefs.t())
if bias is not None and ctx.needs_input_grad[4]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_coefs, grad_dictionary, grad_bias
The layer is defined as:
class FCDecomp(nn.Module):
def __init__(self, coefs, dictionary, bias_val, input_features, output_features, bias=True):
super(FCDecomp, self).__init__()
self.dictionary = nn.Parameter(dictionary, requires_grad=False).cuda()
self.coefs = nn.Parameter(coefs, requires_grad=True).cuda()
if bias:
self.bias = nn.Parameter(bias_val, requires_grad=True).cuda()
else:
self.register_parameter('bias', None)
def forward(self, input):
return LinearDecomp.apply(input, self.coefs, self.dictionary, self.bias)
Could somebody please help me on this?