I need to gather vectors from a tensor, some of which may be redundant. Here’s a toy example:
tmp = Variable(torch.randn((5,2)))
Variable containing:
0.2994 0.9508
0.9271 0.0781
0.2100 0.0150
-0.4117 0.8177
-1.0631 1.0050
[torch.FloatTensor of size 5x2]
ind = Variable(torch.LongTensor([0, 1, 0, 0, 1, 1, 0, 1])) #redundant, longer than tmp
tmp.gather(1, ind.expand(5,-1))
Variable containing:
0.2994 0.9508 0.2994 0.2994 0.9508 0.9508 0.2994 0.9508
0.9271 0.0781 0.9271 0.9271 0.0781 0.0781 0.9271 0.0781
0.2100 0.0150 0.2100 0.2100 0.0150 0.0150 0.2100 0.0150
-0.4117 0.8177 -0.4117 -0.4117 0.8177 0.8177 -0.4117 0.8177
-1.0631 1.0050 -1.0631 -1.0631 1.0050 1.0050 -1.0631 1.0050
[torch.FloatTensor of size 5x8]
This works, but I’m trying to reduce memory usage as much as possible. Does gather
copy data by default? If not, is there some other way to accomplish this without copying data?
I can’t get around the issue by removing redundancies; they are needed as each gathered vector will be handled by different parameters of my network.
Thank you!