How to best create a model that shares weights between different dimensions of the input features

How to best create a model that shares weights between different dimensions of the input features?
I tried multiple ways of implementing this. One way that it works looks like this:

class DynamicWeightSharingModel(torch.nn.Module):
    def __init__(self, ranges:dict):
        super().__init__()
        dim = len(ranges.keys())
        print(dim)
        self.w = torch.nn.Parameter(torch.zeros((dim,)))
        self.ranges = ranges
        for idx,name in enumerate(self.ranges.keys()):
            self.register_buffer(name,torch.ones((self.ranges[name])))

    def forward(self, x)
        shared_weigts = []
        for idx,name in enumerate(self.ranges.keys()):
            shared_weigts.append(self.w[idx] * self.__dict__['_buffers'][name])
        shared_weigts = torch.concat(shared_weigts, dim=0)
        res = (x1 * shared_weigts)
        return res

The idea is the following, let x be 3d features vector x=[x1,x2,x3], i want to be able to multiply x1 and x2 with w1 and x3 with w2. Unfortunately it requires to create shared_weights sequentially in every iteration. So maybe there is a better way?
I tried putting the creation of “shared_weigts” into the init function but then there will be problems when calling loss.backward() a second time.
I also came up with another implementation where i add an additional dimension to the input for all feature dimensions that are supposed to be multiplied by the same weight. So e.g let x=[x1,x2,x3,x4] => [[x1,x2],[x2,x3]] and then use broadcasting with w=[w1,w2].
Unfortunately this requires to use channels of the same dimension unlike in the example above.

how about initializing just one unified shared weight, but keep different mask buffers. In forward(), you select corresponding mask up to input and multiply it by the weight.

1 Like

i will try that tomorrow, thank you

Here is what I came up with implementing the mask idea from huahuanZ:

class DynamicWeightSharingModel(torch.nn.Module):
    def __init__(self, ranges:dict ,dim:int):
        super().__init__()
        self.w = torch.nn.Parameter(torch.zeros((len(ranges.keys()),)))
        print(self.w.shape)
        self.ranges = ranges
        start = 0
        masks = []
        for idx,name in enumerate(self.ranges.keys()):
            mask = torch.zeros((dim,1), dtype=bool)
            mask[start:start+self.ranges[name]] =True
            start += self.ranges[name]
            masks.append(mask)
        self.masks = torch.concat(masks, dim=1)
        self.register_buffer("maskes",self.masks)

    def forward(self, x):
        res = x[:,:,None]
        res = res.expand((res.shape[0], res.shape[1],self.masks.shape[1]))
        res = (res * self.w)[:,self.masks]
        return res.sum(dim=1)

It uses expand to create a view that replicates the input according to the number of masks and stacks them, then it will be multiplied by the weight vector and then the masks are applied. The downside to this is that this requires as many multiplications with w as there are mask and also it might not be very memory efficient since i assume the result of (res * self.w) will be stored in memory first. So is there a better way still?

I found that the most efficient way to implement this is using:
torch.repeat_interleave