Hi all,

I am currently having some memory issues and I think it is due to my use of torch.func.grad and torch.func.vmap. The code currently looks like this:

```
models = []
with torch.no_grad():
for i in range(len(x)):
new_params = x[i]
vector_to_parameters(new_params, model.parameters())
model_copy = copy.deepcopy(model)
models.append(model_copy)
del model_copy
params, buffers = stack_module_state(models)
base_model = copy.deepcopy(models[0])
del models
def fmodel(params, buffers, mini_batches):
# return functional_call(base_model, (params, buffers), (x, ))
return functional_call(base_model, params, mini_batches)
def compute_loss(params, buffer, inputs, labels):
outputs = fmodel(params, buffers, inputs)
return -crit(outputs, labels)
batch_indices = torch.tensor(batch_indices)
mini_batches = X[batch_indices]
labels = Y[batch_indices]
ll = vmap(compute_loss)(params, buffers, mini_batches, labels)
ll_grads = vmap(grad(compute_loss))(params, buffers, mini_batches, labels)
concat_grads = torch.cat([torch.tensor(value).view(len(x), -1) for value in ll_grads.values()], dim=1)
return concat_grads
```

After I calculate the gradients, I just need the value of the gradients. Not to store the computation graph. I have tried wrapping functions in with torch.no_grad() but this causes my gradientsa values to go to 0 when calculating them.

if anyone has any idea how to fix this, it would be much appeciated!

Kind regards