I have a large 4-dimensional tensor, self.succ_feats
, and would frequently like to access elements in this tensor by passing in a list, cords
, which corresponds to a set of elements in this tensor. This list of coordinates is also quite large. How can I speed up this operation, as well as any of the subsequent tensor operations?
For references, self.succ_feats
is of shape [70, 10, 10, 400]
, and cords
is of shape [100000,2,2]
, meaning that the resulting tensor selected_succ_feats
is of shape [100000, 70, 2, 400]
which is both extremely slow to process and usually too large to store all at once.
selected_succ_feats =[self.succ_feats[:,x,y].double() for x,y in cords.long()]
selected_succ_feats = torch.stack(selected_succ_feats) #shape = [100000, 70, 2, 400]
vs = self.linear1(selected_succ_feats) #shape = [100000, 70, 2, 1]
del selected_succ_feats
v_pi_approx = torch.sum(torch.mul(self.softmax(vs),vs),dim = 1) #shape = [100000, 1, 2, 1]
v_pi_approx = torch.squeeze(v_pi_approx) #shape = [100000, 2]