Beginner advice on a CNN

Hi all, I am new to neural networks and Pytorch and have a problem that I hope someone can help me with.

After trying the standard MNIST digit problem, I have been working on this chess-positions dataset Chess Positions | Kaggle

First, I trained a CNN to identify chess pieces on individual squares using a balanced dataset of each peice type (incl empty spaces). This got me up to 99% accuracy but that is not sufficient to robustly identify all the pieces on a chess board as 0.99^64=0.53.
I improved things by submitting each board as a batch of 64 images of the individual squares from the same board. In this way, I reasoned that that the network could better distinguish distinguish pieces, especially light from dark, because each batch has the same board style (the dataset contains a mix of both chess piece styles and chess board styles). This worked and I got to 0.9999 accurary on individual squares which gave the expected 0.9999^64=0.99 accuracy on entire boards.

However to play around further I want to make a network that takes in the entire board but splits the input into the individual squares, runs them all through the same convolutional layers and then uses one or two final linear layers to combine the outputs together and correct for any errors due to the differerent light/dark backgrounds.

This is the model

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

class Net2(nn.Module):
    def __init__(self,device):
        super(Net2, self).__init__()
        
        self.device=device
        self.conv=nn.Sequential(
            nn.Conv2d(3,32,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
    
            nn.Conv2d(32,64,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
    
            nn.Dropout2d(0.25),
            nn.Flatten(),
    
            nn.Linear(5184,128),
            nn.ReLU(),
            nn.Dropout(0.5)
    
            )

      # fully connected layer that outputs the logits for our 13 labels for each of the 64 squares
        self.fc2 = nn.Linear(128*64, 13*64)
    
    def forward(self, x):
        
        rows=torch.split(x,50,dim=2)#tuple of 8 Nx3x50x500 tensors
        
        out=torch.empty((rows[0].shape[0],128*64),dtype=torch.float32,device=self.device)

        for i in range(8):
            squares=torch.split(rows[i],50,dim=3)#tuple of 8 Nx3x50x50 tensors
            for j in range(8):
                s=8*i+j
                out[:,s*128:(s+1)*128]=self.conv(squares[j])
                
        out = self.fc2(out)
    
        output=out.reshape(-1,13,64)

        return output

my_nn2 = Net2(device)
print(my_nn2)

and I use

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(my_nn2.parameters(), lr=0.001, momentum=0.9)

I train it like this

n_epochs=10
accuracy_train=np.zeros(n_epochs,)
batch_size=100
N_samples=x_train.shape[0]
all_predictions=np.ndarray((N_samples,64),dtype='int64')

for epoch in range(n_epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i in range(0,N_samples,batch_size):
        
        batch=torch.tensor(x_train[i:(i+batch_size),:,:,:],dtype=torch.float32).to(device)/255
        labels=torch.tensor(y_train[i:(i+batch_size),:],dtype=torch.long).to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = my_nn2(batch)#batch_sizex13x64
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        
        # print statistics
        running_loss += loss.item()
        if i% 2000 == 0:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            

#after each epoch, run again on the training data
    with torch.no_grad():
        for i in range(0,N_samples,batch_size):
            batch=torch.tensor(x_train[i:(i+batch_size),:,:,:],dtype=torch.float32).to(device)/255
            outputs2 = my_nn2(batch)
            _, batch_predictions = torch.max(outputs2.data, 1)
            all_predictions[i:(i+batch_size),:]=batch_predictions.cpu().numpy()
        
        accuracy_train[epoch]=sum(np.all(y_train==all_predictions,axis=1))/x_train.shape[0]
        
    
    print(f'[{epoch + 1}, {i + 1:5d}] training accuracy: {accuracy_train[epoch]}')

print('Finished Training')

The labels consist of 13 integers representing the light and dark pieces and an empty space.

My problem is that training does not seem to be working. The model converges to predicting entirely empty boards. There are only 5-15 pieces per boards so the majority of squares are empty but this was not a problem with my previous network that took in individual squares.

Does anyone have any suggests or see any mistakes in my code?

Thanks
Sean

1 Like

Hi Sean!

I could be wrong, but this part about submitting the entire board as a batch
doesn’t make a lot of sense to me. When you run a batch through a model,
the individual samples in the batch don’t interact with one another. So when
the model is processing one square (as part of the batch) to determine, say,
whether it’s a light piece or dark piece, the model isn’t looking at other pieces
from the same board – both light and dark – for comparison (even though
they are in the same batch).

You don’t say how or whether you were batching squares before you started
using your board-as-batch scheme. Is it possible that the improvement you
saw was just from training with a batch size of 64 and not that the batch
elements were all from the same board?

As I understand it, the various boards have different backgrounds and the
idea is that training where something is known about the difference between
a light square and dark square could make it easier to identify the specific
pieces. This is plausible. However, the scheme you suggest below doesn’t
seem promising to me (some comments below).

Some ideas that come to mind:

Plan A: Train on pairs of square. Each row of eight squares consists of
four pairs of adjacent squares. Now each sample will consist of a light
square and a dark square (with the image being a rectangle that is twice
as wide as it is high), letting the model learn the difference between the
two. The disadvantage is that you will now have 169 (= 13**2) classes.
This will make training harder and also make it harder to compensate for
class imbalance.

Plan B: Do the following preprocessing: For each square you submit for
classification, also submit an empty opposite-color square (snipped from
the same board the sample square comes from). Let’s say that your board
images are RGB so they have three channels. Then I would build a
six-channel sample where the first three channels are the the square you
actually wish to classify and the last three channels are from the empty,
opposite-color square (from the same board) that you are providing as
reference.

At this point, you are passing 64 length-128 feature vectors to self.fc2. one
feature vector for each square in the board. However, all self.fc2 sees is a
single, flat, length-8192 feature vector and it has no idea which features
should get grouped together into a set of 128 that correspond to a single
square. “Learning” this grouping seems hard to me, so I’m not surprised
that this scheme doesn’t work well.

This doesn’t surprise me. If you buy my argument, above, that your model
is prevented from learning about individual squares because the features
from the individual squares are all mixed together, then about the only thing
your model can learn is that there are more empty squares than squares
with pieces. So the best it can do is predict empty boards.

Note, not only is the individual-square problem easier, but you said, above,
that you trained on a balanced dataset. So that model would not have any
particular tendency to predict empty squares.

Best.

K. Frank

Hi K. Frank,

Thanks for your reply.
Before I started doing the board-as-batch scheme, I collected the pieces from each board (taking at most one of each piece type) until I built up 2000 images of each piece type and the empty space. So this data was balanced. However, I didn’t shuffle it so the dataset does fall into groups of a few peices coming from the same board. The batch size was 100.

The second board-as-batch attempt used all squares from each (actually a subset of) boards (so batch size of 64). So was therefore not balanced - empty spaces occur much more frequently than occupied. Yet gave higher accuracy. I chose the number of boards such that approximately the same total number of pieces are expected in the dataset as in the previous approach. While the model is processing one square it is true that it is not looking at other pieces however the model tries to optimise for the whole batch from the same model parameters. So its trying to tune the model to classify these pieces and squares which each have fixed styles. By repeatedly training on batches likes this, I thought the model would learn to ignore what differs between the batches (the two styles) and rather focus on optimising what the batches have in common i.e. the piece and square types. I could indeed be wrong here and it could be due to the differing batch sizes (actual or effective in the sense that the board-as-batch model sees only 5-15 pieces per batch so overall had more updates). Is the number of updates (calls to the optimiser) the import thing for accuracy? I understand that filling the GPU memory with a large batch is more efficient but does it result in more epochs/data being required to get suffiicent updates?

I like your plan B. I was not thinking in that way. I’ll have to try it. Thanks for the sugestion.

This doesn’t surprise me. If you buy my argument, above, that your model
is prevented from learning about individual squares because the features
from the individual squares are all mixed together, then about the only thing
your model can learn is that there are more empty squares than squares
with pieces. So the best it can do is predict empty boards.

Regarding this last layer, I was thinking that the 128 outputs from each square would get mixed to improve the classification and solve for the issue of overall darker or lighter chess sets. I though the model would learn which features to intermix to best match the 64 output vectors to the 64 target labels. For interests sake, how would you go about it?

One thing I was thinking about during this and a motivation for going to a model that acts on boards is accuracy. To get 99.9% accuracy on boards (this is a synthetic dataset so should be possible) I would need 99.998% accuracy on squares. Such high accuracy means almost all batches would be perfectly predicted and so the model would mostly not be updated. It would take longer and longer to get incremental improvements in accuracy. This should of course be well known. Do people ever subset the training dateset after each epoch to increase the fraction of entires that are not perfectly classified and thereby improve covergence to very high accuracies? Is there also a danger of the loss getting so small that numerical accuracy (32bit floats) becomes an issue?

Thanks for your help.
Sean

P.S. My notebook with the first two approaches is on kaggle Identify the layout of a chess board | Kaggle

Hi Sean!

This is in a sense true, but, again, there is no “cross-talk” between the
samples in a batch.

The point is that training is a sequence of optimization steps for a sequence
of batches. To first approximation, taking an optimization step for the first
halves of batches A and B combined into a single batch and then taking
an optimization step for the second halves of the batches is pretty much
the same as taking a step for batch A and then taking a step for batch B.
And to the extent that they are different, my intuition tells me that taking
steps with two more diversified batches will be better than taking steps with
two batches that are different from one another, but that are individually more
homogeneous.

It matters, but it’s definitely not the only (or most) important thing. There are
a lot of cross-currents in how batch size and optimizer steps interact in how
they affect training, but, crudely, taking ten steps with a batch size of one and
a learning rate of, say, 0.001 is about the same as taking one step with a
batch size of ten and a learning rate of 0.01. Arguable, the former would
be a little better (but there are a lot of cross-currents).

To first approximation (assuming you adjust your learning rate appropriately)
the number of epochs (epoch = once through your entire training dataset)
you train for is what matters.

As a practical matter, the likely significant speed-up you get by keeping your
gpu pipelines as full as possible will greatly outweigh any minor disadvantage
you might get by using the larger batch size. Even if you need a few more
epochs to get the same result, the clock time it takes per epoch is likely to
be much less if you are keeping your gpus full.

In principle, your model could learn this “intermixing.” (Note that at the end
of your convolutional section, you have very useful information about this
intermixing, namely, which sets of 128 features group together to describe
single, individual squares. But in the connection between your convolutional
section and the last layer, you throw that very useful information away.)

Train for a long time, potentially with more data. In general, the harder the
learning task, the more training and more data you need.

Also, you would probably want (at least) two fully-connected (Linear) layers
separated by non-linear “activations.” The non-linear activation layers are
essential to how neural networks work (because, in isolation, linear layers
are only capable of “learning” linear relationships).

This is not exactly right. Your model doesn’t predict specific pieces for each
square. Instead is predicts (unnormalized log-probabilities that correspond
to) probabilities for which piece is on the square. When you compute your
accuracy, you take the highest probability to get a “hard,” non-probabilistic
for which piece is on the square.

A (probabilistic) prediction of 94% for “white knight” and 0.5% for the other
twelve choices would yield the same hard prediction of “white knight” as a
prediction of 40% “white knight” and 5% for each of the others. In each
case you would get 100% accuracy (assuming that “white knight” is correct),
but the first probabilistic prediction is much better. Your loss function can
tell the two apart and further training – even after you’ve reached 100%
accuracy – will improve your model.

If I understand what you a proposing, yes, this is done sometimes. It is
called “hard mining,” where you train more heavily on samples that your
model has had a hard time classifying.

No. With single precision, your loss can go down to about 1.e-38 before
it underflows to zero. CrossEntropyLoss is a logarithmic loss. A “perfect”
prediction of a 100% probability for the correct class corresponds to a
loss of log (1.0) = 0.0. An ever-so-slightly-incorrect prediction of a
probability of 1 - 1.e-38 would correspond to the numerically non-zero
(that is, hasn’t underflowed to zero yet) loss of 1.e-38. So, in practice,
your predictions would have to more accurate than you could ever hope
for them to be before you would have any risk of your loss underflowing
to zero.

Best.

K. Frank

Hi K. Frank,

Thanks a lot for your clarifications and corrections. I very much appreciate it.

I have another follow-up question if I may. How does you plan B work for testing or actually using the model since given an unknown board, I don’t know a priori which squares are empty?

Best wishes
Sean

Hi Sean!

Yes, you would have to do some pre-processing to find an empty square.

Plan B1: Use your high-accuracy, individual-square classifier to find an
empty square. (You can simply use the location of the square in the board
to ensure that it is of opposite color.) Presumably your individual-square
classifier is especially accurate in distinguishing empty from non-empty
squares.

Plan B2: I believe that simply computing the standard deviation of pixel
values across an individual square (or across a cropped square, if the
square has a border) would give you a highly-accurate empty-square
detector without any machine-learning shenanigans. Use that to find an
empty square to pair with the square you are actually classifying.

Best.

K. Frank

1 Like

OK. I thought maybe there was a way to train the model with the empty square but then use the model without it.

Thanks
Sean