PyTorch nn.Embedding takes too much memory

Hello, I’m implementing a deep neural network and for that I use torch Embedding in order to encode the input data. The problem is that the input id’s go up to 120 000 and when I ran the model, it requires me a lot of GB RAM. Do you have any idea how can i solve this problem?

How are you loading your data? Is the entirety of the input data loaded into memory ahead of time?

If so, you can consider loading them lazily or use Iterable-style Dataset or IterDataPipe from TorchData.

1 Like

It seems you are concerned about the size of this particular embedding layer.
I don’t know how large the embedding_dim is, but for num_embeddings=120000 and embedding_dim=1024, you would only use ~470MB:

# 0.0

emb = nn.Embedding(120000, 1024).cuda()
# 468.75
1 Like

Thank you guys for your answers! I figured out where the problem was located: in some place a defined 120 000 *10 000 embedding layer - that generated the OOM error