In order to train an ensemble, I am using combine_state_for_ensemble. The code snippet looks as follows:
import torch
from torch import vmap
import functorch
import torchopt
import torch.nn.functional as F
import matplotlib.pyplot as plt
from functorch import combine_state_for_ensemble
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
ensembles = [Net(n_feature=1, n_hidden=10, n_output=1) for _ in range(3)]
model, params, buffers = combine_state_for_ensemble(ensembles)
import ipdb; ipdb.set_trace(context=21)
optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=0.2))
loss_func = torch.nn.MSELoss()
for t in range(200):
pred_mapping = vmap(model, in_dims=(0, 0, None))
pred = pred_mapping(params, buffers, x)
loss = loss_func(pred, y)
if t % 5 == 0:
print(f"The loss in iteration {t} is {loss}")
import ipdb; ipdb.set_trace(context=21)
print(f"First params set is: {params[0]}")
params = optimizer.step(loss, params)
When running the code, I get the error “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”, and while debugging I noticed that the parameters returned by combine_state_for_ensemble have grad None. How to deal with the backpropagation in this case?