I am solving a multi-subjects multi labels classification problem with a large number of subjects and labels. To explain my problem, I will use the image classification problem as an example.
Say, within a single image, I have (M) subjects and (N) different labels per each subject. My final output layer is a softmax function that takes a matrix of dimensions (M by N) (where N > M). Now, for the sake of the problem that I am solving, I need the first forward output (only the first!!) of the softmax to be an identity matrix (i.e., subject (i) is given label (i)). The way I am forcing this now is that I am introducing a linear layer before the softmax that gives (M * N) outputs, and I set the wights of this layer to be zero and the bias to be a reshaped version of the identity (in fact 10 * torch.eye(M,N)) which clearly grows very large as M and N increase. For small values of M and N, this idea is working perfectly; however, as M and N increase, I am running out of memory because of this inefficient way. For example, in my case, M =600 to 800, and N=2000 to 3000.
I am wondering if there is another efficient way to do that.
A sample code of the idea is as follows:
class network(nn.Module): def __init__(self, M,N, K, num_features): super(network, self).__init__() # M subjects # N labels per subject self.conv1 = nn.Conv1d(num_features, num_features, kernel_size=1) self.conv2 = nn.Conv1d(num_features, num_features, kernel_size=1) self.fc1 = nn.Linear(num_features*K, M) self.fc2 = nn.Linear(M,M*N) # initialize the weights and bias of fc2 to ensure that the output of the softmax # is the identity matrix at the first forward pass self.fc2.weight.data.zero_() self.fc2.bias.data.copy_((50)*torch.eye(M,N, dtype=torch.float).reshape(1,-1).squeeze(0)) self.activation = nn.ReLU() self.softmax = nn.Softmax(dim=1) def forward(self, image): x = self.activation(self. conv1(image)) x = self.activation(self.conv2(x)) x = x.reshape(1,-1) x = self.fc1(x) x = self.fc2(x).reshape(M,N) mult_class = self.softmax(x) return mult_class
num_features = [4,8,16]
image = torch.randn(1,num_features,K)
net = network(M,N, K, num_features)
mult_classes = net(image)