Hi,
I’m trying to extend the Function class but when I call backwards the code does not reach the backward method, any help will be appreciated.
My forward method:
def forward(ctx, input):
_index = torch.Tensor().cuda()
one = torch.ones(1).cuda()
if len(input.size()) == 1:
input = torch.unsqueeze(input, 0)
output = torch.zeros(input.size()).cuda()
_index = torch.multinomial(input + constants.epsilon, 1, False)
output.scatter_(1, _index, torch.unsqueeze(one.repeat(_index.size()[0]),1))
ctx.mark_dirty(input)
ctx.save_for_backward(input)
return _index.float()