List to group results in forward method

Hi all!

I have a quick “validation” question. Given n parallel model heads, which do not share parameters and should be trained on different data, is it ok to append the outputs in a list and then call torch.cat on the list?
Here’s few lines of code to better illustrate the problem:

import torch
import torch.nn as nn

class DenseNet(nn.Module):
    def __init__(self):
        super(DenseNet, self).__init__()
        self.dense_list = nn.ModuleList([nn.Linear(16, 16) for i in range(3)])

    def forward(self, x):
        tcn_out = [] # torch.empty(0, self.num_channel_out*self.seq_lenght)
        for i, dense_i in enumerate(self.dense_list):
            o = dense_i(x[:,i,:])
            tcn_out.append(o)
        return torch.cat(tcn_out, dim=1)

input = torch.rand((32, 3, 16))

dense_net = DenseNet()
out = dense_net(input)
print(out.shape)

My concern is possible confusions during the back propagation due to the mixup of tensor and list.

Thanks for your time!

torch.cat should work and will not confuse the tensors during the backward pass as seen here:

modules = [nn.Linear(1, 1, bias=False) for _ in range(3)]
modules[0].weight.requires_grad = False
modules[2].weight.requires_grad = False

x = torch.randn(1, 3, 1)

out = []
for idx in range(len(modules)):
    x_ = x[:, idx]
    o = modules[idx](x_)
    out.append(o)
out = torch.cat(out, dim=1)
print(out.shape)
# torch.Size([1, 3])
out.mean().backward()

for module in modules:
    print(module.weight.grad)
# None
# tensor([[-0.0284]])
# None
1 Like