Custom autograd.Function: must it be static?

Hello,

I noticed that a custom Function can be built in two different ways:

(1)

class Test(torch.autograd.Function):
    def __init__(self):
        super(Test,self).__init__()

    def forward(self, x1, x2):
        self.state = state(x1)
        return torch.arange(8)

    def backward(self, grad_out):
        grad_input = grad_out.clone()
        return torch.arange(10,18),torch.arange(20,28)

# then use function = Test()

or
(2)

class Test(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x1, x2):
        ctx.state = state(x1)
        return torch.arange(8)

    @staticmethod
    def backward(ctx, grad_out):
        grad_input = grad_out.clone()
        return torch.arange(10,18),torch.arange(20,28)

# then use function = Test.apply

The second option is the only to be referred in the official documentation, but it can be burdensome because if there are options as to how the function should be implemented, they must be initialized globally.

I prefer option (1) since I can save an intermediate variable (state) by myself and also initialize the function with some options, but it is not present in the official documentation

Am I safe in using option (1)?

1 Like

Hi,

Option (1) is the old way to define Functions. This does not support gradients of gradients and it’s support might be discontinued in the future (not sure when).
The second one is the way to go. Note that you can do exactly the same thing as you can save arbitrary stuff in the ctx (the same way you would save in self in (1)), and the apply method that calls forward accept any parameter, so you can just pass what you used to give the __init__() function here. That means that you don’t need to define options globally, just pass them to the forward method.

7 Likes

Hi,
Could someone please help me in converting this snippet to the newer way of defining custom layers:

class GradReverse(Function):
    def __init__(self, lambd):
        self.lambd = lambd
    #@staticmethod
    def forward(self, x):
        return x.view_as(x)
    #@staticmethod
    def backward(self, grad_output):
        return (grad_output * -self.lambd)```

Thanks a lot,
Megh

Hi,

Sure! here you go:

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambd):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return (grad_output * -ctx.lambd), None

And you replace:

GradReverse(lambd)(inp)

with

GradReverse.apply(inp, lambd)
3 Likes

Hi @albanD,
Thanks a lot for your help!
Megh

Hello @albanD, I have converted the custom autograd.Function code as new way, but I can’t save the torch model, do you have any suggests pls ?

class bin(torch.autograd.Function):
    """bin {0, 1} a real valued tensor."""

    @staticmethod
    def forward(ctx, inputs, threshold=DEFAULT_THRESHOLD):
        ctx.threshold = threshold
        outputs = inputs.clone()
        outputs[inputs.le(ctx.threshold)] = 0
        outputs[inputs.gt(ctx.threshold)] = 1
        return outputs

    @staticmethod
    def backward(ctx, gradOutput):
        return gradOutput, None

torch.save(ckpt, savename)
File “/home/quannguyen/.local/lib/python3.6/site-packages/torch/serialization.py”, line 372, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File “/home/quannguyen/.local/lib/python3.6/site-packages/torch/serialization.py”, line 476, in _save
pickler.dump(obj)
TypeError: can’t pickle bin objects

hello,i wanna ask the problem i have faced.I hope that you can help me.

def conv_offset2d(input,
                  offset,
                  weight,
                  stride=1,
                  padding=0,
                  dilation=1,
                  deform_groups=1):

    if input is not None and input.dim() != 4:
        raise ValueError(
            "Expected 4D tensor as input, got {}D tensor instead.".format(
                input.dim()))
    stride=parse(stride)
    padding = parse(padding)
    dilation = parse(dilation)
    f = ConvOffset2dFunction(
        _pair(stride), _pair(padding), _pair(dilation), deform_groups)
    return f(input, offset, weight)


class ConvOffset2dFunction(Function):
    def __init__(self, stride, padding, dilation, deformable_groups=1):
        super(ConvOffset2dFunction, self).__init__()
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.deformable_groups = deformable_groups

    @staticmethod
    def forward(self, input, offset, weight):
        self.save_for_backward(input, offset, weight)

        output = input.new(*self._output_size(input, weight))

        self.bufs_ = [input.new(), input.new()]  # columns, ones

        if not input.is_cuda:
            raise NotImplementedError
        else:
            if isinstance(input, torch.autograd.Variable):
                if not isinstance(input.data, torch.cuda.FloatTensor):
                    raise NotImplementedError
            else:
                if not isinstance(input, torch.cuda.FloatTensor):
                    raise NotImplementedError
            deform_conv.deform_conv_forward_cuda(
                input, weight, offset, output, self.bufs_[0], self.bufs_[1],
                weight.size(3), weight.size(2), self.stride[1], self.stride[0],
                self.padding[1], self.padding[0], self.dilation[1],
                self.dilation[0], self.deformable_groups)
        return output

    @staticmethod
    def backward(self, grad_output):
        input, offset, weight = self.saved_tensors

        grad_input = grad_offset = grad_weight = None

        if not grad_output.is_cuda:
            raise NotImplementedError
        else:
            if isinstance(grad_output, torch.autograd.Variable):
                if not isinstance(grad_output.data, torch.cuda.FloatTensor):
                    raise NotImplementedError
            else:
                if not isinstance(grad_output, torch.cuda.FloatTensor):
                    raise NotImplementedError
            if self.needs_input_grad[0] or self.needs_input_grad[1]:
                grad_input = input.new(*input.size()).zero_()
                grad_offset = offset.new(*offset.size()).zero_()
                deform_conv.deform_conv_backward_input_cuda(
                    input, offset, grad_output, grad_input,
                    grad_offset, weight, self.bufs_[0], weight.size(3),
                    weight.size(2), self.stride[1], self.stride[0],
                    self.padding[1], self.padding[0], self.dilation[1],
                    self.dilation[0], self.deformable_groups)

            if self.needs_input_grad[2]:
                grad_weight = weight.new(*weight.size()).zero_()
                deform_conv.deform_conv_backward_parameters_cuda(
                    input, offset, grad_output,
                    grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3),
                    weight.size(2), self.stride[1], self.stride[0],
                    self.padding[1], self.padding[0], self.dilation[1],
                    self.dilation[0], self.deformable_groups, 1)

        return grad_input, grad_offset, grad_weight

    def _output_size(self, input, weight):
        channels = weight.size(0)

        output_size = (input.size(0), channels)
        for d in range(input.dim() - 2):
            in_size = input.size(d + 2)
            pad = self.padding[d]
            kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1
            stride = self.stride[d]
            output_size += ((in_size + (2 * pad) - kernel) // stride + 1, )
        if not all(map(lambda s: s > 0, output_size)):
            raise ValueError(
                "convolution input is too small (output would be {})".format(
                    'x'.join(map(str, output_size))))
        return output_size

@_hu please follow the doc on how to write a custom Function here to avoid this error. :slight_smile: