Speeding up large tensor operations

I have a large 4-dimensional tensor, self.succ_feats, and would frequently like to access elements in this tensor by passing in a list, cords, which corresponds to a set of elements in this tensor. This list of coordinates is also quite large. How can I speed up this operation, as well as any of the subsequent tensor operations?

For references, self.succ_feats is of shape [70, 10, 10, 400], and cords is of shape [100000,2,2], meaning that the resulting tensor selected_succ_feats is of shape [100000, 70, 2, 400] which is both extremely slow to process and usually too large to store all at once.

selected_succ_feats =[self.succ_feats[:,x,y].double() for x,y in cords.long()]
selected_succ_feats = torch.stack(selected_succ_feats) #shape = [100000, 70, 2, 400]
vs = self.linear1(selected_succ_feats) #shape = [100000, 70, 2, 1]
del selected_succ_feats
v_pi_approx = torch.sum(torch.mul(self.softmax(vs),vs),dim = 1) #shape = [100000, 1, 2, 1]
v_pi_approx = torch.squeeze(v_pi_approx)  #shape = [100000, 2]

Welcome to the forums!

List comprehension can be passed directly into a torch.stack([]). Also, you should have your .double() and .long() called before running the list comprehension. Trying to do these inside of a list comprehension will slow things down immensely.

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

succ_feats = torch.rand((70, 10, 10, 400), device = device)
coords = torch.randint(0,10, (100000, 2, 2), device = device)
linear1 = nn.Linear(400, 1).to(device)
print("Tensors made", time.time()-start_time)
selected_succ_feats = torch.stack([succ_feats[:,x,y] for x,y in coords]) #shape = [100000, 70, 2, 400]
print("Selection done", time.time()-start_time)
vs = linear1(selected_succ_feats.float()) #shape = [100000, 70, 2, 1]
print("Linear done", time.time()-start_time)
del selected_succ_feats
v_pi_approx = torch.sum(torch.mul(F.softmax(vs),vs),dim = 1) #shape = [100000, 1, 2, 1]
v_pi_approx = torch.squeeze(v_pi_approx)  #shape = [100000, 2]

print("Completed", time.time()-start_time)

In fact, you might try different data types to see what works best for what you need. I.e. float or half instead of double might be more than sufficient for your purposes and help give you a big speedup with less memory. My guess is you’re running into cached memory, which will slow things down quite a bit with the read/write involved.

The above took under a minute on cpu with 64gb of ram. If I put the succ_feats to double, it started using cached memory which made the time indeterminant.

Ok, thanks! Do you think there might be any benefit to trying to use Pytorch’s Dataloader here?

Yes, if you can batch the operations out, that should eliminate memory overflows. You could use the DataLoader. In that case, just make clear your definitions __init__, __getitem__ and __len__ in the Dataset object. For example:

class CustomDataset(Dataset):
    def __init__(self,  device):
        self.succ_feats = torch.rand((70, 10, 10, 400), device = device)
        self.coords = torch.randint(0,10, (100000, 2, 2), device = device)

    def __len__(self):
        return len(self.coords.size(0))

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        item = self.succ_feats[ : , x, y] 
        return item
1 Like