Why can't I use embedding table to get around large GPU memory requirement?

Suppose I have data that requires a large amount of GPU memory (e.g. 80,000 7 x 7 x 1024 tensors). I was hoping that I can get around this if I use a fixed size embedding table (lets assume its already learned somehow). i.e., if I use an embedding table of size 100, each token is 1024-dim, then my understanding is that all I need to do now is to fit an 100 x 1024 tensor + 80,000 7 x 7 indices to look-up onto memory instead (since all embeddings of the same index should just be represented as views of its corresponding entry in the look-up table).
However, I find that this is not the case - i.e., when I call nn.Emebedding on the indices, the memory usage jumps sharpely as if I am actually storing a 80,000 x 7 x 7 x 1024 tensor, rather then just views of the same underlying entries in the look-up table. What am I missing?


import torch
import torch.nn as nn
num_tokens = 100
codebook_dim = 1024
h = w = 7
batch_size = 80000

# embedding table
codebook = nn.Embedding(num_tokens, codebook_dim).to(device)
codebook.weight.requires_grad = False
indices = torch.randint(num_tokens, (bag_size, h, w)).to(device) # indices 
indices = codebook(indices.view(-1, h*w)) # after look-up, this now uses a huge amount of memory