Weight vector in PyTorch


(Shani Gamrian) #1

I have a 4x4 matrix (let’s say it consists v1,v2,v3,v4) and I want to learn 4 parameters (a1,a2,a3,a4) that sum to 1 and multiply them and the matrix in order to learn which of the vectors are more important (normalized weight vector). Which is the best way to do that in PyTorch?


(Thomas V) #2

So the v-matrix is fixed?
You could use a nn.Parameter for pre-normalized a, softmax and then broadcasting: v_weighted = v * a.softmax(0).unsqueeze(0)) if you want to weight the columns of v or .unsqueeze(1) if you want to weight the rows. v would be a variable (or a buffer if you want to have it saved in the state_dict).

Best regards

Thomas


(Shani Gamrian) #3

Good idea!
Unfortunately the values of a don’t change. Here is the relevant part of my code (I’m using PyTorch 0.1.12):

self.weight_vector = Variable(torch.FloatTensor(1, 4), requires_grad=True)

def forward(self, v): # v.size() = (4,36,256)
    norm_weight_vec = F.softmax(self.weight_vector)
    v = v * (norm_weight_vec.transpose(0, 1).unsqueeze(2)).expand_as(x)
    v = torch.sum(x, dim=0) # result size: (1,36,256)
    ....

(Thomas V) #4

I think you want to make the weight_vector a nn.Parameter of the module to make it show up in m.parameters().

Best regards

Thomas