Replace multiple linear networks with convolutional layers

Given

import torch
import torch.nn as nn

nof_samples = 1
nof_features = 2
input = torch.rand(nof_samples, nof_features)

I have two feature networks (one for each feature), which look like this:

feature_networks = nn.ModuleList()
for _ in range(nof_features):
    feature_networks.append(
        nn.Sequential(
            nn.Linear(in_features=1, out_features=3),
            nn.Linear(in_features=3, out_features=3),
            nn.Linear(in_features=3, out_features=1)
        ))  

output = torch.zeros(nof_samples, nof_features)
for feature_index in range(nof_features):  # Iterate over features
    output[:, feature_index] = feature_networks[feature_index](input[:, feature_index])

output # shape of [1, 2]
# nof trainable params: 44

I would like to replace it with sth like the following to get rid of the (I believe) time-consuming for-loop:

conv1 = nn.Conv1d(in_channels=nof_features, out_channels=nof_features*3, kernel_size=1, groups=nof_features)
conv2 = nn.Conv1d(in_channels=nof_features, out_channels=nof_features*3, kernel_size=3, groups=nof_features)
conv3 = nn.Conv1d(in_channels=nof_features, out_channels=nof_features, kernel_size=3, groups=nof_features)

outputv2 = conv3(conv2(conv1(input.unsqueeze(2)).view(nof_samples, nof_features, 3)).view(nof_samples, nof_features, 3)).squeeze(2)
outputv2
# shape of [1, 2, 1]
# nof trainable params: 44

Are outputs and outputsv2 the same (assuming the weights are)? Or do I have to do it differently?

Yes, your approach seems to be valid based on this quick check by loading the trainable parameters:

import torch
import torch.nn as nn

nof_samples = 1
nof_features = 2
input = torch.rand(nof_samples, nof_features)

bias = True

feature_networks = nn.ModuleList()
for _ in range(nof_features):
    feature_networks.append(
        nn.Sequential(
            nn.Linear(in_features=1, out_features=3, bias=bias),
            nn.Linear(in_features=3, out_features=3, bias=bias),
            nn.Linear(in_features=3, out_features=1, bias=bias)
        ))  

output = torch.zeros(nof_samples, nof_features)
for feature_index in range(nof_features):  # Iterate over features
    output[:, feature_index] = feature_networks[feature_index](input[:, feature_index])


conv1 = nn.Conv1d(in_channels=nof_features, out_channels=nof_features*3, kernel_size=1, groups=nof_features, bias=bias)
conv2 = nn.Conv1d(in_channels=nof_features, out_channels=nof_features*3, kernel_size=3, groups=nof_features, bias=bias)
conv3 = nn.Conv1d(in_channels=nof_features, out_channels=nof_features, kernel_size=3, groups=nof_features, bias=bias)

with torch.no_grad():
    print(conv1.weight.shape)
    weight = torch.cat((feature_networks[0][0].weight, feature_networks[1][0].weight))
    print(weight.shape, conv1.weight.shape)
    conv1.weight.copy_(weight.unsqueeze(1))
    
    if hasattr(conv1, "bias"):
        bias = torch.cat(((feature_networks[0][0].bias, feature_networks[1][0].bias)))
        print(bias.shape, conv1.bias.shape)
        conv1.bias.copy_(bias)
        
    weight = torch.cat((feature_networks[0][1].weight, feature_networks[1][1].weight))
    print(weight.shape, conv2.weight.shape)
    conv2.weight.copy_(weight.unsqueeze(1))
    
    if hasattr(conv2, "bias"):
        bias = torch.cat(((feature_networks[0][1].bias, feature_networks[1][1].bias)))
        print(bias.shape, conv2.bias.shape)
        conv2.bias.copy_(bias)
        
    weight = torch.cat((feature_networks[0][2].weight, feature_networks[1][2].weight))
    print(weight.shape, conv3.weight.shape)
    conv3.weight.copy_(weight.unsqueeze(1))
    
    if hasattr(conv3, "bias"):
        bias = torch.cat(((feature_networks[0][2].bias, feature_networks[1][2].bias)))
        print(bias.shape, conv3.bias.shape)
        conv3.bias.copy_(bias)


outputv2 = conv3(conv2(conv1(input.unsqueeze(2)).view(nof_samples, nof_features, 3)).view(nof_samples, nof_features, 3)).squeeze(2)


print(output - outputv2)
# tensor([[ 0.0000e+00, -2.9802e-08]], grad_fn=<SubBackward0>)

I haven’t checked the reshape operations in detail but given the numerical mismatch is small, I assume it’s valid.

Thank you for your help!

Next, I will have to find out whether it makes sense to replace the linear layers in a neural additive model with convolutional layers like proposed :smile: A forward pass is definitely faster, but I need more epochs - I believe it is due to the backward propagation that is not as specific for the feature networks.