I have a large 4-dimensional tensor,
self.succ_feats, and would frequently like to access elements in this tensor by passing in a list,
cords, which corresponds to a set of elements in this tensor. This list of coordinates is also quite large. How can I speed up this operation, as well as any of the subsequent tensor operations?
self.succ_feats is of shape
[70, 10, 10, 400], and
cords is of shape
[100000,2,2], meaning that the resulting tensor
selected_succ_feats is of shape
[100000, 70, 2, 400] which is both extremely slow to process and usually too large to store all at once.
selected_succ_feats =[self.succ_feats[:,x,y].double() for x,y in cords.long()] selected_succ_feats = torch.stack(selected_succ_feats) #shape = [100000, 70, 2, 400] vs = self.linear1(selected_succ_feats) #shape = [100000, 70, 2, 1] del selected_succ_feats v_pi_approx = torch.sum(torch.mul(self.softmax(vs),vs),dim = 1) #shape = [100000, 1, 2, 1] v_pi_approx = torch.squeeze(v_pi_approx) #shape = [100000, 2]