I am trying to train a model that can have a rather complex loss function that does operations using the output of the model. The easiest way to do these operations is to use the internal machinery of a second model to compute the loss. This causes the pytorch to lose the gradients.
Is there a way to have pytorch track the gradients of the model when the loss uses an auxiliary model to do the loss computations?
I’ve made a MWE using just a basic linear models to illustrate the problem. In the example below, L1 represents the model I am actually interested in training and L2 represents the secondary model I’d like to use to compute the loss. In my actual use case, I am not at all interested in training L2, it’s just there to be used to compute the loss. It would also be cumbersome and presumably computationally slower to reimplement the inner workings of the secondary model.
Is there a way to get pytorch to track the gradients of L1 when using the output in L2?
# data X1 = torch.tensor([[.5]]) X2 = torch.arange(1, 11.) y = 2*X2 if len(X2.shape) == 1: # if X is a 1d tensor make 2d X2.unsqueeze_(1) # shape is now n by 1 # create a basic model class Linear(nn.Module): def __init__(self): super().__init__() self.l = skip_init(nn.Linear, 1, 1, bias = False) def forward(self, x): x = self.l(x) return x # loss that uses a secondary model to do operations def wrapped_loss(L, X = X2, y = y): Y = y.unsqueeze(1) d = (L(X)-Y).pow(2) return d.mean() # loss that does the operations with the output itself def loss(b, X = X2, y = y): Y = y.unsqueeze(1) d = (torch.mm(X, b) - Y).pow(2) return d.mean() # initializations L1 = Linear() # thing we are interested in updating L1.l.weight = nn.Parameter(torch.tensor([[4.]])) L2 = Linear() # auxilary computation, the parameters will be populated from output that depends on L1 opt = torch.optim.SGD(params = L1.parameters(), lr = .1) epochs = 1 # training for epoch in range(epochs): intermediate_output = 2*L1(X1) # put output into the secondary model that will # be used to do operations with the output. with torch.no_grad(): L2.l.weight.copy_(intermediate_output) wrapped_l = wrapped_loss(L2) # this is identical to l = loss(intermediate_output) opt.zero_grad() wrapped_l.backward() # compare against using l.backward() opt.step() print(L1.l.weight.grad) # this is currently None but we'd expect it to be the same as when we use loss