I am using PyTorch and am still quite new to the library.
I have a relation between my input and output given by y = ax + b
, where a
and b
are sampled from some distribution (say Uniform), that is, they are random. I would like to train a network to predict x
upon seeing y
and a
. I am employing a network, named probability_network
, with nn.Linear
layers. There are N
(say 10) classes to choose from for x
.
class ProabilityNetwork(nn.Module):
def __init__(self):
super(ProabilityNetwork, self).__init__()
self.fc1 = nn.Linear(8, 76)
self.fc2 = nn.Linear(76, 150)
self.fc3 = nn.Linear(150, 75)
self.fc4 = nn.Linear(75, 14)
self.fc5 = nn.Linear(14, 10)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, inputs):
return self.softmax(self.fc5(self.fc4(self.relu(self.fc3(self.relu(self.fc2(self.relu(self.fc1(inputs)))))))))
probabilty_network = ProbabilityNetwork()
Upon seeing y
, the loss function should help the network predict an x
that minimizes ||y-ax||^2
. All the quantities in y = ax + b
are vectors (each of length 4, in this example). I have already tried the following loss function.
prob_values = probabilty_network(torch.cat([y, a], dim=0)) # shape: (batch_size, 10)
x_hat = mapping_tensor[torch.argmax(prob_values, dim=1)] # Mapping from probability to one of 10 classes, mapping_tensor is an array of shape (10, 4)
mse_loss = nn.MSELoss()
loss = mse_loss(y, a*x_hat)
As an example, the mapping_tensor
could contain binary representation of values from 0
(0000
) to 9
(1001
). The reason I need the binary representation of the class is that I need a vector x
for the loss ||y-ax||^2
. In this case, x
is a 4
length vector whereas the output of the neural network is a 10
length vector.
The above setup doesn’t work. Half the values in the predicted class (written out in binary) are always wrong, implying the network is confused while training.
Further, this is not an unachievable problem. Solution to the above mentioned loss function exists (with errors, of course, but the errors are far lesser than 50%) but is computationally intensive. I am trying to check if the network can somehow learn to predict at lower complexity. Any help is appreciated. Thanks.
Furthermore, from an optimization standpoint, the loss function is the best possible solution (that I know of). So, changing the loss function would only lead to poorer results.
Another way to look at the problem is as follows. Suppose the network sees y
and a
. The network then computes |y-ax|
for each class x
(out of 10 possible classes), and then chooses the class with least value of the computed value. And my question is, what loss function can I use to make the network train in this fashion?