Depthwise 1D convolution with shared filter

I want to use Conv1D depthwise on the text and the filters should be the same with each other. Is there a way we can do this?

Could you explain the use case a bit more? Repeating a kernel would yield the same outputs, so I’m not sure if I understand the question correctly.

Hi, @ptrblck! Thanks for interested in this question.

I’m doing a multi-label classification task, and the label space is about 8900. The classifier needs to make predictions about what labels the input text corresponds to (generally, an input text might correspond to 5~10 labels).

As described in this paper: https://arxiv.org/pdf/1909.11386.pdf, I’m gonna do a per-label mask of the input embedding [B, L, D] (batch size, input length, embed dimension), after the mask, the embedding would become a 4-D tensor [B, L, T, D]. Then I will do convolution.

The original paper suggests that all embedding share the same convolution layer, which means all label embedding should be convolved by the same weights. For simplicity, we could stack the 4-D tensor at the embedding dimension, then it has the shape [B, L, T*D], which is suitable for depthwise convolution.

However, if we directly use 1-D convolution, there will be one unique filter for each label embedding, and there will be 8900 different filters in total, which can be a disaster for GPU memory. I’m wondering if there is a method to make the filters share the same parameters.

Is your input shape [B, L, T*D] corresponding to a channels-last memory format, i.e. would T*D represent the channels? If so, you would have to permute the data, but each kernel would still use all input channels in the default layout.
Could you post the input shape and desired output shapes (with the description what the temporal and channel dimensions would be), please?

Yes, T*D represents channels, so the code should be like conv1d([B, T*D, L]). While initializing the convolution layer, it would be layer = nn.Conv1d(T*D,filter_maps*T,kernel_size,groups=T)

This setup:

T, D, filter_maps = 2, 3, 4
kernel_size=5
layer = nn.Conv1d(T*D,filter_maps*T,kernel_size,groups=T)
print(layer.weight.shape)
> torch.Size([8, 3, 5])

would use 8 filters (defined by filter_maps*T) where each will use 3 input channels (defined by T*D and the groups). Would you explain a bit more which filters should now be shared?
Would you like to use a single output filter only?

1 Like

The code you wrote above is exactly what I mean. Sorry for not explaining clearly.
Well, as you can see, we set groups=T, then there would be T kernels (filters) in total. What I want is, all the filters should be initialized with the same weight (share the weights), and in gradient descent, they could be updated to the same value.

You could certainly initialize all 8 filters to the same value either by directly set the values:

T, D, filter_maps = 2, 3, 4
kernel_size=5
layer = nn.Conv1d(T*D,filter_maps*T,kernel_size,groups=T)

with torch.no_grad():
    ref = layer.weight[0:1]
    layer.weight.copy_(ref.repeat(8, 1, 1))

print(layer.weight)

or by using the functional API and stacking the filters.

Unfortunately that won’t work directly, since you are using 2 groups. Even though all filters are equal, they would use different input channels and would thus also create different outputs and gradients:

x = torch.randn(2, 6, 24)
out = layer(x)
out.mean().backward()
print(layer.weight.grad)

Hey @ptrblck bro, I came up with an idea where we can initialize the weights and biases in different groups with the same value and update them during training with the same value as well.

class Model(nn.Module):
    def __init__(self):
       # some other codes

        # initialize the weights
        np.random.seed(123)

        a = np.random.randn(self.num_filter_maps,self.embed_size,kernel_size)
        a = torch.from_numpy(a).type(torch.FloatTensor)
        self.cnn_weight = nn.Parameter(a)
        
        b = np.random.randn(64)
        b = torch.from_numpy(b).type(torch.FloatTensor)
        self.cnn_bias = nn.Parameter(b)

        #linear output
        self.fc = nn.Linear(num_filter_maps, label_space)
        xavier_uniform_(self.fc.weight)


    def forward(self,x):
        # pdb.set_trace()
        batch_size = x.shape[0]
        max_len = x.shape[1]
        with torch.no_grad():
            lengths = torch.count_nonzero(x,dim=-1).cpu()
        with torch.no_grad():
            conv = nn.Conv1d(50*self.embed_size,
                            self.num_filter_maps*50,
                            kernel_size=self.kernel_size,
                            padding=int(self.kernel_size//2))
        conv.weight.data = self.cnn_weight.repeat(50,1,1).data.clone()
        conv.bias.data = self.cnn_bias.repeat(50).data.clone()

        return something

and then during training we only update the self.cnn_weights and self.cnn_bias instead of the whole CNN layer.