Hi all,
Is there a way to efficiently compute the loss and graident for a batch of parameters and a corresponding batch of data in torch.func?
For example, is there a way to do the following psuedocode using vmap anf functional_call or other methods?:
losses = torch.empty(N)
grads = torch.empty(N)
for i in range(N):
# params is will be inserted into the model/net
params = all_params[N]
# get the batch of data to be processed for the loss and gradient
batch = all_data[N]
# replace the parameters of the neural network model with params
# compute the loss wrt the params and the batch of data
loss = compute_loss(params, batch)
# compute the data wrt the model parameters we just inserted and the loss
losses[i] = loss
# save both the gradients and the loss so we have an N size tensors
gradients = compute_grad(loss)
grads[i] = gradients
return loss, gradients