Let’s say i have a data “x” with shape [32, 3, 32, 32] and i split it into 3 parts channel-wise/second index, with shapes [32,1,32,32] each. My current model has 3 depth-wise linear layers facing the input each of which takes a copy of the original data x . Now, I want each of the 3 depth-wise layers to take a slice of data without needing to fix the number of input channels. Let’s say self.cin1 = x1.shape[1], self.cin2 = x2.shape[2] and self.cin3=x3.shape[3] where x1,x2,x3 are the 3 data splits. I split the data in the forward funtion as :

def forward(self, x):

x1, x2, x3 = data_splitter(x)

l1 = self.linear(x1)

l2 = self.linear(x2)

l3 = self.linear(x3)

…