Keeping only part of model parameters on GPU

Thanks for the prompt response, @albanD.

As for your solution, I have some concerns. Namely:

  • Wouldn’t doing loss += local_loss keep storing the computation graph for each slice? Isn’t this is infeasible because it would store all the intermediary variables/nodes? Just a side note, in my use-case, mathematically I can afford to do a backward call every time I compute the loss function for a slice. So I don’t need to keep adding the loss and can call .backward() for each slice.
  • As for doing param_subset = param_subset.cuda(), you load the slice on the GPU, but never take it off/to the CPU. Should it matter? Does PyTorch automatically delete GPU variables once they go out of scope?
  • Also you didn’t talk about optimizers here. The main problem as per my experiments lies in the opimizer. As per the docs (Linked again), I need to keep the parameters that are being optimized on the same hardware while declaring the optimizer class and calling optimizer.step(). I have attached the main part of my code for reference.
weight_matrices = []; bias_matrices = []; optimizers = []
for slice in slices:
    weight_matrices.append(
        nn.Parameter(
            torch.FloatTensor(
                num_features, slice_classes
            ).uniform_()
        )
    )
    bias_matrices.append(
        nn.Parameter(
            torch.FloatTensor(
                1, slice_classes
            ).uniform_()
        )
    )
    optimizers.append(
        torch.optim.Adam(
            [ weight_matrices[-1], bias_matrices[-1] ], 
            lr = 0.01
        )
    )

# Notice I've initialized everything on the CPU

for x, y in train_data:
    # x (some feature vector batch), y (indices batch) already on GPU
    for slice in slices:
        w = weight_matrices[slice]
        bias = bias_matrices[slice]
        optimizer = optimizers[slice]

        # Transferring parameters to GPU
        w = w.cuda()
        bias = bias.cuda()

        # Forward Pass
        scores = (w_t @ x) + bias

        # Backward Pass
        # Some pointwise loss
        # Independent of scores for classes outside the slice
        loss = slice_loss(scores, y) 
        loss.backward()
        optimizer.step()

        # Do I need to manually call .cpu() on w and bias?

Running this code doesn’t update the weight and bias parameters. I feel this could be because of two potential problems:

  • Doing w = w.cuda() and bias = bias.cuda() creates two non-leaf variables which doesn’t pass the gradients, and hence, doesn’t update w and bias. (See LINK for existing discussion thread)
  • I’ve initialized w and bias (for all slices) on CPU and I’m loading them one-by-one on GPU. Should this be a problem for the optimizer?

Apologies for the detailed question. Thank you for your time, appreciate it!

-Noveen.

1 Like