I am trying to modify the following implementation of Meta-SGD training to support minibatches through gradient accumulation:
Y_pred = model(X)
loss = loss_fn(Y_pred, Y)
model.zero_grad()
grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
adapted_state_dict = model.cloned_state_dict()
adapted_params = OrderedDict()
for (key, val), grad in zip(model.named_parameters(), grads):
task_lr = model.task_lr[key]
adapted_params[key] = val - task_lr * grad
adapted_state_dict[key] = adapted_params[key]
return adapted_state_dict
Every parameter in the model has its own learning rate, and this learning rate needs to be tuned, so a gradient graph is required. To simulate a larger batch size using gradient accumulation, I used the following code, but I think it is wrong because the memory requirements are extremely high:
for xb,yb in data.train_dl:
ypred = model(xb)
loss = loss_fn(ypred,yb)
if not grads:
grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
else:
grads = [grads[i] + torch.autograd.grad(loss, model.parameters(), create_graph=True)[i] for i in range(len(grads))]
I think this code is creating a separate gradient graph for every minibatch. If I pass create_graph=True for the first call to autograd.grad and only retain_graph=True for the subsequent calls, will it work or is there a better way?
Also, the behaviour of loss.backward() and autograd.grad seems to be different wrt to retain_graph. The backward() function can be called multiple times even without retain_graph=True, but autograd.grad can only be called once. For example, this code works when executed multiple times:
outputs = model(inputs)
loss = loss_fn(outputs,labels)
loss.backward()
But this code works only once:
outputs = model(inputs)
loss = loss_fn(outputs,labels)
grads = torch.autograd.grad(loss,model.parameters())
backward() is also much faster than autograd.grad()