torch.autograd.Function returns error asking for (depricated) Variable when passed ParameterList

Hi,

I’m trying to pass ParameterList to the forward and backward pass to torch.autograd.Function, however the error as returned to me asks for a Variable.

As some example code let’s borrow and adjust some code from the documentation to replicate the behavior:

import torch
from torch.autograd import Function
from numpy import flip
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
import torch.nn as nn
from torch.autograd import Variable

class ScipyConv2dFunction(Function):
  @staticmethod
  def forward(ctx, input, filter, bias, params0, params1):
    # detach so we can cast to NumPy
    input, filter, bias = input.detach(), filter.detach(), bias.detach()
    # Creating new lists from the list of params in a basic way. We add 0.01
    # to each value in the list to represent a change between the input
    # list and the output list. In practice the values of the list may change
    # much more depending on the functionality of the forward pass. For purposes
    # of brevity we keep such a change to be only "+ 0.01"
    param_new0=[]
    for param in params0:
      param_new0.append(param.detach().numpy()[0] + 0.01)
    param_new1=[]
    for param in params1:
      param_new1.append(param.detach().numpy()[0] + 0.01)

    result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
    result += bias.numpy()

    param_new0 = torch.as_tensor(param_new0).to(torch.float64)
    param_new1 = torch.as_tensor(param_new1).to(torch.float64)
    ctx.save_for_backward(input, filter, bias, param_new0, param_new1)
    return torch.as_tensor(result, dtype=input.dtype)

  @staticmethod
  def backward(ctx, grad_output):
    grad_output = grad_output.detach()
    input, filter, bias, param_new0, param_new1 = ctx.saved_tensors
    grad_output = grad_output.numpy()
    # Converting the parameter lists to NumPy
    param_new0 = param_new0.numpy()
    param_new1 = param_new1.numpy()
    # Making a small change in the lists to represent a change due to backprop.
    # In practice this change may be much larger, however for purposes of
    # brevity we keep such a change to be only "- 0.05"
    param_new0 = param_new0 - 0.05
    param_new1 = param_new1 - 0.05
    grad_bias = np.sum(grad_output, keepdims=True)
    grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
    # the previous line can be expressed equivalently as:
    # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
    grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
    # returning our gradients including our parameter lists
    return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float), torch.from_numpy(param_new0).to(torch.float), torch.from_numpy(param_new1).to(torch.float) 


class ScipyConv2d(Module):
  def __init__(self, filter_width, filter_height):
    super(ScipyConv2d, self).__init__()
    self.filter = Parameter(torch.randn(filter_width, filter_height))
    self.bias = Parameter(torch.randn(1, 1))
    # Defining new ParameterLists
    self.params0 = nn.ParameterList([Parameter(torch.randn(1)), Parameter(torch.randn(1))])
    self.params1 = nn.ParameterList([Parameter(torch.randn(1)), Parameter(torch.randn(1))])

  def forward(self, input):
    return ScipyConv2dFunction.apply(input, self.filter, self.bias, self.params0, self.params1)
    #return ScipyConv2dFunction.apply(input, self.filter, self.bias, Variable(self.params0), Variable(self.params1))

module = ScipyConv2d(3, 3)
print("Filter and bias: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print("Output from the convolution: ", output)
output.backward(torch.randn(8, 8))
print("Gradient for the input map: ", input.grad)

The Traceback reads:

RuntimeError                              Traceback (most recent call last)

<ipython-input-13-e294f85b0ebd> in <module>()
     79 output = module(input)
     80 print("Output from the convolution: ", output)
---> 81 output.backward(torch.randn(8, 8))
     82 print("Gradient for the input map: ", input.grad)

/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    305                 create_graph=create_graph,
    306                 inputs=inputs)
--> 307         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    308 
    309     def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    154     Variable._execution_engine.run_backward(
    155         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 156         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    157 
    158 

RuntimeError: function ScipyConv2dFunctionBackward returned a gradient different than None at position 4, but the corresponding forward input was not a Variable

This error is a bit strange since Variable is deprecated.

If follow the error’s advice and make the change:

    #return ScipyConv2dFunction.apply(input, self.filter, self.bias, self.params0, self.params1)
    return ScipyConv2dFunction.apply(input, self.filter, self.bias, Variable(self.params0), Variable(self.params1))

We get the error:

TypeError                                 Traceback (most recent call last)

<ipython-input-14-d575e588a022> in <module>()
     77 print("Filter and bias: ", list(module.parameters()))
     78 input = torch.randn(10, 10, requires_grad=True)
---> 79 output = module(input)
     80 print("Output from the convolution: ", output)
     81 output.backward(torch.randn(8, 8))

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-14-d575e588a022> in forward(self, input)
     72   def forward(self, input):
     73     #return ScipyConv2dFunction.apply(input, self.filter, self.bias, self.params0, self.params1)
---> 74     return ScipyConv2dFunction.apply(input, self.filter, self.bias, Variable(self.params0), Variable(self.params1))
     75 
     76 module = ScipyConv2d(3, 3)

TypeError: Variable data has to be a tensor, but got ParameterList

At this point I’m not sure what to do, since the ParameterList should be properly registered.

Any ideas on how to solve this error and successfully update the Parameters in ParameterList (and ideally not using Variable since it’s deprecated)? Also why does the error mention Variable even though it’s deprecated?

For context I have multiple (>4) ParameterList’s in my code each of which contain 100+ different parameters, so to write out and pass each Parameter doesn’t make sense (and hence why in the code above I define 2 ParameterList’s to better replicate how my code actually looks).

The error is a bit misleading as it doesn’t expect the deprecated Variable class exposed in Python but the internal one.
It’s raised gradients are expected for each tensors input argument in the forward. Currently you are trying to return a tensor gradient for an nn.ParameterList, which won’t work.

@ptrblck :

The error is a bit misleading as it doesn’t expect the deprecated Variable class exposed in Python but the internal one.

OK- got it.

It’s raised gradients are expected for each tensors input argument in the forward. Currently you are trying to return a tensor gradient for an nn.ParameterList, which won’t work.

OK, so ParameterList doesn’t act like a tensor (I was hoping it would act like a tensor of tensors). Any suggestions on what the correct implementation is? I’m wondering if what we’re seeing is expected behavior or if it’s just functionality that hasn’t been created yet…

I think both, it’s an expected error as the method is expecting gradients for tensor inputs and None for other objects and most likely more valid types weren’t requested yet, but @albanD can correct me.

I’m also not familiar with your use case and don’t know why gradients should be returned for parameter inputs as I would expect to see gradients for input activations.

1 Like

most likely more valid types weren’t requested yet

Good to know. @albanD - can you comment that torch.autograd.Function doesn’t support ParameterList yet? Or is there something wrong in posted code?

don’t know why gradients should be returned for parameter inputs as I would expect to see gradients for input activations

Do you mean that you don’t see the gradient for correlate2d? Looking at the code it’s my understanding that the gradient for that function is reflected in backward and given by the gradient of the inputs as generated by convolve2d.

As mentioned on a github issue. It is expected. autograd.Function never open up data structures.

@albanD - finally found the github issue on this. Thanks for the heads up about it and for explaining it in the issue.

Hey!
Sorry about that, since the issue was opened at the same time as this post, I though the two were by the same author. My bad! Hope you didn’t loose too much time finding it.