Most efficient way to store and load training embeddings that don't fit in GPU memory

I’m training a collaborative filtering model where embeddings for users and products are learned. I have too many users to fit the entire embedding matrix into GPU memory. What is the best way to represent the embedding matrices in terms of training speed? I guess the naive solution would be to store the embeddings in CPU RAM, then move each batch onto the GPU, do the update, and write them into CPU RAM again. Is there a more efficient method?

1 Like

If you have multiple GPUs, you might split the embedding matrix onto different devices.
If that’s not the case, then your approach of keeping the embedding layer on the CPU might be your best option.

Thanks for the advice @ptrblck. I have a couple of questions in that case,

  1. What’s the running time of updating a batch of embedding vectors within my embedding matrix at the end of each training batch? Is it O(batch_size) or O(matrix_size)?
  2. Would it be faster to load a “chunk” (maybe 100 batches) of embeddings into gpu in each Dataloader iteration and do a for-loop over batches in a chunk for training? (Assuming you’re ok with some embeddings being stale in the chunk)
  1. Embedding is similar to a lookup table, so I would assume the more input indices you are passing, the longer it’ll take.

  2. I don’t understand the question. Would you like to load chunks of input indices or the embedding weight matrix in the DataLoader?

For 2. I meant leave the embedding matrix in CPU but load as many rows/embedding vetors as you can fit into GPU memory in a chunk.

I just realized, are optimizers like Adam that require storing a history of gradient updates not feasible in this case since the embedding matrix is stored in CPU? It seems like the only optimizer I could use here is SGD (without momentum)?

You could use two optimizers, one for the CPU parameters and the other one for the GPU parameters.
The internal states of Adam should be parameter-dependent and thus wouldn’t create a bad interaction, would they?

Is there an example of what this looks like? I’m having a hard time understanding how to allow the gradients to flow from GPU to CPU.

Autograd will make sure to track all to(), cuda(), and cpu() operations.

Here is a small example:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.emb = nn.Embedding(10, 100)
        self.fc = nn.Linear(100, 2)
        
    def forward(self, x):
        x = self.emb(x)
        x = x.to(self.fc.weight.device)
        x = self.fc(x)
        return x

model = MyModel()
model.fc.cuda()

x = torch.randint(0, 10, (20,))
out = model(x)
print(out.device)
> cuda:0
out.mean().backward()

for name, param in model.named_parameters():
    print('{}, param.grad.abs().sum() {}, param.device {}, grad.device {}'.format(
        name, param.grad.abs().sum(), param.device, param.grad.device))

> emb.weight, param.grad.abs().sum() 3.2752888202667236, param.device cpu, grad.device cpu
fc.weight, param.grad.abs().sum() 31.677473068237305, param.device cuda:0, grad.device cuda:0
fc.bias, param.grad.abs().sum() 1.0, param.device cuda:0, grad.device cuda:0
1 Like

Thanks @ptrblck I’m going to try this today.

Seeing I need to compute gradients for the nn.Embedding on CPU, I’m worried about the runtime.

I’m guessing I should be using SparseAdam with the CPU embeddings
matrix? Would this provide a speedup versus regular Adam? I think theoretically SparseAdam would not need to compute the gradient on the full nn.Embedding matrix, so it should be size_of_matrix/batch_size faster than Adam right (minus some overhead of sparse computation)?

I don’t know how large the performance difference between SparseAdam and Adam for which workload is, unfortunately. :confused:

Let us know, how your experiments go and if you are seeing a major bottleneck in the embedding layer.

@ptrblck I benchmarked the performance and got the following results on my dataset:

CPU embeddings, SparseAdam

  • 15 iterations/second

CPU embeddings, Adam

  • 5 seconds/iteration

GPU embeddings, SparseAdam

  • 90 iterations/second

Using CPU embeddings with regular Adam is a 450x slow down. The model becomes untrainable at this point. Using CPU embeddings with SparseAdam is a 6x slow down. While this is not good, it is manageable if it allows us to train a much larger embedding size.

Unfortunately the speed is dependent on the embedding size. If I 4x my embedding size, time/iteration doubles. So I get ~7.5 iterations/second with CPU embeddings + SparseAdam when I 4x embedding size. This limits the usefulness of this solution.

In conclusion, I think there’s a sweet spot where this solution is helpful if you can’t fit your embedding size in GPU memory but the embedding size is not so big that it makes the CPU training unbearable. We’ll keep testing to see if the slowdown is bearable for our particular application.

One interesting to note for anyone who stumbles upon this post, it’s 1.5x faster (in iterations/second) to directly promote the embedding vectors to GPU than wait. E.g. for neural collaborative filtering:

u = u.to(device)
i = i.to(device)
h = torch.cat([u, i], dim = -1)

is faster than:

u = ...
i = ...

h = torch.cat([u, i], dim = -1)
h = h.to(device)
5 Likes