Dynamic shapes and PyTorch

Hello,

I’m curious about how training is done on datasets with dynamic shapes. For instance, in point cloud segmentation tasks you might have different sizes for each point cloud.

I’m immediately confused on how people handle these shapes, as even stacking two tensors of unequal size results in an error. The Dataloader class similarly expects that each item in a batch will have an identical shape - so you cannot even load such a dataloader.

From looking online, people suggest a workaround such as using collate_fn and padding the points, but this seems like a massive issue if your dataset has large deviations. For instance, what if one pointcloud has 1,000 vertices and another has 1 million? What if you have a batch size of 256, with all having ~1,000 points and the final one having 1 million? Suddenly, you need to clear space and compute gradients with 256 million vertices, not to mention the additional extracted features.

Another ‘solution’ would be to use a batch size of 1, but this results in other similar issues. Perhaps using a batch size of 1 and accumulating gradients is the solution? This still results in significantly reduced performance.

The main question: I would like to know how one generally deals with a dataset which has varying, dynamic shapes to classify. Any information would be great.

Thank you :slight_smile:

1 Like

You might be interested in NestedTensors How to apply vmap on a heterogeneous tensor - #2 by soulitzer which supports representing a tensor with varying shapes. (e.g., you can have your collate_fn create a nested tensor). The nested tensor behaves as if you had a batch of padded data, but without requiring any extra memory or compute.

Hi Mark!

The approach I generally take is to group together samples of approximately the same
size and then pad the smaller samples up to the size of the largest sample in that
group. (If padding somehow confuses or breaks your network, then you would have
to group together samples of equal size.) It will depend on the nature of your dataset,
but for many use cases it is possible to know the sizes of your samples in advance
of batching.

Best.

K. Frank