Integrating a new loss function to autograd

I’m trying to integrate the SpatialWeightedClassNLLCriterion.c written by @kiranvaidhya into pytorch. It’s updateOutput and updateGradInput takes in an extra variable weight_map . We have got it working in legacy by adding an extra class SpatialWeightedClassNLLCriterion in the module torch.legacy.nn.

To use it with autograd, I added a new function _make_function_weighted_class_criterion similar to _make_function_class_criterion in nn/_functions/thnn/auto.py to allow additional argument weight_map. Now the forward() pass is working fine, but there are some problems with backward(). I got the following error, when I executed the backward() function.

Traceback (most recent call last):
File “/workspace/test.py”, line 22, in
loss.backward()
File “/opt/pytorch/torch/autograd/variable.py”, line 151, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
File “/opt/pytorch/torch/autograd/init.py”, line 98, in backward
variables, grad_variables, retain_graph)
RuntimeError: WeightedNLLLoss returned an invalid number of gradient tensors (expected 3, but got 2)


  • _make_function_weighted_class_criterion definition can be seen below.
def _make_function_weighted_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters):
    weight_arg_idx = -1

    for i, arg in enumerate(update_output.arguments):
        if arg.name.startswith('weight') and arg.name != 'weight_map':
            weight_arg_idx = i
            break

    buffers_idx = []
    additional_arg_idx = 0
    for arg in update_output.arguments[5:]:
        if not arg.name.startswith('weight') and arg.type == 'THTensor*':
            buffers_idx.append(additional_arg_idx)
        additional_arg_idx += 1

    def __init__(self, *args, **kwargs):
        Function.__init__(self)
        self.weight = kwargs.get('weight')
        self.additional_args = list(args)

    def forward(self, input, target, weight_map):
        self._backend = type2backend[type(input)]
        self.save_for_backward(input, target, weight_map)
        if weight_arg_idx >= 0:
            insert_idx = weight_arg_idx - 5  # state, input, target, weight_map, output
            self.additional_args.insert(insert_idx, self.weight)
        for idx in buffers_idx:
            self.additional_args.insert(idx, input.new(1))
        output = input.new(1)
        getattr(self._backend, update_output.name)(self._backend.library_state, input, target, weight_map,
                                                   output, *self.additional_args)
        return output

    def backward(self, grad_output):
        input, target, weight_map = self.saved_tensors
        grad_input = grad_output.new().resize_as_(input).zero_()
        getattr(self._backend, update_grad_input.name)(self._backend.library_state, input, target, weight_map,
                                                       grad_input, *self.additional_args)
        grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim()))
        grad_input.mul_(grad_output_expanded.expand_as(grad_input))
        return grad_input, None

    return type(class_name, (Function,), dict(__init__=__init__, forward=forward, backward=backward))
  • Could someone point me in the right direction. Am I missing anythin in the backward() function ?

Hello,

When forward takes three params, autograd expects backward to return three parameters, even if weights is not intended to be differentiable. Thus you would want to return another None.

Is there a reason to not use a python class properly?

Best regards

Thomas

1 Like

Thanks @tom :slight_smile:

We have implemented it in C, since it can boost the performance.