Problem with back-propagation when using the concatenation of parameters from a set of separate models

Hi,

I would like to concatenate the parameters of a set of models and forward them through a network (in order to get an output, calculate a loss, and back-propagate), but it seems that the graph is “broken”(?) – even though it doesn’t raise any errors, training is not being conducted, parameters are not updated.

import torch
import torch.nn as nn


# A simple model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()        
        self.params = nn.Parameter(data=torch.randn(18, 512))
        
    def forward(self):
        return self.params


# A list of N Model objects
N = 10
device = 'cuda'
models = [Model().to(device) for _ in range(N)]


# I need to forward the parameters of all N models from a subsequent network, calculate a loss and back-propagate
params = torch.cat([m().unsqueeze(0) for m in models], dim=0)  # torch.Size([10, 18, 512])

y = some_network(params)

# Calculate the loss
loss = some_loss(y)

# Back-propagate
loss.backward()

Obviously, I could define the parameters as

params = torch.Parameter(data=torch.randn(N, 18, 512))

and do the rest without issues, but I have reasons to need each (18, 512)-dimensional parameters as the parameters of a separate model. Any insight on how this could be done?

Thank you!

Your code seems to work for me:

# A simple model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()        
        self.params = nn.Parameter(data=torch.randn(18, 512))
        
    def forward(self):
        return self.params


class some_network(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params = params
        
    def forward(self, x):
        x = x * self.params
        return x

# A list of N Model objects
N = 10
device = 'cuda'
models = [Model().to(device) for _ in range(N)]


# I need to forward the parameters of all N models from a subsequent network, calculate a loss and back-propagate
params = torch.cat([m().unsqueeze(0) for m in models], dim=0)  # torch.Size([10, 18, 512])


model = some_network(params)
x = torch.randn(1, 18, 512, device="cuda")
out= model(x)
loss = out.mean()

# Back-propagate
loss.backward()

print([m.params.grad.abs().sum() for m in models])
# [tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0'), tensor(0.0812, device='cuda:0')]

Note that I’ve implemented the missing parts from your code as your posted code snippet is incomplete and not executable.

@ptrblck thanks for the reply! Indeed I had some parts incomplete, but my problem is not any errors during backprop, but rather that it doesn’t work as expected. It’s difficult to give a precise implementation here, but the missing part includes a GAN generator and CLIP’s ViT.

So, you don’t find it wrong fundamentally, right? I mean, to work with the concatenation of parameters taken from multiple different model objects. Thanks again for your time.

In your current code I don’t think wrapping the parameters in Model is necessary, but I don’t see any issues with it.
Concatenating parameters should also be fine since it’s a differentiable operation and would properly allow Autograd to backpropagate the gradients to the original parameters as seen in my example.

@ptrblck I see, many thanks!