Hi Everyone,
I am solving a multi-subjects multi labels classification problem with a large number of subjects and labels. To explain my problem, I will use the image classification problem as an example.
Say, within a single image, I have (M) subjects and (N) different labels per each subject. My final output layer is a softmax function that takes a matrix of dimensions (M by N) (where N > M). Now, for the sake of the problem that I am solving, I need the first forward output (only the first!!) of the softmax to be an identity matrix (i.e., subject (i) is given label (i)). The way I am forcing this now is that I am introducing a linear layer before the softmax that gives (M * N) outputs, and I set the wights of this layer to be zero and the bias to be a reshaped version of the identity (in fact 10 * torch.eye(M,N)) which clearly grows very large as M and N increase. For small values of M and N, this idea is working perfectly; however, as M and N increase, I am running out of memory because of this inefficient way. For example, in my case, M =600 to 800, and N=2000 to 3000.
I am wondering if there is another efficient way to do that.
A sample code of the idea is as follows:
class network(nn.Module):
def __init__(self, M,N, K, num_features):
super(network, self).__init__()
# M subjects
# N labels per subject
self.conv1 = nn.Conv1d(num_features[0], num_features[1], kernel_size=1)
self.conv2 = nn.Conv1d(num_features[1], num_features[2], kernel_size=1)
self.fc1 = nn.Linear(num_features[2]*K, M)
self.fc2 = nn.Linear(M,M*N)
# initialize the weights and bias of fc2 to ensure that the output of the softmax
# is the identity matrix at the first forward pass
self.fc2.weight.data.zero_()
self.fc2.bias.data.copy_((50)*torch.eye(M,N, dtype=torch.float).reshape(1,-1).squeeze(0))
self.activation = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, image):
x = self.activation(self. conv1(image))
x = self.activation(self.conv2(x))
x = x.reshape(1,-1)
x = self.fc1(x)
x = self.fc2(x).reshape(M,N)
mult_class = self.softmax(x)
return mult_class
Main Code
M=4
N=3
K=6
num_features = [4,8,16]
image = torch.randn(1,num_features[0],K)
net = network(M,N, K, num_features)
mult_classes = net(image)