Layer Initialisation in multi-subjects multi labels classification problem

Hi Everyone,

I am solving a multi-subjects multi labels classification problem with a large number of subjects and labels. To explain my problem, I will use the image classification problem as an example.
Say, within a single image, I have (M) subjects and (N) different labels per each subject. My final output layer is a softmax function that takes a matrix of dimensions (M by N) (where N > M). Now, for the sake of the problem that I am solving, I need the first forward output (only the first!!) of the softmax to be an identity matrix (i.e., subject (i) is given label (i)). The way I am forcing this now is that I am introducing a linear layer before the softmax that gives (M * N) outputs, and I set the wights of this layer to be zero and the bias to be a reshaped version of the identity (in fact 10 * torch.eye(M,N)) which clearly grows very large as M and N increase. For small values of M and N, this idea is working perfectly; however, as M and N increase, I am running out of memory because of this inefficient way. For example, in my case, M =600 to 800, and N=2000 to 3000.

I am wondering if there is another efficient way to do that.

A sample code of the idea is as follows:

class network(nn.Module): 
  def __init__(self, M,N, K, num_features):
    super(network, self).__init__()
    # M subjects
    # N labels per subject
    self.conv1 =  nn.Conv1d(num_features[0], num_features[1], kernel_size=1)  
    self.conv2 =  nn.Conv1d(num_features[1], num_features[2], kernel_size=1)

    self.fc1 = nn.Linear(num_features[2]*K, M)     
    self.fc2 = nn.Linear(M,M*N)
    
    # initialize the weights and bias of fc2 to ensure that the output of the softmax 
    # is the identity matrix at the first forward pass
    self.fc2.weight.data.zero_()
    self.fc2.bias.data.copy_((50)*torch.eye(M,N, dtype=torch.float).reshape(1,-1).squeeze(0))
    
    
    self.activation = nn.ReLU() 
    self.softmax = nn.Softmax(dim=1)

def forward(self, image):
    x = self.activation(self. conv1(image))
    x = self.activation(self.conv2(x))
    x = x.reshape(1,-1)
    x = self.fc1(x)
    x = self.fc2(x).reshape(M,N)
    mult_class = self.softmax(x)

    return mult_class

Main Code

M=4
N=3
K=6
num_features = [4,8,16]

image = torch.randn(1,num_features[0],K)
net = network(M,N, K, num_features)
mult_classes = net(image)