Speed up collating very large tensors?

I’m working with a dataset where each sample is 120x1024 dimensions. I’m hoping to use a large batch size 4096 (each batch is 4096x120x1024) but am experiencing very slow dataloading even with num_workers=20.

Below is an example of my code. I’m getting 4 iterations/second on my machine. This is a bit too slow to finish training in a reasonable amount of time.

I’ve narrowed the bottleneck down to the torch.stack() call in collate. This seems to be 10x slower than everything else. I believe what is happening is the data is being copied to a contiguous block, hence the slowdown?

Is there any way to speed up dataloading in my case?

class MyDataset(Dataset):
    def __init__(self):
          ...

    def __getitem__(self, index):

         return torch.ones([120, 1024])        


train_dataset = MyDataset()

def collate(batch):
    
    data = torch.stack([b[0] for b in batch])
    
    return data

train_sampler = RandomSampler(train_dataset)

train_dataloader = DataLoader(
    train_dataset, sampler=train_sampler, batch_size=4096,
    num_workers=20, pin_memory=True, collate_fn=collate
)

1 Like

Note that too many workers might slow down your system, so you should test different values for your current setup.

That being said, to avoid the torch.stack call, you could use a BatchSampler to pass a batch of indices to your Dataset.__getitem__, preallocate the final tensor via torch.empty(batch_size, 120, 1024) and copy the tensors into it.

Thanks @ptrblck! The idea of copying the data into the preallocated tensor helps (it’s now 2x faster) but it seems to get slowed down because I need to do a for-loop copy.

for i, b in eumerate(batch):
    final_tensor[i] = b[0]

I think if I try to construct a (batch_size, 120, 1024) tensor at any point it will be slow so it seems like I actually need to loop over individual small tensors (1, 120, 1024)? I’m not sure how to use BatchSampler here such that it would help.

I think the approach should be right, since you would need the loop at one point anyway to load and process each image, no?

In the standard approach, the DataLoader will load each image one by one and use the collate_fn to create the batch. Now you could push this loop into the __getitem__ to load each sample in the loop and copy the data into the preallocated tensor.

Let me know, if I misunderstood the use case.

If we push the for-loop into the __getitem__ function, does that mean it’s 1 worker doing the loop? Would that be faster than just doing the for-loop in collate?

I was wondering if it is possible to assign a row index of the preallocated tensor to each worker and have them copy the data into that index of the tensor in __getitem__?

Yes, but also in the default setup, where you are using a single index, a single worker will create the batch, so there shouldn’t be a difference I assume (each worker creates its own batch).

Multiple workers do not create the same batch, but each worker will build its own batch.
There is a feature request letting multiple workers work on the same batch, but I don’t think it’s ready yet.

Assuming your tensors are uint8 (images with values between 0-255 inclusive), it’s much faster to convert to Numpy, then stack in Numpy, then convert to torch.

Random tensors:

tensor_list = [torch.randint(size=(3, 120, 1024), low=0, high=255, dtype=torch.uint8) for _ in range(4096)]

Stacking in torch:

def stack_via_torch1(tensor_list):
    return torch.stack(tensor_list)

def stack_via_torch2(tensor_list):
    stacked_ten = torch.empty((len(tensor_list), *tensor_list[0].shape), dtype=torch.uint8)
    for i, ten in enumerate(tensor_list):
        stacked_ten[i] = ten
    return stacked_ten

%timeit stack_via_torch1(tensor_list)
946 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit stack_via_torch2(tensor_list)
972 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Stacking via NumPy (~1.8x faster for 4096 tensors):

def stack_via_numpy1(tensor_list):
    return torch.from_numpy(np.array([ten.numpy() for ten in tensor_list]))

def stack_via_numpy2(tensor_list):
    np_data: np.ndarray = np.empty(len(tensor_list), dtype=object)
    for i, ten in enumerate(tensor_list):
        np_data[i] = tensor_list[i]
    return torch.from_numpy(np.stack(np_data))

def stack_via_numpy3(tensor_list):
    np_data: np.ndarray = np.empty((len(tensor_list), *tensor_list[0].shape), dtype=np.uint8)
    for i, ten in enumerate(tensor_list):
        np_data[i] = ten
    return torch.from_numpy(np_data)

def stack_via_numpy4(tensor_list):
    np_data: np.ndarray = np.empty(len(tensor_list), dtype=object)
    for i, ten in enumerate(tensor_list):
        np_data[i] = ten.numpy()
    return torch.from_numpy(np.stack(np_data))

def stack_via_numpy5(tensor_list):
    np_data: np.ndarray = np.empty(len(tensor_list), dtype=object)
    for i, ten in enumerate(tensor_list):
        np_data[i] = tensor_list[i]
    np_empty = np.empty((len(tensor_list), *tensor_list[0].shape), dtype=np.uint8)
    return torch.from_numpy(np.stack(np_data, out=np_empty))


%timeit stack_via_numpy1(tensor_list)
526 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit stack_via_numpy2(tensor_list)
533 ms ± 2.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit stack_via_numpy3(tensor_list)
533 ms ± 2.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit stack_via_numpy4(tensor_list)
528 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit stack_via_numpy5(tensor_list)
534 ms ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Thus stacking via Numpy is ~1.8x faster for 4096 tensors.

Additional:

  • I prefer stack_via_numpy1 since it’s the fastest and easy to remember.
  • I observed this level of speedup even for a small number of tensors (1 to 1024) and upto 65k. An exception was in the range of 64 to 256 tensors, where both torch and numpy took about the same time weirdly.
1 Like