Memory issues with gradients in torch.func.grad

Hi all,

I am currently having some memory issues and I think it is due to my use of torch.func.grad and torch.func.vmap. The code currently looks like this:

models = []
with torch.no_grad(): 
    for i in range(len(x)):
        new_params = x[i]
        vector_to_parameters(new_params, model.parameters())
        model_copy = copy.deepcopy(model)
        models.append(model_copy)
        del model_copy 
    params, buffers = stack_module_state(models)
    base_model = copy.deepcopy(models[0])
del models

def fmodel(params, buffers, mini_batches):
    # return functional_call(base_model, (params, buffers), (x, ))
    return functional_call(base_model, params, mini_batches)

def compute_loss(params, buffer, inputs, labels):
    outputs = fmodel(params, buffers, inputs)
    return -crit(outputs, labels)

batch_indices = torch.tensor(batch_indices)
mini_batches = X[batch_indices]
labels = Y[batch_indices]

ll = vmap(compute_loss)(params, buffers, mini_batches, labels)

ll_grads = vmap(grad(compute_loss))(params, buffers, mini_batches, labels)
concat_grads = torch.cat([torch.tensor(value).view(len(x), -1) for value in ll_grads.values()], dim=1)
return concat_grads

After I calculate the gradients, I just need the value of the gradients. Not to store the computation graph. I have tried wrapping functions in with torch.no_grad() but this causes my gradientsa values to go to 0 when calculating them.

if anyone has any idea how to fix this, it would be much appeciated!

Kind regards

I managed to fix these memory issues by changing the following:

models = []
with torch.no_grad(): 
    for i in range(len(x)):
        new_params = x[i]
        vector_to_parameters(new_params, model.parameters())
        model_copy = copy.deepcopy(model)
        models.append(model_copy)
        del model_copy 
    params, buffers = stack_module_state(models)
    base_model = copy.deepcopy(models[0])

We need to use requires_grad_ on the parameters and set it to False. This fixed my memory issues.

models = []
with torch.no_grad(): 
    for i in range(len(x)):
        new_params = x[i]
        vector_to_parameters(new_params, self.model.parameters())
        model_copy = copy.deepcopy(self.model)
        models.append(model_copy)
        del model_copy 
    params, buffers = stack_module_state(models)
    for param in params.values():
        param.requires_grad_(False)
    base_model = copy.deepcopy(models[0])

This still needs to be cleaned up a lot