Keeping only part of model parameters on GPU

Hi everyone,

I have a specific use-case wherein I want to dynamically move model parameters to GPU from CPU, and once the weight updates are done (on the GPU), transfer it back again on the CPU.

To be specific, I am dealing with a classification problem, where the number of classes is huge. Let’s say we have a linear setup, i.e we want to learn a weight matrix of size [num_classes x feature_size]. The problem is that we can’t afford to keep this matrix completely on the GPU. Hence, what I wanted is to initialize everything on CPU and keep loading parts of the matrix onto GPU and update correspondingly.

With this ideology, I’ll have to declare the weight parameters on CPU (along with their optimizers) and when I want to load part of the matrix on GPU, I’ll have to call .cuda(), which, as the docs suggest (LINK) needs to be called before optimizer construction. Anyways, on trying this approach, as expected the model weights don’t update.

Any help would be greatly appreciated.

Thanks,
Noveen Sachdeva

2 Likes

Hi,

Indeed, you will need to keep the nn.Parameter on the CPU and only perform computation on the GPU in a differentiable manner. That way, it will automatically accumulate back onto the CPU parameters.
Something like:

full_params = torch.rand(num_classes, feature_size, device="cpu")

loss = 0.
for slice in slices:
  param_subset = full_params[slice]
  param_subset = param_subset.cuda()
  # Compute that part of the loss
  loss += local_loss
loss.backward()

Does that match what you want?

Thanks for the prompt response, @albanD.

As for your solution, I have some concerns. Namely:

  • Wouldn’t doing loss += local_loss keep storing the computation graph for each slice? Isn’t this is infeasible because it would store all the intermediary variables/nodes? Just a side note, in my use-case, mathematically I can afford to do a backward call every time I compute the loss function for a slice. So I don’t need to keep adding the loss and can call .backward() for each slice.
  • As for doing param_subset = param_subset.cuda(), you load the slice on the GPU, but never take it off/to the CPU. Should it matter? Does PyTorch automatically delete GPU variables once they go out of scope?
  • Also you didn’t talk about optimizers here. The main problem as per my experiments lies in the opimizer. As per the docs (Linked again), I need to keep the parameters that are being optimized on the same hardware while declaring the optimizer class and calling optimizer.step(). I have attached the main part of my code for reference.
weight_matrices = []; bias_matrices = []; optimizers = []
for slice in slices:
    weight_matrices.append(
        nn.Parameter(
            torch.FloatTensor(
                num_features, slice_classes
            ).uniform_()
        )
    )
    bias_matrices.append(
        nn.Parameter(
            torch.FloatTensor(
                1, slice_classes
            ).uniform_()
        )
    )
    optimizers.append(
        torch.optim.Adam(
            [ weight_matrices[-1], bias_matrices[-1] ], 
            lr = 0.01
        )
    )

# Notice I've initialized everything on the CPU

for x, y in train_data:
    # x (some feature vector batch), y (indices batch) already on GPU
    for slice in slices:
        w = weight_matrices[slice]
        bias = bias_matrices[slice]
        optimizer = optimizers[slice]

        # Transferring parameters to GPU
        w = w.cuda()
        bias = bias.cuda()

        # Forward Pass
        scores = (w_t @ x) + bias

        # Backward Pass
        # Some pointwise loss
        # Independent of scores for classes outside the slice
        loss = slice_loss(scores, y) 
        loss.backward()
        optimizer.step()

        # Do I need to manually call .cpu() on w and bias?

Running this code doesn’t update the weight and bias parameters. I feel this could be because of two potential problems:

  • Doing w = w.cuda() and bias = bias.cuda() creates two non-leaf variables which doesn’t pass the gradients, and hence, doesn’t update w and bias. (See LINK for existing discussion thread)
  • I’ve initialized w and bias (for all slices) on CPU and I’m loading them one-by-one on GPU. Should this be a problem for the optimizer?

Apologies for the detailed question. Thank you for your time, appreciate it!

-Noveen.

1 Like

Does PyTorch automatically delete GPU variables once they go out of scope?

Yes ! You would get a lot of memory leaks otherwise :smiley:

nn.Parameter

Note that nn.Parameter are only important if your work inside a nn.Module and you want them to be found by mod.parameters(). Otherwise, there is no need to have them.

Here is an updated example with optimizer:

full_params = torch.rand(num_classes, feature_size, requires_grad=True, device="cpu")
opt = optim.SGD([full_params,], lr=1)

def acc_grads(full_params, slice):
  param_subset = full_params[slice]
  param_subset = param_subset.cuda()
  # Compute that part of the loss
  local_loss = your_fn(param_subset)
  # Accumulate into full_params.grad
  local_loss.backward()
  # Use .item() to drop gradient history and get a python number
  return local_loss.item()

loss = 0.
# reset the gradients
opt.zero_grad()
for slice in slices:
  # Nothing outside `acc_grad` is on GPU. So the all the GPU memory will be freed
  # after it returns (use torch.cuda.memory_allocated() to see the used memory)
  partial_loss = acc_grad(full_params, slice)
  loss += partial_loss
opt.step()
print("Did a step for a total loss of {}".format(loss))

Note that with this, the gradients will only live on the CPU and the optimizer will work with CPU Tensors.
Also because we send the Tensor to cuda after slicing it, during the backward, the full gradient Tensor will only exist on the CPU.

Does that match what you want better?

1 Like