Datatype behaviour when passing in Variables/Parameters to functions

Hello everyone,
I’m trying to build a custom module layer which itself uses a custom function. Then, inside this function it would be nice, if I could use existing functions. As a simplified example I wrapped a Linear Layer inside my function and try to pass its weights as a parameter from the “surrounding” module. However, when I first try to execute the code like quoted below I ran into my first problem: When passing a Variable to the Module LinLayer, the datatypes are as expected: (Variable & Parameter), BUT when I check the types inside the function, they changed to Tensors. Why is that? This results in the error “TypeError: cannot assign ‘torch.FloatTensor’ as parameter ‘weight’ (torch.nn.Parameter or None expected)

import torch

class linFct(torch.autograd.Function):
    def forward(self, fctDataIn, fctWeight):
        self.save_for_backward(fctDataIn, fctWeight)
        print("function")
        print("data: ", type(fctDataIn))
        print("weight: ", type(fctWeight))
        tmpLin = torch.nn.Linear(3, 2, bias=False)
        tmpLin.weight = fctWeight
        return tmpLin(fctDataIn)

    def backward(self, grad_output):
        fctDataIn, fctWeight = self.saved_tensors
        tmpLin = torch.nn.Linear(3, 2, bias=False)
        tmpLin.weight = fctWeight
        tmpLin.zero_grad()
        tmpLin(fctDataIn).backward(grad_output)
        grad_fctDataIn = fctDataIn.grad.data
        grad_fctWeight = fctWeight.grad.data
        return grad_fctDataIn, grad_fctWeight


class linLayer(torch.nn.Module):
    def __init__(self):
        super(linLayer, self).__init__()
        self.wParam = torch.nn.Parameter(torch.randn(2, 3))
        self.fct = linFct()

    def forward(self, layerIn):
        print("layer")
        print("data: ", type(layerIn))
        print("weight: ", type(self.wParam))
        return self.fct(layerIn, self.wParam)

x = torch.autograd.Variable(torch.randn(2, 3), requires_grad=True)
fct = linLayer()
print("forward...")
y = fct(x)

See point 2 and point 3 here. backward in these style functions unwraps Variables into Tensor, and you should return Tensor
http://pytorch.org/docs/notes/extending.html#extending-torch-autograd

Thank you for the reference @smth !