Losing gradients when using auxiliary model to compute loss

Overview

I am trying to train a model that can have a rather complex loss function that does operations using the output of the model. The easiest way to do these operations is to use the internal machinery of a second model to compute the loss. This causes the pytorch to lose the gradients.

Is there a way to have pytorch track the gradients of the model when the loss uses an auxiliary model to do the loss computations?

Minimal Example

I’ve made a MWE using just a basic linear models to illustrate the problem. In the example below, L1 represents the model I am actually interested in training and L2 represents the secondary model I’d like to use to compute the loss. In my actual use case, I am not at all interested in training L2, it’s just there to be used to compute the loss. It would also be cumbersome and presumably computationally slower to reimplement the inner workings of the secondary model.

Is there a way to get pytorch to track the gradients of L1 when using the output in L2?

# data
X1 = torch.tensor([[.5]])
X2 = torch.arange(1, 11.)
y = 2*X2
if len(X2.shape) == 1: # if X is a 1d tensor make 2d
    X2.unsqueeze_(1) # shape is now n by 1

# create a basic model
class Linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.l = skip_init(nn.Linear, 1, 1, bias = False)

    def forward(self, x):
        x = self.l(x)
        return x

#  loss that uses a secondary model to do operations
def wrapped_loss(L, X = X2, y = y):
    Y = y.unsqueeze(1)
    d = (L(X)-Y).pow(2)
    return d.mean()

# loss that does the operations with the output itself
def loss(b, X = X2, y = y):
    Y = y.unsqueeze(1)
    d = (torch.mm(X, b) - Y).pow(2)
    return d.mean()

# initializations 
L1 = Linear() # thing we are interested in updating
L1.l.weight = nn.Parameter(torch.tensor([[4.]]))
L2 = Linear() # auxilary computation, the parameters will be populated from output that depends on L1
opt = torch.optim.SGD(params = L1.parameters(), lr = .1)
epochs = 1

# training
for epoch in range(epochs):   
    intermediate_output = 2*L1(X1)
    
    # put output into the secondary model that will 
    # be used to do operations with the output.
    with torch.no_grad():
        L2.l.weight.copy_(intermediate_output) 

    wrapped_l = wrapped_loss(L2) # this is identical to l = loss(intermediate_output)

    opt.zero_grad()
    wrapped_l.backward() # compare against using l.backward()
    opt.step()

    print(L1.l.weight.grad) # this is currently None but we'd expect it to be the same as when we use loss

1 Like

The tensor is detached since you are copying the intermediate activation into the weight in a no_grad context.
torch.func.functional_call should work as seen here:

# create a basic model
class Linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.l = nn.Linear(1, 1, bias = False)

    def forward(self, x):
        x = self.l(x)
        return x

#  loss that uses a secondary model to do operations
def wrapped_loss(L, intermediate_output, X, y):
    Y = y.unsqueeze(1)
    d = (torch.func.functional_call(L.l, {"weight": intermediate_output}, X)-Y).pow(2)
    return d.mean()


# data
X1 = torch.tensor([[.5]])
X2 = torch.arange(1, 11.)
y = 2*X2
if len(X2.shape) == 1:
    X2.unsqueeze_(1)

# initializations 
L1 = Linear() # thing we are interested in updating
L1.l.weight = nn.Parameter(torch.tensor([[4.]]))
L2 = Linear() # auxilary computation, the parameters will be populated from output that depends on L1
opt = torch.optim.SGD(params = L1.parameters(), lr = .1)
epochs = 5

# training
for epoch in range(epochs):   
    intermediate_output = 2*L1(X1)
    wrapped_l = wrapped_loss(L2, intermediate_output, X2, y)

    opt.zero_grad()
    wrapped_l.backward()
    opt.step()

    print(L1.l.weight.grad)
    
# tensor([[154.]])
# tensor([[-1031.8000]])
# tensor([[6913.0605]])
# tensor([[-46317.5078]])
# tensor([[310327.3125]])

Is there a way to use your solution (calling torch.funct.function_call(L2, params_dict, X2)) over a set of parameters? In my case, instead of one set of parameters {"weight": intermediate_output} the loss depends on computations using the machinery of L2 model instantiated with m different sets of parameters: list_of_params = [{"weight": intermediate_output_1}, ..., {"weight": intermediate_output_m}].

The naive solution is to use a for loop:

losses = torch.tensor([(torch.func.functional_call(L2, param_dict, X2)-Y2).pow(2) for param_dict in list_of_params ])
loss = losses.mean()

I’m guessing it is computationally inefficient, but likely there is some vectorized version for doing this? Like a map reduce?

Thanks! this thread is very helpful!

functional_call accepts multiple directories as seen in the example in the docs so did you try to use this approach?

1 Like

Yes. Here are three approaches/examples and the errors thrown by them, but first, the “background” notation.

M2 is a simple logistic regression model (class nn.Module), and it has the following state dict format:
{ 'linear.weight': <2 by 1 tensor> ,'linear.bias': <1 by 1 tensor> }
In short it has 3 parameters.

Now I’ll make a list of parameter dictionaries. These are from the outputs of a different model M1. I need the M2 model to be evaluated on dataset X for each set of parameters in the list, in order to compute the loss function for training M1 (so gradients of M1 need to be tracked through the computation)

Output of M1 provides torch tensor B which is 2 by 3 (two sets of parameters for M2).
e.g.:

B = torch.tensor([[1.,2,3], [4,5,6] ]) # output of M1

param_dicts = [ { 'linear.weight': B[0,:-1].unsqueeze(0), 'linear.bias': B[0,-1:] },
                { 'linear.weight': B[1,:-1].unsqueeze(0), 'linear.bias': B[1,-1:] } ]

Errant try 1:
Now trying to pass the list to function_call fails:
x = functional_call(M2, param_dicts, X)# this fails with error “ValueError: [‘linear.weight’, ‘linear.bias’] appeared in multiple dictionaries; behavior of functional call is ambiguous”:


Errant try 2:
Trying to use the for loop inside the list works partially–it computes M2(params, X), but it doesn’t track M1’s gradients:

opt = torch.optim.RMSprop( params = M1.parameters(), lr = learning_rate, momentum = momentum)

x = torch.tensor([ ((torch.func.functional_call(M2, d, X) - Y).pow(2)).mean() for d in param_dicts])  # no errors!
loss = x.mean() 

# now step M1's params based on that computation... error:
opt.zero_grad()
loss.backward() # <--- error here
opt.step()
scheduler.step(loss)

Error given is “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”


Errant try 3: use the grad function: torch.func.grad — PyTorch 2.0 documentation to explicitly track the gradients of M1?

def compute_loss(m1_state_dict): 
    global X # training features for M2
    global Y # training targets for M2
    global M1 # first model 
    global M2
    global batch_size
    global dim 
    samples = random_normal_samples(batch_size, dim) # input data for the NF

    B, log_det_T = functional_call(M1,m1_state_dict,samples)   # do i need to use func.function_call here? I want to track M1's gradients,
    param_dicts = [ {	'linear.weight': B[i,:-1].unsqueeze(0), 'linear.bias': B[i,-1:]}\
            for i in range(B.shape[0])]
    l2_losses = torch.tensor([ ((torch.func.functional_call(M2, d, X) - Y).pow(2)).mean()\
                                for d in param_dicts]) # shape = [batchsize]    
    loss = (-l2_losses + log_det_T.squeeze()).mean()
    return loss

m1_state_dict = M1.state_dict()
grad(compute_loss)(m1_state_dict) # < --- error here: 
---
RuntimeError: unwrapped_count > 0 INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/TensorWrapper.cpp":202, please report a bug to PyTorch. Should have at least one dead wrapper

Overall, I need to train M1, with loss that depends on M1’s output and M2 using some of M1’s output as parameters, and I’m stuck. Thanks!

@ptrblck should i create a new question for this? I’m unsure if I should expect a response here or need to start a new thread.