Saving tensor with torch.save uses too much memory

Hi all!
Im trying to better manage the training set for my CNN.
The training data is a tensor with shape [54K,2,8,1320,14] (targets are with same shape),and i use batch of 50, mini-batch shape [50,2,8,1320,14], as it enters conv3D layer (2 in channels). since the whole set is about 250 GB (125 GB for each data and targets), which is too big for the RAM to hold, so currently its 5 data-target pair with 10800 samples each, which is 50GB per pair (25GB for data and 25 for inference).
I want to save it in separated files each contains 25 samples each so i’ll be able to load 2 random files for constructing a mini-batch, without overloading my memory.
When i try to save 25 samples with 'torch.save(training_set.data[:25],‘training_data_00.pt’) it gives me a file ‘training_data_00.pt’ of size 25GB (same as 10800 samples) but has the correct shape [25,8,2,1320].
what am I doing wrong? do you have any solution to it?
Thanks

2 Likes

Hi,

selecting a subset of a tensor does not actually create a new tensor in most cases but just looks at a subset of the original one.
When saving, the original tensor is saved.
You can save training_set.data[:25].clone() to save only the part you want as the clone operation will force the creation of a new smaller tensor containing your data.

5 Likes

Seems to work perfectly!
It is still weird that i have loaded the data and still got only a subset of the dataset even though it stored the whole set .
Thanks!!

Thanks for the answer!

I have known that the slice op does not create new tensor, but It is still a weird behavior for me. Could you explain this in more details why PyTorch has the design that saving the original whole tensor? For memory efficiency?

Hey!

This is a side effect of the fact that we support serializing Tensor that share memory. And that memory sharing will be preserved by serialization.
That way, if your model has tied weights or weights that look at the same data (even if partially), then serialization will preserve that property of your model. This is one of the major different between pt serialization and safetensors for example.

1 Like