Loss is the same no matter of network

I am trying to classify (2 classes) images in my dataset.
First of all, my datasets were highly imbalanced (0.97 against 0.03) and when I tried to train network, test showed me the same numbers for every item ([~0.97 for first class, ~0.03 for the second class]). Loss was around 0.6 every epoch, no matter of learning rate (from 1e-5 to 3).
Then I balanced my classes (50/50) but loss stayed exactly the same! Around 0.6 every epoch, no matter of learning rate (from 1e-5 to 3). Results on the test became also [~0.5 for first class, ~0.5 for the second class]).
Then I tried to train second network (first was convolutional one), this was fc. Results were EXACTLY the same.
What can be problem here?
Here is the code:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5, 2)
        self.conv2 = nn.Conv2d(30, 50, 5, 1)
        # self.conv3 = nn.Conv2d(30, 50, 5, 1)
        self.dropout = nn.Dropout2d()
        self.fc1 = nn.Linear(5*5*50, 100)
        self.fc2 = nn.Linear(100, 2)
        

    def forward(self, x):
        x = F.relu(self.conv1(x))
        
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        #print(x.shape)
        #x = F.relu(self.conv3(x))
        #x = F.max_pool2d(x, 2, 2)
        # print(x.shape)
        x = x.view(-1, 5*5*50)
        x = F.relu(self.fc1(x))
        # x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        #print(x.data.max(), x.data.min() )
        return F.sigmoid(x)

class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()       
        self.fc1 = nn.Linear(64*64, 8*8)
        self.fc2 = nn.Linear(64, 2)
        

    def forward(self, x):
        x = x.view(-1, 64*64)
        x = F.relu(self.fc1(x))
        # x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        #print(x.data.max(), x.data.min() )
        return F.sigmoid(x)

net = Net2().cuda()

class PillDataset(Dataset):
    def __init__(self):
        self.X = []
        self.y = []
        self.transform = transforms.Compose([transforms.RandomAffine(degrees=[-40,40],
                                                                    translate = [0.1,0.1]),
                                            
                                            transforms.RandomHorizontalFlip(p=0.6),
                                             transforms.RandomVerticalFlip(p=0.5),
                                             transforms.Scale(64),
                                             transforms.ToTensor()])

        filenames = os.listdir('./correct/')
        filenames = np.random.choice(filenames, 1000)
        for fn in filenames:
            self.X.append('./correct/'+fn)
            #self.y.append([1, 0])
            self.y.append(1)
        filenames = os.listdir('./defect/')
        for fn in filenames:
            self.X.append('./defect/'+fn)
            #self.y.append([0, 1])
            self.y.append(2)
    def __getitem__(self, index):
        image = Image.open(self.X[index])
        # image = cv2.Laplacian(image,cv2.CV_64F)
        image = self.transform(image)
        label = torch.tensor(self.y[index])
        return image, label

    def __len__(self):
        return len(self.X)


full_dataset = PillDataset()
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_set, test_set = torch.utils.data.dataset.random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_set,
                          batch_size=256,
                          shuffle=True,
                          num_workers=0,
                         pin_memory=True 
                         )

optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def train(epoch):
    running_loss = 0.0
    net.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(async=True), target.cuda(async=True)
        
        optimizer.zero_grad()
        output = net(data)

        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx == 0:
            print('Train Epoch: {} [{}]'.format(epoch, running_loss))
            running_loss = 0.0

for epoch in range(1, 30):
    train(epoch)

There is a small mismatch between the output non-linearity and the criterion.

If you would like to use nn.CrossEntropyLoss, you would have to pass the raw logits (without any non-linearity) to the criterion, as internally F.log_softmax will be applied.

Alternatively, you could use one output neuron, apply F.sigmoid, and use nn.BCELoss as your criterion.

1 Like

Thanks for your answer!
I changed net to log_softmax function - but nothing changed, everything stayed the same.

If you use F.log_softmax(x, 1), then you would have to use nn.NLLLoss as your criterion.
For a standard multi-class classification use case you would have these options:

  • raw logits (no output non-linearity) + nn.CrossentropyLoss
  • F.log_softmax + nn.NLLLoss

Let me know, if this fixes your training.

1 Like

I used F.log_softmax(x, 1) and changed labels of classes to 1 and -1, now everything works! Thanks a lot!

Good to hear it’s working, however the labels should contain class indices in the range [0, nb_classes-1] for nn.NLLLoss. Are you using nn.NLLLoss as your criterion?

1 Like

I still use CrossEntropyLoss…
Will switch now to NLLLoss!

Now it looks like this:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5, 2)
        self.conv2 = nn.Conv2d(30, 50, 5, 1)
        # self.conv3 = nn.Conv2d(30, 50, 5, 1)
        self.dropout = nn.Dropout2d()
        self.fc1 = nn.Linear(5*5*50, 100)
        self.fc2 = nn.Linear(100, 2)
        self.sm = nn.LogSoftmax()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        #print(x.shape)
        #x = F.relu(self.conv3(x))
        #x = F.max_pool2d(x, 2, 2)
        # print(x.shape)
        x = x.view(-1, 5*5*50)
        x = F.relu(self.fc1(x))
        # x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        #print(x.data.max(), x.data.min() )
        return F.log_softmax(x, dim=1)

class PillDataset(Dataset):
    def __init__(self):
        self.X = []
        self.y = []
        self.transform = transforms.Compose([transforms.RandomAffine(degrees=[-40,40],
                                                                    translate = [0.1,0.1]),
                                            
                                            transforms.RandomHorizontalFlip(p=0.6),
                                             transforms.RandomVerticalFlip(p=0.5),
                                             transforms.Scale(64),
                                             transforms.ToTensor()])

        filenames = os.listdir('./correct/')
        filenames = np.random.choice(filenames, 1000)
        for fn in filenames:
            self.X.append('./correct/'+fn)
            self.y.append([1, 0])
            # self.y.append(1)
        filenames = os.listdir('./defect/')
        for fn in filenames:
            self.X.append('./defect/'+fn)
            self.y.append([0, 1])
            #self.y.append(-1)
    def __getitem__(self, index):
        image = Image.open(self.X[index])
        # image = cv2.Laplacian(image,cv2.CV_64F)
        image = self.transform(image)
        label = torch.tensor(self.y[index])
        return image, label

    def __len__(self):
        return len(self.X)

def train(epoch):
    running_loss = 0.0
    net.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(async=True), target.cuda(async=True)
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = net(data)

        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx == 0:
            print('Train Epoch: {} [{}]'.format(epoch, running_loss))
            running_loss = 0.0

And shows
RuntimeError: multi-target not supported at c:\programdata\miniconda3\conda-bld\pytorch_1524543037166\work\aten\src\thcunn\generic/ClassNLLCriterion.cu:16

Your target tensor should be in shape [batch_size] containing class indices in the range [0, nb_classes-1].
Could you print the shape and values of a target tensor?

1 Like

Yes. Values are Tensor of [0,1] and [1,0], with shape torch.Size([256, 2])

Currently you are passing a one-hot encoded target, which won’t work.
Your target should have the shape [256] and just contain the class indices 0 and 1.
Try to call target = torch.argmax(target, 1) and pass it to your criterion.

1 Like

I changed DataLoader and now each class coded as 0 or 1.
But now problem returned - it is the same loss 0.6 all the way

Could you try to play around with the hyperparameters?
Also, what kind of data are you using?

I successfully managed to fit a random dataset using your model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5, 2)
        self.conv2 = nn.Conv2d(30, 50, 5, 1)
        self.fc1 = nn.Linear(5*5*50, 100)
        self.fc2 = nn.Linear(100, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))        
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

x = torch.randn(100, 1, 64, 64)
target = torch.empty(100, dtype=torch.long).random_(2)
model = Net()

criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(100):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    
    print('Epoch {}, loss {}'.format(epoch, loss.item()))
1 Like

Dataset are images of pills, where some of them are defect (heve some cracks) and some of them are correct. The task I want to solve is to find whether this pill is broken.
I tried to play with hp when use SGD, and it didn’t help. I will try ADAM now, thanks!

(I hoped that this network works, because it quite similar to MNIST task).

I am just not sure, that I did Dataset object right, because I did it first time. Can it be the problem?

The Dataset looks alright.
It might help if you try to overfit a very small sample of your data to see if there might be other bugs in the code. E.g. you could start with 5 intact and 5 defect pill images and see if the network successfully learns them. If that’s not the case, we might look for other possible bugs.

Also, as a small side note: Variables are deprecated sind 0.4.0, so you can just remove them and use tensors instead. :wink:

1 Like

Thanks a lot for your help!
I will try to find what’s wrong here :slight_smile:

1 Like