Shared weights using permutation view sematics

Hi Ken!

Pytorch tensor indexing should let you use a tensor of your permuted indices
to index into your A and w tensors.

I don’t follow the details of your use case, but here is an example where the
permutation indices are packaged as 1d tensors:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> m = 5
>>> n = 7
>>>
>>> A = torch.randn (m, n)
>>> w = torch.randn (m, n)
>>>
>>> p = torch.randperm (5)
>>> q = torch.randperm (5)
>>>
>>> B = (A[:, p] * w[:, q]).sum (dim = 1)
>>>
>>> B
tensor([-4.3236,  0.2211, -0.5869,  0.6099, -1.7975])

Best.

K. Frank

1 Like