Unable to Learn XOR Representation using 2 layers of Multi-Layered Perceptron (MLP)

Using PyTorch nn.Sequential model, I’m unable to learn all four representation of the XOR booleans:

import numpy as np

import torch
from torch import nn
from torch.autograd import Variable
from torch import FloatTensor
from torch import optim

use_cuda = torch.cuda.is_available()

X = xor_input = np.array([[0,0], [0,1], [1,0], [1,1]])
Y = xor_output = np.array([[0,1,1,0]]).T

# Converting the X to PyTorch-able data structure.
X_pt = Variable(FloatTensor(X))
X_pt = X_pt.cuda() if use_cuda else X_pt
# Converting the Y to PyTorch-able data structure.
Y_pt = Variable(FloatTensor(Y), requires_grad=False)
Y_pt = Y_pt.cuda() if use_cuda else Y_pt

hidden_dim = 5

model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                      nn.Linear(hidden_dim, output_dim),
                      nn.Sigmoid())
criterion = nn.L1Loss()
learning_rate = 0.03
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
num_epochs = 10000

for _ in range(num_epochs):
    predictions = model(X_pt)
    loss_this_epoch = criterion(predictions, Y_pt)
    loss_this_epoch.backward()
    optimizer.step()
    print([int(_pred > 0.5) for _pred in predictions], list(map(int, Y_pt)), loss_this_epoch.data[0])

After learning:

for _x, _y in zip(X_pt, Y_pt):
    prediction = model(_x)
    print('Input:\t', list(map(int, _x)))
    print('Pred:\t', int(prediction))
    print('Ouput:\t', int(_y))
    print('######')

[out]:

Input:	 [0, 0]
Pred:	 0
Ouput:	 0
######
Input:	 [0, 1]
Pred:	 1
Ouput:	 1
######
Input:	 [1, 0]
Pred:	 0
Ouput:	 1
######
Input:	 [1, 1]
Pred:	 0
Ouput:	 0
######

I’ve tried running the same code over a couple of random seeds but it didn’t manage to learn all for XOR representation.

Without PyTorch, I could easily train a model with self-defined derivative functions and manually perform the backpropagation, see https://www.kaggle.io/svf/2342536/635025ecf1de59b71ea4fa03eb84f9f9/results.html#After-some-enlightenment

Why is it that the 2-layered MLP using PyTorch didn’t learn the XOR representation?

Also asked on https://stackoverflow.com/questions/48619928/unable-to-learn-xor-representation-using-2-layers-of-multi-layered-perceptron-m

I think the main reason your model does not learn the XOR problem is, because it’s a linear model.
Try to add a non-linearity between both linear layers. Also, you could try to change the loss to nn.MSELoss, which might just learn better.
Have a look at this chapter from the deeplearning book.
Goodfellow et al. explain this problem pretty clear. Also they mention why a linear network won’t be able to learn the representation.

This code should work for you:

model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                      nn.ReLU(),
                      nn.Linear(hidden_dim, output_dim),
                      nn.Sigmoid())
model.cuda()
criterion = nn.MSELoss()
learning_rate = 1e-3
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
num_epochs = 10000

You might test some seeds, since this problem is sometimes a bit sensitive to the initializations.

2 Likes

Actually, what does it mean when I stack 2 linear layers with nn.Sequential(nn.Linear, nn.Linear) in PyTorch? Wouldn’t that make the learning already non-linear?

If I stack 2 linear layers with sigmoid activation, wouldn’t it make it non-linear? Or is it only specific to PyTorch where two linear layers are somehow smashed into one and learn a single linear function to map the
input to the output without a hidden layer?

If we look at https://www.kaggle.com/alvations/xor-with-mlp (at the end of the notebook, without pytorch), simply stacking two layers of perceptrons and doing a non-absolute L1 loss would allow non-linearity to set in. And the network would have learnt to represent the XOR function.

Also, written by hand, the math checked out for the two layered perceptrons: http://www.aclweb.org/anthology/S16-1148

Wouldn’t the non-linearly come directly from the fully connected layers between the hidden dimension and the in-/output? There shouldn’t be a need to force a non-linear activation between the hidden layer.

No, composition of linear functions is still linear. Imagine your first linear function represented by a matrix W, and a second linear function represented by matrix V. Then, f(x) = V * (W * x) = (V * W) * x, which is just a linear function over x.

In fact, in the example you linked, they do:

layer1 = sigmoid(np.dot(layer0, W1))
layer2 = sigmoid(np.dot(layer1, W2))

which introduces non-linearity between the layers.

Ah yes, sigmoid needs to be in between the Linear layers.

But even when I added the sigmoid activation:


from collections import Counter 
import random
random.seed(100)

import torch
from torch import nn
from torch.autograd import Variable
from torch import FloatTensor
from torch import optim
use_cuda = torch.cuda.is_available()


all_results=[]

for _ in range(100):
    hidden_dim = 2

    model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                          nn.Sigmoid(), # Does the sigmoid has a build in biased? 
                          nn.Linear(hidden_dim, output_dim),
                          nn.Sigmoid())

    criterion = nn.MSELoss()
    learning_rate = 0.03
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    num_epochs = 3000

    for _ in range(num_epochs):
        predictions = model(X_pt)
        loss_this_epoch = criterion(predictions, Y_pt)
        loss_this_epoch.backward()
        optimizer.step()
        ##print([float(_pred) for _pred in predictions], list(map(int, Y_pt)), loss_this_epoch.data[0])

    x_pred = [int(model(_x)) for _x in X_pt]
    y_truth = list([int(_y[0]) for _y in Y_pt])
    all_results.append([x_pred == y_truth, x_pred, loss_this_epoch.data[0]])


tf, outputsss, losses__ = zip(*all_results)
print(Counter(tf))

I’m getting a rate of 37 out of 100, the model successfully learnt XOR. Any ideas?

Suspicion is that L1Loss is not differentiable. But in the NumPy code, the loss criterion is set to the difference without the absolute, so the loss could be back propagated.

Reference: http://christopher5106.github.io/deep/learning/2016/09/16/about-loss-functions-multinomial-logistic-logarithm-cross-entropy-square-errors-euclidian-absolute-frobenius-hinge.html

But then again I get the same failure rate when I change the criterion to SmoothL1Loss, 19 out of 100 times success rate…

L1Loss is differentiable everywhere except at 0.

I notice you don’t reset the gradients on each epoch. You need to run optimizer.zero_grad() before running loss_this_epoch.backward() otherwise the gradients for each epoch just get added to the gradients from the previous epoch.

That messes with the training.

With this architecture:

model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                      nn.ReLU(), 
                      nn.Linear(hidden_dim, output_dim),
                      nn.Sigmoid())

and training the model with:

for _ in range(num_epochs):
    optimizer.zero_grad()
    predictions = model(X_pt)
    loss_this_epoch = criterion(predictions, Y_pt)
    loss_this_epoch.backward()
    optimizer.step()

leads to 0 out of 100 success rate in representing the XOR =(

I think it is really sensitive to the weight initialisations. Sometimes it succeeds in a few hundred epochs, but most times it converges to a partial solution and gets one case wrong.

If I increase the hidden_dim to 10, then it almost always succeeds, but if I set hidden_dim == 9 then the number of times it fails to learn all four cases increases considerably.

I tested this code…

import numpy as np

import torch
from torch import nn
from torch.autograd import Variable
from torch import FloatTensor
from torch import optim

use_cuda = torch.cuda.is_available()

X = xor_input = np.array([[0,0], [0,1], [1,0], [1,1]])
Y = xor_output = np.array([[0,1,1,0]]).T

# Converting the X to PyTorch-able data structure.
X_pt = Variable(FloatTensor(X))
X_pt = X_pt.cuda() if use_cuda else X_pt

# Converting the Y to PyTorch-able data structure.
Y_pt = Variable(FloatTensor(Y), requires_grad=False)
Y_pt = Y_pt.cuda() if use_cuda else Y_pt

input_dim = 2
hidden_dim = 5
output_dim = 1

model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                      nn.ReLU(),
                      nn.Linear(hidden_dim, output_dim),
                      nn.Sigmoid())
criterion = nn.L1Loss()
learning_rate = 0.03
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
num_epochs = 1000

try:
    for epoch in range(num_epochs):
        predictions = model(X_pt)
        loss_this_epoch = criterion(predictions, Y_pt)
        if loss_this_epoch.data[0] < 1e-8:
            break
        loss_this_epoch.backward()
        optimizer.step()
        print(epoch, [int(_pred > 0.5) for _pred in predictions], list(map(int, Y_pt)), loss_this_epoch.data[0])
except KeyboardInterrupt:
    pass

for _x, _y in zip(X_pt, Y_pt):
    prediction = model(_x)
    print('Input:\t', list(map(int, _x)))
    print('Pred:\t', int(prediction))
    print('Ouput:\t', int(_y))
    print('######')

As others have noted, the non-linearity between the two linear layers is essential. The composition of two linear layers is linear, but XOR cannot be computed with a linear function.

To get your network to learn XOR every time, you can do the following:

  1. Change your criterion to MSELoss. Unless you know exactly why you need it, don’t use L1Loss, since it usually requires more careful tuning of hyper-parameters.

  2. Change your optimizer to Adam. SGD requires more careful tuning of hyper-parameters and weight initialization. Adam with a learning rate between 0.001 and 0.0001 generally works well. In this case, I found that 0.001 converges fast.

Bug…

int(prediction) will equal 0 for prediction > .5 and < 1.

You probably want int(prediction > 0.5) here.

I ran some tests…
Notes…

  • Each trail is run 100 times
  • Each trial runs for a max of 1000 epochs because training generally stagnates before then.
  • Successes is the number of trial runs that succeeded in solving XOR
  • Unless otherwise stated each experiment uses the following settings
    • hidden_dim 5
    • activation between linear layers ReLU
    • final activation Sigmoid
    • loss L1Loss
    • learning_rate 0.03
    • optimizer SGD

Is MSELoss is better?

loss L1Loss
38 / 100 = 38.00% successes
loss MSELoss
78 / 100 = 78.00% successes

Conclusions: MSELoss is better for this model and this data.

Is Adam(lr=0.001) is better than SGD(lr=0.03)?

loss L1Loss
learning_rate 0.001
optimizer Adam
22 / 100 = 22.00% successes
2 / 100 = 2.00% all zero predictions
loss MSELoss
learning_rate 0.001
optimizer Adam
16 / 100 = 16.00% successes

Conclusion: For L1Loss and for MSELoss Adam(lr=0.001) is quite a lot worse than SGD(lr=0.03).

What about using Sigmoid instead of ReLU?

Motivation: ReLU is piecewise linear, so maybe it doesn’t help that much.

activation Sigmoid
loss L1Loss
18 / 100 = 18.00% successes
activation Sigmoid
loss MSELoss
87 / 100 = 87.00% successes

Conclusion: Sigmoid activation between the linear layers improves the model when using MSELoss, but worsens the model when using L1Loss.

What about ELU?

activation ELU
loss L1Loss
18 / 100 = 18.00% successes
activation ELU
loss MSELoss
100 / 100 = 100.00% successes

Wow! ELU activation between the layers is really beneficial when using MSELoss, but somewhat harmful when using L1Loss.

Any other ideas?

For reference: here is my ugly/hacky code

import numpy as np

import argparse

import torch
from torch import nn
from torch.autograd import Variable
from torch import FloatTensor
from torch import optim

use_cuda = torch.cuda.is_available()

X = xor_input = np.array([[0,0], [0,1], [1,0], [1,1]])
Y = xor_output = np.array([[0,1,1,0]]).T

# Converting the X to PyTorch-able data structure.
X_pt = Variable(FloatTensor(X))
X_pt = X_pt.cuda() if use_cuda else X_pt

# Converting the Y to PyTorch-able data structure.
Y_pt = Variable(FloatTensor(Y), requires_grad=False)
Y_pt = Y_pt.cuda() if use_cuda else Y_pt

input_dim = 2
output_dim = 1

parser = argparse.ArgumentParser(description='PyTorch example')
parser.add_argument('--loss', type=str, default="L1Loss")
parser.add_argument('--hidden_dim', type=int, default=5)
parser.add_argument('--learning_rate', type=float, default=0.03)
parser.add_argument('--activation', type=str, default="ReLU")
parser.add_argument('--optimizer', type=str, default="SGD")
args = parser.parse_args()
print(args)

hidden_dim = args.hidden_dim
criterion = getattr(nn, args.loss)()
learning_rate = 0.03
activation_between_linear_layers = getattr(nn, args.activation)
optimizer_class = getattr(optim, args.optimizer)

num_epochs = 1000
num_trials = 100
count_successes = 0
count_all_zeros = 0

try:
    for trial in range(num_trials):
        print("Trial", trial + 1, end="\r")
        model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                              activation_between_linear_layers(),
                              nn.Linear(hidden_dim, output_dim),
                              nn.Sigmoid())
        optimizer = optimizer_class(model.parameters(), lr=learning_rate)

        try:
            for epoch in range(num_epochs):
                predictions = model(X_pt)
                loss_this_epoch = criterion(predictions, Y_pt)
                if loss_this_epoch.data[0] < 1e-8:
                    break
                loss_this_epoch.backward()
                optimizer.step()
        except KeyboardInterrupt:
            raise KeyboardInterrupt()

        success = 0
        all_zero = True
        for _x, _y in zip(X_pt, Y_pt):
            prediction = model(_x)
            success += 25*(int(prediction > .5)==_y).data[0]
            all_zero = all_zero and int(prediction <= .5)
        count_successes += (success == 100)
        count_all_zeros += all_zero
except KeyboardInterrupt:
    pass
print()
print("hidden_dim", args.hidden_dim)
print("activation", args.activation)
print("loss", args.loss)
print("learning_rate", args.learning_rate)
print("optimizer", args.optimizer)
trial += 1
print("%d / %d = %.2f%% successes" % (count_successes, trial, 100.*count_successes/trial))
print("%d / %d = %.2f%% all zero predictions" % (count_all_zeros, trial, 100.*count_all_zeros/trial))
1 Like

@jpeg729, you’re missing optimizer.zero_grad() at the start of the training loop. The results will change quite a bit if you add it.

python xor.py --loss MSELoss --learning_rate 0.001 --activation Tanh --optimizer Adam
100 / 100 = 100.00% successes

python xor.py --loss MSELoss --learning_rate 0.001 --activation Sigmoid --optimizer Adam
99 / 100 = 99.00% successes

python xor.py --loss MSELoss --learning_rate 0.001 --activation ReLU --optimizer Adam
77 / 100 = 77.00% successes

Oh. Oops. And I’m the one who pointed out that it was missing in the original code. :thinking:

I tried again with ELU and SGD, and got 2% success. Tuning the learning rate does nothing much, but adding momentum=.9 gives 98% success with SGD.

I suppose all this goes to show the vast number of ways we can alter our models and get widely differing results.

That’s probably because not zeroing the gradient has the same effect as adding momentum: it retains some of the gradient from the previous training step.

Exactly what I thought. Though the fact that not zeroing the gradients actually worked, suggests that with ELU activation and MSELoss the problem was nicely convex, so that repeatedly taking steps in roughly the same direction pretty nearly always led to the solution.

It’s pretty much overkilling but sweeping Activation X Loss X Optimizer:


from itertools import product
from collections import Counter 

import time

import random
random.seed(100)

import numpy as np

import torch
from torch import nn
from torch.autograd import Variable
from torch import FloatTensor
from torch import optim
use_cuda = torch.cuda.is_available()

# Activation functions.
from torch.nn import ReLU, ReLU6, ELU, SELU, LeakyReLU
from torch.nn import Hardtanh, Sigmoid, Tanh, LogSigmoid
from torch.nn import Softplus, Softshrink, Tanhshrink, Softmin
from torch.nn import Softmax, LogSoftmax # Softmax2d


# Loss functions.
from torch.nn import L1Loss, MSELoss # NLLLoss, CrossEntropyLoss
from torch.nn import PoissonNLLLoss, KLDivLoss, BCELoss
from torch.nn import BCEWithLogitsLoss, HingeEmbeddingLoss # MarginRankingLoss
from torch.nn import SmoothL1Loss, SoftMarginLoss # MultiLabelMarginLoss, CosineEmbeddingLoss, 
from torch.nn import MultiLabelSoftMarginLoss # MultiMarginLoss, TripletMarginLoss

# Optimizers.
from torch.optim import Adadelta, Adagrad, Adam, Adamax # SparseAdam
from torch.optim import ASGD, RMSprop, Rprop # LBFGS

Activations = [ReLU, ReLU6, ELU, SELU, LeakyReLU, 
                Hardtanh, Sigmoid, Tanh, LogSigmoid,
                Softplus, Softshrink, Tanhshrink, Softmin, 
                Softmax, LogSoftmax]

Criterions = [L1Loss, MSELoss,
              PoissonNLLLoss, KLDivLoss, BCELoss,
              BCEWithLogitsLoss, HingeEmbeddingLoss,
              SmoothL1Loss, SoftMarginLoss,


Criterions = [L1Loss, MSELoss,
              PoissonNLLLoss, KLDivLoss, BCELoss,
              BCEWithLogitsLoss, HingeEmbeddingLoss,
              SmoothL1Loss, SoftMarginLoss,
              MultiLabelSoftMarginLoss]

Optimizers = [Adadelta, Adagrad, Adam, Adamax,
             ASGD, RMSprop, Rprop]

X = xor_input = np.array([[0,0], [0,1], [1,0], [1,1]])
Y = xor_output = np.array([[0,1,1,0]]).T

# Converting the X to PyTorch-able data structure.
X_pt = Variable(FloatTensor(X))
X_pt = X_pt.cuda() if use_cuda else X_pt
# Converting the Y to PyTorch-able data structure.
Y_pt = Variable(FloatTensor(Y), requires_grad=False)
Y_pt = Y_pt.cuda() if use_cuda else Y_pt

# Use FloatTensor.shape to get the shape of the matrix/tensor.
num_data, input_dim = X_pt.shape
num_data, output_dim = Y_pt.shape

learning_rate = 0.03
hidden_dim = 5
num_epochs = 10000
num_experiments = 100



for Activation, Criterion, Optimizer in product(Activations, Criterions, Optimizers):
    all_results=[]
    start = time.time()
    for _ in range(num_experiments):
        model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                              Activation(), 
                              nn.Linear(hidden_dim, output_dim),
                              nn.Sigmoid())
        model = model.cuda() if use_cuda else model
        criterion = Criterion()
        optimizer = Optimizer(model.parameters(), lr=learning_rate)
        
        for _e in range(num_epochs):
            optimizer.zero_grad()
            predictions = model(X_pt)
            loss_this_epoch = criterion(predictions, Y_pt)
            loss_this_epoch.backward()
            optimizer.step()
            ##print(_e, [float(_pred) for _pred in predictions], list(map(int, Y_pt)), loss_this_epoch.data[0])

        x_pred = [int(model(_x) > 0.5) for _x in X_pt]
        y_truth = list([int(_y[0]) for _y in Y_pt])
        all_results.append([x_pred == y_truth, x_pred, loss_this_epoch.data[0]])

    tf, outputsss, losses__ = zip(*all_results)
    print(Activation, Criterion, Optimizer, Counter(tf), time.time() - start)

I’ve some numbers from the sweep: https://github.com/alvations/ixora/blob/master/xor.output

It seems like nothing gets close to the Python implementation on https://www.kaggle.com/alvations/xor-with-mlp where I get 100% all the time.

I tried out your code and got your results of non-convergence. I then upped the hidden dimension to 20, and I got convergence 100% of the time. Can you confirm that you see that also before we try to figure out why you can converge with a bigger one-layer hidden dimension? (I think it has to do with saddle points but first I want to know you get convergence with a bigger net.)

Also, you don’t need to try out so many losses and Activations. For this problem, BCELoss() and RELU will work just fine.