Trainable & non-trainable parameters in custom model

Hello!

I want to create a custom model using Pytorch, where I need to multiply inputs with a matrix containing trainable and non-trainable parameters (I’m looking to implement a trainable Kalman-filter, with free and fixed parameters). Furthermore, such matrix has the same parameter in more than one entry.

However I am struggling, (maybe too much!) in training this… any workaround?

class CustomModel(torch.nn.Module):
    def __init__(self,w0):
        super(CustomModel, self).__init__()
        self.w = torch.nn.Parameter(data = torch.tensor([w0], dtype=torch.float32, requires_grad=True))
    
        #self.matrix = torch.tensor(data = [[self.w, -1.],[-self.w, -1.]], dtype=torch.float32, requires_grad=True    This computes \partial_matrix(COST) --> BAD

        self.matrix_trainable = self.w*torch.tensor(data=[[0,1],[-1,0]], dtype=torch.float32,requires_grad=False)
        self.matrix = self.matrix_trainable - torch.eye(2)
    
    def forward(self, x):
        return self.matrix.matmul(x)

def loss(pred,y):
    return torch.mean((pred- y)**2)


my_model = CustomModel(w0=0.01)
optimizer = torch.optim.Adam(lr=0.01, params=my_model.parameters())

device = torch.device("cpu")
x = torch.ones(2).to(device)
y = torch.tensor(data=[2.,0.], dtype=torch.float32).to(device)


for k in range(10):

    optimizer.zero_grad()
    my_model.zero_grad()
    pred = my_model(x)
    cost = loss(pred,y)
    cost.backward()
    optimizer.step()

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

torch version 2.0.1

Thanks a lot!!

Matías

You are creating self.matrix_trainable in a differentiable way in __init__, which will cause the issue since the computation graph will stay alive.
Create matrix_trainable in the forward for each new input and it should work:

    def forward(self, x):
        matrix_trainable = self.w*torch.tensor(data=[[0,1],[-1,0]], dtype=torch.float32,requires_grad=False)
        matrix = matrix_trainable - torch.eye(2)
        return matrix.matmul(x)
1 Like

thanks a lot!
is this the most efficient solution available in pytorch ?