Should create_graph always be True when accumulating gradients in autograd.grad()

I am trying to modify the following implementation of Meta-SGD training to support minibatches through gradient accumulation:

    Y_pred = model(X)
    loss = loss_fn(Y_pred, Y)
    model.zero_grad()

    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)

    adapted_state_dict = model.cloned_state_dict()
    adapted_params = OrderedDict()
    for (key, val), grad in zip(model.named_parameters(), grads):
        task_lr = model.task_lr[key]
        adapted_params[key] = val - task_lr * grad
        adapted_state_dict[key] = adapted_params[key]

    return adapted_state_dict

Every parameter in the model has its own learning rate, and this learning rate needs to be tuned, so a gradient graph is required. To simulate a larger batch size using gradient accumulation, I used the following code, but I think it is wrong because the memory requirements are extremely high:

for xb,yb in data.train_dl:
            ypred = model(xb)
            loss = loss_fn(ypred,yb)
            if not grads:    
                grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
            else:
                grads = [grads[i] + torch.autograd.grad(loss, model.parameters(), create_graph=True)[i] for i in range(len(grads))]

I think this code is creating a separate gradient graph for every minibatch. If I pass create_graph=True for the first call to autograd.grad and only retain_graph=True for the subsequent calls, will it work or is there a better way?

Also, the behaviour of loss.backward() and autograd.grad seems to be different wrt to retain_graph. The backward() function can be called multiple times even without retain_graph=True, but autograd.grad can only be called once. For example, this code works when executed multiple times:

outputs = model(inputs)
loss = loss_fn(outputs,labels)
loss.backward()

But this code works only once:

outputs = model(inputs)
loss = loss_fn(outputs,labels)
grads = torch.autograd.grad(loss,model.parameters())

backward() is also much faster than autograd.grad()

Hi, I also met the same issue. Setting create_graph=True is usually used for calculating second derivative. I found that if I set create_graph=True for each batch, the memory usage will soon expire. This indicates that this will create separate graphs for each batch. If you want to prevent this and only set create_graph=True for the first batch but set retrain_graph=True for the following batches (in this way, create_graph=False), the second derivative will be wrong because it only calculates the second derivative using the first batch of data.