Computing loss and gradients for batches of parameters and data using torch.func

Hi all,

Is there a way to efficiently compute the loss and graident for a batch of parameters and a corresponding batch of data in torch.func?

For example, is there a way to do the following psuedocode using vmap anf functional_call or other methods?:

losses = torch.empty(N)
grads = torch.empty(N)

for i in range(N):
    # params is will be inserted into the model/net
    params = all_params[N]

    # get the batch of data to be processed for the loss and gradient 
    batch = all_data[N]

    # replace the parameters of the neural network model with params

    # compute the loss wrt the params and the batch of data 
    loss = compute_loss(params, batch)

    # compute the data wrt the model parameters we just inserted and the loss
    losses[i] = loss

    # save both the gradients and the loss so we have an N size tensors 
    gradients = compute_grad(loss)
    grads[i] = gradients

return loss, gradients