How to design a custom layer with single weight and some masked nets

I am new to pytorch and I want to write a layer in which the input/output connecteion
is defined through an adjacency matrix, and all existing edges should
have same weight. My current implementation is as follows:

class myFunc(Function):
    def __init__(self):
        super(myFunc, self).__init__()
    def forward(self, input, weight,adj):
        a = torch.zeros_like(adj,dtype=torch.float)
        a[adj] = weight
        return input.mm(a)

class MyLayer(nn.Module):
    def __init__(self,adj):
        super(MyLayer, self).__init__()
        self.adj = adj
        self.weight = nn.Parameter(data=torch.Tensor(1), requires_grad=True)
        nn.init.normal_(self.weight.data, mean=1, std=0.1)

    def forward(self, input):
        return myFunc().forward(input,self.weight,self.adj)

This approach does not use less memory than the case that I have a separate weight per net.
Since I have a single parameter to learn (one weight for the entire layer), is there any way that I can implement this in a more efficient way?

Hi,

You don’t need a custom Function here.
You should just move that forward code into the nn.Module’s forward function.

1 Like

Don’t think so - here memory is consumed to preserve original input tensor for backprop (mm op), so right matrix “sparsity” makes no difference. Maybe take a look at torch.utils.checkpoint if memory is an issue.

1 Like