The current recommendation for nested tensors is to create them with a torch.jagged
layout as opposed to the torch.strided
layout. However, I have noticed a considerable drawback to this - if I have a nested tensor in a Dataset
for use in a DataLoader
, the dataset using jagged layouts is massively slower. For example:
import torch
# dataset with nested tensor
class TestDatasetNested(torch.utils.data.Dataset):
def __init__(self, N=100, dl=100, layout=torch.jagged):
self.N = N
self.dl = dl
# set std deviations
nnoise = torch.randint(1, high=5, size=(self.N,))
sigmas = [torch.rand(n) for n in nnoise]
self.sigmas = torch.nested.nested_tensor(sigmas, layout=layout)
def __getitem__(self, i):
sigmas = self.sigmas[i % self.N]
return torch.cat([sigma * torch.randn((self.dl, 1)) for sigma in sigmas], dim=-1).sum(dim=-1)
def __len__(self):
return self.N
# create dataset with jagged and strided nested tensor layouts
dataset_jagged = TestDatasetNested(layout=torch.jagged)
dataset_strided = TestDatasetNested(layout=torch.strided)
# create dataloader for both cases
dl_jagged = torch.utils.data.DataLoader(dataset=dataset_jagged, batch_size=10)
dl_strided = torch.utils.data.DataLoader(dataset=dataset_strided, batch_size=10)
def nepochs(dl, n=100):
for i in range(n):
for _ in dl:
pass
%timeit nepochs(dl_jagged)
4.57 s ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit nepochs(dl_strided)
322 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
So, in this case using the jagged layout is ~14 times slower. Does anyone know what this is the case and if it can be improved?