Backpropagation through combine_state_for_ensemble

In order to train an ensemble, I am using combine_state_for_ensemble. The code snippet looks as follows:

import torch
from torch import vmap
import functorch
import torchopt
import torch.nn.functional as F
import matplotlib.pyplot as plt
from functorch import combine_state_for_ensemble
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())

class Net(torch.nn.Module):

    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden) 
        self.predict = torch.nn.Linear(n_hidden, n_output) 

    def forward(self, x):
        x = F.relu(self.hidden(x))     
        x = self.predict(x) 
        return x

ensembles = [Net(n_feature=1, n_hidden=10, n_output=1) for _ in range(3)]
model, params, buffers = combine_state_for_ensemble(ensembles)
import ipdb; ipdb.set_trace(context=21)
optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=0.2)) 
loss_func = torch.nn.MSELoss() 

for t in range(200):
    pred_mapping = vmap(model, in_dims=(0, 0, None))
    pred = pred_mapping(params, buffers, x) 
    loss = loss_func(pred, y)
    if t % 5 == 0:
        print(f"The loss in iteration {t} is {loss}")
        import ipdb; ipdb.set_trace(context=21)
        print(f"First params set is: {params[0]}")
    params = optimizer.step(loss, params)

When running the code, I get the error “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”, and while debugging I noticed that the parameters returned by combine_state_for_ensemble have grad None. How to deal with the backpropagation in this case?

@ptrblck could you plz help in here? :slight_smile:

Generally the way you would debug in these cases is to check your tensors as you forward to see whether they “require grad” or not. If you realize as you go forward that your tensors are no longer requiring grad, you probably did something that autograd did not understand.